Skip to content

Commit 4264ca0

Browse files
committed
Fix typing issues
1 parent f729e49 commit 4264ca0

File tree

5 files changed

+24
-29
lines changed

5 files changed

+24
-29
lines changed

graphdatascience/error/endpoint_suggester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def generate_suggestive_error_message(requested_endpoint: str, all_endpoints: Li
99
MIN_SIMILARITY_FOR_SUGGESTION = 0.9
1010

1111
closest_endpoint = None
12-
curr_max_similarity = 0
12+
curr_max_similarity = 0.0
1313
for ep in all_endpoints:
1414
similarity = textdistance.jaro_winkler(requested_endpoint, ep)
1515
if similarity >= MIN_SIMILARITY_FOR_SUGGESTION:
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,35 @@
11
from __future__ import annotations
22

33
from enum import Enum
4+
from typing import List
45

56

67
class ArrowEndpointVersion(Enum):
78
ALPHA = ""
89
V1 = "v1/"
910

10-
def name(self) -> str:
11+
def version(self) -> str:
1112
return self._name_.lower()
1213

1314
def prefix(self) -> str:
1415
return self._value_
1516

1617
@staticmethod
17-
def from_arrow_info(arrow_info: Series[Any]) -> ArrowEndpointVersion:
18-
supported_arrow_versions = arrow_info.get("versions", [])
18+
def from_arrow_info(supported_arrow_versions: List[str]) -> ArrowEndpointVersion:
1919
# Fallback for pre 2.6.0 servers that do not support versions
2020
if len(supported_arrow_versions) == 0:
2121
return ArrowEndpointVersion.ALPHA
2222

2323
# If the server supports versioned endpoints, we try v1 first
24-
if ArrowEndpointVersion.V1.name() in supported_arrow_versions:
24+
if ArrowEndpointVersion.V1.version() in supported_arrow_versions:
2525
return ArrowEndpointVersion.V1
2626

27-
if ArrowEndpointVersion.ALPHA.name() in supported_arrow_versions:
27+
if ArrowEndpointVersion.ALPHA.version() in supported_arrow_versions:
2828
return ArrowEndpointVersion.ALPHA
2929

3030
raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
3131

3232

3333
class UnsupportedArrowEndpointVersion(Exception):
34-
def __init__(self, server_version):
34+
def __init__(self, server_version: List[str]) -> None:
3535
super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
126126
def _versioned_action_type(self, action_type: str) -> str:
127127
return self._arrow_endpoint_version.prefix() + action_type
128128

129-
def _versioned_flight_desriptor(self, flight_descriptor: dict) -> dict:
129+
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
130130
return (
131131
flight_descriptor
132132
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
133133
else {
134134
"name": "PUT_MESSAGE",
135-
"version": ArrowEndpointVersion.V1.name(),
135+
"version": ArrowEndpointVersion.V1.version(),
136136
"body": flight_descriptor,
137137
}
138138
)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Dict, List, Optional, Tuple
88

99
import pyarrow.flight as flight
10-
from pandas import DataFrame, Series
10+
from pandas import DataFrame
1111
from pyarrow import ChunkedArray, Table, chunked_array
1212
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
1313
from pyarrow.types import is_dictionary # type: ignore
@@ -31,11 +31,13 @@ def create(
3131
encrypted: bool = False,
3232
disable_server_verification: bool = False,
3333
tls_root_certs: Optional[bytes] = None,
34-
) -> "QueryRunner":
35-
arrow_info = ArrowQueryRunner._get_arrow_debug_info(fallback_query_runner)
34+
) -> QueryRunner:
35+
arrow_info = (
36+
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
37+
)
3638
server_version = fallback_query_runner.server_version()
37-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
38-
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info)
39+
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40+
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
3941

4042
if arrow_info["running"]:
4143
return ArrowQueryRunner(
@@ -51,10 +53,6 @@ def create(
5153
else:
5254
return fallback_query_runner
5355

54-
@staticmethod
55-
def _get_arrow_debug_info(query_runner) -> Series[Any]:
56-
return query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze()
57-
5856
def __init__(
5957
self,
6058
uri: str,
@@ -280,7 +278,7 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
280278
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
281279
payload = {
282280
"name": "GET_MESSAGE",
283-
"version": ArrowEndpointVersion.V1.name(),
281+
"version": ArrowEndpointVersion.V1.version(),
284282
"body": payload,
285283
}
286284

graphdatascience/tests/unit/test_arrow_endpoint_version.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import List, Tuple
2+
13
import pytest
24

3-
from graphdatascience.graph_data_science import ArrowQueryRunner
45
from graphdatascience.query_runner.arrow_endpoint_version import (
56
ArrowEndpointVersion,
67
UnsupportedArrowEndpointVersion,
@@ -19,23 +20,19 @@
1920
(ArrowEndpointVersion.ALPHA, ["alpha", "v2"]),
2021
],
2122
)
22-
def test_from_arrow_info_multiple_versions(arrow_versions):
23-
arrow_info = dict()
24-
arrow_info["versions"] = arrow_versions[1]
25-
actual = ArrowEndpointVersion.from_arrow_info(arrow_info)
23+
def test_from_arrow_info_multiple_versions(arrow_versions: Tuple[ArrowEndpointVersion, List[str]]) -> None:
24+
actual = ArrowEndpointVersion.from_arrow_info(arrow_versions[1])
2625
assert actual == arrow_versions[0]
2726

2827

2928
@pytest.mark.parametrize("arrow_version", ["v2", "unsupported"])
30-
def test_from_arrow_info_fails(arrow_version):
31-
arrow_info = dict()
32-
arrow_info["versions"] = [arrow_version]
29+
def test_from_arrow_info_fails(arrow_version: str) -> None:
3330
with pytest.raises(UnsupportedArrowEndpointVersion) as e:
34-
ArrowEndpointVersion.from_arrow_info(arrow_info)
31+
ArrowEndpointVersion.from_arrow_info([arrow_version])
3532
assert "Unsupported" in str(e.value)
3633
assert arrow_version in str(e.value)
3734

3835

39-
def test_prefix():
36+
def test_prefix() -> None:
4037
assert ArrowEndpointVersion.ALPHA.prefix() == ""
4138
assert ArrowEndpointVersion.V1.prefix() == "v1/"

0 commit comments

Comments
 (0)