44
55import pyarrow as pa
66import pytest
7- from pyarrow import flight
87from pyarrow ._flight import GeneratorStream
9- from pyarrow .flight import Action , Ticket
8+ from pyarrow .flight import (
9+ Action ,
10+ FlightServerBase ,
11+ FlightServerError ,
12+ FlightTimedOutError ,
13+ FlightUnavailableError ,
14+ Ticket ,
15+ )
1016
1117from graphdatascience .query_runner .gds_arrow_client import AuthMiddleware , GdsArrowClient
1218
1319ActionParam = Union [str , tuple [str , Any ], Action ]
1420
1521
16- class FlightServer (flight . FlightServerBase ): # type: ignore
22+ class FlightServer (FlightServerBase ): # type: ignore
1723 def __init__ (self , location : str = "grpc://0.0.0.0:0" , ** kwargs : dict [str , Any ]) -> None :
1824 super (FlightServer , self ).__init__ (location , ** kwargs )
1925 self ._location : str = location
@@ -49,18 +55,81 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
4955 return [json .dumps (response ).encode ("utf-8" )]
5056
5157
58+ class FlakyFlightServer (FlightServerBase ): # type: ignore
59+ def __init__ (self , location : str = "grpc://0.0.0.0:0" , ** kwargs : dict [str , Any ]) -> None :
60+ super (FlakyFlightServer , self ).__init__ (location , ** kwargs )
61+ self ._location : str = location
62+ self ._actions : list [ActionParam ] = []
63+ self ._tickets : list [Ticket ] = []
64+ self ._expected_failures = [
65+ FlightUnavailableError ("Flight server is unavailable" , "some reason" ),
66+ FlightTimedOutError ("Time out for some reason" , "still timed out" ),
67+ ]
68+ self ._expected_retries = len (self ._expected_failures ) + 1
69+
70+ def expected_retries (self ) -> int :
71+ return self ._expected_retries
72+
73+ def do_get (self , context : Any , ticket : Ticket ) -> GeneratorStream :
74+ self ._tickets .append (ticket )
75+
76+ if len (self ._expected_failures ) > 0 :
77+ raise self ._expected_failures .pop ()
78+
79+ table = pa .Table .from_pydict ({"ids" : [42 , 1337 , 1234 ]})
80+ return GeneratorStream (schema = table .schema , generator = table .to_batches ())
81+
82+ def do_action (self , context : Any , action : ActionParam ) -> list [bytes ]:
83+ self ._actions .append (action )
84+
85+ if len (self ._expected_failures ) > 0 :
86+ raise self ._expected_failures .pop ()
87+
88+ if isinstance (action , Action ):
89+ actionType = action .type
90+ elif isinstance (action , tuple ):
91+ actionType = action [0 ]
92+ elif isinstance (action , str ):
93+ actionType = action
94+
95+ response : dict [str , Any ] = {}
96+ if "CREATE" in actionType :
97+ response = {"name" : "g" }
98+ elif "NODE_LOAD_DONE" in actionType :
99+ response = {"name" : "g" , "node_count" : 42 }
100+ elif "RELATIONSHIP_LOAD_DONE" in actionType :
101+ response = {"name" : "g" , "relationship_count" : 42 }
102+ elif "TRIPLET_LOAD_DONE" in actionType :
103+ response = {"name" : "g" , "node_count" : 42 , "relationship_count" : 1337 }
104+ else :
105+ response = {}
106+ return [json .dumps (response ).encode ("utf-8" )]
107+
108+
52109@pytest .fixture ()
53110def flight_server () -> Generator [None , FlightServer , None ]:
54111 with FlightServer () as server :
55112 yield server
56113
57114
115+ @pytest .fixture ()
116+ def flaky_flight_server () -> Generator [None , FlakyFlightServer , None ]:
117+ with FlakyFlightServer () as server :
118+ yield server
119+
120+
58121@pytest .fixture ()
59122def flight_client (flight_server : FlightServer ) -> Generator [GdsArrowClient , None , None ]:
60123 with GdsArrowClient ("localhost" , flight_server .port ) as client :
61124 yield client
62125
63126
127+ @pytest .fixture ()
128+ def flaky_flight_client (flaky_flight_server : FlakyFlightServer ) -> Generator [GdsArrowClient , None , None ]:
129+ with GdsArrowClient ("localhost" , flaky_flight_server .port ) as client :
130+ yield client
131+
132+
64133def test_create_graph_with_defaults (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
65134 flight_client .create_graph ("g" , "DB" )
66135 actions = flight_server ._actions
@@ -87,6 +156,15 @@ def test_create_graph_with_options(flight_server: FlightServer, flight_client: G
87156 )
88157
89158
159+ def test_create_graph_with_flaky_server (
160+ flaky_flight_server : FlakyFlightServer , flaky_flight_client : GdsArrowClient
161+ ) -> None :
162+ flaky_flight_client .create_graph ("g" , "DB" )
163+ actions = flaky_flight_server ._actions
164+ assert len (actions ) == flaky_flight_server .expected_retries ()
165+ assert_action (actions [0 ], "v1/CREATE_GRAPH" , {"name" : "g" , "database_name" : "DB" })
166+
167+
90168def test_create_graph_from_triplets_with_defaults (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
91169 flight_client .create_graph_from_triplets ("g" , "DB" )
92170 actions = flight_server ._actions
@@ -202,6 +280,22 @@ def test_get_node_property(flight_server: FlightServer, flight_client: GdsArrowC
202280 )
203281
204282
283+ def test_flakey_get_node_property (flaky_flight_server : FlakyFlightServer , flaky_flight_client : GdsArrowClient ) -> None :
284+ flaky_flight_client .get_node_properties ("g" , "db" , "id" , ["Person" ], concurrency = 42 )
285+ tickets = flaky_flight_server ._tickets
286+ assert len (tickets ) == flaky_flight_server .expected_retries ()
287+ assert_ticket (
288+ tickets [0 ],
289+ {
290+ "concurrency" : 42 ,
291+ "configuration" : {"list_node_labels" : False , "node_labels" : ["Person" ], "node_property" : "id" },
292+ "database_name" : "db" ,
293+ "graph_name" : "g" ,
294+ "procedure_name" : "gds.graph.nodeProperty.stream" ,
295+ },
296+ )
297+
298+
205299def test_get_node_properties (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
206300 flight_client .get_node_properties ("g" , "db" , ["foo" , "bar" ], ["Person" ], list_node_labels = True , concurrency = 42 )
207301 tickets = flight_server ._tickets
@@ -314,21 +408,21 @@ def test_auth_middleware_bad_headers() -> None:
314408
315409def test_handle_flight_error () -> None :
316410 with pytest .raises (
317- flight . FlightServerError ,
411+ FlightServerError ,
318412 match = "FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database." ,
319413 ):
320414 GdsArrowClient .handle_flight_error (
321- flight . FlightServerError (
415+ FlightServerError (
322416 'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
323417 )
324418 )
325419
326420 with pytest .raises (
327- flight . FlightServerError ,
421+ FlightServerError ,
328422 match = re .escape ("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]" ),
329423 ):
330424 GdsArrowClient .handle_flight_error (
331- flight . FlightServerError (
425+ FlightServerError (
332426 "FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
333427 )
334428 )
0 commit comments