44import logging
55import os
66import time
7+ from collections import defaultdict
78from dataclasses import dataclass
8- from typing import Any , List , Optional
9+ from typing import Any , List , NamedTuple , Optional , Set
910from urllib .parse import urlparse
1011
1112import requests as req
1415from graphdatascience .version import __version__
1516
1617
17- @dataclass (repr = True )
18+ @dataclass (repr = True , frozen = True )
1819class InstanceDetails :
1920 id : str
2021 name : str
@@ -31,7 +32,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
3132 )
3233
3334
34- @dataclass (repr = True )
35+ @dataclass (repr = True , frozen = True )
3536class InstanceSpecificDetails (InstanceDetails ):
3637 status : str
3738 connection_url : str
@@ -54,7 +55,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
5455 )
5556
5657
57- @dataclass (repr = True )
58+ @dataclass (repr = True , frozen = True )
5859class InstanceCreateDetails :
5960 id : str
6061 username : str
@@ -70,6 +71,51 @@ def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
7071 return cls (** {f .name : json [f .name ] for f in fields })
7172
7273
74+ class WaitResult (NamedTuple ):
75+ connection_url : str
76+ error : str
77+
78+ @classmethod
79+ def from_error (cls , error : str ) -> WaitResult :
80+ return cls (connection_url = "" , error = error )
81+
82+ @classmethod
83+ def from_connection_url (cls , connection_url : str ) -> WaitResult :
84+ return cls (connection_url = connection_url , error = "" )
85+
86+
87+ @dataclass (repr = True , frozen = True )
88+ class TenantDetails :
89+ id : str
90+ ds_type : str
91+ regions_per_provider : dict [str , Set [str ]]
92+
93+ @classmethod
94+ def from_json (cls , json : dict [str , Any ]) -> TenantDetails :
95+ regions_per_provider = defaultdict (set )
96+ instance_types = set ()
97+ ds_type = None
98+
99+ for configs in json ["instance_configurations" ]:
100+ type = configs ["type" ]
101+ if type .split ("-" )[1 ] == "ds" :
102+ regions_per_provider [configs ["cloud_provider" ]].add (configs ["region" ])
103+ ds_type = type
104+ instance_types .add (configs ["type" ])
105+
106+ id = json ["id" ]
107+ if not ds_type :
108+ raise RuntimeError (
109+ f"Tenant with id `{ id } ` cannot create DS instances. Available instances are `{ instance_types } `."
110+ )
111+
112+ return cls (
113+ id = id ,
114+ ds_type = ds_type ,
115+ regions_per_provider = regions_per_provider ,
116+ )
117+
118+
73119class AuraApi :
74120 class AuraAuthToken :
75121 access_token : str
@@ -99,6 +145,7 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
99145 self ._token : Optional [AuraApi .AuraAuthToken ] = None
100146 self ._logger = logging .getLogger ()
101147 self ._tenant_id = tenant_id if tenant_id else self ._get_tenant_id ()
148+ self ._tenant_details : Optional [TenantDetails ] = None
102149
103150 @staticmethod
104151 def extract_id (uri : str ) -> str :
@@ -110,14 +157,14 @@ def extract_id(uri: str) -> str:
110157 return host .split ("." )[0 ].split ("-" )[0 ]
111158
112159 def create_instance (self , name : str , memory : str , cloud_provider : str , region : str ) -> InstanceCreateDetails :
113- # TODO should give more control here
160+ tenant_details = self .tenant_details ()
161+
114162 data = {
115163 "name" : name ,
116164 "memory" : memory ,
117165 "version" : "5" ,
118166 "region" : region ,
119- # TODO should be figured out from the tenant details in the future
120- "type" : self ._instance_type (),
167+ "type" : tenant_details .ds_type ,
121168 "tenant_id" : self ._tenant_id ,
122169 "cloud_provider" : cloud_provider ,
123170 }
@@ -179,16 +226,16 @@ def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
179226
180227 def wait_for_instance_running (
181228 self , instance_id : str , sleep_time : float = 0.2 , max_sleep_time : float = 300
182- ) -> Optional [ str ] :
229+ ) -> WaitResult :
183230 waited_time = 0.0
184231 while waited_time <= max_sleep_time :
185232 instance = self .list_instance (instance_id )
186233 if instance is None :
187- return "Instance is not found -- please retry"
234+ return WaitResult . from_error ( "Instance is not found -- please retry" )
188235 elif instance .status in ["deleting" , "destroying" ]:
189- return "Instance is being deleted"
236+ return WaitResult . from_error ( "Instance is being deleted" )
190237 elif instance .status == "running" :
191- return None
238+ return WaitResult . from_connection_url ( instance . connection_url )
192239 else :
193240 self ._logger .debug (
194241 f"Instance `{ instance_id } ` is not yet running. "
@@ -198,7 +245,7 @@ def wait_for_instance_running(
198245 waited_time += sleep_time
199246 time .sleep (sleep_time )
200247
201- return f"Instance is not running after waiting for { waited_time } seconds"
248+ return WaitResult . from_error ( f"Instance is not running after waiting for { waited_time } seconds" )
202249
203250 def _get_tenant_id (self ) -> str :
204251 response = req .get (
@@ -216,6 +263,16 @@ def _get_tenant_id(self) -> str:
216263
217264 return raw_data [0 ]["id" ] # type: ignore
218265
266+ def tenant_details (self ) -> TenantDetails :
267+ if not self ._tenant_details :
268+ response = req .get (
269+ f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " ,
270+ headers = self ._build_header (),
271+ )
272+ response .raise_for_status ()
273+ self ._tenant_details = TenantDetails .from_json (response .json ()["data" ])
274+ return self ._tenant_details
275+
219276 def _build_header (self ) -> dict [str , str ]:
220277 return {"Authorization" : f"Bearer { self ._auth_token ()} " , "User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
221278
0 commit comments