Skip to content

Commit 40e67e2

Browse files
committed
Read Arrow Server version from debug procedure
1 parent adae7fd commit 40e67e2

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import time
@@ -14,6 +16,7 @@
1416
from ..server_version.server_version import ServerVersion
1517
from .arrow_graph_constructor import ArrowGraphConstructor
1618
from .graph_constructor import GraphConstructor
19+
from .arrow_version import ArrowVersion, UnsupportedArrowVersion
1720
from .query_runner import QueryRunner
1821
from graphdatascience.server_version.compatible_with import (
1922
IncompatibleServerVersionError,
@@ -32,14 +35,16 @@ def create(
3235
server_version = fallback_query_runner.server_version()
3336

3437
yield_fields = (
35-
["running", "listenAddress"]
38+
["running", "listenAddress", "version"]
3639
if server_version >= ServerVersion(2, 2, 1)
37-
else ["running", "advertisedListenAddress"]
40+
else ["running", "advertisedListenAddress", "version"]
3841
)
39-
arrow_info: "Series[Any]" = fallback_query_runner.call_procedure(
42+
arrow_info: Series[Any] = fallback_query_runner.call_procedure(
4043
endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
4144
).squeeze()
4245
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
46+
arrow_version = ArrowQueryRunner.read_arrow_version(arrow_info)
47+
4348
if arrow_info["running"]:
4449
return ArrowQueryRunner(
4550
listen_address,
@@ -49,10 +54,20 @@ def create(
4954
encrypted,
5055
disable_server_verification,
5156
tls_root_certs,
57+
arrow_version,
5258
)
5359
else:
5460
return fallback_query_runner
5561

62+
@staticmethod
63+
def read_arrow_version(arrow_info: Series[Any]) -> ArrowVersion:
64+
arrow_version = arrow_info.get("version", "alpha")
65+
try:
66+
arrow_version = ArrowVersion[arrow_version.upper()]
67+
return arrow_version
68+
except KeyError:
69+
raise UnsupportedArrowVersion(arrow_version)
70+
5671
def __init__(
5772
self,
5873
uri: str,
@@ -62,9 +77,11 @@ def __init__(
6277
encrypted: bool = False,
6378
disable_server_verification: bool = False,
6479
tls_root_certs: Optional[bytes] = None,
80+
arrow_version: ArrowVersion = ArrowVersion.ALPHA,
6581
):
6682
self._fallback_query_runner = fallback_query_runner
6783
self._server_version = server_version
84+
self._arrow_version = arrow_version
6885

6986
host, port_string = uri.split(":")
7087

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from enum import Enum
2+
3+
4+
class ArrowVersion(Enum):
5+
ALPHA = ""
6+
V1 = "v1/"
7+
8+
def name(self):
9+
return self._name_.lower()
10+
11+
def prefix(self):
12+
return self._value_
13+
14+
15+
class UnsupportedArrowVersion(Exception):
16+
def __init__(self, server_version):
17+
super().__init__(self, f"Unsupported Arrow server version: {server_version}")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from graphdatascience.graph_data_science import ArrowQueryRunner
4+
from graphdatascience.query_runner.arrow_version import ArrowVersion, UnsupportedArrowVersion
5+
6+
7+
@pytest.mark.parametrize("arrow_version", [ArrowVersion.V1, ArrowVersion.ALPHA])
8+
def test_arrow_version_parsing(arrow_version):
9+
arrow_info = dict()
10+
arrow_info["version"] = arrow_version.name()
11+
actual = ArrowQueryRunner.read_arrow_version(arrow_info)
12+
assert actual == arrow_version
13+
14+
15+
@pytest.mark.parametrize("arrow_version", ["v2", "unsupported"])
16+
def test_arrow_version_parsing_fails(arrow_version):
17+
arrow_info = dict()
18+
arrow_info["version"] = arrow_version
19+
with pytest.raises(UnsupportedArrowVersion) as e:
20+
ArrowQueryRunner.read_arrow_version(arrow_info)
21+
assert arrow_version in str(e.value)

0 commit comments

Comments
 (0)