Skip to content

Commit 0c5ad6e

Browse files
Merge pull request #613 from DarthMax/push_based_remote_projection
JobId-based remote writeback
2 parents 5a57f14 + 277a875 commit 0c5ad6e

File tree

7 files changed

+213
-78
lines changed

7 files changed

+213
-78
lines changed

graphdatascience/graph/graph_remote_project_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,24 @@ def __call__(
2020
concurrency: int = 4,
2121
undirected_relationship_types: Optional[List[str]] = None,
2222
inverse_indexed_relationship_types: Optional[List[str]] = None,
23+
batch_size: Optional[int] = None,
2324
) -> GraphCreateResult:
2425
if inverse_indexed_relationship_types is None:
2526
inverse_indexed_relationship_types = []
2627
if undirected_relationship_types is None:
2728
undirected_relationship_types = []
2829

30+
arrow_configuration = {}
31+
if batch_size is not None:
32+
arrow_configuration["batchSize"] = batch_size
33+
2934
params = CallParameters(
3035
graph_name=graph_name,
3136
query=query,
3237
concurrency=concurrency,
3338
undirected_relationship_types=undirected_relationship_types,
3439
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
40+
arrow_configuration=arrow_configuration,
3541
)
3642

3743
result = self._query_runner.call_procedure(

graphdatascience/query_runner/aura_db_query_runner.py

Lines changed: 19 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from typing import Any, Dict, List, Optional
3+
from uuid import uuid4
34

45
from pandas import DataFrame
56

@@ -104,7 +105,7 @@ def _remote_projection(
104105
database: Optional[str] = None,
105106
logging: bool = False,
106107
) -> DataFrame:
107-
self._inject_connection_parameters(params)
108+
self._inject_arrow_config(params["arrow_configuration"])
108109
return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False)
109110

110111
def _remote_write_back(
@@ -119,21 +120,29 @@ def _remote_write_back(
119120
if params["config"] is None:
120121
params["config"] = {}
121122

123+
# we pop these out so that they are not retained for the GDS proc call
124+
db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore
125+
self._inject_arrow_config(db_arrow_config)
126+
127+
job_id = params["config"]["jobId"] if "jobId" in params["config"] else str(uuid4()) # type: ignore
128+
params["config"]["jobId"] = job_id # type: ignore
129+
122130
params["config"]["writeToResultStore"] = True # type: ignore
131+
123132
gds_write_result = self._gds_query_runner.call_procedure(
124133
endpoint, params, yields, database, logging, custom_error
125134
)
126135

127-
write_params = {
136+
db_write_proc_params = {
128137
"graphName": params["graph_name"],
129138
"databaseName": self._gds_query_runner.database(),
130-
"writeConfiguration": self._extract_write_back_arguments(endpoint, params),
139+
"jobId": job_id,
140+
"arrowConfiguration": db_arrow_config,
131141
}
132-
self._inject_connection_parameters(write_params)
133142

134143
write_back_start = time.time()
135144
database_write_result = self._db_query_runner.call_procedure(
136-
"gds.arrow.write", CallParameters(write_params), yields, None, False, False
145+
"gds.arrow.write", CallParameters(db_write_proc_params), yields, None, False, False
137146
)
138147
write_millis = (time.time() - write_back_start) * 1000
139148
gds_write_result["writeMillis"] = write_millis
@@ -149,75 +158,13 @@ def _remote_write_back(
149158

150159
return gds_write_result
151160

152-
def _inject_connection_parameters(self, params: Dict[str, Any]) -> None:
161+
def _inject_arrow_config(self, params: Dict[str, Any]) -> None:
153162
host, port = self._gds_arrow_client.connection_info()
154163
token = self._gds_arrow_client.request_token()
155164
if token is None:
156165
token = "IGNORED"
157-
params["arrowConfiguration"] = {
158-
"host": host,
159-
"port": port,
160-
"token": token,
161-
"encrypted": self._encrypted,
162-
}
163166

164-
@staticmethod
165-
def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
166-
config = params.get("config", {})
167-
write_config = {}
168-
169-
if "writeConcurrency" in config:
170-
write_config["concurrency"] = config["writeConcurrency"]
171-
elif "concurrency" in config:
172-
write_config["concurrency"] = config["concurrency"]
173-
174-
if "gds.shortestPath" in proc_name or "gds.allShortestPaths" in proc_name:
175-
write_config["relationshipType"] = config["writeRelationshipType"]
176-
177-
write_node_ids = config.get("writeNodeIds")
178-
write_costs = config.get("writeCosts")
179-
180-
if write_node_ids and write_costs:
181-
write_config["relationshipProperties"] = ["totalCost", "nodeIds", "costs"]
182-
elif write_node_ids:
183-
write_config["relationshipProperties"] = ["totalCost", "nodeIds"]
184-
elif write_costs:
185-
write_config["relationshipProperties"] = ["totalCost", "costs"]
186-
else:
187-
write_config["relationshipProperties"] = ["totalCost"]
188-
189-
elif "gds.graph." in proc_name:
190-
if "gds.graph.nodeProperties.write" == proc_name:
191-
properties = params["properties"]
192-
write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties]
193-
write_config["nodeLabels"] = params["entities"]
194-
195-
elif "gds.graph.nodeLabel.write" == proc_name:
196-
write_config["nodeLabels"] = [params["node_label"]]
197-
198-
elif "gds.graph.relationshipProperties.write" == proc_name:
199-
write_config["relationshipProperties"] = params["relationship_properties"]
200-
write_config["relationshipType"] = params["relationship_type"]
201-
202-
elif "gds.graph.relationship.write" == proc_name:
203-
if "relationship_property" in params and params["relationship_property"] != "":
204-
write_config["relationshipProperties"] = [params["relationship_property"]]
205-
write_config["relationshipType"] = params["relationship_type"]
206-
207-
else:
208-
raise ValueError(f"Unsupported procedure name: {proc_name}")
209-
210-
else:
211-
if "writeRelationshipType" in config:
212-
write_config["relationshipType"] = config["writeRelationshipType"]
213-
if "writeProperty" in config:
214-
write_config["relationshipProperties"] = [config["writeProperty"]]
215-
else:
216-
if "writeProperty" in config:
217-
write_config["nodeProperties"] = [config["writeProperty"]]
218-
if "nodeLabels" in params:
219-
write_config["nodeLabels"] = params["nodeLabels"]
220-
else:
221-
write_config["nodeLabels"] = ["*"]
222-
223-
return write_config
167+
params["host"] = host
168+
params["port"] = port
169+
params["token"] = token
170+
params["encrypted"] = self._encrypted

graphdatascience/tests/integration/test_remote_graph_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None:
4343
assert result["nodeCount"] == 3
4444

4545

46+
@pytest.mark.cloud_architecture
47+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
48+
def test_remote_projection_with_small_batch_size(gds_with_cloud_setup: AuraGraphDataScience) -> None:
49+
G, result = gds_with_cloud_setup.graph.project(
50+
GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)", batch_size=10
51+
)
52+
53+
assert G.name() == GRAPH_NAME
54+
assert result["nodeCount"] == 3
55+
56+
4657
@pytest.mark.cloud_architecture
4758
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
4859
def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None:

graphdatascience/tests/unit/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def encrypted(self) -> bool:
6868
return False
6969

7070
def last_query(self) -> str:
71+
if len(self.queries) == 0:
72+
return ""
7173
return self.queries[-1]
7274

7375
def last_params(self) -> Dict[str, Any]:
76+
if len(self.params) == 0:
77+
return {}
7478
return self.params[-1]
7579

7680
def set_database(self, database: str) -> None:

graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py

Whitespace-only changes.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from typing import Tuple
2+
3+
from pandas import DataFrame
4+
5+
from graphdatascience import ServerVersion
6+
from graphdatascience.call_parameters import CallParameters
7+
from graphdatascience.query_runner.aura_db_query_runner import AuraDbQueryRunner
8+
from graphdatascience.tests.unit.conftest import CollectingQueryRunner
9+
10+
11+
class FakeArrowClient:
12+
13+
def connection_info(self) -> Tuple[str, str]:
14+
return "myHost", "1234"
15+
16+
def request_token(self) -> str:
17+
return "myToken"
18+
19+
20+
def test_extracts_parameters_projection() -> None:
21+
version = ServerVersion(2, 7, 0)
22+
db_query_runner = CollectingQueryRunner(version)
23+
gds_query_runner = CollectingQueryRunner(version)
24+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
25+
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
26+
27+
qr.call_procedure(
28+
endpoint="gds.arrow.project",
29+
params=CallParameters(
30+
graph_name="g",
31+
query="RETURN 1",
32+
concurrency=2,
33+
undirRels=[],
34+
inverseRels=[],
35+
arrow_configuration={"batchSize": 100},
36+
),
37+
)
38+
39+
# doesn't run anything on GDS
40+
assert gds_query_runner.last_query() == ""
41+
assert gds_query_runner.last_params() == {}
42+
assert (
43+
db_query_runner.last_query()
44+
== "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirRels, $inverseRels, $arrow_configuration)"
45+
)
46+
assert db_query_runner.last_params() == {
47+
"graph_name": "g",
48+
"query": "RETURN 1",
49+
"concurrency": 2,
50+
"undirRels": [],
51+
"inverseRels": [],
52+
"arrow_configuration": {
53+
"encrypted": False,
54+
"host": "myHost",
55+
"port": "1234",
56+
"token": "myToken",
57+
"batchSize": 100,
58+
},
59+
}
60+
61+
62+
def test_extracts_parameters_algo_write() -> None:
63+
version = ServerVersion(2, 7, 0)
64+
db_query_runner = CollectingQueryRunner(version)
65+
gds_query_runner = CollectingQueryRunner(version)
66+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
67+
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
68+
69+
qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"}))
70+
71+
assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
72+
assert gds_query_runner.last_params() == {
73+
"graph_name": "g",
74+
"config": {"jobId": "my-job", "writeToResultStore": True},
75+
}
76+
assert (
77+
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
78+
)
79+
assert db_query_runner.last_params() == {
80+
"graphName": "g",
81+
"databaseName": "dummy",
82+
"jobId": "my-job",
83+
"arrowConfiguration": {"encrypted": False, "host": "myHost", "port": "1234", "token": "myToken"},
84+
}
85+
86+
87+
def test_arrow_and_write_configuration() -> None:
88+
version = ServerVersion(2, 7, 0)
89+
db_query_runner = CollectingQueryRunner(version)
90+
gds_query_runner = CollectingQueryRunner(version)
91+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
92+
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
93+
94+
qr.call_procedure(
95+
endpoint="gds.degree.write",
96+
params=CallParameters(
97+
graph_name="g",
98+
config={"arrowConfiguration": {"batchSize": 1000}, "jobId": "my-job"},
99+
),
100+
)
101+
102+
assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
103+
assert gds_query_runner.last_params() == {
104+
"graph_name": "g",
105+
"config": {"writeToResultStore": True, "jobId": "my-job"},
106+
}
107+
assert (
108+
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
109+
)
110+
assert db_query_runner.last_params() == {
111+
"graphName": "g",
112+
"databaseName": "dummy",
113+
"jobId": "my-job",
114+
"arrowConfiguration": {
115+
"encrypted": False,
116+
"host": "myHost",
117+
"port": "1234",
118+
"token": "myToken",
119+
"batchSize": 1000,
120+
},
121+
}
122+
123+
124+
def test_arrow_and_write_configuration_graph_write() -> None:
125+
version = ServerVersion(2, 7, 0)
126+
db_query_runner = CollectingQueryRunner(version)
127+
gds_query_runner = CollectingQueryRunner(version)
128+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
129+
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
130+
131+
qr.call_procedure(
132+
endpoint="gds.graph.nodeProperties.write",
133+
params=CallParameters(
134+
graph_name="g",
135+
properties=[],
136+
entities=[],
137+
config={"arrowConfiguration": {"batchSize": 42}, "jobId": "my-job"},
138+
),
139+
)
140+
141+
assert (
142+
gds_query_runner.last_query()
143+
== "CALL gds.graph.nodeProperties.write($graph_name, $properties, $entities, $config)"
144+
)
145+
assert gds_query_runner.last_params() == {
146+
"graph_name": "g",
147+
"entities": [],
148+
"properties": [],
149+
"config": {"writeToResultStore": True, "jobId": "my-job"},
150+
}
151+
assert (
152+
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
153+
)
154+
assert db_query_runner.last_params() == {
155+
"graphName": "g",
156+
"databaseName": "dummy",
157+
"jobId": "my-job",
158+
"arrowConfiguration": {
159+
"encrypted": False,
160+
"host": "myHost",
161+
"port": "1234",
162+
"token": "myToken",
163+
"batchSize": 42,
164+
},
165+
}

graphdatascience/tests/unit/test_graph_ops.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc
9595
aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)")
9696

9797
assert (
98-
runner.last_query()
99-
== "CALL gds.arrow.project("
100-
+ "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)"
98+
runner.last_query() == "CALL gds.arrow.project("
99+
"$graph_name, $query, $concurrency, "
100+
"$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)"
101101
)
102102
# injection of token and host into the params is done by the actual query runner
103103
assert runner.last_params() == {
@@ -106,6 +106,7 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc
106106
"inverse_indexed_relationship_types": [],
107107
"query": "RETURN gds.graph.project.remote(0, 1, null)",
108108
"undirected_relationship_types": [],
109+
"arrow_configuration": {},
109110
}
110111

111112

@@ -720,9 +721,9 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura
720721
)
721722

722723
assert (
723-
runner.last_query()
724-
== "CALL gds.arrow.project("
725-
+ "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)"
724+
runner.last_query() == "CALL gds.arrow.project("
725+
"$graph_name, $query, $concurrency, "
726+
"$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)"
726727
)
727728

728729
assert runner.last_params() == {
@@ -738,4 +739,5 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura
738739
""",
739740
"undirected_relationship_types": ["R"],
740741
"inverse_indexed_relationship_types": ["R"],
742+
"arrow_configuration": {},
741743
}

0 commit comments

Comments
 (0)