Skip to content

Commit d211dfd

Browse files
authored
Merge pull request #603 from neo4j/robinhood
Support overriding Arrow connection string
2 parents 9d9f531 + 0583bca commit d211dfd

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

graphdatascience/graph_data_science.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
auth: Optional[Tuple[str, str]] = None,
2929
aura_ds: bool = False,
3030
database: Optional[str] = None,
31-
arrow: bool = True,
31+
arrow: Union[str, bool] = True,
3232
arrow_disable_server_verification: bool = True,
3333
arrow_tls_root_certs: Optional[bytes] = None,
3434
bookmarks: Optional[Any] = None,
@@ -47,9 +47,10 @@ def __init__(
4747
to a Neo4j Aura instance.
4848
database: Optional[str], default None
4949
The Neo4j database to query against.
50-
arrow : bool, default True
51-
A flag that indicates that the client should use Apache Arrow
52-
for data streaming if it is available on the server.
50+
arrow : Union[str, bool], default True
51+
Arrow connection information. Either a flag that indicates whether the client should use Apache Arrow
52+
for data streaming if it is available on the server. True means discover the connection URI from the server.
53+
A connection URI (str) can also be provided.
5354
arrow_disable_server_verification : bool, default True
5455
A flag that indicates that, if the flight client is connecting with
5556
TLS, that it skips server verification. If this is enabled, all
@@ -77,6 +78,7 @@ def __init__(
7778
self._query_runner.encrypted(),
7879
arrow_disable_server_verification,
7980
arrow_tls_root_certs,
81+
None if arrow is True else arrow,
8082
)
8183

8284
super().__init__(self._query_runner, "gds", self._server_version)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,22 @@ def create(
3131
encrypted: bool = False,
3232
disable_server_verification: bool = False,
3333
tls_root_certs: Optional[bytes] = None,
34+
connection_string_override: Optional[str] = None,
3435
) -> QueryRunner:
3536
arrow_info = (
3637
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
3738
)
3839
server_version = fallback_query_runner.server_version()
39-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40+
connection_string: str
41+
if connection_string_override is not None:
42+
connection_string = connection_string_override
43+
else:
44+
connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
4045
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
4146

4247
if arrow_info["running"]:
4348
return ArrowQueryRunner(
44-
listen_address,
49+
connection_string,
4550
fallback_query_runner,
4651
server_version,
4752
auth,
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
from pandas import DataFrame
3+
from pyarrow.flight import FlightUnavailableError
4+
5+
from .conftest import CollectingQueryRunner
6+
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
7+
from graphdatascience.server_version.server_version import ServerVersion
8+
9+
10+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)])
11+
def test_create(runner: CollectingQueryRunner) -> None:
12+
runner.set__mock_result(DataFrame([{"running": True, "listenAddress": "localhost:1234"}]))
13+
14+
arrow_runner = ArrowQueryRunner.create(runner)
15+
16+
assert isinstance(arrow_runner, ArrowQueryRunner)
17+
18+
with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:1234: .+"):
19+
arrow_runner._flight_client.list_actions()
20+
21+
22+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)])
23+
def test_return_fallback_when_arrow_is_not_running(runner: CollectingQueryRunner) -> None:
24+
runner.set__mock_result(DataFrame([{"running": False, "listenAddress": "localhost:1234"}]))
25+
26+
arrow_runner = ArrowQueryRunner.create(runner)
27+
28+
assert arrow_runner is runner
29+
30+
31+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)])
32+
def test_create_with_provided_connection(runner: CollectingQueryRunner) -> None:
33+
runner.set__mock_result(DataFrame([{"running": True, "listenAddress": "localhost:1234"}]))
34+
35+
arrow_runner = ArrowQueryRunner.create(runner, connection_string_override="localhost:4321")
36+
37+
assert isinstance(arrow_runner, ArrowQueryRunner)
38+
39+
with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:4321: .+"):
40+
arrow_runner._flight_client.list_actions()

0 commit comments

Comments
 (0)