Skip to content

Commit a636ccc

Browse files
authored
Merge pull request #584 from FlorentinD/try-sessions-against-production
Validate create requests against tenants available configurations
2 parents 853c748 + 2937bc8 commit a636ccc

File tree

8 files changed

+243
-41
lines changed

8 files changed

+243
-41
lines changed

examples/dev/gds-sessions.ipynb

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,21 @@
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
73-
"from graphdatascience.gds_session.dbms_connection_info import DbmsConnectionInfo\n",
73+
"from graphdatascience.gds_session import DbmsConnectionInfo\n",
7474
"import os\n",
7575
"\n",
7676
"from neo4j import GraphDatabase\n",
7777
"\n",
78-
"# We need to tell the GDS client that we are working with a development environment.\n",
79-
"# This does not need to be set in production.\n",
80-
"os.environ[\"AURA_ENV\"] = \"devstrawberryfield\"\n",
78+
"# We can tell the GDS client that we are working with a development environment.\n",
79+
"# os.environ[\"AURA_ENV\"] = \"devstrawberryfield\"\n",
8180
"\n",
82-
"db_connection_info = DbmsConnectionInfo(\n",
83-
" f\"neo4j+s://{db_id}-{os.environ['AURA_ENV']}.databases.neo4j-dev.io\", \"neo4j\", db_password\n",
81+
"uri = (\n",
82+
" f\"neo4j+s://{db_id}-{os.environ['AURA_ENV']}.databases.neo4j-dev.io\"\n",
83+
" if os.environ.get(\"AURA_ENV\")\n",
84+
" else f\"neo4j+s://{db_id}.databases.neo4j.io\"\n",
8485
")\n",
86+
"\n",
87+
"db_connection_info = DbmsConnectionInfo(uri, \"neo4j\", db_password)\n",
8588
"# start a standard Neo4j Python Driver to connect to the AuraDB instance\n",
8689
"driver = GraphDatabase.driver(db_connection_info.uri, auth=db_connection_info.auth())\n",
8790
"\n",
@@ -171,9 +174,11 @@
171174
"metadata": {},
172175
"outputs": [],
173176
"source": [
174-
"# Let's call this function and see what it returns\n",
177+
"# Let's call this function and verify the arrow server is enabled\n",
175178
"with driver.session() as session:\n",
176-
" display(session.run(\"CALL internal.arrow.status\").to_df())"
179+
" arrow_status = session.run(\"CALL internal.arrow.status\").to_df()\n",
180+
" print(arrow_status)\n",
181+
" assert arrow_status[\"enabled\"][0] == True, \"Arrow server on the db needs to be enabled\""
177182
]
178183
},
179184
{
@@ -244,7 +249,7 @@
244249
"outputs": [],
245250
"source": [
246251
"# The new stuff!\n",
247-
"from graphdatascience.gds_session.gds_sessions import GdsSessions, AuraAPICredentials\n",
252+
"from graphdatascience.gds_session import GdsSessions, AuraAPICredentials\n",
248253
"\n",
249254
"# Create a new AuraSessions object\n",
250255
"sessions = GdsSessions(ds_connection=AuraAPICredentials(CLIENT_ID, CLIENT_SECRET))"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .dbms_connection_info import DbmsConnectionInfo
2+
from .gds_sessions import AuraAPICredentials, GdsSessions, SessionInfo
3+
from .session_sizes import SessionSizes
4+
5+
__all__ = [
6+
"GdsSessions",
7+
"SessionInfo",
8+
"DbmsConnectionInfo",
9+
"AuraAPICredentials",
10+
"SessionSizes",
11+
]

graphdatascience/gds_session/aura_api.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import logging
55
import os
66
import time
7+
from collections import defaultdict
78
from dataclasses import dataclass
8-
from typing import Any, List, Optional
9+
from typing import Any, List, NamedTuple, Optional, Set
910
from urllib.parse import urlparse
1011

1112
import requests as req
@@ -14,7 +15,7 @@
1415
from graphdatascience.version import __version__
1516

1617

17-
@dataclass(repr=True)
18+
@dataclass(repr=True, frozen=True)
1819
class InstanceDetails:
1920
id: str
2021
name: str
@@ -31,7 +32,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
3132
)
3233

3334

34-
@dataclass(repr=True)
35+
@dataclass(repr=True, frozen=True)
3536
class InstanceSpecificDetails(InstanceDetails):
3637
status: str
3738
connection_url: str
@@ -54,7 +55,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
5455
)
5556

5657

57-
@dataclass(repr=True)
58+
@dataclass(repr=True, frozen=True)
5859
class InstanceCreateDetails:
5960
id: str
6061
username: str
@@ -70,6 +71,51 @@ def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
7071
return cls(**{f.name: json[f.name] for f in fields})
7172

7273

74+
class WaitResult(NamedTuple):
75+
connection_url: str
76+
error: str
77+
78+
@classmethod
79+
def from_error(cls, error: str) -> WaitResult:
80+
return cls(connection_url="", error=error)
81+
82+
@classmethod
83+
def from_connection_url(cls, connection_url: str) -> WaitResult:
84+
return cls(connection_url=connection_url, error="")
85+
86+
87+
@dataclass(repr=True, frozen=True)
88+
class TenantDetails:
89+
id: str
90+
ds_type: str
91+
regions_per_provider: dict[str, Set[str]]
92+
93+
@classmethod
94+
def from_json(cls, json: dict[str, Any]) -> TenantDetails:
95+
regions_per_provider = defaultdict(set)
96+
instance_types = set()
97+
ds_type = None
98+
99+
for configs in json["instance_configurations"]:
100+
type = configs["type"]
101+
if type.split("-")[1] == "ds":
102+
regions_per_provider[configs["cloud_provider"]].add(configs["region"])
103+
ds_type = type
104+
instance_types.add(configs["type"])
105+
106+
id = json["id"]
107+
if not ds_type:
108+
raise RuntimeError(
109+
f"Tenant with id `{id}` cannot create DS instances. Available instances are `{instance_types}`."
110+
)
111+
112+
return cls(
113+
id=id,
114+
ds_type=ds_type,
115+
regions_per_provider=regions_per_provider,
116+
)
117+
118+
73119
class AuraApi:
74120
class AuraAuthToken:
75121
access_token: str
@@ -99,6 +145,7 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
99145
self._token: Optional[AuraApi.AuraAuthToken] = None
100146
self._logger = logging.getLogger()
101147
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
148+
self._tenant_details: Optional[TenantDetails] = None
102149

103150
@staticmethod
104151
def extract_id(uri: str) -> str:
@@ -110,14 +157,14 @@ def extract_id(uri: str) -> str:
110157
return host.split(".")[0].split("-")[0]
111158

112159
def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
113-
# TODO should give more control here
160+
tenant_details = self.tenant_details()
161+
114162
data = {
115163
"name": name,
116164
"memory": memory,
117165
"version": "5",
118166
"region": region,
119-
# TODO should be figured out from the tenant details in the future
120-
"type": self._instance_type(),
167+
"type": tenant_details.ds_type,
121168
"tenant_id": self._tenant_id,
122169
"cloud_provider": cloud_provider,
123170
}
@@ -179,16 +226,16 @@ def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
179226

180227
def wait_for_instance_running(
181228
self, instance_id: str, sleep_time: float = 0.2, max_sleep_time: float = 300
182-
) -> Optional[str]:
229+
) -> WaitResult:
183230
waited_time = 0.0
184231
while waited_time <= max_sleep_time:
185232
instance = self.list_instance(instance_id)
186233
if instance is None:
187-
return "Instance is not found -- please retry"
234+
return WaitResult.from_error("Instance is not found -- please retry")
188235
elif instance.status in ["deleting", "destroying"]:
189-
return "Instance is being deleted"
236+
return WaitResult.from_error("Instance is being deleted")
190237
elif instance.status == "running":
191-
return None
238+
return WaitResult.from_connection_url(instance.connection_url)
192239
else:
193240
self._logger.debug(
194241
f"Instance `{instance_id}` is not yet running. "
@@ -198,7 +245,7 @@ def wait_for_instance_running(
198245
waited_time += sleep_time
199246
time.sleep(sleep_time)
200247

201-
return f"Instance is not running after waiting for {waited_time} seconds"
248+
return WaitResult.from_error(f"Instance is not running after waiting for {waited_time} seconds")
202249

203250
def _get_tenant_id(self) -> str:
204251
response = req.get(
@@ -216,6 +263,16 @@ def _get_tenant_id(self) -> str:
216263

217264
return raw_data[0]["id"] # type: ignore
218265

266+
def tenant_details(self) -> TenantDetails:
267+
if not self._tenant_details:
268+
response = req.get(
269+
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
270+
headers=self._build_header(),
271+
)
272+
response.raise_for_status()
273+
self._tenant_details = TenantDetails.from_json(response.json()["data"])
274+
return self._tenant_details
275+
219276
def _build_header(self) -> dict[str, str]:
220277
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
221278

graphdatascience/gds_session/gds_sessions.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from graphdatascience.gds_session.aura_graph_data_science import AuraGraphDataScience
1414
from graphdatascience.gds_session.dbms_connection_info import DbmsConnectionInfo
15+
from graphdatascience.gds_session.region_suggester import closest_match
1516
from graphdatascience.gds_session.session_sizes import SessionSizeByMemory
1617

1718

@@ -45,7 +46,10 @@ def __init__(self, ds_connection: AuraAPICredentials) -> None:
4546
)
4647

4748
def get_or_create(
48-
self, session_name: str, size: SessionSizeByMemory, db_connection: DbmsConnectionInfo
49+
self,
50+
session_name: str,
51+
size: SessionSizeByMemory,
52+
db_connection: DbmsConnectionInfo,
4953
) -> AuraGraphDataScience:
5054
connected_instance = self._try_connect(session_name, db_connection)
5155
if connected_instance is not None:
@@ -56,15 +60,17 @@ def get_or_create(
5660
if not db_instance:
5761
raise ValueError(f"Could not find Aura instance with the uri `{db_connection.uri}`")
5862

63+
region = self._ds_region(db_instance.region, db_instance.cloud_provider)
64+
5965
create_details = self._aura_api.create_instance(
60-
GdsSessions._instance_name(session_name), size.value, db_instance.cloud_provider, db_instance.region
66+
GdsSessions._instance_name(session_name), size.value, db_instance.cloud_provider, region
6167
)
6268
wait_result = self._aura_api.wait_for_instance_running(create_details.id)
63-
if wait_result is not None:
64-
raise RuntimeError(f"Failed to create session `{session_name}`: {wait_result}")
69+
if err := wait_result.error:
70+
raise RuntimeError(f"Failed to create session `{session_name}`: {err}")
6571

6672
gds_user = create_details.username
67-
gds_url = create_details.connection_url
73+
gds_url = wait_result.connection_url
6874

6975
self._change_initial_pw(
7076
gds_url=gds_url, gds_user=gds_user, initial_pw=create_details.password, new_pw=db_connection.password
@@ -118,16 +124,24 @@ def _try_connect(self, session_name: str, db_connection: DbmsConnectionInfo) ->
118124
if len(matched_instances) > 1:
119125
self._fail_ambiguous_session(session_name, matched_instances)
120126

121-
instance_details = self._aura_api.list_instance(matched_instances[0].id)
127+
wait_result = self._aura_api.wait_for_instance_running(matched_instances[0].id)
128+
if err := wait_result.error:
129+
raise RuntimeError(f"Failed to connect to session `{session_name}`: {err}")
130+
gds_url = wait_result.connection_url
131+
132+
return self._construct_client(session_name=session_name, gds_url=gds_url, db_connection=db_connection)
122133

123-
if instance_details:
124-
gds_url = instance_details.connection_url
125-
else:
126-
raise RuntimeError(
127-
f"Unable to get connection information for session `{session_name}`. Does it still exist?"
134+
def _ds_region(self, region: str, cloud_provider: str) -> str:
135+
tenant_details = self._aura_api.tenant_details()
136+
available_regions = tenant_details.regions_per_provider[cloud_provider]
137+
138+
match = closest_match(region, available_regions)
139+
if not match:
140+
raise ValueError(
141+
f"Tenant `{tenant_details.id}` cannot create GDS sessions at cloud provider `{cloud_provider}`."
128142
)
129143

130-
return self._construct_client(session_name=session_name, gds_url=gds_url, db_connection=db_connection)
144+
return match
131145

132146
def _construct_client(
133147
self, session_name: str, gds_url: str, db_connection: DbmsConnectionInfo
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Iterable, Optional
2+
3+
import textdistance
4+
5+
6+
# AuraDB and AuraDS regions are not the same, so we need to find the closest match.
7+
def closest_match(db_region: str, ds_regions: Iterable[str]) -> Optional[str]:
8+
curr_max_similarity = 0.0
9+
closest_option = None
10+
11+
for region in ds_regions:
12+
similarity = textdistance.jaro_winkler(db_region, region)
13+
if similarity > curr_max_similarity:
14+
closest_option = region
15+
curr_max_similarity = similarity
16+
17+
return closest_option

0 commit comments

Comments
 (0)