99from types import TracebackType
1010from typing import Any , Callable , Dict , Iterable , Optional , Type , Union
1111
12+ import pandas
1213import pyarrow
1314from neo4j .exceptions import ClientError
14- from pandas import DataFrame
1515from pyarrow import Array , ChunkedArray , DictionaryArray , RecordBatch , Table , chunked_array , flight
1616from pyarrow import __version__ as arrow_version
1717from pyarrow .flight import ClientMiddleware , ClientMiddlewareFactory
1818from pyarrow .types import is_dictionary
1919from tenacity import retry , retry_if_exception_type , stop_after_attempt , stop_after_delay , wait_exponential
2020
21+ from graphdatascience .server_version .server_version import ServerVersion
22+
2123from ..version import __version__
2224from .arrow_endpoint_version import ArrowEndpointVersion
2325from .arrow_info import ArrowInfo
@@ -154,7 +156,7 @@ def get_node_properties(
154156 node_labels : Optional [list [str ]] = None ,
155157 list_node_labels : bool = False ,
156158 concurrency : Optional [int ] = None ,
157- ) -> DataFrame :
159+ ) -> pandas . DataFrame :
158160 """
159161 Get node properties from the graph.
160162
@@ -194,7 +196,7 @@ def get_node_properties(
194196
195197 return self ._do_get (database , graph_name , proc , concurrency , config )
196198
197- def get_node_labels (self , graph_name : str , database : str , concurrency : Optional [int ] = None ) -> DataFrame :
199+ def get_node_labels (self , graph_name : str , database : str , concurrency : Optional [int ] = None ) -> pandas . DataFrame :
198200 """
199201 Get all nodes and their labels from the graph.
200202
@@ -216,7 +218,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[
216218
217219 def get_relationships (
218220 self , graph_name : str , database : str , relationship_types : list [str ], concurrency : Optional [int ] = None
219- ) -> DataFrame :
221+ ) -> pandas . DataFrame :
220222 """
221223 Get relationships from the graph.
222224
@@ -251,7 +253,7 @@ def get_relationship_properties(
251253 relationship_properties : Union [str , list [str ]],
252254 relationship_types : list [str ],
253255 concurrency : Optional [int ] = None ,
254- ) -> DataFrame :
256+ ) -> pandas . DataFrame :
255257 """
256258 Get relationships and their properties from the graph.
257259
@@ -488,7 +490,7 @@ def abort(self, graph_name: str) -> None:
488490 def upload_nodes (
489491 self ,
490492 graph_name : str ,
491- node_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], DataFrame ],
493+ node_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], pandas . DataFrame ],
492494 batch_size : int = 10_000 ,
493495 progress_callback : Callable [[int ], None ] = lambda x : None ,
494496 ) -> None :
@@ -511,7 +513,7 @@ def upload_nodes(
511513 def upload_relationships (
512514 self ,
513515 graph_name : str ,
514- relationship_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], DataFrame ],
516+ relationship_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], pandas . DataFrame ],
515517 batch_size : int = 10_000 ,
516518 progress_callback : Callable [[int ], None ] = lambda x : None ,
517519 ) -> None :
@@ -534,7 +536,7 @@ def upload_relationships(
534536 def upload_triplets (
535537 self ,
536538 graph_name : str ,
537- triplet_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], DataFrame ],
539+ triplet_data : Union [pyarrow .Table , Iterable [pyarrow .RecordBatch ], pandas . DataFrame ],
538540 batch_size : int = 10_000 ,
539541 progress_callback : Callable [[int ], None ] = lambda x : None ,
540542 ) -> None :
@@ -590,13 +592,13 @@ def _upload_data(
590592 self ,
591593 graph_name : str ,
592594 entity_type : str ,
593- data : Union [pyarrow .Table , list [pyarrow .RecordBatch ], DataFrame ],
595+ data : Union [pyarrow .Table , list [pyarrow .RecordBatch ], pandas . DataFrame ],
594596 batch_size : int ,
595597 progress_callback : Callable [[int ], None ],
596598 ) -> None :
597599 if isinstance (data , pyarrow .Table ):
598600 batches = data .to_batches (batch_size )
599- elif isinstance (data , DataFrame ):
601+ elif isinstance (data , pandas . DataFrame ):
600602 batches = pyarrow .Table .from_pandas (data ).to_batches (batch_size )
601603 else :
602604 batches = data
@@ -635,7 +637,7 @@ def _do_get(
635637 procedure_name : str ,
636638 concurrency : Optional [int ],
637639 configuration : dict [str , Any ],
638- ) -> DataFrame :
640+ ) -> pandas . DataFrame :
639641 payload : dict [str , Any ] = {
640642 "database_name" : database ,
641643 "graph_name" : graph_name ,
@@ -674,7 +676,12 @@ def _do_get(
674676 message = r"Passing a BlockManager to DataFrame is deprecated" ,
675677 )
676678
677- return self ._sanitize_arrow_table (arrow_table ).to_pandas () # type: ignore
679+ arrow_table = self ._sanitize_arrow_table (arrow_table )
680+
681+ if ServerVersion .from_string (pandas .__version__ ) >= ServerVersion (2 , 0 , 0 ):
682+ return arrow_table .to_pandas (types_mapper = pandas .ArrowDtype ) # type: ignore
683+ else :
684+ return arrow_table .to_pandas () # type: ignore
678685
679686 def __enter__ (self ) -> GdsArrowClient :
680687 return self
0 commit comments