Skip to content

Commit 9e0b74b

Browse files
committed
Validate region and allow custom override
1 parent 1863c4a commit 9e0b74b

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

graphdatascience/gds_session/gds_sessions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def __init__(self, ds_connection: AuraAPICredentials) -> None:
4545
)
4646

4747
def get_or_create(
48-
self, session_name: str, size: SessionSizeByMemory, db_connection: DbmsConnectionInfo
48+
self,
49+
session_name: str,
50+
size: SessionSizeByMemory,
51+
db_connection: DbmsConnectionInfo,
52+
region: Optional[str] = None,
4953
) -> AuraGraphDataScience:
5054
connected_instance = self._try_connect(session_name, db_connection)
5155
if connected_instance is not None:
@@ -56,8 +60,11 @@ def get_or_create(
5660
if not db_instance:
5761
raise ValueError(f"Could not find Aura instance with the uri `{db_connection.uri}`")
5862

63+
region = region if region else db_instance.region
64+
self._validate_region(region, db_instance)
65+
5966
create_details = self._aura_api.create_instance(
60-
GdsSessions._instance_name(session_name), size.value, db_instance.cloud_provider, db_instance.region
67+
GdsSessions._instance_name(session_name), size.value, db_instance.cloud_provider, region
6168
)
6269
wait_result = self._aura_api.wait_for_instance_running(create_details.id)
6370
if wait_result is not None:
@@ -129,6 +136,15 @@ def _try_connect(self, session_name: str, db_connection: DbmsConnectionInfo) ->
129136

130137
return self._construct_client(session_name=session_name, gds_url=gds_url, db_connection=db_connection)
131138

139+
def _validate_region(self, region: str, db_instance: InstanceSpecificDetails) -> None:
140+
tenant_details = self._aura_api.tenant_details()
141+
available_regions = tenant_details.regions_per_provider[db_instance.cloud_provider]
142+
if region not in available_regions:
143+
raise ValueError(
144+
f"Region `{region}` is not supported by the tenant `{tenant_details.id}`."
145+
f" Supported regions: {available_regions}`"
146+
)
147+
132148
def _construct_client(
133149
self, session_name: str, gds_url: str, db_connection: DbmsConnectionInfo
134150
) -> AuraGraphDataScience:

graphdatascience/tests/unit/test_gds_sessions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
InstanceCreateDetails,
1313
InstanceDetails,
1414
InstanceSpecificDetails,
15+
TenantDetails,
1516
)
1617
from graphdatascience.gds_session.dbms_connection_info import DbmsConnectionInfo
1718
from graphdatascience.gds_session.gds_sessions import (
@@ -80,6 +81,11 @@ def wait_for_instance_running(
8081
) -> Optional[str]:
8182
return super().wait_for_instance_running(instance_id, sleep_time=0.0001, max_sleep_time=0.001)
8283

84+
def tenant_details(self) -> TenantDetails:
85+
return TenantDetails(
86+
id=self._tenant_id, ds_type="fake-ds", regions_per_provider={"aws": {"leipzig-1", "dresden-2"}}
87+
)
88+
8389

8490
@pytest.fixture
8591
def aura_api() -> AuraApi:
@@ -176,6 +182,31 @@ def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> Non
176182
assert instance_details.memory == "8GB"
177183

178184

185+
def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi) -> None:
186+
_setup_db_instance(aura_api)
187+
188+
sessions = GdsSessions(AuraAPICredentials("", "", "placeholder"))
189+
sessions._aura_api = aura_api
190+
191+
mocker.patch(
192+
"graphdatascience.gds_session.gds_sessions.GdsSessions._construct_client", lambda *args, **kwargs: kwargs
193+
)
194+
mocker.patch(
195+
"graphdatascience.gds_session.gds_sessions.GdsSessions._change_initial_pw", lambda *args, **kwargs: kwargs
196+
)
197+
198+
sessions.get_or_create(
199+
"my-session",
200+
SessionSizeByMemory.DEFAULT,
201+
DbmsConnectionInfo("neo4j+ssc://ffff0.databases.neo4j.io", "dbuser", "db_pw"),
202+
region="dresden-2",
203+
)
204+
instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore
205+
assert instance_details.cloud_provider == "aws"
206+
assert instance_details.region == "dresden-2"
207+
assert instance_details.memory == "8GB"
208+
209+
179210
def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None:
180211
_setup_db_instance(aura_api)
181212

@@ -368,5 +399,22 @@ def test_create_waiting_forever() -> None:
368399
)
369400

370401

402+
def test_create_session_invalid_region(aura_api: AuraApi) -> None:
403+
aura_api.create_instance("test", "8GB", "aws", "only-db-region")
404+
405+
sessions = GdsSessions(AuraAPICredentials("", "", "placeholder"))
406+
sessions._aura_api = aura_api
407+
408+
expected_message = (
409+
"Region `only-db-region` is not supported by the tenant `tenant_id`." " Supported regions: {'leipzig-1'}."
410+
)
411+
with pytest.raises(ValueError, match=expected_message):
412+
sessions.get_or_create(
413+
"my-session",
414+
SessionSizes.by_memory().X5L,
415+
DbmsConnectionInfo("neo4j+ssc://ffff0.databases.neo4j.io", "dbuser", "db_pw"),
416+
)
417+
418+
371419
def _setup_db_instance(aura_api: AuraApi) -> None:
372420
aura_api.create_instance("test", "8GB", "aws", "leipzig-1")

0 commit comments

Comments
 (0)