11from __future__ import annotations
22
33import concurrent
4- import json
54import math
65import warnings
76from concurrent .futures import ThreadPoolExecutor
87from typing import Any , Dict , List , NoReturn , Optional
98
109import numpy
11- import pyarrow .flight as flight
1210from pandas import DataFrame
1311from pyarrow import Table
1412from tqdm .auto import tqdm
1513
16- from .arrow_endpoint_version import ArrowEndpointVersion
14+ from .gds_arrow_client import GdsArrowClient
1715from .graph_constructor import GraphConstructor
1816
1917
@@ -22,17 +20,15 @@ def __init__(
2220 self ,
2321 database : str ,
2422 graph_name : str ,
25- flight_client : flight . FlightClient ,
23+ flight_client : GdsArrowClient ,
2624 concurrency : int ,
27- arrow_endpoint_version : ArrowEndpointVersion ,
2825 undirected_relationship_types : Optional [List [str ]],
2926 chunk_size : int = 10_000 ,
3027 ):
3128 self ._database = database
3229 self ._concurrency = concurrency
3330 self ._graph_name = graph_name
3431 self ._client = flight_client
35- self ._arrow_endpoint_version = arrow_endpoint_version
3632 self ._undirected_relationship_types = (
3733 [] if undirected_relationship_types is None else undirected_relationship_types
3834 )
@@ -49,20 +45,20 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
4945 if self ._undirected_relationship_types :
5046 config ["undirected_relationship_types" ] = self ._undirected_relationship_types
5147
52- self ._send_action (
48+ self ._client . send_action (
5349 "CREATE_GRAPH" ,
5450 config ,
5551 )
5652
5753 self ._send_dfs (node_dfs , "node" )
5854
59- self ._send_action ("NODE_LOAD_DONE" , {"name" : self ._graph_name })
55+ self ._client . send_action ("NODE_LOAD_DONE" , {"name" : self ._graph_name })
6056
6157 self ._send_dfs (relationship_dfs , "relationship" )
6258
63- self ._send_action ("RELATIONSHIP_LOAD_DONE" , {"name" : self ._graph_name })
59+ self ._client . send_action ("RELATIONSHIP_LOAD_DONE" , {"name" : self ._graph_name })
6460 except (Exception , KeyboardInterrupt ) as e :
65- self ._send_action ("ABORT" , {"name" : self ._graph_name })
61+ self ._client . send_action ("ABORT" , {"name" : self ._graph_name })
6662
6763 raise e
6864
@@ -85,25 +81,12 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
8581
8682 return partitioned_dfs
8783
88- def _send_action (self , action_type : str , meta_data : Dict [str , Any ]) -> None :
89- action_type = self ._versioned_action_type (action_type )
90- result = self ._client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
91-
92- # Consume result fully to sanity check and avoid cancelled streams
93- collected_result = list (result )
94- assert len (collected_result ) == 1
95-
96- json .loads (collected_result [0 ].body .to_pybytes ().decode ())
97-
9884 def _send_df (self , df : DataFrame , entity_type : str , pbar : tqdm [NoReturn ]) -> None :
9985 table = Table .from_pandas (df )
10086 batches = table .to_batches (self ._chunk_size )
10187 flight_descriptor = {"name" : self ._graph_name , "entity_type" : entity_type }
102- flight_descriptor = self ._versioned_flight_desriptor (flight_descriptor )
10388
104- # Write schema
105- upload_descriptor = flight .FlightDescriptor .for_command (json .dumps (flight_descriptor ).encode ("utf-8" ))
106- writer , _ = self ._client .do_put (upload_descriptor , table .schema )
89+ writer , _ = self ._client .start_put (flight_descriptor , table .schema )
10790
10891 with writer :
10992 # Write table in chunks
@@ -126,17 +109,3 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
126109 if not future .exception ():
127110 continue
128111 raise future .exception () # type: ignore
129-
130- def _versioned_action_type (self , action_type : str ) -> str :
131- return self ._arrow_endpoint_version .prefix () + action_type
132-
133- def _versioned_flight_desriptor (self , flight_descriptor : Dict [str , Any ]) -> Dict [str , Any ]:
134- return (
135- flight_descriptor
136- if self ._arrow_endpoint_version == ArrowEndpointVersion .ALPHA
137- else {
138- "name" : "PUT_MESSAGE" ,
139- "version" : ArrowEndpointVersion .V1 .version (),
140- "body" : flight_descriptor ,
141- }
142- )
0 commit comments