Skip to content

Commit c074f13

Browse files
committed
Rename ArrowVersion -> ArrowEndpointVersion
1 parent d65f1dc commit c074f13

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33

4-
class ArrowVersion(Enum):
4+
class ArrowEndpointVersion(Enum):
55
ALPHA = ""
66
V1 = "v1/"
77

@@ -12,6 +12,6 @@ def prefix(self):
1212
return self._value_
1313

1414

15-
class UnsupportedArrowVersion(Exception):
15+
class UnsupportedArrowEndpointVersion(Exception):
1616
def __init__(self, server_version):
17-
super().__init__(self, f"Unsupported Arrow server version: {server_version}")
17+
super().__init__(self, f"Unsupported Arrow endpoint version: {server_version}")

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyarrow import Table
1212
from tqdm.auto import tqdm
1313

14-
from .arrow_version import ArrowVersion
14+
from .arrow_endpoint_version import ArrowEndpointVersion
1515
from .graph_constructor import GraphConstructor
1616

1717

@@ -22,15 +22,15 @@ def __init__(
2222
graph_name: str,
2323
flight_client: flight.FlightClient,
2424
concurrency: int,
25-
arrow_version: ArrowVersion,
25+
arrow_endpoint_version: ArrowEndpointVersion,
2626
undirected_relationship_types: Optional[List[str]],
2727
chunk_size: int = 10_000,
2828
):
2929
self._database = database
3030
self._concurrency = concurrency
3131
self._graph_name = graph_name
3232
self._client = flight_client
33-
self._arrow_version = arrow_version
33+
self._arrow_endpoint_version = arrow_endpoint_version
3434
self._undirected_relationship_types = (
3535
[] if undirected_relationship_types is None else undirected_relationship_types
3636
)
@@ -124,15 +124,15 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
124124
raise future.exception() # type: ignore
125125

126126
def _versioned_action_type(self, action_type: str) -> str:
127-
return self._arrow_version.prefix() + action_type
127+
return self._arrow_endpoint_version.prefix() + action_type
128128

129129
def _versioned_flight_desriptor(self, flight_descriptor: dict) -> dict:
130130
return (
131131
flight_descriptor
132-
if self._arrow_version == ArrowVersion.ALPHA
132+
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
133133
else {
134134
"name": "PUT_MESSAGE",
135-
"version": ArrowVersion.V1.name(),
135+
"version": ArrowEndpointVersion.V1.name(),
136136
"body": flight_descriptor,
137137
}
138138
)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from ..call_parameters import CallParameters
1616
from ..server_version.server_version import ServerVersion
17+
from .arrow_endpoint_version import ArrowEndpointVersion
1718
from .arrow_graph_constructor import ArrowGraphConstructor
18-
from .arrow_version import ArrowVersion, UnsupportedArrowVersion
19+
from .arrow_endpoint_version import ArrowEndpointVersion, UnsupportedArrowEndpointVersion
1920
from .graph_constructor import GraphConstructor
2021
from .query_runner import QueryRunner
2122
from graphdatascience.server_version.compatible_with import (
@@ -35,7 +36,7 @@ def create(
3536
arrow_info = ArrowQueryRunner._get_arrow_debug_info(fallback_query_runner)
3637
server_version = fallback_query_runner.server_version()
3738
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
38-
arrow_version = ArrowQueryRunner._read_arrow_version(arrow_info)
39+
arrow_endpoint_version = ArrowQueryRunner._read_arrow_version(arrow_info)
3940

4041
if arrow_info["running"]:
4142
return ArrowQueryRunner(
@@ -46,7 +47,7 @@ def create(
4647
encrypted,
4748
disable_server_verification,
4849
tls_root_certs,
49-
arrow_version,
50+
arrow_endpoint_version,
5051
)
5152
else:
5253
return fallback_query_runner
@@ -56,13 +57,13 @@ def _get_arrow_debug_info(query_runner) -> Series[Any]:
5657
return query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze()
5758

5859
@staticmethod
59-
def _read_arrow_version(arrow_info: Series[Any]) -> ArrowVersion:
60+
def _read_arrow_version(arrow_info: Series[Any]) -> ArrowEndpointVersion:
6061
arrow_version = arrow_info.get("version", "alpha")
6162
try:
62-
arrow_version = ArrowVersion[arrow_version.upper()]
63+
arrow_version = ArrowEndpointVersion[arrow_version.upper()]
6364
return arrow_version
6465
except KeyError:
65-
raise UnsupportedArrowVersion(arrow_version)
66+
raise UnsupportedArrowEndpointVersion(arrow_version)
6667

6768
def __init__(
6869
self,
@@ -73,11 +74,11 @@ def __init__(
7374
encrypted: bool = False,
7475
disable_server_verification: bool = False,
7576
tls_root_certs: Optional[bytes] = None,
76-
arrow_version: ArrowVersion = ArrowVersion.ALPHA,
77+
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
7778
):
7879
self._fallback_query_runner = fallback_query_runner
7980
self._server_version = server_version
80-
self._arrow_version = arrow_version
81+
self._arrow_endpoint_version = arrow_endpoint_version
8182

8283
host, port_string = uri.split(":")
8384

@@ -286,10 +287,10 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
286287
"configuration": configuration,
287288
}
288289

289-
if self._arrow_version == ArrowVersion.V1:
290+
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
290291
payload = {
291292
"name": "GET_MESSAGE",
292-
"version": ArrowVersion.V1.name(),
293+
"version": ArrowEndpointVersion.V1.name(),
293294
"body": payload,
294295
}
295296

@@ -322,7 +323,7 @@ def create_graph_constructor(
322323
)
323324

324325
return ArrowGraphConstructor(
325-
database, graph_name, self._flight_client, concurrency, self._arrow_version, undirected_relationship_types
326+
database, graph_name, self._flight_client, concurrency, self._arrow_endpoint_version, undirected_relationship_types
326327
)
327328

328329
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:

graphdatascience/tests/unit/test_arrow_version.py renamed to graphdatascience/tests/unit/test_arrow_endpoint_version.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
22

33
from graphdatascience.graph_data_science import ArrowQueryRunner
4-
from graphdatascience.query_runner.arrow_version import (
5-
ArrowVersion,
6-
UnsupportedArrowVersion,
4+
from graphdatascience.query_runner.arrow_endpoint_version import (
5+
ArrowEndpointVersion,
6+
UnsupportedArrowEndpointVersion,
77
)
88

99

10-
@pytest.mark.parametrize("arrow_version", [ArrowVersion.V1, ArrowVersion.ALPHA])
10+
@pytest.mark.parametrize("arrow_version", [ArrowEndpointVersion.V1, ArrowEndpointVersion.ALPHA])
1111
def test_arrow_version_parsing(arrow_version):
1212
arrow_info = dict()
1313
arrow_info["version"] = arrow_version.name()
@@ -19,11 +19,11 @@ def test_arrow_version_parsing(arrow_version):
1919
def test_arrow_version_parsing_fails(arrow_version):
2020
arrow_info = dict()
2121
arrow_info["version"] = arrow_version
22-
with pytest.raises(UnsupportedArrowVersion) as e:
22+
with pytest.raises(UnsupportedArrowEndpointVersion) as e:
2323
ArrowQueryRunner._read_arrow_version(arrow_info)
2424
assert arrow_version in str(e.value)
2525

2626

2727
def test_arrow_version_prefix():
28-
assert ArrowVersion.ALPHA.prefix() == ""
29-
assert ArrowVersion.V1.prefix() == "v1/"
28+
assert ArrowEndpointVersion.ALPHA.prefix() == ""
29+
assert ArrowEndpointVersion.V1.prefix() == "v1/"

0 commit comments

Comments
 (0)