1414
1515from ..call_parameters import CallParameters
1616from ..server_version .server_version import ServerVersion
17+ from .arrow_endpoint_version import ArrowEndpointVersion
1718from .arrow_graph_constructor import ArrowGraphConstructor
18- from .arrow_version import ArrowVersion , UnsupportedArrowVersion
19+ from .arrow_endpoint_version import ArrowEndpointVersion , UnsupportedArrowEndpointVersion
1920from .graph_constructor import GraphConstructor
2021from .query_runner import QueryRunner
2122from graphdatascience .server_version .compatible_with import (
@@ -35,7 +36,7 @@ def create(
3536 arrow_info = ArrowQueryRunner ._get_arrow_debug_info (fallback_query_runner )
3637 server_version = fallback_query_runner .server_version ()
3738 listen_address : str = arrow_info .get ("advertisedListenAddress" , arrow_info ["listenAddress" ]) # type: ignore
38- arrow_version = ArrowQueryRunner ._read_arrow_version (arrow_info )
39+ arrow_endpoint_version = ArrowQueryRunner ._read_arrow_version (arrow_info )
3940
4041 if arrow_info ["running" ]:
4142 return ArrowQueryRunner (
@@ -46,7 +47,7 @@ def create(
4647 encrypted ,
4748 disable_server_verification ,
4849 tls_root_certs ,
49- arrow_version ,
50+ arrow_endpoint_version ,
5051 )
5152 else :
5253 return fallback_query_runner
@@ -56,13 +57,13 @@ def _get_arrow_debug_info(query_runner) -> Series[Any]:
5657 return query_runner .call_procedure (endpoint = "gds.debug.arrow" , custom_error = False ).squeeze ()
5758
5859 @staticmethod
59- def _read_arrow_version (arrow_info : Series [Any ]) -> ArrowVersion :
60+ def _read_arrow_version (arrow_info : Series [Any ]) -> ArrowEndpointVersion :
6061 arrow_version = arrow_info .get ("version" , "alpha" )
6162 try :
62- arrow_version = ArrowVersion [arrow_version .upper ()]
63+ arrow_version = ArrowEndpointVersion [arrow_version .upper ()]
6364 return arrow_version
6465 except KeyError :
65- raise UnsupportedArrowVersion (arrow_version )
66+ raise UnsupportedArrowEndpointVersion (arrow_version )
6667
6768 def __init__ (
6869 self ,
@@ -73,11 +74,11 @@ def __init__(
7374 encrypted : bool = False ,
7475 disable_server_verification : bool = False ,
7576 tls_root_certs : Optional [bytes ] = None ,
76- arrow_version : ArrowVersion = ArrowVersion .ALPHA ,
77+ arrow_endpoint_version : ArrowEndpointVersion = ArrowEndpointVersion .ALPHA ,
7778 ):
7879 self ._fallback_query_runner = fallback_query_runner
7980 self ._server_version = server_version
80- self ._arrow_version = arrow_version
81+ self ._arrow_endpoint_version = arrow_endpoint_version
8182
8283 host , port_string = uri .split (":" )
8384
@@ -286,10 +287,10 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
286287 "configuration" : configuration ,
287288 }
288289
289- if self ._arrow_version == ArrowVersion .V1 :
290+ if self ._arrow_endpoint_version == ArrowEndpointVersion .V1 :
290291 payload = {
291292 "name" : "GET_MESSAGE" ,
292- "version" : ArrowVersion .V1 .name (),
293+ "version" : ArrowEndpointVersion .V1 .name (),
293294 "body" : payload ,
294295 }
295296
@@ -322,7 +323,7 @@ def create_graph_constructor(
322323 )
323324
324325 return ArrowGraphConstructor (
325- database , graph_name , self ._flight_client , concurrency , self ._arrow_version , undirected_relationship_types
326+ database , graph_name , self ._flight_client , concurrency , self ._arrow_endpoint_version , undirected_relationship_types
326327 )
327328
328329 def _sanitize_arrow_table (self , arrow_table : Table ) -> Table :
0 commit comments