Skip to content

Commit 25517b1

Browse files
committed
Send versioned actions and put messages
1 parent 7eb3a5f commit 25517b1

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tqdm.auto import tqdm
1313

1414
from .graph_constructor import GraphConstructor
15+
from .arrow_version import ArrowVersion
1516

1617

1718
class ArrowGraphConstructor(GraphConstructor):
@@ -21,13 +22,15 @@ def __init__(
2122
graph_name: str,
2223
flight_client: flight.FlightClient,
2324
concurrency: int,
25+
arrow_version: ArrowVersion,
2426
undirected_relationship_types: Optional[List[str]],
2527
chunk_size: int = 10_000,
2628
):
2729
self._database = database
2830
self._concurrency = concurrency
2931
self._graph_name = graph_name
3032
self._client = flight_client
33+
self._arrow_version = arrow_version
3134
self._undirected_relationship_types = (
3235
[] if undirected_relationship_types is None else undirected_relationship_types
3336
)
@@ -81,6 +84,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
8184
return partitioned_dfs
8285

8386
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
87+
action_type = self._versioned_action_type(action_type)
8488
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
8589

8690
# Consume result fully to sanity check and avoid cancelled streams
@@ -93,6 +97,7 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
9397
table = Table.from_pandas(df)
9498
batches = table.to_batches(self._chunk_size)
9599
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
100+
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
96101

97102
# Write schema
98103
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
@@ -117,3 +122,17 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
117122
if not future.exception():
118123
continue
119124
raise future.exception() # type: ignore
125+
126+
def _versioned_action_type(self, action_type: str) -> str:
127+
return self._arrow_version.prefix() + action_type
128+
129+
def _versioned_flight_desriptor(self, flight_descriptor: dict) -> dict:
130+
return (
131+
flight_descriptor
132+
if self._arrow_version == ArrowVersion.ALPHA
133+
else {
134+
"name": "PUT_MESSAGE",
135+
"version": ArrowVersion.V1.name(),
136+
"body": flight_descriptor,
137+
}
138+
)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def create_graph_constructor(
326326
)
327327

328328
return ArrowGraphConstructor(
329-
database, graph_name, self._flight_client, concurrency, undirected_relationship_types
329+
database, graph_name, self._flight_client, concurrency, self._arrow_version, undirected_relationship_types
330330
)
331331

332332
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:

0 commit comments

Comments
 (0)