Skip to content

Commit e8c00d4

Browse files
s1ckFlorentinD
andcommitted
Simplify calling Arrow debug procedure
Co-Authored-By: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent 25517b1 commit e8c00d4

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,10 @@ def create(
3232
disable_server_verification: bool = False,
3333
tls_root_certs: Optional[bytes] = None,
3434
) -> "QueryRunner":
35+
arrow_info = ArrowQueryRunner._get_arrow_debug_info(fallback_query_runner)
3536
server_version = fallback_query_runner.server_version()
36-
37-
yield_fields = (
38-
["running", "listenAddress", "version"]
39-
if server_version >= ServerVersion(2, 2, 1)
40-
else ["running", "advertisedListenAddress", "version"]
41-
)
42-
arrow_info: Series[Any] = fallback_query_runner.call_procedure(
43-
endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
44-
).squeeze()
4537
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
46-
arrow_version = ArrowQueryRunner.read_arrow_version(arrow_info)
38+
arrow_version = ArrowQueryRunner._read_arrow_version(arrow_info)
4739

4840
if arrow_info["running"]:
4941
return ArrowQueryRunner(
@@ -60,7 +52,11 @@ def create(
6052
return fallback_query_runner
6153

6254
@staticmethod
63-
def read_arrow_version(arrow_info: Series[Any]) -> ArrowVersion:
55+
def _get_arrow_debug_info(query_runner) -> Series[Any]:
56+
return query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze()
57+
58+
@staticmethod
59+
def _read_arrow_version(arrow_info: Series[Any]) -> ArrowVersion:
6460
arrow_version = arrow_info.get("version", "alpha")
6561
try:
6662
arrow_version = ArrowVersion[arrow_version.upper()]

graphdatascience/tests/unit/test_arrow_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def test_arrow_version_parsing(arrow_version):
99
arrow_info = dict()
1010
arrow_info["version"] = arrow_version.name()
11-
actual = ArrowQueryRunner.read_arrow_version(arrow_info)
11+
actual = ArrowQueryRunner._read_arrow_version(arrow_info)
1212
assert actual == arrow_version
1313

1414

@@ -17,7 +17,7 @@ def test_arrow_version_parsing_fails(arrow_version):
1717
arrow_info = dict()
1818
arrow_info["version"] = arrow_version
1919
with pytest.raises(UnsupportedArrowVersion) as e:
20-
ArrowQueryRunner.read_arrow_version(arrow_info)
20+
ArrowQueryRunner._read_arrow_version(arrow_info)
2121
assert arrow_version in str(e.value)
2222

2323

0 commit comments

Comments
 (0)