Skip to content

Commit b52d556

Browse files
authored
Merge pull request #607 from adamnsch/gds-construct-correct-db
Use provided `database` for fetching metadata at `GraphDataScience` construction
2 parents d211dfd + 13ddc85 commit b52d556

File tree

5 files changed

+49
-12
lines changed

5 files changed

+49
-12
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
* Fixed an issue where source and target IDs of relationships in heterogeneous OGBL graphs were not parsed correctly.
2121
* Fixed an issue where configuration parameters such as `aggregation` were ignored by `gds.graph.toUndirected`.
22+
* Fixed an issue where the `database` given for the `GraphDataScience` construction was not used for metadata retrieval, causing an exception to be raised if the default "neo4j" database was missing.
2223

2324

2425
## Improvements

doc/modules/ROOT/pages/getting-started.adoc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ assert gds.version()
3838
"2.1.9"
3939
----
4040

41+
Please note that the `GraphDataScience` object needs to communicate with a Neo4j database upon construction, and uses the default "neo4j" database by default.
42+
If there is no such database, you will need to <<specifying-targeted-database, provide a valid database using the `database` keyword parameter>>.
43+
4144

4245
=== AuraDS
4346

@@ -75,6 +78,7 @@ using_enterprise = gds.is_licensed()
7578
----
7679

7780

81+
[[specifying-targeted-database]]
7882
=== Specifying targeted database
7983

8084
If we don't want to use the default database of our DBMS we can provide the `GraphDataScience` constructor with the keyword parameter `database`:
@@ -102,10 +106,10 @@ If Apache Arrow is available on the https://neo4j.com/docs/graph-data-science/cu
102106
[source,python,role=no-test]
103107
----
104108
gds = GraphDataScience(
105-
NEO4J_URI,
106-
auth=(NEO4J_USER, NEO4J_PASSWORD),
107-
arrow=True,
108-
arrow_disable_server_verification=False,
109+
NEO4J_URI,
110+
auth=(NEO4J_USER, NEO4J_PASSWORD),
111+
arrow=True,
112+
arrow_disable_server_verification=False,
109113
arrow_tls_root_certs=CERT
110114
)
111115
----

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,20 @@ def create(
4646
driver = neo4j.GraphDatabase.driver(endpoint, auth=auth, **config)
4747

4848
query_runner = Neo4jQueryRunner(
49-
driver, auto_close=True, bookmarks=bookmarks, config=config, server_version=server_version
49+
driver,
50+
auto_close=True,
51+
bookmarks=bookmarks,
52+
config=config,
53+
server_version=server_version,
54+
database=database,
5055
)
5156

5257
elif isinstance(endpoint, neo4j.Driver):
53-
query_runner = Neo4jQueryRunner(endpoint, auto_close=False, bookmarks=bookmarks)
58+
query_runner = Neo4jQueryRunner(endpoint, auto_close=False, bookmarks=bookmarks, database=database)
5459

5560
else:
5661
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
5762

58-
if database:
59-
query_runner.set_database(database)
60-
6163
return query_runner
6264

6365
@staticmethod
@@ -97,7 +99,7 @@ def run_cypher(
9799
if database is None:
98100
database = self._database
99101

100-
self._verify_connectivity()
102+
self._verify_connectivity(database=database)
101103

102104
with self._driver.session(database=database, bookmarks=self.bookmarks()) as session:
103105
try:
@@ -303,11 +305,14 @@ def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
303305

304306
raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints)) from e
305307

306-
def _verify_connectivity(self) -> None:
308+
def _verify_connectivity(self, database: Optional[str] = None) -> None:
307309
WAIT_TIME = 1
308310
MAX_RETRYS = 10 * 60
309311
WARN_INTERVAL = 10
310312

313+
if database is None:
314+
database = self._database
315+
311316
exception = None
312317
retrys = 0
313318
while retrys < MAX_RETRYS:
@@ -318,7 +323,16 @@ def _verify_connectivity(self) -> None:
318323
category=neo4j.ExperimentalWarning,
319324
message=r"^The configuration may change in the future.$",
320325
)
321-
self._driver.verify_connectivity()
326+
else:
327+
warnings.filterwarnings(
328+
"ignore",
329+
category=neo4j.ExperimentalWarning,
330+
message=(
331+
r"^All configuration key-word arguments to verify_connectivity\(\) are experimental. "
332+
"They might be changed or removed in any future version without prior notice.$"
333+
),
334+
)
335+
self._driver.verify_connectivity(database=database)
322336
break
323337
except neo4j.exceptions.DriverError as e:
324338
exception = e

graphdatascience/tests/integration/test_database_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313
GRAPH_NAME = "g"
1414

1515

16+
@pytest.mark.skip_on_aura
17+
def test_init_without_neo4j_db(runner: Neo4jQueryRunner) -> None:
18+
default_database = runner.database()
19+
20+
MY_DB_NAME = "bananas"
21+
runner.run_cypher("CREATE DATABASE $dbName WAIT", {"dbName": MY_DB_NAME})
22+
23+
runner.run_cypher("DROP DATABASE $dbName WAIT", {"dbName": default_database})
24+
25+
try:
26+
gds = GraphDataScience(URI, AUTH, database=MY_DB_NAME)
27+
gds.close()
28+
finally:
29+
runner.run_cypher("CREATE DATABASE $dbName WAIT", {"dbName": default_database}, database=MY_DB_NAME)
30+
runner.run_cypher("DROP DATABASE $dbName WAIT", {"dbName": MY_DB_NAME})
31+
32+
1633
@pytest.mark.skip_on_aura
1734
def test_switching_db(runner: Neo4jQueryRunner) -> None:
1835
default_database = runner.database()

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ deps =
6868
-r {toxinidir}/requirements/dev/notebook-ci.txt
6969
commands =
7070
python ./scripts/run_notebooks.py
71+
7172
[testenv:jupyter-notebook-session-ci]
7273
passenv =
7374
AURA_API_CLIENT_ID

0 commit comments

Comments
 (0)