From 0f5fdecb4f725b6f9795d09abb1c11493fd958d2 Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 8 Dec 2025 17:36:15 -0800 Subject: [PATCH 1/4] feat: implement the to_sql api --- rust/lance-graph/src/query.rs | 90 ++++++ rust/lance-graph/tests/test_to_sql.rs | 400 ++++++++++++++++++++++++++ 2 files changed, 490 insertions(+) create mode 100644 rust/lance-graph/tests/test_to_sql.rs diff --git a/rust/lance-graph/src/query.rs b/rust/lance-graph/src/query.rs index 49f4d35..f5c3f5f 100644 --- a/rust/lance-graph/src/query.rs +++ b/rust/lance-graph/src/query.rs @@ -201,6 +201,61 @@ impl CypherQuery { self.explain_internal(Arc::new(catalog), ctx).await } + /// Convert the Cypher query to a DataFusion SQL string + /// + /// This method generates a SQL string that corresponds to the DataFusion logical plan + /// derived from the Cypher query. It uses the `datafusion-sql` unparser. + /// + /// **WARNING**: This method is experimental and the generated SQL dialect may change. + /// + /// **Case Sensitivity Limitation**: All table names in the generated SQL are lowercased + /// (e.g., `Person` becomes `person`, `Company` becomes `company`), due to the internal + /// handling of DataFusion's SQL unparser. Note that this only affects the SQL string + /// representation - actual query execution with `execute()` handles case-sensitive labels + /// correctly. + /// + /// If you need case-sensitive table names in the SQL output, consider: + /// - Using lowercase labels consistently in your Cypher queries and table names + /// - Post-processing the SQL string to replace table names with the correct case + /// + /// # Arguments + /// * `datasets` - HashMap of table name to RecordBatch (nodes and relationships) + /// + /// # Returns + /// A SQL string representing the query + pub async fn to_sql( + &self, + datasets: HashMap, + ) -> Result { + use datafusion_sql::unparser::plan_to_sql; + use std::sync::Arc; + + let _config = self.require_config()?; + + // Build catalog and context from datasets using the helper + let (catalog, ctx) = self + .build_catalog_and_context_from_datasets(datasets) + .await?; + + // Generate Logical Plan + let (_, df_plan) = self.create_logical_plans(Arc::new(catalog))?; + + // Optimize the plan using DataFusion's default optimizer rules + // This helps simplify the plan (e.g., merging projections) to produce cleaner SQL + let optimized_plan = ctx.state().optimize(&df_plan).map_err(|e| GraphError::PlanError { + message: format!("Failed to optimize plan: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + // Unparse to SQL + let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError { + message: format!("Failed to unparse plan to SQL: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + Ok(sql_ast.to_string()) + } + /// Execute query with a DataFusion SessionContext, automatically building the catalog /// /// This is a convenience method that builds the graph catalog by querying the @@ -1822,4 +1877,39 @@ mod tests { assert_eq!(names.value(1), "Bob"); assert_eq!(scores.value(1), 92); } + + #[tokio::test] + async fn test_to_sql() { + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use std::collections::HashMap; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::new_empty(schema.clone()); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), batch); + + let cfg = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(cfg); + + let sql = query.to_sql(datasets).await.unwrap(); + println!("Generated SQL: {}", sql); + + assert!(sql.contains("SELECT")); + assert!(sql.to_lowercase().contains("from person")); + // Note: DataFusion unparser might quote identifiers or use aliases + // We check for "p.name" which is the expected output alias + assert!(sql.contains("p.name")); + } } diff --git a/rust/lance-graph/tests/test_to_sql.rs b/rust/lance-graph/tests/test_to_sql.rs new file mode 100644 index 0000000..d5ec2b6 --- /dev/null +++ b/rust/lance-graph/tests/test_to_sql.rs @@ -0,0 +1,400 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Integration tests for the to_sql API +//! +//! These tests verify that Cypher queries can be correctly converted to SQL strings. + +use arrow:: array:: { Int32Array, StringArray }; +use arrow:: datatypes:: { DataType, Field, Schema }; +use arrow:: record_batch:: RecordBatch; +use lance_graph:: { CypherQuery, GraphConfig }; +use std:: collections:: HashMap; +use std:: sync:: Arc; + +/// Helper to create a simple Person table +fn create_person_table() -> RecordBatch { + let schema = Arc::new (Schema:: new (vec![ + Field:: new ("person_id", DataType:: Int32, false), + Field:: new ("name", DataType:: Utf8, false), + Field:: new ("age", DataType:: Int32, false), + Field:: new ("city", DataType:: Utf8, false), + ])); + + let person_ids = Int32Array:: from(vec![1, 2, 3, 4]); + let names = StringArray:: from(vec!["Alice", "Bob", "Carol", "David"]); + let ages = Int32Array:: from(vec![28, 34, 29, 42]); + let cities = StringArray:: from(vec!["New York", "San Francisco", "New York", "Chicago"]); + + RecordBatch:: try_new( + schema, + vec![ + Arc:: new (person_ids), + Arc:: new (names), + Arc:: new (ages), + Arc:: new (cities), + ], + ) + .unwrap() +} + +/// Helper to create a Company table +fn create_company_table() -> RecordBatch { + let schema = Arc::new (Schema:: new (vec![ + Field:: new ("company_id", DataType:: Int32, false), + Field:: new ("company_name", DataType:: Utf8, false), + Field:: new ("industry", DataType:: Utf8, false), + ])); + + let company_ids = Int32Array:: from(vec![101, 102, 103]); + let names = StringArray:: from(vec!["TechCorp", "DataInc", "CloudSoft"]); + let industries = StringArray:: from(vec!["Technology", "Analytics", "Cloud"]); + + RecordBatch:: try_new( + schema, + vec![ + Arc:: new (company_ids), + Arc:: new (names), + Arc:: new (industries), + ], + ) + .unwrap() +} + +/// Helper to create a WORKS_FOR relationship table +fn create_works_for_table() -> RecordBatch { + let schema = Arc::new (Schema:: new (vec![ + Field:: new ("person_id", DataType:: Int32, false), + Field:: new ("company_id", DataType:: Int32, false), + Field:: new ("position", DataType:: Utf8, false), + Field:: new ("salary", DataType:: Int32, false), + ])); + + let person_ids = Int32Array:: from(vec![1, 2, 3, 4]); + let company_ids = Int32Array:: from(vec![101, 101, 102, 103]); + let positions = StringArray:: from(vec!["Engineer", "Designer", "Manager", "Director"]); + let salaries = Int32Array:: from(vec![120000, 95000, 130000, 180000]); + + RecordBatch:: try_new( + schema, + vec![ + Arc:: new (person_ids), + Arc:: new (company_ids), + Arc:: new (positions), + Arc:: new (salaries), + ], + ) + .unwrap() +} + +#[tokio::test] +async fn test_to_sql_simple_node_scan() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains expected elements + assert!(sql.to_uppercase().contains("SELECT"), "SQL should contain SELECT"); + assert!(sql.to_lowercase().contains("person"), "SQL should reference person table"); + assert!(sql.contains("name"), "SQL should reference name column"); + + // SQL should be non-empty and valid + assert!(!sql.is_empty(), "Generated SQL should not be empty"); + println!("Generated SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_filter() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains filter condition + assert!(sql.contains("SELECT"), "SQL should contain SELECT"); + assert!(sql.contains("WHERE") || sql.contains("FILTER"), "SQL should contain WHERE clause"); + assert!(sql.contains("age"), "SQL should reference age column"); + assert!(sql.contains("30"), "SQL should contain filter value"); + + println!("Generated SQL with filter:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_multiple_properties() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name, p.age, p.city") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify all columns are present + assert!(sql.contains("name"), "SQL should contain name"); + assert!(sql.contains("age"), "SQL should contain age"); + assert!(sql.contains("city"), "SQL should contain city"); + + println!("Generated SQL with multiple properties:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_relationship() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new ( + "MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN p.name, c.company_name", + ) + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains join + let sql_upper = sql.to_uppercase(); + let sql_lower = sql.to_lowercase(); + assert!(sql_upper.contains("SELECT"), "SQL should contain SELECT"); + assert!(sql_upper.contains("JOIN"), "SQL should contain JOIN"); + assert!(sql_lower.contains("person"), "SQL should reference person"); + assert!(sql_lower.contains("company"), "SQL should reference company"); + + println!("Generated SQL with relationship:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_relationship_filter() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new ( + "MATCH (p:Person)-[w:WORKS_FOR]->(c:Company) WHERE w.salary > 100000 RETURN p.name, c.company_name, w.salary", + ) + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains filter on relationship property + assert!(sql.contains("salary"), "SQL should reference salary"); + assert!(sql.contains("100000"), "SQL should contain filter value"); + + println!("Generated SQL with relationship filter:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_order_by() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age DESC") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains ORDER BY + assert!( + sql.contains("ORDER BY") || sql.contains("order by"), + "SQL should contain ORDER BY" + ); + assert!(sql.contains("age"), "SQL should reference age in ORDER BY"); + + println!("Generated SQL with ORDER BY:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_limit() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name LIMIT 2") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains LIMIT + assert!( + sql.contains("LIMIT") || sql.contains("limit"), + "SQL should contain LIMIT" + ); + assert!(sql.contains("2"), "SQL should contain limit value"); + + println!("Generated SQL with LIMIT:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_distinct() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN DISTINCT p.city") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL is generated successfully + // Note: DISTINCT might be optimized away by DataFusion's optimizer in some cases + assert!(!sql.is_empty(), "SQL should be generated"); + assert!(sql.contains("city"), "SQL should reference city"); + + println!("Generated SQL with DISTINCT:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_with_alias() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name AS person_name, p.age AS person_age") + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify SQL contains aliases + assert!(sql.contains("AS") || sql.contains("as"), "SQL should contain AS for aliases"); + + println!("Generated SQL with aliases:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_complex_query() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .with_node_label("Company", "company_id") + .with_relationship("WORKS_FOR", "person_id", "company_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + datasets.insert("Company".to_string(), create_company_table()); + datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); + + let query = CypherQuery::new ( + "MATCH (p:Person)-[w:WORKS_FOR]->(c:Company) \ + WHERE p.age > 30 AND c.industry = 'Technology' \ + RETURN p.name, c.company_name, w.position \ + ORDER BY p.age DESC \ + LIMIT 5", + ) + .unwrap() + .with_config(config); + + let sql = query.to_sql(datasets).await.unwrap(); + + // Verify complex query elements + assert!(sql.contains("SELECT"), "SQL should contain SELECT"); + assert!(sql.contains("JOIN") || sql.contains("join"), "SQL should contain JOIN"); + assert!(sql.contains("WHERE") || sql.contains("where"), "SQL should contain WHERE"); + assert!(sql.contains("ORDER BY") || sql.contains("order by"), "SQL should contain ORDER BY"); + assert!(sql.contains("LIMIT") || sql.contains("limit"), "SQL should contain LIMIT"); + + println!("Generated complex SQL:\n{}", sql); +} + +#[tokio::test] +async fn test_to_sql_missing_config() { + let mut datasets = HashMap::new (); + datasets.insert("Person".to_string(), create_person_table()); + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name").unwrap(); + // Note: No config set + + let result = query.to_sql(datasets).await; + + // Should fail without config + assert!(result.is_err(), "to_sql should fail without config"); + assert!( + result.unwrap_err().to_string().contains("configuration"), + "Error should mention missing configuration" + ); +} + +#[tokio::test] +async fn test_to_sql_empty_datasets() { + let config = GraphConfig:: builder() + .with_node_label("Person", "person_id") + .build() + .unwrap(); + + let datasets = HashMap::new (); // Empty + + let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name") + .unwrap() + .with_config(config); + + let result = query.to_sql(datasets).await; + + // Should fail with empty datasets + assert!(result.is_err(), "to_sql should fail with empty datasets"); + assert!( + result.unwrap_err().to_string().contains("No input datasets"), + "Error should mention missing datasets" + ); +} From dc23ced37df2ef8deca60d2bd54aab42489918ab Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 8 Dec 2025 17:36:27 -0800 Subject: [PATCH 2/4] feat: add the python api for to_sql --- python/python/tests/test_to_sql.py | 202 +++++++++++++++++++++++++++++ python/src/graph.rs | 22 +++- 2 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 python/python/tests/test_to_sql.py diff --git a/python/python/tests/test_to_sql.py b/python/python/tests/test_to_sql.py new file mode 100644 index 0000000..f758c37 --- /dev/null +++ b/python/python/tests/test_to_sql.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +"""Tests for the to_sql API that converts Cypher queries to SQL.""" + +import pyarrow as pa +import pytest +from lance_graph import CypherQuery, GraphConfig + + +@pytest.fixture +def knowledge_graph_env(): + """Create a complex knowledge graph with multiple entity types and relationships.""" + # Authors and their publications + authors_table = pa.table( + { + "author_id": [1, 2, 3, 4, 5], + "name": ["Alice Chen", "Bob Smith", "Carol Wang", "David Lee", "Eve Martinez"], + "institution": ["MIT", "Stanford", "CMU", "Berkeley", "MIT"], + "h_index": [45, 38, 52, 41, 29], + "country": ["USA", "USA", "USA", "USA", "Spain"], + } + ) + + papers_table = pa.table( + { + "paper_id": [101, 102, 103, 104, 105, 106], + "title": [ + "Deep Learning Advances", + "Graph Neural Networks", + "Transformer Architecture", + "Reinforcement Learning", + "Computer Vision Methods", + "Natural Language Processing", + ], + "year": [2020, 2021, 2019, 2022, 2021, 2020], + "citations": [450, 320, 890, 210, 380, 520], + "venue": ["NeurIPS", "ICML", "NeurIPS", "ICLR", "CVPR", "ACL"], + } + ) + + authorship_table = pa.table( + { + "author_id": [1, 1, 2, 2, 3, 3, 4, 5, 5], + "paper_id": [101, 102, 102, 103, 103, 104, 105, 105, 106], + "author_position": [1, 1, 2, 1, 2, 1, 1, 2, 1], + } + ) + + citations_table = pa.table( + { + "citing_paper_id": [102, 103, 104, 104, 105, 106], + "cited_paper_id": [101, 101, 102, 103, 103, 101], + } + ) + + config = ( + GraphConfig.builder() + .with_node_label("Author", "author_id") + .with_node_label("Paper", "paper_id") + .with_relationship("AUTHORED", "author_id", "paper_id") + .with_relationship("CITES", "citing_paper_id", "cited_paper_id") + .build() + ) + + datasets = { + "Author": authors_table, + "Paper": papers_table, + "AUTHORED": authorship_table, + "CITES": citations_table, + } + + return config, datasets + + +def test_multi_hop_relationship_with_aggregation(knowledge_graph_env): + """Test complex multi-hop query with aggregation and filtering. + + Find authors who have written highly cited papers (>400 citations) and count + how many such papers they have, filtering for prolific authors. + """ + config, datasets = knowledge_graph_env + query = CypherQuery( + """ + MATCH (a:Author)-[:AUTHORED]->(p:Paper) + WHERE p.citations > 400 + RETURN a.name, a.institution, COUNT(*) AS high_impact_papers + ORDER BY high_impact_papers DESC + """ + ).with_config(config) + + sql = query.to_sql(datasets) + + assert isinstance(sql, str) + sql_upper = sql.upper() + assert "SELECT" in sql_upper + assert "JOIN" in sql_upper + assert "WHERE" in sql_upper + assert "COUNT" in sql_upper + assert "GROUP BY" in sql_upper + assert "ORDER BY" in sql_upper + + +def test_citation_network_analysis(knowledge_graph_env): + """Test citation network traversal with multiple joins. + + Find papers that cite other papers, along with author information, + filtered by venue and year range. + """ + config, datasets = knowledge_graph_env + query = CypherQuery( + """ + MATCH (citing:Paper)-[:CITES]->(cited:Paper) + WHERE citing.year >= 2020 AND citing.venue = 'NeurIPS' + RETURN citing.title, cited.title, citing.year, cited.citations + ORDER BY cited.citations DESC + LIMIT 10 + """ + ).with_config(config) + + sql = query.to_sql(datasets) + + assert isinstance(sql, str) + sql_upper = sql.upper() + assert "SELECT" in sql_upper + assert "JOIN" in sql_upper + assert "WHERE" in sql_upper + assert "ORDER BY" in sql_upper + assert "LIMIT" in sql_upper + + +def test_collaborative_network_query(knowledge_graph_env): + """Test finding collaboration patterns through shared papers. + + Find pairs of authors who have co-authored papers, with filtering + on institution and h-index. + """ + config, datasets = knowledge_graph_env + query = CypherQuery( + """ + MATCH (a1:Author)-[:AUTHORED]->(p:Paper)<-[:AUTHORED]-(a2:Author) + WHERE a1.author_id < a2.author_id + AND a1.institution = 'MIT' + AND a2.h_index > 30 + RETURN DISTINCT a1.name, a2.name, p.title, p.year + ORDER BY p.year DESC + """ + ).with_config(config) + + sql = query.to_sql(datasets) + + assert isinstance(sql, str) + sql_upper = sql.upper() + assert "SELECT" in sql_upper + # DISTINCT may be converted to GROUP BY by the SQL unparser + assert "DISTINCT" in sql_upper or "GROUP BY" in sql_upper + assert "JOIN" in sql_upper + assert "WHERE" in sql_upper + assert "ORDER BY" in sql_upper + + +def test_parameterized_complex_query(knowledge_graph_env): + """Test complex query with multiple parameters. + + Find authors from a specific country with papers above a citation threshold, + published in recent years. + """ + config, datasets = knowledge_graph_env + query = ( + CypherQuery( + """ + MATCH (a:Author)-[:AUTHORED]->(p:Paper) + WHERE a.country = $country + AND p.citations > $min_citations + AND p.year >= $min_year + RETURN a.name, a.h_index, p.title, p.citations + ORDER BY p.citations DESC, a.h_index DESC + """ + ) + .with_config(config) + .with_parameter("country", "USA") + .with_parameter("min_citations", 300) + .with_parameter("min_year", 2020) + ) + + sql = query.to_sql(datasets) + + assert isinstance(sql, str) + sql_upper = sql.upper() + assert "SELECT" in sql_upper + assert "JOIN" in sql_upper + assert "WHERE" in sql_upper + assert "ORDER BY" in sql_upper + + +def test_to_sql_without_config_raises_error(knowledge_graph_env): + """Test that to_sql fails gracefully without config.""" + _, datasets = knowledge_graph_env + query = CypherQuery("MATCH (a:Author) RETURN a.name") + + with pytest.raises(Exception): + query.to_sql(datasets) diff --git a/python/src/graph.rs b/python/src/graph.rs index d8563be..8983265 100644 --- a/python/src/graph.rs +++ b/python/src/graph.rs @@ -269,6 +269,11 @@ impl CypherQuery { /// Convert query to SQL /// + /// Parameters + /// ---------- + /// datasets : dict + /// Dictionary mapping table names to Lance datasets + /// /// Returns /// ------- /// str @@ -278,10 +283,19 @@ impl CypherQuery { /// ------ /// RuntimeError /// If SQL generation fails - fn to_sql(&self) -> PyResult { - // SQL generation not yet implemented in lance-graph. - // Return the original query text for now to keep API stable. - Ok(self.inner.query_text().to_string()) + fn to_sql(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult { + // Convert datasets to Arrow RecordBatch map + let arrow_datasets = python_datasets_to_batches(datasets)?; + + // Clone for async move + let inner_query = self.inner.clone(); + + // Execute via runtime + let sql = RT + .block_on(Some(py), inner_query.to_sql(arrow_datasets))? + .map_err(graph_error_to_pyerr)?; + + Ok(sql) } /// Execute query against Lance datasets From 26ebbc4af40462265a198d304d29862bc89dacb8 Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 8 Dec 2025 17:42:28 -0800 Subject: [PATCH 3/4] format code --- rust/lance-graph/src/query.rs | 11 +- rust/lance-graph/tests/test_to_sql.rs | 227 ++++++++++++++------------ 2 files changed, 134 insertions(+), 104 deletions(-) diff --git a/rust/lance-graph/src/query.rs b/rust/lance-graph/src/query.rs index f5c3f5f..0fb0a48 100644 --- a/rust/lance-graph/src/query.rs +++ b/rust/lance-graph/src/query.rs @@ -242,10 +242,13 @@ impl CypherQuery { // Optimize the plan using DataFusion's default optimizer rules // This helps simplify the plan (e.g., merging projections) to produce cleaner SQL - let optimized_plan = ctx.state().optimize(&df_plan).map_err(|e| GraphError::PlanError { - message: format!("Failed to optimize plan: {}", e), - location: snafu::Location::new(file!(), line!(), column!()), - })?; + let optimized_plan = ctx + .state() + .optimize(&df_plan) + .map_err(|e| GraphError::PlanError { + message: format!("Failed to optimize plan: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; // Unparse to SQL let sql_ast = plan_to_sql(&optimized_plan).map_err(|e| GraphError::PlanError { diff --git a/rust/lance-graph/tests/test_to_sql.rs b/rust/lance-graph/tests/test_to_sql.rs index d5ec2b6..1ba8a8d 100644 --- a/rust/lance-graph/tests/test_to_sql.rs +++ b/rust/lance-graph/tests/test_to_sql.rs @@ -5,107 +5,109 @@ //! //! These tests verify that Cypher queries can be correctly converted to SQL strings. -use arrow:: array:: { Int32Array, StringArray }; -use arrow:: datatypes:: { DataType, Field, Schema }; -use arrow:: record_batch:: RecordBatch; -use lance_graph:: { CypherQuery, GraphConfig }; -use std:: collections:: HashMap; -use std:: sync:: Arc; +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use lance_graph::{CypherQuery, GraphConfig}; +use std::collections::HashMap; +use std::sync::Arc; /// Helper to create a simple Person table fn create_person_table() -> RecordBatch { - let schema = Arc::new (Schema:: new (vec![ - Field:: new ("person_id", DataType:: Int32, false), - Field:: new ("name", DataType:: Utf8, false), - Field:: new ("age", DataType:: Int32, false), - Field:: new ("city", DataType:: Utf8, false), + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("city", DataType::Utf8, false), ])); - let person_ids = Int32Array:: from(vec![1, 2, 3, 4]); - let names = StringArray:: from(vec!["Alice", "Bob", "Carol", "David"]); - let ages = Int32Array:: from(vec![28, 34, 29, 42]); - let cities = StringArray:: from(vec!["New York", "San Francisco", "New York", "Chicago"]); + let person_ids = Int32Array::from(vec![1, 2, 3, 4]); + let names = StringArray::from(vec!["Alice", "Bob", "Carol", "David"]); + let ages = Int32Array::from(vec![28, 34, 29, 42]); + let cities = StringArray::from(vec!["New York", "San Francisco", "New York", "Chicago"]); - RecordBatch:: try_new( + RecordBatch::try_new( schema, vec![ - Arc:: new (person_ids), - Arc:: new (names), - Arc:: new (ages), - Arc:: new (cities), + Arc::new(person_ids), + Arc::new(names), + Arc::new(ages), + Arc::new(cities), ], ) - .unwrap() + .unwrap() } /// Helper to create a Company table fn create_company_table() -> RecordBatch { - let schema = Arc::new (Schema:: new (vec![ - Field:: new ("company_id", DataType:: Int32, false), - Field:: new ("company_name", DataType:: Utf8, false), - Field:: new ("industry", DataType:: Utf8, false), + let schema = Arc::new(Schema::new(vec![ + Field::new("company_id", DataType::Int32, false), + Field::new("company_name", DataType::Utf8, false), + Field::new("industry", DataType::Utf8, false), ])); - let company_ids = Int32Array:: from(vec![101, 102, 103]); - let names = StringArray:: from(vec!["TechCorp", "DataInc", "CloudSoft"]); - let industries = StringArray:: from(vec!["Technology", "Analytics", "Cloud"]); + let company_ids = Int32Array::from(vec![101, 102, 103]); + let names = StringArray::from(vec!["TechCorp", "DataInc", "CloudSoft"]); + let industries = StringArray::from(vec!["Technology", "Analytics", "Cloud"]); - RecordBatch:: try_new( + RecordBatch::try_new( schema, - vec![ - Arc:: new (company_ids), - Arc:: new (names), - Arc:: new (industries), - ], + vec![Arc::new(company_ids), Arc::new(names), Arc::new(industries)], ) - .unwrap() + .unwrap() } /// Helper to create a WORKS_FOR relationship table fn create_works_for_table() -> RecordBatch { - let schema = Arc::new (Schema:: new (vec![ - Field:: new ("person_id", DataType:: Int32, false), - Field:: new ("company_id", DataType:: Int32, false), - Field:: new ("position", DataType:: Utf8, false), - Field:: new ("salary", DataType:: Int32, false), + let schema = Arc::new(Schema::new(vec![ + Field::new("person_id", DataType::Int32, false), + Field::new("company_id", DataType::Int32, false), + Field::new("position", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), ])); - let person_ids = Int32Array:: from(vec![1, 2, 3, 4]); - let company_ids = Int32Array:: from(vec![101, 101, 102, 103]); - let positions = StringArray:: from(vec!["Engineer", "Designer", "Manager", "Director"]); - let salaries = Int32Array:: from(vec![120000, 95000, 130000, 180000]); + let person_ids = Int32Array::from(vec![1, 2, 3, 4]); + let company_ids = Int32Array::from(vec![101, 101, 102, 103]); + let positions = StringArray::from(vec!["Engineer", "Designer", "Manager", "Director"]); + let salaries = Int32Array::from(vec![120000, 95000, 130000, 180000]); - RecordBatch:: try_new( + RecordBatch::try_new( schema, vec![ - Arc:: new (person_ids), - Arc:: new (company_ids), - Arc:: new (positions), - Arc:: new (salaries), + Arc::new(person_ids), + Arc::new(company_ids), + Arc::new(positions), + Arc::new(salaries), ], ) - .unwrap() + .unwrap() } #[tokio::test] async fn test_to_sql_simple_node_scan() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name") + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") .unwrap() .with_config(config); let sql = query.to_sql(datasets).await.unwrap(); // Verify SQL contains expected elements - assert!(sql.to_uppercase().contains("SELECT"), "SQL should contain SELECT"); - assert!(sql.to_lowercase().contains("person"), "SQL should reference person table"); + assert!( + sql.to_uppercase().contains("SELECT"), + "SQL should contain SELECT" + ); + assert!( + sql.to_lowercase().contains("person"), + "SQL should reference person table" + ); assert!(sql.contains("name"), "SQL should reference name column"); // SQL should be non-empty and valid @@ -115,15 +117,15 @@ async fn test_to_sql_simple_node_scan() { #[tokio::test] async fn test_to_sql_with_filter() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age") + let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age") .unwrap() .with_config(config); @@ -131,7 +133,10 @@ async fn test_to_sql_with_filter() { // Verify SQL contains filter condition assert!(sql.contains("SELECT"), "SQL should contain SELECT"); - assert!(sql.contains("WHERE") || sql.contains("FILTER"), "SQL should contain WHERE clause"); + assert!( + sql.contains("WHERE") || sql.contains("FILTER"), + "SQL should contain WHERE clause" + ); assert!(sql.contains("age"), "SQL should reference age column"); assert!(sql.contains("30"), "SQL should contain filter value"); @@ -140,15 +145,15 @@ async fn test_to_sql_with_filter() { #[tokio::test] async fn test_to_sql_with_multiple_properties() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name, p.age, p.city") + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age, p.city") .unwrap() .with_config(config); @@ -164,23 +169,23 @@ async fn test_to_sql_with_multiple_properties() { #[tokio::test] async fn test_to_sql_with_relationship() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .with_node_label("Company", "company_id") .with_relationship("WORKS_FOR", "person_id", "company_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); datasets.insert("Company".to_string(), create_company_table()); datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); - let query = CypherQuery::new ( + let query = CypherQuery::new( "MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN p.name, c.company_name", ) - .unwrap() - .with_config(config); + .unwrap() + .with_config(config); let sql = query.to_sql(datasets).await.unwrap(); @@ -190,21 +195,24 @@ async fn test_to_sql_with_relationship() { assert!(sql_upper.contains("SELECT"), "SQL should contain SELECT"); assert!(sql_upper.contains("JOIN"), "SQL should contain JOIN"); assert!(sql_lower.contains("person"), "SQL should reference person"); - assert!(sql_lower.contains("company"), "SQL should reference company"); + assert!( + sql_lower.contains("company"), + "SQL should reference company" + ); println!("Generated SQL with relationship:\n{}", sql); } #[tokio::test] async fn test_to_sql_with_relationship_filter() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .with_node_label("Company", "company_id") .with_relationship("WORKS_FOR", "person_id", "company_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); datasets.insert("Company".to_string(), create_company_table()); datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); @@ -226,15 +234,15 @@ async fn test_to_sql_with_relationship_filter() { #[tokio::test] async fn test_to_sql_with_order_by() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age DESC") + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name, p.age ORDER BY p.age DESC") .unwrap() .with_config(config); @@ -252,15 +260,15 @@ async fn test_to_sql_with_order_by() { #[tokio::test] async fn test_to_sql_with_limit() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name LIMIT 2") + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name LIMIT 2") .unwrap() .with_config(config); @@ -278,15 +286,15 @@ async fn test_to_sql_with_limit() { #[tokio::test] async fn test_to_sql_with_distinct() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN DISTINCT p.city") + let query = CypherQuery::new("MATCH (p:Person) RETURN DISTINCT p.city") .unwrap() .with_config(config); @@ -302,68 +310,84 @@ async fn test_to_sql_with_distinct() { #[tokio::test] async fn test_to_sql_with_alias() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name AS person_name, p.age AS person_age") - .unwrap() - .with_config(config); + let query = + CypherQuery::new("MATCH (p:Person) RETURN p.name AS person_name, p.age AS person_age") + .unwrap() + .with_config(config); let sql = query.to_sql(datasets).await.unwrap(); // Verify SQL contains aliases - assert!(sql.contains("AS") || sql.contains("as"), "SQL should contain AS for aliases"); + assert!( + sql.contains("AS") || sql.contains("as"), + "SQL should contain AS for aliases" + ); println!("Generated SQL with aliases:\n{}", sql); } #[tokio::test] async fn test_to_sql_complex_query() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .with_node_label("Company", "company_id") .with_relationship("WORKS_FOR", "person_id", "company_id") .build() .unwrap(); - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); datasets.insert("Company".to_string(), create_company_table()); datasets.insert("WORKS_FOR".to_string(), create_works_for_table()); - let query = CypherQuery::new ( + let query = CypherQuery::new( "MATCH (p:Person)-[w:WORKS_FOR]->(c:Company) \ WHERE p.age > 30 AND c.industry = 'Technology' \ RETURN p.name, c.company_name, w.position \ ORDER BY p.age DESC \ LIMIT 5", ) - .unwrap() - .with_config(config); + .unwrap() + .with_config(config); let sql = query.to_sql(datasets).await.unwrap(); // Verify complex query elements assert!(sql.contains("SELECT"), "SQL should contain SELECT"); - assert!(sql.contains("JOIN") || sql.contains("join"), "SQL should contain JOIN"); - assert!(sql.contains("WHERE") || sql.contains("where"), "SQL should contain WHERE"); - assert!(sql.contains("ORDER BY") || sql.contains("order by"), "SQL should contain ORDER BY"); - assert!(sql.contains("LIMIT") || sql.contains("limit"), "SQL should contain LIMIT"); + assert!( + sql.contains("JOIN") || sql.contains("join"), + "SQL should contain JOIN" + ); + assert!( + sql.contains("WHERE") || sql.contains("where"), + "SQL should contain WHERE" + ); + assert!( + sql.contains("ORDER BY") || sql.contains("order by"), + "SQL should contain ORDER BY" + ); + assert!( + sql.contains("LIMIT") || sql.contains("limit"), + "SQL should contain LIMIT" + ); println!("Generated complex SQL:\n{}", sql); } #[tokio::test] async fn test_to_sql_missing_config() { - let mut datasets = HashMap::new (); + let mut datasets = HashMap::new(); datasets.insert("Person".to_string(), create_person_table()); - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name").unwrap(); + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name").unwrap(); // Note: No config set let result = query.to_sql(datasets).await; @@ -378,14 +402,14 @@ async fn test_to_sql_missing_config() { #[tokio::test] async fn test_to_sql_empty_datasets() { - let config = GraphConfig:: builder() + let config = GraphConfig::builder() .with_node_label("Person", "person_id") .build() .unwrap(); - let datasets = HashMap::new (); // Empty + let datasets = HashMap::new(); // Empty - let query = CypherQuery::new ("MATCH (p:Person) RETURN p.name") + let query = CypherQuery::new("MATCH (p:Person) RETURN p.name") .unwrap() .with_config(config); @@ -394,7 +418,10 @@ async fn test_to_sql_empty_datasets() { // Should fail with empty datasets assert!(result.is_err(), "to_sql should fail with empty datasets"); assert!( - result.unwrap_err().to_string().contains("No input datasets"), + result + .unwrap_err() + .to_string() + .contains("No input datasets"), "Error should mention missing datasets" ); } From f0f4bf04efbbcfc868c61199a2e8a6df549f9c05 Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 8 Dec 2025 18:27:44 -0800 Subject: [PATCH 4/4] format code --- python/python/tests/test_to_sql.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/python/tests/test_to_sql.py b/python/python/tests/test_to_sql.py index f758c37..4607b70 100644 --- a/python/python/tests/test_to_sql.py +++ b/python/python/tests/test_to_sql.py @@ -15,7 +15,13 @@ def knowledge_graph_env(): authors_table = pa.table( { "author_id": [1, 2, 3, 4, 5], - "name": ["Alice Chen", "Bob Smith", "Carol Wang", "David Lee", "Eve Martinez"], + "name": [ + "Alice Chen", + "Bob Smith", + "Carol Wang", + "David Lee", + "Eve Martinez", + ], "institution": ["MIT", "Stanford", "CMU", "Berkeley", "MIT"], "h_index": [45, 38, 52, 41, 29], "country": ["USA", "USA", "USA", "USA", "Spain"],