Skip to content

Commit 0253f6e

Browse files
committed
Use arrow backend in pandas 2.0
1 parent 7ac0417 commit 0253f6e

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
from types import TracebackType
1010
from typing import Any, Callable, Dict, Iterable, Optional, Type, Union
1111

12+
import pandas
1213
import pyarrow
1314
from neo4j.exceptions import ClientError
14-
from pandas import DataFrame
1515
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight
1616
from pyarrow import __version__ as arrow_version
1717
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
1818
from pyarrow.types import is_dictionary
1919
from tenacity import retry, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
2020

21+
from graphdatascience.server_version.server_version import ServerVersion
22+
2123
from ..version import __version__
2224
from .arrow_endpoint_version import ArrowEndpointVersion
2325
from .arrow_info import ArrowInfo
@@ -154,7 +156,7 @@ def get_node_properties(
154156
node_labels: Optional[list[str]] = None,
155157
list_node_labels: bool = False,
156158
concurrency: Optional[int] = None,
157-
) -> DataFrame:
159+
) -> pandas.DataFrame:
158160
"""
159161
Get node properties from the graph.
160162
@@ -194,7 +196,7 @@ def get_node_properties(
194196

195197
return self._do_get(database, graph_name, proc, concurrency, config)
196198

197-
def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[int] = None) -> DataFrame:
199+
def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[int] = None) -> pandas.DataFrame:
198200
"""
199201
Get all nodes and their labels from the graph.
200202
@@ -216,7 +218,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[
216218

217219
def get_relationships(
218220
self, graph_name: str, database: str, relationship_types: list[str], concurrency: Optional[int] = None
219-
) -> DataFrame:
221+
) -> pandas.DataFrame:
220222
"""
221223
Get relationships from the graph.
222224
@@ -251,7 +253,7 @@ def get_relationship_properties(
251253
relationship_properties: Union[str, list[str]],
252254
relationship_types: list[str],
253255
concurrency: Optional[int] = None,
254-
) -> DataFrame:
256+
) -> pandas.DataFrame:
255257
"""
256258
Get relationships and their properties from the graph.
257259
@@ -488,7 +490,7 @@ def abort(self, graph_name: str) -> None:
488490
def upload_nodes(
489491
self,
490492
graph_name: str,
491-
node_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], DataFrame],
493+
node_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], pandas.DataFrame],
492494
batch_size: int = 10_000,
493495
progress_callback: Callable[[int], None] = lambda x: None,
494496
) -> None:
@@ -511,7 +513,7 @@ def upload_nodes(
511513
def upload_relationships(
512514
self,
513515
graph_name: str,
514-
relationship_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], DataFrame],
516+
relationship_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], pandas.DataFrame],
515517
batch_size: int = 10_000,
516518
progress_callback: Callable[[int], None] = lambda x: None,
517519
) -> None:
@@ -534,7 +536,7 @@ def upload_relationships(
534536
def upload_triplets(
535537
self,
536538
graph_name: str,
537-
triplet_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], DataFrame],
539+
triplet_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], pandas.DataFrame],
538540
batch_size: int = 10_000,
539541
progress_callback: Callable[[int], None] = lambda x: None,
540542
) -> None:
@@ -590,13 +592,13 @@ def _upload_data(
590592
self,
591593
graph_name: str,
592594
entity_type: str,
593-
data: Union[pyarrow.Table, list[pyarrow.RecordBatch], DataFrame],
595+
data: Union[pyarrow.Table, list[pyarrow.RecordBatch], pandas.DataFrame],
594596
batch_size: int,
595597
progress_callback: Callable[[int], None],
596598
) -> None:
597599
if isinstance(data, pyarrow.Table):
598600
batches = data.to_batches(batch_size)
599-
elif isinstance(data, DataFrame):
601+
elif isinstance(data, pandas.DataFrame):
600602
batches = pyarrow.Table.from_pandas(data).to_batches(batch_size)
601603
else:
602604
batches = data
@@ -635,7 +637,7 @@ def _do_get(
635637
procedure_name: str,
636638
concurrency: Optional[int],
637639
configuration: dict[str, Any],
638-
) -> DataFrame:
640+
) -> pandas.DataFrame:
639641
payload: dict[str, Any] = {
640642
"database_name": database,
641643
"graph_name": graph_name,
@@ -674,7 +676,12 @@ def _do_get(
674676
message=r"Passing a BlockManager to DataFrame is deprecated",
675677
)
676678

677-
return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
679+
arrow_table = self._sanitize_arrow_table(arrow_table)
680+
681+
if ServerVersion.from_string(pandas.__version__) >= ServerVersion(2, 0, 0):
682+
return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore
683+
else:
684+
return arrow_table.to_pandas() # type: ignore
678685

679686
def __enter__(self) -> GdsArrowClient:
680687
return self

0 commit comments

Comments
 (0)