Skip to content

Commit 06b1f14

Browse files
committed
unify explain api and update tests
1 parent 7dbfee5 commit 06b1f14

File tree

9 files changed

+483
-253
lines changed

9 files changed

+483
-253
lines changed

python/python/tests/test_explain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Tests for explain_datafusion API."""
1+
"""Tests for explain API."""
22

33
import pyarrow as pa
44
import pytest
@@ -25,7 +25,7 @@ def test_explain_simple_query(person_data):
2525
"""Test explain output contains all expected sections."""
2626
config, people = person_data
2727
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
28-
plan = query.explain_datafusion({"Person": people})
28+
plan = query.explain({"Person": people})
2929

3030
# Verify the plan is a non-empty string
3131
assert isinstance(plan, str)
@@ -48,7 +48,7 @@ def test_explain_with_clauses(person_data):
4848
query = CypherQuery(
4949
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name ORDER BY p.age LIMIT 2"
5050
).with_config(config)
51-
plan = query.explain_datafusion({"Person": people})
51+
plan = query.explain({"Person": people})
5252

5353
assert isinstance(plan, str)
5454
assert "WHERE p.age > 30" in plan
@@ -63,11 +63,11 @@ def test_explain_error_handling(person_data):
6363
# Missing config
6464
query_no_config = CypherQuery("MATCH (p:Person) RETURN p.name")
6565
with pytest.raises(ValueError, match="Graph configuration is required"):
66-
query_no_config.explain_datafusion({"Person": people})
66+
query_no_config.explain({"Person": people})
6767

6868
# Missing datasets
6969
query_with_config = CypherQuery("MATCH (p:Person) RETURN p.name").with_config(
7070
config
7171
)
7272
with pytest.raises(ValueError, match="No input datasets provided"):
73-
query_with_config.explain_datafusion({})
73+
query_with_config.explain({})

python/python/tests/test_graph.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,38 @@ def graph_env(tmp_path):
6262
return config, datasets, people_table
6363

6464

65-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
66-
def test_basic_node_selection(graph_env, execute_method):
65+
def test_basic_node_selection(graph_env):
6766
config, datasets, _ = graph_env
6867
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
69-
result = getattr(query, execute_method)({"Person": datasets["Person"]})
68+
result = query.execute({"Person": datasets["Person"]})
7069
data = result.to_pydict()
7170

7271
assert set(data.keys()) == {"p.name", "p.age"}
7372
assert len(data["p.name"]) == 4
7473
assert "Alice" in set(data["p.name"])
7574

7675

77-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
78-
def test_filtered_query(graph_env, execute_method):
76+
def test_filtered_query(graph_env):
7977
config, datasets, _ = graph_env
8078
query = CypherQuery(
8179
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age"
8280
).with_config(config)
83-
result = getattr(query, execute_method)({"Person": datasets["Person"]})
81+
result = query.execute({"Person": datasets["Person"]})
8482
data = result.to_pydict()
8583

8684
assert len(data["p.name"]) == 2
8785
assert set(data["p.name"]) == {"Bob", "David"}
8886
assert all(age > 30 for age in data["p.age"])
8987

9088

91-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
92-
def test_relationship_query(graph_env, execute_method):
89+
def test_relationship_query(graph_env):
9390
config, datasets, _ = graph_env
9491
query = CypherQuery(
9592
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) "
9693
"RETURN p.person_id AS person_id, p.name AS name, c.company_id AS company_id"
9794
).with_config(config)
9895

99-
result = getattr(query, execute_method)(
96+
result = query.execute(
10097
{
10198
"Person": datasets["Person"],
10299
"Company": datasets["Company"],
@@ -109,8 +106,7 @@ def test_relationship_query(graph_env, execute_method):
109106
assert data["company_id"] == [101, 101, 102, 103]
110107

111108

112-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
113-
def test_friendship_direct_and_network(graph_env, execute_method):
109+
def test_friendship_direct_and_network(graph_env):
114110
config, datasets, _ = graph_env
115111
# Direct friends of Alice (person_id = 1)
116112
query_direct = CypherQuery(
@@ -119,7 +115,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
119115
"RETURN b.person_id AS friend_id"
120116
).with_config(config)
121117

122-
result_direct = getattr(query_direct, execute_method)(
118+
result_direct = query_direct.execute(
123119
{
124120
"Person": datasets["Person"],
125121
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -134,7 +130,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
134130
"RETURN f.person_id AS person1_id, t.person_id AS person2_id"
135131
).with_config(config)
136132

137-
result_edges = getattr(query_edges, execute_method)(
133+
result_edges = query_edges.execute(
138134
{
139135
"Person": datasets["Person"],
140136
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -145,16 +141,15 @@ def test_friendship_direct_and_network(graph_env, execute_method):
145141
assert got == {(1, 2), (1, 3), (2, 4), (3, 4)}
146142

147143

148-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
149-
def test_two_hop_friends_of_friends(graph_env, execute_method):
144+
def test_two_hop_friends_of_friends(graph_env):
150145
config, datasets, _ = graph_env
151146
query = CypherQuery(
152147
"MATCH (a:Person)-[:FRIEND_OF]->(b:Person)-[:FRIEND_OF]->(c:Person) "
153148
"WHERE a.person_id = 1 "
154149
"RETURN a.person_id AS a_id, b.person_id AS b_id, c.person_id AS c_id"
155150
).with_config(config)
156151

157-
result = getattr(query, execute_method)(
152+
result = query.execute(
158153
{
159154
"Person": datasets["Person"],
160155
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -164,29 +159,31 @@ def test_two_hop_friends_of_friends(graph_env, execute_method):
164159
assert set(data["c_id"]) == {4}
165160

166161

167-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
168-
def test_variable_length_path(graph_env, execute_method):
162+
def test_variable_length_path(graph_env):
169163
config, datasets, _ = graph_env
170164
query = CypherQuery(
171165
"MATCH (p1:Person)-[:FRIEND_OF*1..2]-(p2:Person) "
172166
"RETURN p1.person_id AS p1, p2.person_id AS p2"
173167
).with_config(config)
174-
_ = getattr(query, execute_method)(
168+
169+
result = query.execute(
175170
{
176171
"Person": datasets["Person"],
177172
"FRIEND_OF": datasets["FRIEND_OF"],
178173
}
179174
)
175+
data = result.to_pydict()
176+
got = set(zip(data["p1"], data["p2"]))
177+
assert got == {(1, 2), (1, 3), (2, 4), (3, 4), (1, 4)}
180178

181179

182-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
183-
def test_distinct_clause(graph_env, execute_method):
180+
def test_distinct_clause(graph_env):
184181
config, datasets, _ = graph_env
185182
query = CypherQuery(
186183
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN DISTINCT c.company_name"
187184
).with_config(config)
188185

189-
result = getattr(query, execute_method)(
186+
result = query.execute(
190187
{
191188
"Person": datasets["Person"],
192189
"Company": datasets["Company"],

python/src/graph.rs

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -335,41 +335,7 @@ impl CypherQuery {
335335
record_batch_to_python_table(py, &result_batch)
336336
}
337337

338-
/// Execute query using the DataFusion planner with in-memory datasets
339-
///
340-
/// Parameters
341-
/// ----------
342-
/// datasets : dict
343-
/// Dictionary mapping table names to in-memory tables (pyarrow.Table, LanceDataset, etc.)
344-
/// Keys should match node labels and relationship types in the graph config.
345-
///
346-
/// Returns
347-
/// -------
348-
/// pyarrow.Table
349-
/// Query results as Arrow table
350-
///
351-
/// Raises
352-
/// ------
353-
/// ValueError
354-
/// If the query is invalid or datasets are missing
355-
/// RuntimeError
356-
/// If query execution fails
357-
fn execute_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
358-
// Convert datasets to Arrow RecordBatch map
359-
let arrow_datasets = python_datasets_to_batches(datasets)?;
360-
361-
// Clone for async move
362-
let inner_query = self.inner.clone();
363-
364-
// Execute via runtime
365-
let result_batch = RT
366-
.block_on(Some(py), inner_query.execute_datafusion(arrow_datasets))?
367-
.map_err(graph_error_to_pyerr)?;
368-
369-
record_batch_to_python_table(py, &result_batch)
370-
}
371-
372-
/// Explain query uusing the DataFusion planner with in-memory datasets
338+
/// Explain query using the DataFusion planner with in-memory datasets
373339
///
374340
/// Parameters
375341
/// ----------
@@ -388,7 +354,7 @@ impl CypherQuery {
388354
/// If the query is invalid or datasets are missing
389355
/// RuntimeError
390356
/// If query explain fails
391-
fn explain_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
357+
fn explain(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
392358
// Convert datasets to Arrow RecordBatch map
393359
let arrow_datasets = python_datasets_to_batches(datasets)?;
394360

@@ -397,7 +363,7 @@ impl CypherQuery {
397363

398364
// Execute via runtime
399365
let plan = RT
400-
.block_on(Some(py), inner_query.explain_datafusion(arrow_datasets))?
366+
.block_on(Some(py), inner_query.explain(arrow_datasets))?
401367
.map_err(graph_error_to_pyerr)?;
402368

403369
Ok(plan)

0 commit comments

Comments
 (0)