Skip to content

Commit 12c4f7b

Browse files
committed
Fill all progress bar on completion
1 parent 0156d50 commit 12c4f7b

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* Fixed an issue where source and target IDs of relationships in heterogeneous OGBL graphs were not parsed correctly.
2121
* Fixed an issue where configuration parameters such as `aggregation` were ignored by `gds.graph.toUndirected`.
2222
* Fixed an issue where the `database` given for the `GraphDataScience` construction was not used for metadata retrieval, causing an exception to be raised if the default "neo4j" database was missing.
23+
* Fixed an issue where progress bars would not complete.
2324

2425

2526
## Improvements

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import warnings
55
from concurrent.futures import ThreadPoolExecutor
6-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, NoReturn, Optional
77

88
import numpy
99
import pyarrow.flight as flight
@@ -93,7 +93,7 @@ def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
9393

9494
json.loads(collected_result[0].body.to_pybytes().decode())
9595

96-
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
96+
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
9797
table = Table.from_pandas(df)
9898
batches = table.to_batches(self._chunk_size)
9999
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import warnings
77
from concurrent.futures import Future, ThreadPoolExecutor, wait
8-
from typing import Any, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
99
from uuid import uuid4
1010

1111
import neo4j
@@ -226,7 +226,7 @@ def _forward_cypher_warnings(self, notification: Dict[str, Any]) -> None:
226226
self._logger.info(notification)
227227

228228
def _log(self, job_id: str, future: "Future[Any]", database: Optional[str] = None) -> None:
229-
pbar = None
229+
pbars: Dict[str, tqdm[NoReturn]] = {}
230230
warn_if_failure = True
231231

232232
while wait([future], timeout=self._LOG_POLLING_INTERVAL).not_done:
@@ -248,16 +248,18 @@ def _log(self, job_id: str, future: "Future[Any]", database: Optional[str] = Non
248248
continue
249249

250250
progress_percent = progress["progress"][0]
251-
if not progress_percent == "n/a":
252-
task_name = progress["taskName"][0].split("|--")[-1][1:]
253-
pbar = pbar or tqdm(total=100, unit="%", desc=task_name)
254-
else:
251+
if progress_percent == "n/a":
255252
return
256253

254+
task_name = progress["taskName"][0].split("|--")[-1][1:]
255+
if task_name not in pbars:
256+
pbars[task_name] = tqdm(total=100, unit="%", desc=task_name)
257+
pbar = pbars[task_name]
258+
257259
parsed_progress = float(progress_percent[:-1])
258260
pbar.update(parsed_progress - pbar.n)
259261

260-
if pbar:
262+
for pbar in pbars.values():
261263
pbar.update(100 - pbar.n)
262264

263265
def set_database(self, database: str) -> None:

requirements/dev/dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ tox == 4.11.3
1010
types-setuptools == 68.1.0.1
1111
sphinx == 7.2.6
1212
types-requests
13+
types-tqdm

0 commit comments

Comments
 (0)