Skip to content

Commit ed42fed

Browse files
authored
Merge pull request #583 from s1ck/arrow-version-support
Arrow Server version support
2 parents adae7fd + 4264ca0 commit ed42fed

File tree

5 files changed

+120
-14
lines changed

5 files changed

+120
-14
lines changed

graphdatascience/error/endpoint_suggester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def generate_suggestive_error_message(requested_endpoint: str, all_endpoints: Li
99
MIN_SIMILARITY_FOR_SUGGESTION = 0.9
1010

1111
closest_endpoint = None
12-
curr_max_similarity = 0
12+
curr_max_similarity = 0.0
1313
for ep in all_endpoints:
1414
similarity = textdistance.jaro_winkler(requested_endpoint, ep)
1515
if similarity >= MIN_SIMILARITY_FOR_SUGGESTION:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
from typing import List
5+
6+
7+
class ArrowEndpointVersion(Enum):
8+
ALPHA = ""
9+
V1 = "v1/"
10+
11+
def version(self) -> str:
12+
return self._name_.lower()
13+
14+
def prefix(self) -> str:
15+
return self._value_
16+
17+
@staticmethod
18+
def from_arrow_info(supported_arrow_versions: List[str]) -> ArrowEndpointVersion:
19+
# Fallback for pre 2.6.0 servers that do not support versions
20+
if len(supported_arrow_versions) == 0:
21+
return ArrowEndpointVersion.ALPHA
22+
23+
# If the server supports versioned endpoints, we try v1 first
24+
if ArrowEndpointVersion.V1.version() in supported_arrow_versions:
25+
return ArrowEndpointVersion.V1
26+
27+
if ArrowEndpointVersion.ALPHA.version() in supported_arrow_versions:
28+
return ArrowEndpointVersion.ALPHA
29+
30+
raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
31+
32+
33+
class UnsupportedArrowEndpointVersion(Exception):
34+
def __init__(self, server_version: List[str]) -> None:
35+
super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pyarrow import Table
1212
from tqdm.auto import tqdm
1313

14+
from .arrow_endpoint_version import ArrowEndpointVersion
1415
from .graph_constructor import GraphConstructor
1516

1617

@@ -21,13 +22,15 @@ def __init__(
2122
graph_name: str,
2223
flight_client: flight.FlightClient,
2324
concurrency: int,
25+
arrow_endpoint_version: ArrowEndpointVersion,
2426
undirected_relationship_types: Optional[List[str]],
2527
chunk_size: int = 10_000,
2628
):
2729
self._database = database
2830
self._concurrency = concurrency
2931
self._graph_name = graph_name
3032
self._client = flight_client
33+
self._arrow_endpoint_version = arrow_endpoint_version
3134
self._undirected_relationship_types = (
3235
[] if undirected_relationship_types is None else undirected_relationship_types
3336
)
@@ -81,6 +84,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
8184
return partitioned_dfs
8285

8386
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
87+
action_type = self._versioned_action_type(action_type)
8488
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
8589

8690
# Consume result fully to sanity check and avoid cancelled streams
@@ -93,6 +97,7 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
9397
table = Table.from_pandas(df)
9498
batches = table.to_batches(self._chunk_size)
9599
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
100+
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
96101

97102
# Write schema
98103
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
@@ -117,3 +122,17 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
117122
if not future.exception():
118123
continue
119124
raise future.exception() # type: ignore
125+
126+
def _versioned_action_type(self, action_type: str) -> str:
127+
return self._arrow_endpoint_version.prefix() + action_type
128+
129+
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
130+
return (
131+
flight_descriptor
132+
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
133+
else {
134+
"name": "PUT_MESSAGE",
135+
"version": ArrowEndpointVersion.V1.version(),
136+
"body": flight_descriptor,
137+
}
138+
)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import time
46
import warnings
57
from typing import Any, Dict, List, Optional, Tuple
68

79
import pyarrow.flight as flight
8-
from pandas import DataFrame, Series
10+
from pandas import DataFrame
911
from pyarrow import ChunkedArray, Table, chunked_array
1012
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
1113
from pyarrow.types import is_dictionary # type: ignore
1214

1315
from ..call_parameters import CallParameters
1416
from ..server_version.server_version import ServerVersion
17+
from .arrow_endpoint_version import ArrowEndpointVersion
1518
from .arrow_graph_constructor import ArrowGraphConstructor
1619
from .graph_constructor import GraphConstructor
1720
from .query_runner import QueryRunner
@@ -28,18 +31,14 @@ def create(
2831
encrypted: bool = False,
2932
disable_server_verification: bool = False,
3033
tls_root_certs: Optional[bytes] = None,
31-
) -> "QueryRunner":
34+
) -> QueryRunner:
35+
arrow_info = (
36+
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
37+
)
3238
server_version = fallback_query_runner.server_version()
39+
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40+
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
3341

34-
yield_fields = (
35-
["running", "listenAddress"]
36-
if server_version >= ServerVersion(2, 2, 1)
37-
else ["running", "advertisedListenAddress"]
38-
)
39-
arrow_info: "Series[Any]" = fallback_query_runner.call_procedure(
40-
endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
41-
).squeeze()
42-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
4342
if arrow_info["running"]:
4443
return ArrowQueryRunner(
4544
listen_address,
@@ -49,6 +48,7 @@ def create(
4948
encrypted,
5049
disable_server_verification,
5150
tls_root_certs,
51+
arrow_endpoint_version,
5252
)
5353
else:
5454
return fallback_query_runner
@@ -62,9 +62,11 @@ def __init__(
6262
encrypted: bool = False,
6363
disable_server_verification: bool = False,
6464
tls_root_certs: Optional[bytes] = None,
65+
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
6566
):
6667
self._fallback_query_runner = fallback_query_runner
6768
self._server_version = server_version
69+
self._arrow_endpoint_version = arrow_endpoint_version
6870

6971
host, port_string = uri.split(":")
7072

@@ -272,8 +274,15 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
272274
"procedure_name": procedure_name,
273275
"configuration": configuration,
274276
}
275-
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
276277

278+
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
279+
payload = {
280+
"name": "GET_MESSAGE",
281+
"version": ArrowEndpointVersion.V1.version(),
282+
"body": payload,
283+
}
284+
285+
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
277286
get = self._flight_client.do_get(ticket)
278287
arrow_table = get.read_all()
279288

@@ -302,7 +311,12 @@ def create_graph_constructor(
302311
)
303312

304313
return ArrowGraphConstructor(
305-
database, graph_name, self._flight_client, concurrency, undirected_relationship_types
314+
database,
315+
graph_name,
316+
self._flight_client,
317+
concurrency,
318+
self._arrow_endpoint_version,
319+
undirected_relationship_types,
306320
)
307321

308322
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import List, Tuple
2+
3+
import pytest
4+
5+
from graphdatascience.query_runner.arrow_endpoint_version import (
6+
ArrowEndpointVersion,
7+
UnsupportedArrowEndpointVersion,
8+
)
9+
10+
11+
@pytest.mark.parametrize(
12+
"arrow_versions",
13+
[
14+
(ArrowEndpointVersion.ALPHA, []),
15+
(ArrowEndpointVersion.ALPHA, ["alpha"]),
16+
(ArrowEndpointVersion.V1, ["v1"]),
17+
(ArrowEndpointVersion.V1, ["alpha", "v1"]),
18+
(ArrowEndpointVersion.V1, ["v1", "v2"]),
19+
(ArrowEndpointVersion.ALPHA, ["alpha"]),
20+
(ArrowEndpointVersion.ALPHA, ["alpha", "v2"]),
21+
],
22+
)
23+
def test_from_arrow_info_multiple_versions(arrow_versions: Tuple[ArrowEndpointVersion, List[str]]) -> None:
24+
actual = ArrowEndpointVersion.from_arrow_info(arrow_versions[1])
25+
assert actual == arrow_versions[0]
26+
27+
28+
@pytest.mark.parametrize("arrow_version", ["v2", "unsupported"])
29+
def test_from_arrow_info_fails(arrow_version: str) -> None:
30+
with pytest.raises(UnsupportedArrowEndpointVersion) as e:
31+
ArrowEndpointVersion.from_arrow_info([arrow_version])
32+
assert "Unsupported" in str(e.value)
33+
assert arrow_version in str(e.value)
34+
35+
36+
def test_prefix() -> None:
37+
assert ArrowEndpointVersion.ALPHA.prefix() == ""
38+
assert ArrowEndpointVersion.V1.prefix() == "v1/"

0 commit comments

Comments
 (0)