1212from tqdm .auto import tqdm
1313
1414from .graph_constructor import GraphConstructor
15+ from .arrow_version import ArrowVersion
1516
1617
1718class ArrowGraphConstructor (GraphConstructor ):
@@ -21,13 +22,15 @@ def __init__(
2122 graph_name : str ,
2223 flight_client : flight .FlightClient ,
2324 concurrency : int ,
25+ arrow_version : ArrowVersion ,
2426 undirected_relationship_types : Optional [List [str ]],
2527 chunk_size : int = 10_000 ,
2628 ):
2729 self ._database = database
2830 self ._concurrency = concurrency
2931 self ._graph_name = graph_name
3032 self ._client = flight_client
33+ self ._arrow_version = arrow_version
3134 self ._undirected_relationship_types = (
3235 [] if undirected_relationship_types is None else undirected_relationship_types
3336 )
@@ -81,6 +84,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
8184 return partitioned_dfs
8285
8386 def _send_action (self , action_type : str , meta_data : Dict [str , Any ]) -> None :
87+ action_type = self ._versioned_action_type (action_type )
8488 result = self ._client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
8589
8690 # Consume result fully to sanity check and avoid cancelled streams
@@ -93,6 +97,7 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
9397 table = Table .from_pandas (df )
9498 batches = table .to_batches (self ._chunk_size )
9599 flight_descriptor = {"name" : self ._graph_name , "entity_type" : entity_type }
100+ flight_descriptor = self ._versioned_flight_desriptor (flight_descriptor )
96101
97102 # Write schema
98103 upload_descriptor = flight .FlightDescriptor .for_command (json .dumps (flight_descriptor ).encode ("utf-8" ))
@@ -117,3 +122,17 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
117122 if not future .exception ():
118123 continue
119124 raise future .exception () # type: ignore
125+
126+ def _versioned_action_type (self , action_type : str ) -> str :
127+ return self ._arrow_version .prefix () + action_type
128+
129+ def _versioned_flight_desriptor (self , flight_descriptor : dict ) -> dict :
130+ return (
131+ flight_descriptor
132+ if self ._arrow_version == ArrowVersion .ALPHA
133+ else {
134+ "name" : "PUT_MESSAGE" ,
135+ "version" : ArrowVersion .V1 .name (),
136+ "body" : flight_descriptor ,
137+ }
138+ )
0 commit comments