Skip to content

Commit f729e49

Browse files
committed
Add support for reading version lists
1 parent c074f13 commit f729e49

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed
Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,35 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
24

35

46
class ArrowEndpointVersion(Enum):
57
ALPHA = ""
68
V1 = "v1/"
79

8-
def name(self):
10+
def name(self) -> str:
911
return self._name_.lower()
1012

11-
def prefix(self):
13+
def prefix(self) -> str:
1214
return self._value_
1315

16+
@staticmethod
17+
def from_arrow_info(arrow_info: Series[Any]) -> ArrowEndpointVersion:
18+
supported_arrow_versions = arrow_info.get("versions", [])
19+
# Fallback for pre 2.6.0 servers that do not support versions
20+
if len(supported_arrow_versions) == 0:
21+
return ArrowEndpointVersion.ALPHA
22+
23+
# If the server supports versioned endpoints, we try v1 first
24+
if ArrowEndpointVersion.V1.name() in supported_arrow_versions:
25+
return ArrowEndpointVersion.V1
26+
27+
if ArrowEndpointVersion.ALPHA.name() in supported_arrow_versions:
28+
return ArrowEndpointVersion.ALPHA
29+
30+
raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
31+
1432

1533
class UnsupportedArrowEndpointVersion(Exception):
1634
def __init__(self, server_version):
17-
super().__init__(self, f"Unsupported Arrow endpoint version: {server_version}")
35+
super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ..server_version.server_version import ServerVersion
1717
from .arrow_endpoint_version import ArrowEndpointVersion
1818
from .arrow_graph_constructor import ArrowGraphConstructor
19-
from .arrow_endpoint_version import ArrowEndpointVersion, UnsupportedArrowEndpointVersion
2019
from .graph_constructor import GraphConstructor
2120
from .query_runner import QueryRunner
2221
from graphdatascience.server_version.compatible_with import (
@@ -36,7 +35,7 @@ def create(
3635
arrow_info = ArrowQueryRunner._get_arrow_debug_info(fallback_query_runner)
3736
server_version = fallback_query_runner.server_version()
3837
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
39-
arrow_endpoint_version = ArrowQueryRunner._read_arrow_version(arrow_info)
38+
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info)
4039

4140
if arrow_info["running"]:
4241
return ArrowQueryRunner(
@@ -56,15 +55,6 @@ def create(
5655
def _get_arrow_debug_info(query_runner) -> Series[Any]:
5756
return query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze()
5857

59-
@staticmethod
60-
def _read_arrow_version(arrow_info: Series[Any]) -> ArrowEndpointVersion:
61-
arrow_version = arrow_info.get("version", "alpha")
62-
try:
63-
arrow_version = ArrowEndpointVersion[arrow_version.upper()]
64-
return arrow_version
65-
except KeyError:
66-
raise UnsupportedArrowEndpointVersion(arrow_version)
67-
6858
def __init__(
6959
self,
7060
uri: str,
@@ -323,7 +313,12 @@ def create_graph_constructor(
323313
)
324314

325315
return ArrowGraphConstructor(
326-
database, graph_name, self._flight_client, concurrency, self._arrow_endpoint_version, undirected_relationship_types
316+
database,
317+
graph_name,
318+
self._flight_client,
319+
concurrency,
320+
self._arrow_endpoint_version,
321+
undirected_relationship_types,
327322
)
328323

329324
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:

graphdatascience/tests/unit/test_arrow_endpoint_version.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,35 @@
77
)
88

99

10-
@pytest.mark.parametrize("arrow_version", [ArrowEndpointVersion.V1, ArrowEndpointVersion.ALPHA])
11-
def test_arrow_version_parsing(arrow_version):
10+
@pytest.mark.parametrize(
11+
"arrow_versions",
12+
[
13+
(ArrowEndpointVersion.ALPHA, []),
14+
(ArrowEndpointVersion.ALPHA, ["alpha"]),
15+
(ArrowEndpointVersion.V1, ["v1"]),
16+
(ArrowEndpointVersion.V1, ["alpha", "v1"]),
17+
(ArrowEndpointVersion.V1, ["v1", "v2"]),
18+
(ArrowEndpointVersion.ALPHA, ["alpha"]),
19+
(ArrowEndpointVersion.ALPHA, ["alpha", "v2"]),
20+
],
21+
)
22+
def test_from_arrow_info_multiple_versions(arrow_versions):
1223
arrow_info = dict()
13-
arrow_info["version"] = arrow_version.name()
14-
actual = ArrowQueryRunner._read_arrow_version(arrow_info)
15-
assert actual == arrow_version
24+
arrow_info["versions"] = arrow_versions[1]
25+
actual = ArrowEndpointVersion.from_arrow_info(arrow_info)
26+
assert actual == arrow_versions[0]
1627

1728

1829
@pytest.mark.parametrize("arrow_version", ["v2", "unsupported"])
19-
def test_arrow_version_parsing_fails(arrow_version):
30+
def test_from_arrow_info_fails(arrow_version):
2031
arrow_info = dict()
21-
arrow_info["version"] = arrow_version
32+
arrow_info["versions"] = [arrow_version]
2233
with pytest.raises(UnsupportedArrowEndpointVersion) as e:
23-
ArrowQueryRunner._read_arrow_version(arrow_info)
34+
ArrowEndpointVersion.from_arrow_info(arrow_info)
35+
assert "Unsupported" in str(e.value)
2436
assert arrow_version in str(e.value)
2537

2638

27-
def test_arrow_version_prefix():
39+
def test_prefix():
2840
assert ArrowEndpointVersion.ALPHA.prefix() == ""
2941
assert ArrowEndpointVersion.V1.prefix() == "v1/"

0 commit comments

Comments
 (0)