Skip to content

Commit 03c24af

Browse files
Optionally override arrow address
Co-Authored-By: Paul Horn <paul.horn@neo4j.com>
1 parent 0a6d2a8 commit 03c24af

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

graphdatascience/graph_data_science.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
auth: Optional[Tuple[str, str]] = None,
2929
aura_ds: bool = False,
3030
database: Optional[str] = None,
31-
arrow: bool = True,
31+
arrow: Union[str, bool] = True,
3232
arrow_disable_server_verification: bool = True,
3333
arrow_tls_root_certs: Optional[bytes] = None,
3434
bookmarks: Optional[Any] = None,
@@ -47,9 +47,10 @@ def __init__(
4747
to a Neo4j Aura instance.
4848
database: Optional[str], default None
4949
The Neo4j database to query against.
50-
arrow : bool, default True
51-
A flag that indicates that the client should use Apache Arrow
52-
for data streaming if it is available on the server.
50+
arrow : Union[str, bool], default True
51+
Arrow connection information. Either a flag that indicates whether the client should use Apache Arrow
52+
for data streaming if it is available on the server. True means discover the connection URI from the server.
53+
A connection URI (str) can also be provided.
5354
arrow_disable_server_verification : bool, default True
5455
A flag that indicates that, if the flight client is connecting with
5556
TLS, that it skips server verification. If this is enabled, all
@@ -77,6 +78,7 @@ def __init__(
7778
self._query_runner.encrypted(),
7879
arrow_disable_server_verification,
7980
arrow_tls_root_certs,
81+
None if arrow is True else arrow,
8082
)
8183

8284
super().__init__(self._query_runner, "gds", self._server_version)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ def create(
3131
encrypted: bool = False,
3232
disable_server_verification: bool = False,
3333
tls_root_certs: Optional[bytes] = None,
34+
listen_address_override: Optional[str] = None,
3435
) -> QueryRunner:
3536
arrow_info = (
3637
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
3738
)
3839
server_version = fallback_query_runner.server_version()
39-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40+
if listen_address_override is not None:
41+
listen_address = listen_address_override
42+
else:
43+
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
4044
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
4145

4246
if arrow_info["running"]:

0 commit comments

Comments
 (0)