Skip to content

Commit 85550a9

Browse files
committed
Allow specifying arrow and writing configuration through proc call
signatures this allows for example: gds.degree.write(G, arrowConfiguration={batchSize: 1}) gds.graph.nodeProperties.write(G, "prop", [], arrowConfiguration={batchSize: 1})
1 parent 5a57f14 commit 85550a9

File tree

3 files changed

+204
-28
lines changed

3 files changed

+204
-28
lines changed

graphdatascience/query_runner/aura_db_query_runner.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _remote_projection(
104104
database: Optional[str] = None,
105105
logging: bool = False,
106106
) -> DataFrame:
107-
self._inject_connection_parameters(params)
107+
self._inject_arrow_config(params["arrow_configuration"])
108108
return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False)
109109

110110
def _remote_write_back(
@@ -119,21 +119,27 @@ def _remote_write_back(
119119
if params["config"] is None:
120120
params["config"] = {}
121121

122+
# we pop these out so that they are not retained for the GDS proc call
123+
db_write_config = params["config"].pop("writeConfiguration", {}) # type: ignore
124+
db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore
125+
self._inject_write_config(endpoint, params, db_write_config)
126+
self._inject_arrow_config(db_arrow_config)
127+
122128
params["config"]["writeToResultStore"] = True # type: ignore
123129
gds_write_result = self._gds_query_runner.call_procedure(
124130
endpoint, params, yields, database, logging, custom_error
125131
)
126132

127-
write_params = {
133+
db_write_proc_params = {
128134
"graphName": params["graph_name"],
129135
"databaseName": self._gds_query_runner.database(),
130-
"writeConfiguration": self._extract_write_back_arguments(endpoint, params),
136+
"writeConfiguration": db_write_config,
137+
"arrowConfiguration": db_arrow_config,
131138
}
132-
self._inject_connection_parameters(write_params)
133139

134140
write_back_start = time.time()
135141
database_write_result = self._db_query_runner.call_procedure(
136-
"gds.arrow.write", CallParameters(write_params), yields, None, False, False
142+
"gds.arrow.write", CallParameters(db_write_proc_params), yields, None, False, False
137143
)
138144
write_millis = (time.time() - write_back_start) * 1000
139145
gds_write_result["writeMillis"] = write_millis
@@ -149,22 +155,20 @@ def _remote_write_back(
149155

150156
return gds_write_result
151157

152-
def _inject_connection_parameters(self, params: Dict[str, Any]) -> None:
158+
def _inject_arrow_config(self, params: Dict[str, Any]) -> None:
153159
host, port = self._gds_arrow_client.connection_info()
154160
token = self._gds_arrow_client.request_token()
155161
if token is None:
156162
token = "IGNORED"
157-
params["arrowConfiguration"] = {
158-
"host": host,
159-
"port": port,
160-
"token": token,
161-
"encrypted": self._encrypted,
162-
}
163+
164+
params["host"] = host
165+
params["port"] = port
166+
params["token"] = token
167+
params["encrypted"] = self._encrypted
163168

164169
@staticmethod
165-
def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
166-
config = params.get("config", {})
167-
write_config = {}
170+
def _inject_write_config(proc_name: str, proc_params: Dict[str, Any], write_config: Dict[str, Any]) -> None:
171+
config = proc_params.get("config", {})
168172

169173
if "writeConcurrency" in config:
170174
write_config["concurrency"] = config["writeConcurrency"]
@@ -188,21 +192,21 @@ def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dic
188192

189193
elif "gds.graph." in proc_name:
190194
if "gds.graph.nodeProperties.write" == proc_name:
191-
properties = params["properties"]
195+
properties = proc_params["properties"]
192196
write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties]
193-
write_config["nodeLabels"] = params["entities"]
197+
write_config["nodeLabels"] = proc_params["entities"]
194198

195199
elif "gds.graph.nodeLabel.write" == proc_name:
196-
write_config["nodeLabels"] = [params["node_label"]]
200+
write_config["nodeLabels"] = [proc_params["node_label"]]
197201

198202
elif "gds.graph.relationshipProperties.write" == proc_name:
199-
write_config["relationshipProperties"] = params["relationship_properties"]
200-
write_config["relationshipType"] = params["relationship_type"]
203+
write_config["relationshipProperties"] = proc_params["relationship_properties"]
204+
write_config["relationshipType"] = proc_params["relationship_type"]
201205

202206
elif "gds.graph.relationship.write" == proc_name:
203-
if "relationship_property" in params and params["relationship_property"] != "":
204-
write_config["relationshipProperties"] = [params["relationship_property"]]
205-
write_config["relationshipType"] = params["relationship_type"]
207+
if "relationship_property" in proc_params and proc_params["relationship_property"] != "":
208+
write_config["relationshipProperties"] = [proc_params["relationship_property"]]
209+
write_config["relationshipType"] = proc_params["relationship_type"]
206210

207211
else:
208212
raise ValueError(f"Unsupported procedure name: {proc_name}")
@@ -215,9 +219,7 @@ def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dic
215219
else:
216220
if "writeProperty" in config:
217221
write_config["nodeProperties"] = [config["writeProperty"]]
218-
if "nodeLabels" in params:
219-
write_config["nodeLabels"] = params["nodeLabels"]
222+
if "nodeLabels" in proc_params:
223+
write_config["nodeLabels"] = proc_params["nodeLabels"]
220224
else:
221225
write_config["nodeLabels"] = ["*"]
222-
223-
return write_config

graphdatascience/tests/unit/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def encrypted(self) -> bool:
6868
return False
6969

7070
def last_query(self) -> str:
71+
if len(self.queries) == 0:
72+
return ""
7173
return self.queries[-1]
7274

73-
def last_params(self) -> Dict[str, Any]:
75+
def last_params(self) -> dict[str, Any]:
76+
if len(self.queries) == 0:
77+
return {}
7478
return self.params[-1]
7579

7680
def set_database(self, database: str) -> None:
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import Tuple
2+
3+
from pandas import DataFrame
4+
5+
from graphdatascience import ServerVersion
6+
from graphdatascience.call_parameters import CallParameters
7+
from graphdatascience.query_runner.aura_db_arrow_query_runner import (
8+
AuraDbArrowQueryRunner,
9+
)
10+
from graphdatascience.tests.unit.conftest import CollectingQueryRunner
11+
12+
13+
class FakeArrowClient:
14+
15+
def connection_info(self) -> Tuple[str, str]:
16+
return "myHost", "1234"
17+
18+
def request_token(self) -> str:
19+
return "myToken"
20+
21+
22+
def test_extracts_parameters_projection() -> None:
23+
version = ServerVersion(2, 7, 0)
24+
db_query_runner = CollectingQueryRunner(version)
25+
gds_query_runner = CollectingQueryRunner(version)
26+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
27+
qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
28+
29+
qr.call_procedure(
30+
endpoint="gds.arrow.project",
31+
params=CallParameters(
32+
graph_name="g",
33+
query="RETURN 1",
34+
concurrency=2,
35+
undirRels=[],
36+
inverseRels=[],
37+
arrow_configuration={"batchSize": 100},
38+
),
39+
)
40+
41+
# doesn't run anything on GDS
42+
assert gds_query_runner.last_query() == ""
43+
assert gds_query_runner.last_params() == {}
44+
assert (
45+
db_query_runner.last_query()
46+
== "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirRels, $inverseRels, $arrow_configuration)"
47+
)
48+
assert db_query_runner.last_params() == {
49+
"graph_name": "g",
50+
"query": "RETURN 1",
51+
"concurrency": 2,
52+
"undirRels": [],
53+
"inverseRels": [],
54+
"arrow_configuration": {
55+
"encrypted": False,
56+
"host": "myHost",
57+
"port": "1234",
58+
"token": "myToken",
59+
"batchSize": 100,
60+
},
61+
}
62+
63+
64+
def test_extracts_parameters_algo_write() -> None:
65+
version = ServerVersion(2, 7, 0)
66+
db_query_runner = CollectingQueryRunner(version)
67+
gds_query_runner = CollectingQueryRunner(version)
68+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
69+
qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
70+
71+
qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={}))
72+
73+
assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
74+
assert gds_query_runner.last_params() == {
75+
"graph_name": "g",
76+
"config": {"writeToResultStore": True},
77+
}
78+
assert (
79+
db_query_runner.last_query()
80+
== "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)"
81+
)
82+
assert db_query_runner.last_params() == {
83+
"graphName": "g",
84+
"databaseName": "dummy",
85+
"writeConfiguration": {"nodeLabels": ["*"]},
86+
"arrowConfiguration": {"encrypted": False, "host": "myHost", "port": "1234", "token": "myToken"},
87+
}
88+
89+
90+
def test_arrow_and_write_configuration() -> None:
91+
version = ServerVersion(2, 7, 0)
92+
db_query_runner = CollectingQueryRunner(version)
93+
gds_query_runner = CollectingQueryRunner(version)
94+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
95+
qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
96+
97+
qr.call_procedure(
98+
endpoint="gds.degree.write",
99+
params=CallParameters(
100+
graph_name="g",
101+
config={"arrowConfiguration": {"batchSize": 1000}, "writeConfiguration": {"writeMode": "FOOBAR"}},
102+
),
103+
)
104+
105+
assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
106+
assert gds_query_runner.last_params() == {
107+
"graph_name": "g",
108+
"config": {"writeToResultStore": True},
109+
}
110+
assert (
111+
db_query_runner.last_query()
112+
== "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)"
113+
)
114+
assert db_query_runner.last_params() == {
115+
"graphName": "g",
116+
"databaseName": "dummy",
117+
"writeConfiguration": {"nodeLabels": ["*"], "writeMode": "FOOBAR"},
118+
"arrowConfiguration": {
119+
"encrypted": False,
120+
"host": "myHost",
121+
"port": "1234",
122+
"token": "myToken",
123+
"batchSize": 1000,
124+
},
125+
}
126+
127+
128+
def test_arrow_and_write_configuration_graph_write() -> None:
129+
version = ServerVersion(2, 7, 0)
130+
db_query_runner = CollectingQueryRunner(version)
131+
gds_query_runner = CollectingQueryRunner(version)
132+
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
133+
qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore
134+
135+
qr.call_procedure(
136+
endpoint="gds.graph.nodeProperties.write",
137+
params=CallParameters(
138+
graph_name="g",
139+
properties=[],
140+
entities=[],
141+
config={"arrowConfiguration": {"batchSize": 42}, "writeConfiguration": {"writeMode": "FOOBAR"}},
142+
),
143+
)
144+
145+
assert (
146+
gds_query_runner.last_query()
147+
== "CALL gds.graph.nodeProperties.write($graph_name, $properties, $entities, $config)"
148+
)
149+
assert gds_query_runner.last_params() == {
150+
"graph_name": "g",
151+
"entities": [],
152+
"properties": [],
153+
"config": {"writeToResultStore": True},
154+
}
155+
assert (
156+
db_query_runner.last_query()
157+
== "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)"
158+
)
159+
assert db_query_runner.last_params() == {
160+
"graphName": "g",
161+
"databaseName": "dummy",
162+
"writeConfiguration": {"nodeLabels": [], "nodeProperties": [], "writeMode": "FOOBAR"},
163+
"arrowConfiguration": {
164+
"encrypted": False,
165+
"host": "myHost",
166+
"port": "1234",
167+
"token": "myToken",
168+
"batchSize": 42,
169+
},
170+
}

0 commit comments

Comments
 (0)