Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions python/python/tests/test_to_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

"""Tests for the to_sql API that converts Cypher queries to SQL."""

import pyarrow as pa
import pytest
from lance_graph import CypherQuery, GraphConfig


@pytest.fixture
def knowledge_graph_env():
"""Create a complex knowledge graph with multiple entity types and relationships."""
# Authors and their publications
authors_table = pa.table(
{
"author_id": [1, 2, 3, 4, 5],
"name": [
"Alice Chen",
"Bob Smith",
"Carol Wang",
"David Lee",
"Eve Martinez",
],
"institution": ["MIT", "Stanford", "CMU", "Berkeley", "MIT"],
"h_index": [45, 38, 52, 41, 29],
"country": ["USA", "USA", "USA", "USA", "Spain"],
}
)

papers_table = pa.table(
{
"paper_id": [101, 102, 103, 104, 105, 106],
"title": [
"Deep Learning Advances",
"Graph Neural Networks",
"Transformer Architecture",
"Reinforcement Learning",
"Computer Vision Methods",
"Natural Language Processing",
],
"year": [2020, 2021, 2019, 2022, 2021, 2020],
"citations": [450, 320, 890, 210, 380, 520],
"venue": ["NeurIPS", "ICML", "NeurIPS", "ICLR", "CVPR", "ACL"],
}
)

authorship_table = pa.table(
{
"author_id": [1, 1, 2, 2, 3, 3, 4, 5, 5],
"paper_id": [101, 102, 102, 103, 103, 104, 105, 105, 106],
"author_position": [1, 1, 2, 1, 2, 1, 1, 2, 1],
}
)

citations_table = pa.table(
{
"citing_paper_id": [102, 103, 104, 104, 105, 106],
"cited_paper_id": [101, 101, 102, 103, 103, 101],
}
)

config = (
GraphConfig.builder()
.with_node_label("Author", "author_id")
.with_node_label("Paper", "paper_id")
.with_relationship("AUTHORED", "author_id", "paper_id")
.with_relationship("CITES", "citing_paper_id", "cited_paper_id")
.build()
)

datasets = {
"Author": authors_table,
"Paper": papers_table,
"AUTHORED": authorship_table,
"CITES": citations_table,
}

return config, datasets


def test_multi_hop_relationship_with_aggregation(knowledge_graph_env):
"""Test complex multi-hop query with aggregation and filtering.

Find authors who have written highly cited papers (>400 citations) and count
how many such papers they have, filtering for prolific authors.
"""
config, datasets = knowledge_graph_env
query = CypherQuery(
"""
MATCH (a:Author)-[:AUTHORED]->(p:Paper)
WHERE p.citations > 400
RETURN a.name, a.institution, COUNT(*) AS high_impact_papers
ORDER BY high_impact_papers DESC
"""
).with_config(config)

sql = query.to_sql(datasets)

assert isinstance(sql, str)
sql_upper = sql.upper()
assert "SELECT" in sql_upper
assert "JOIN" in sql_upper
assert "WHERE" in sql_upper
assert "COUNT" in sql_upper
assert "GROUP BY" in sql_upper
assert "ORDER BY" in sql_upper


def test_citation_network_analysis(knowledge_graph_env):
"""Test citation network traversal with multiple joins.

Find papers that cite other papers, along with author information,
filtered by venue and year range.
"""
config, datasets = knowledge_graph_env
query = CypherQuery(
"""
MATCH (citing:Paper)-[:CITES]->(cited:Paper)
WHERE citing.year >= 2020 AND citing.venue = 'NeurIPS'
RETURN citing.title, cited.title, citing.year, cited.citations
ORDER BY cited.citations DESC
LIMIT 10
"""
).with_config(config)

sql = query.to_sql(datasets)

assert isinstance(sql, str)
sql_upper = sql.upper()
assert "SELECT" in sql_upper
assert "JOIN" in sql_upper
assert "WHERE" in sql_upper
assert "ORDER BY" in sql_upper
assert "LIMIT" in sql_upper


def test_collaborative_network_query(knowledge_graph_env):
"""Test finding collaboration patterns through shared papers.

Find pairs of authors who have co-authored papers, with filtering
on institution and h-index.
"""
config, datasets = knowledge_graph_env
query = CypherQuery(
"""
MATCH (a1:Author)-[:AUTHORED]->(p:Paper)<-[:AUTHORED]-(a2:Author)
WHERE a1.author_id < a2.author_id
AND a1.institution = 'MIT'
AND a2.h_index > 30
RETURN DISTINCT a1.name, a2.name, p.title, p.year
ORDER BY p.year DESC
"""
).with_config(config)

sql = query.to_sql(datasets)

assert isinstance(sql, str)
sql_upper = sql.upper()
assert "SELECT" in sql_upper
# DISTINCT may be converted to GROUP BY by the SQL unparser
assert "DISTINCT" in sql_upper or "GROUP BY" in sql_upper
assert "JOIN" in sql_upper
assert "WHERE" in sql_upper
assert "ORDER BY" in sql_upper


def test_parameterized_complex_query(knowledge_graph_env):
"""Test complex query with multiple parameters.

Find authors from a specific country with papers above a citation threshold,
published in recent years.
"""
config, datasets = knowledge_graph_env
query = (
CypherQuery(
"""
MATCH (a:Author)-[:AUTHORED]->(p:Paper)
WHERE a.country = $country
AND p.citations > $min_citations
AND p.year >= $min_year
RETURN a.name, a.h_index, p.title, p.citations
ORDER BY p.citations DESC, a.h_index DESC
"""
)
.with_config(config)
.with_parameter("country", "USA")
.with_parameter("min_citations", 300)
.with_parameter("min_year", 2020)
)

sql = query.to_sql(datasets)

assert isinstance(sql, str)
sql_upper = sql.upper()
assert "SELECT" in sql_upper
assert "JOIN" in sql_upper
assert "WHERE" in sql_upper
assert "ORDER BY" in sql_upper


def test_to_sql_without_config_raises_error(knowledge_graph_env):
"""Test that to_sql fails gracefully without config."""
_, datasets = knowledge_graph_env
query = CypherQuery("MATCH (a:Author) RETURN a.name")

with pytest.raises(Exception):
query.to_sql(datasets)
22 changes: 18 additions & 4 deletions python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ impl CypherQuery {

/// Convert query to SQL
///
/// Parameters
/// ----------
/// datasets : dict
/// Dictionary mapping table names to Lance datasets
///
/// Returns
/// -------
/// str
Expand All @@ -278,10 +283,19 @@ impl CypherQuery {
/// ------
/// RuntimeError
/// If SQL generation fails
fn to_sql(&self) -> PyResult<String> {
// SQL generation not yet implemented in lance-graph.
// Return the original query text for now to keep API stable.
Ok(self.inner.query_text().to_string())
fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
// Convert datasets to Arrow RecordBatch map
let arrow_datasets = python_datasets_to_batches(datasets)?;

// Clone for async move
let inner_query = self.inner.clone();

// Execute via runtime
let sql = RT
.block_on(Some(py), inner_query.to_sql(arrow_datasets))?
.map_err(graph_error_to_pyerr)?;

Ok(sql)
}

/// Execute query against Lance datasets
Expand Down
93 changes: 93 additions & 0 deletions rust/lance-graph/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,64 @@ impl CypherQuery {
self.explain_internal(Arc::new(catalog), ctx).await
}

/// Convert the Cypher query to a DataFusion SQL string
///
/// This method generates a SQL string that corresponds to the DataFusion logical plan
/// derived from the Cypher query. It uses the `datafusion-sql` unparser.
///
/// **WARNING**: This method is experimental and the generated SQL dialect may change.
///
/// **Case Sensitivity Limitation**: All table names in the generated SQL are lowercased
/// (e.g., `Person` becomes `person`, `Company` becomes `company`), due to the internal
/// handling of DataFusion's SQL unparser. Note that this only affects the SQL string
/// representation - actual query execution with `execute()` handles case-sensitive labels
/// correctly.
///
/// If you need case-sensitive table names in the SQL output, consider:
/// - Using lowercase labels consistently in your Cypher queries and table names
/// - Post-processing the SQL string to replace table names with the correct case
///
/// # Arguments
/// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships)
///
/// # Returns
/// A SQL string representing the query
pub async fn to_sql(
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<String> {
use datafusion_sql::unparser::plan_to_sql;
use std::sync::Arc;

let _config = self.require_config()?;

// Build catalog and context from datasets using the helper
let (catalog, ctx) = self
.build_catalog_and_context_from_datasets(datasets)
.await?;

// Generate Logical Plan
let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?;

// Optimize the plan using DataFusion's default optimizer rules
// This helps simplify the plan (e.g., merging projections) to produce cleaner SQL
let optimized_plan = ctx
.state()
.optimize(&df_plan)
.map_err(|e| GraphError::PlanError {
message: format!("Failed to optimize plan: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;

// Unparse to SQL
let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError {
message: format!("Failed to unparse plan to SQL: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;

Ok(sql_ast.to_string())
}

/// Execute query with a DataFusion SessionContext, automatically building the catalog
///
/// This is a convenience method that builds the graph catalog by querying the
Expand Down Expand Up @@ -1822,4 +1880,39 @@ mod tests {
assert_eq!(names.value(1), "Bob");
assert_eq!(scores.value(1), 92);
}

#[tokio::test]
async fn test_to_sql() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use std::collections::HashMap;
use std::sync::Arc;

let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::new_empty(schema.clone());

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), batch);

let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();

let query = CypherQuery::new("MATCH (p:Person) RETURN p.name")
.unwrap()
.with_config(cfg);

let sql = query.to_sql(datasets).await.unwrap();
println!("Generated SQL: {}", sql);

assert!(sql.contains("SELECT"));
assert!(sql.to_lowercase().contains("from person"));
// Note: DataFusion unparser might quote identifiers or use aliases
// We check for "p.name" which is the expected output alias
assert!(sql.contains("p.name"));
}
}
Loading
Loading