Skip to content

Commit 22ab792

Browse files
committed
Allow advanced auth methods for dbms connection info
1 parent c7a76c7 commit 22ab792

File tree

7 files changed

+72
-16
lines changed

7 files changed

+72
-16
lines changed

changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Breaking changes
44

5-
* Drop support for PyArrow 16
5+
- Drop support for PyArrow 16
66

77
## New features
88

@@ -20,5 +20,6 @@
2020
- `DbmsConnectionInfo::from_env`
2121
- Retry internal functions known to be idempotent. Reduces issues such as `SessionExpiredError`.
2222
- Add support for PyArrow 20
23+
- Add support for more advanced authentication in `DbmsConnectionInfo`. Allowing to pass `auth` of type `neo4j.Auth` instead of username + password.
2324

2425
## Other changes

graphdatascience/graph_data_science.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from types import TracebackType
55
from typing import Any, Optional, Type, Union
66

7+
import neo4j
78
from neo4j import Driver
89
from pandas import DataFrame
910

@@ -78,8 +79,10 @@ def __init__(
7879
if isinstance(endpoint, QueryRunner):
7980
self._query_runner = endpoint
8081
else:
82+
if auth:
83+
db_auth = neo4j.basic_auth(*auth)
8184
self._query_runner = Neo4jQueryRunner.create_for_db(
82-
endpoint, auth, aura_ds, database, bookmarks, show_progress
85+
endpoint, db_auth, aura_ds, database, bookmarks, show_progress
8386
)
8487

8588
self._server_version = self._query_runner.server_version()

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Neo4jQueryRunner(QueryRunner):
3232
@staticmethod
3333
def create_for_db(
3434
endpoint: Union[str, neo4j.Driver],
35-
auth: Optional[tuple[str, str]] = None,
35+
auth: Union[tuple[str, str], neo4j.Auth, None] = None,
3636
aura_ds: bool = False,
3737
database: Optional[str] = None,
3838
bookmarks: Optional[Any] = None,
@@ -79,7 +79,7 @@ def create_for_db(
7979
@staticmethod
8080
def create_for_session(
8181
endpoint: str,
82-
auth: Optional[tuple[str, str]] = None,
82+
auth: Union[tuple[str, str], neo4j.Auth, None] = None,
8383
show_progress: bool = True,
8484
) -> Neo4jQueryRunner:
8585
driver_config: dict[str, Any] = {"user_agent": f"neo4j-graphdatascience-v{__version__}"}
@@ -125,7 +125,7 @@ def __init__(
125125
self,
126126
driver: neo4j.Driver,
127127
protocol: str,
128-
auth: Optional[tuple[str, str]] = None,
128+
auth: Union[tuple[str, str], neo4j.Auth, None] = None,
129129
config: dict[str, Any] = {},
130130
database: Optional[str] = neo4j.DEFAULT_DATABASE,
131131
auto_close: bool = False,

graphdatascience/session/aura_graph_data_science.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def create(
4444
) -> AuraGraphDataScience:
4545
session_bolt_query_runner = Neo4jQueryRunner.create_for_session(
4646
endpoint=session_bolt_connection_info.uri,
47-
auth=session_bolt_connection_info.auth(),
47+
auth=session_bolt_connection_info.get_auth(),
4848
show_progress=show_progress,
4949
)
5050

@@ -75,7 +75,7 @@ def create(
7575
else:
7676
db_bolt_query_runner = Neo4jQueryRunner.create_for_db(
7777
db_endpoint.uri,
78-
db_endpoint.auth(),
78+
db_endpoint.get_auth(),
7979
aura_ds=True,
8080
show_progress=False,
8181
database=db_endpoint.database,

graphdatascience/session/dbms_connection_info.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,40 @@
44
from dataclasses import dataclass
55
from typing import Optional
66

7+
from neo4j import Auth, basic_auth
8+
79

810
@dataclass
911
class DbmsConnectionInfo:
1012
"""
1113
Represents the connection information for a Neo4j DBMS, such as an AuraDB instance.
14+
Supports both username/password as well as the authentication options provided by the Neo4j Python driver.
1215
"""
1316

1417
uri: str
15-
username: str
16-
password: str
18+
username: Optional[str] = None
19+
password: Optional[str] = None
1720
database: Optional[str] = None
18-
19-
def auth(self) -> tuple[str, str]:
21+
# Optional: typed authentication, used instead of username/password. Supports for example a token. See https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods
22+
auth: Optional[Auth] = None
23+
24+
def __post_init__(self) -> None:
25+
# Validate auth fields
26+
if (self.username or self.password) and self.auth:
27+
raise ValueError(
28+
"Cannot provide both username/password and token for authentication. "
29+
"Please provide either a username/password or a token."
30+
)
31+
32+
def get_auth(self) -> Optional[Auth]:
2033
"""
21-
Returns the username and password for authentication.
22-
2334
Returns:
24-
A tuple containing the username and password.
35+
A neo4j.Auth object for authentication.
2536
"""
26-
return self.username, self.password
37+
auth = self.auth
38+
if self.username and self.password:
39+
auth = basic_auth(self.username, self.password)
40+
return auth
2741

2842
@staticmethod
2943
def from_env() -> DbmsConnectionInfo:

graphdatascience/session/dedicated_sessions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _create_db_runner(
111111
) -> Neo4jQueryRunner:
112112
db_runner = Neo4jQueryRunner.create_for_db(
113113
endpoint=db_connection.uri,
114-
auth=db_connection.auth(),
114+
auth=db_connection.get_auth(),
115115
aura_ds=True,
116116
show_progress=False,
117117
database=db_connection.database,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import neo4j
2+
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
3+
4+
5+
def test_dbms_connection_info_username_password() -> None:
6+
dci = DbmsConnectionInfo(
7+
uri="foo.bar",
8+
username="neo4j",
9+
password="password",
10+
database="neo4j",
11+
)
12+
13+
assert dci.get_auth() == neo4j.basic_auth("neo4j", "password")
14+
15+
16+
def test_dbms_connection_info_advanced_auth() -> None:
17+
advanced_auth = neo4j.kerberos_auth("foo bar")
18+
19+
dci = DbmsConnectionInfo(uri="foo.bar", database="neo4j", auth=advanced_auth)
20+
21+
assert dci.get_auth() == advanced_auth
22+
23+
24+
def test_dbms_connection_info_fail_on_auth_and_username() -> None:
25+
try:
26+
DbmsConnectionInfo(
27+
uri="foo.bar",
28+
username="neo4j",
29+
password="password",
30+
auth=neo4j.basic_auth("other", "other"),
31+
)
32+
except ValueError as e:
33+
assert str(e) == (
34+
"Cannot provide both username/password and token for authentication. "
35+
"Please provide either a username/password or a token."
36+
)
37+
else:
38+
assert False, "Expected ValueError was not raised"

0 commit comments

Comments
 (0)