1+ from __future__ import annotations
2+
13import base64
24import json
35import time
1416from ..server_version .server_version import ServerVersion
1517from .arrow_graph_constructor import ArrowGraphConstructor
1618from .graph_constructor import GraphConstructor
19+ from .arrow_version import ArrowVersion , UnsupportedArrowVersion
1720from .query_runner import QueryRunner
1821from graphdatascience .server_version .compatible_with import (
1922 IncompatibleServerVersionError ,
@@ -32,14 +35,16 @@ def create(
3235 server_version = fallback_query_runner .server_version ()
3336
3437 yield_fields = (
35- ["running" , "listenAddress" ]
38+ ["running" , "listenAddress" , "version" ]
3639 if server_version >= ServerVersion (2 , 2 , 1 )
37- else ["running" , "advertisedListenAddress" ]
40+ else ["running" , "advertisedListenAddress" , "version" ]
3841 )
39- arrow_info : " Series[Any]" = fallback_query_runner .call_procedure (
42+ arrow_info : Series [Any ] = fallback_query_runner .call_procedure (
4043 endpoint = "gds.debug.arrow" , yields = yield_fields , custom_error = False
4144 ).squeeze ()
4245 listen_address : str = arrow_info .get ("advertisedListenAddress" , arrow_info ["listenAddress" ]) # type: ignore
46+ arrow_version = ArrowQueryRunner .read_arrow_version (arrow_info )
47+
4348 if arrow_info ["running" ]:
4449 return ArrowQueryRunner (
4550 listen_address ,
@@ -49,10 +54,20 @@ def create(
4954 encrypted ,
5055 disable_server_verification ,
5156 tls_root_certs ,
57+ arrow_version ,
5258 )
5359 else :
5460 return fallback_query_runner
5561
62+ @staticmethod
63+ def read_arrow_version (arrow_info : Series [Any ]) -> ArrowVersion :
64+ arrow_version = arrow_info .get ("version" , "alpha" )
65+ try :
66+ arrow_version = ArrowVersion [arrow_version .upper ()]
67+ return arrow_version
68+ except KeyError :
69+ raise UnsupportedArrowVersion (arrow_version )
70+
5671 def __init__ (
5772 self ,
5873 uri : str ,
@@ -62,9 +77,11 @@ def __init__(
6277 encrypted : bool = False ,
6378 disable_server_verification : bool = False ,
6479 tls_root_certs : Optional [bytes ] = None ,
80+ arrow_version : ArrowVersion = ArrowVersion .ALPHA ,
6581 ):
6682 self ._fallback_query_runner = fallback_query_runner
6783 self ._server_version = server_version
84+ self ._arrow_version = arrow_version
6885
6986 host , port_string = uri .split (":" )
7087
0 commit comments