diff --git a/rust/lance-graph/Cargo.lock b/rust/lance-graph/Cargo.lock index 6592e65..1152f1c 100644 --- a/rust/lance-graph/Cargo.lock +++ b/rust/lance-graph/Cargo.lock @@ -3267,6 +3267,7 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-sql", "futures", "lance", diff --git a/rust/lance-graph/Cargo.toml b/rust/lance-graph/Cargo.toml index a497198..a3f37d2 100644 --- a/rust/lance-graph/Cargo.toml +++ b/rust/lance-graph/Cargo.toml @@ -26,6 +26,7 @@ datafusion = { version = "49.0.2", default-features = false, features = [ datafusion-common = "49.0.2" datafusion-expr = "49.0.2" datafusion-sql = "49.0.2" +datafusion-functions-aggregate = "49.0.2" lance-core = "0.37.0" nom = "7.1" serde = { version = "1", features = ["derive"] } diff --git a/rust/lance-graph/src/datafusion_planner.rs b/rust/lance-graph/src/datafusion_planner.rs index 162b7c7..98ee913 100644 --- a/rust/lance-graph/src/datafusion_planner.rs +++ b/rust/lance-graph/src/datafusion_planner.rs @@ -21,6 +21,7 @@ use crate::source_catalog::GraphSourceCatalog; use datafusion::logical_expr::{ col, lit, BinaryExpr, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, }; +use datafusion_functions_aggregate::count::count; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -359,7 +360,7 @@ impl DataFusionPlanner { } => self.build_scan(ctx, variable, label, properties), LogicalOperator::Filter { input, predicate } => { let input_plan = self.build_operator(ctx, input)?; - let expr = self.to_df_boolean_expr(predicate); + let expr = Self::to_df_boolean_expr(predicate); LogicalPlanBuilder::from(input_plan) .filter(expr) .map_err(|e| self.plan_error("Failed to build filter", e))? @@ -368,25 +369,88 @@ impl DataFusionPlanner { } LogicalOperator::Project { input, projections } => { let input_plan = self.build_operator(ctx, input)?; - let exprs: Vec = projections + + // Check if any projection contains an aggregate function + let has_aggregates = projections .iter() - .map(|p| { - let expr = self.to_df_value_expr(&p.expression); - // Apply alias if provided, otherwise use Cypher dot notation - if let Some(alias) = &p.alias { - expr.alias(alias) + .any(|p| Self::contains_aggregate(&p.expression)); + + if has_aggregates { + // Build aggregate plan + // Separate group expressions (non-aggregates) from aggregate expressions + let mut group_exprs = Vec::new(); + let mut agg_exprs = Vec::new(); + // Store computed aliases for aggregates to reuse in final projection + let mut agg_aliases = Vec::new(); + + for p in projections { + let expr = Self::to_df_value_expr(&p.expression); + + if Self::contains_aggregate(&p.expression) { + // Aggregate expressions get aliased + let alias = if let Some(alias) = &p.alias { + alias.clone() + } else { + self.to_cypher_column_name(&p.expression) + }; + agg_exprs.push(expr.alias(&alias)); + agg_aliases.push(alias); } else { - // Convert to Cypher dot notation (e.g., p__name -> p.name) - let cypher_name = self.to_cypher_column_name(&p.expression); - expr.alias(cypher_name) + // Group expressions: use raw expression for grouping, no alias + group_exprs.push(expr); } - }) - .collect(); - LogicalPlanBuilder::from(input_plan) - .project(exprs) - .map_err(|e| self.plan_error("Failed to build projection", e))? - .build() - .map_err(|e| self.plan_error("Failed to build plan", e)) + } + + // After aggregation, add a projection to apply aliases to group columns + let mut final_projection = Vec::new(); + let mut agg_idx = 0; + for p in projections { + if !Self::contains_aggregate(&p.expression) { + // Re-create the expression and apply alias + let expr = Self::to_df_value_expr(&p.expression); + let aliased = if let Some(alias) = &p.alias { + expr.alias(alias) + } else { + let cypher_name = self.to_cypher_column_name(&p.expression); + expr.alias(cypher_name) + }; + final_projection.push(aliased); + } else { + // For aggregates, reference the column using the same alias we computed earlier + final_projection.push(col(&agg_aliases[agg_idx])); + agg_idx += 1; + } + } + + LogicalPlanBuilder::from(input_plan) + .aggregate(group_exprs, agg_exprs) + .map_err(|e| self.plan_error("Failed to build aggregate", e))? + .project(final_projection) + .map_err(|e| self.plan_error("Failed to project after aggregate", e))? + .build() + .map_err(|e| self.plan_error("Failed to build plan", e)) + } else { + // Regular projection + let exprs: Vec = projections + .iter() + .map(|p| { + let expr = Self::to_df_value_expr(&p.expression); + // Apply alias if provided, otherwise use Cypher dot notation + if let Some(alias) = &p.alias { + expr.alias(alias) + } else { + // Convert to Cypher dot notation (e.g., p__name -> p.name) + let cypher_name = self.to_cypher_column_name(&p.expression); + expr.alias(cypher_name) + } + }) + .collect(); + LogicalPlanBuilder::from(input_plan) + .project(exprs) + .map_err(|e| self.plan_error("Failed to build projection", e))? + .build() + .map_err(|e| self.plan_error("Failed to build plan", e)) + } } LogicalOperator::Distinct { input } => { let input_plan = self.build_operator(ctx, input)?; @@ -405,7 +469,7 @@ impl DataFusionPlanner { let sort_exprs: Vec = sort_items .iter() .map(|item| { - let expr = self.to_df_value_expr(&item.expression); + let expr = Self::to_df_value_expr(&item.expression); let asc = matches!(item.direction, crate::ast::SortDirection::Ascending); SortExpr { expr, @@ -513,8 +577,9 @@ impl DataFusionPlanner { let filter_exprs: Vec = properties .iter() .map(|(k, v)| { - let lit_expr = self - .to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); + let lit_expr = Self::to_df_value_expr( + &crate::ast::ValueExpression::Literal(v.clone()), + ); Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), op: Operator::Eq, @@ -683,7 +748,7 @@ impl DataFusionPlanner { // Apply relationship property filters (e.g., -[r {since: 2020}]->) for (k, v) in relationship_properties.iter() { - let lit_expr = self.to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); + let lit_expr = Self::to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), op: Operator::Eq, @@ -776,7 +841,7 @@ impl DataFusionPlanner { // Apply target property filters (e.g., (b {age: 30})) for (k, v) in params.target_properties.iter() { - let lit_expr = self.to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); + let lit_expr = Self::to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), op: Operator::Eq, @@ -1313,7 +1378,7 @@ impl DataFusionPlanner { // Apply target property filters for (k, v) in target_properties.iter() { - let lit_expr = self.to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); + let lit_expr = Self::to_df_value_expr(&crate::ast::ValueExpression::Literal(v.clone())); let filter_expr = Expr::BinaryExpr(BinaryExpr { left: Box::new(col(k)), op: Operator::Eq, @@ -1382,7 +1447,7 @@ impl DataFusionPlanner { // Expression Translators // ============================================================================ - fn to_df_boolean_expr(&self, expr: &crate::ast::BooleanExpression) -> Expr { + fn to_df_boolean_expr(expr: &crate::ast::BooleanExpression) -> Expr { use crate::ast::{BooleanExpression as BE, ComparisonOperator as CO}; match expr { BE::Comparison { @@ -1390,8 +1455,8 @@ impl DataFusionPlanner { operator, right, } => { - let l = self.to_df_value_expr(left); - let r = self.to_df_value_expr(right); + let l = Self::to_df_value_expr(left); + let r = Self::to_df_value_expr(right); let op = match operator { CO::Equal => Operator::Eq, CO::NotEqual => Operator::NotEq, @@ -1408,32 +1473,30 @@ impl DataFusionPlanner { } BE::In { expression, list } => { use datafusion::logical_expr::expr::InList as DFInList; - let expr = self.to_df_value_expr(expression); - let list_exprs = list - .iter() - .map(|item| self.to_df_value_expr(item)) - .collect::>(); + let expr = Self::to_df_value_expr(expression); + let list_exprs = list.iter().map(Self::to_df_value_expr).collect::>(); Expr::InList(DFInList::new(Box::new(expr), list_exprs, false)) } BE::And(l, r) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(self.to_df_boolean_expr(l)), + left: Box::new(Self::to_df_boolean_expr(l)), op: Operator::And, - right: Box::new(self.to_df_boolean_expr(r)), + right: Box::new(Self::to_df_boolean_expr(r)), }), BE::Or(l, r) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(self.to_df_boolean_expr(l)), + left: Box::new(Self::to_df_boolean_expr(l)), op: Operator::Or, - right: Box::new(self.to_df_boolean_expr(r)), + right: Box::new(Self::to_df_boolean_expr(r)), }), - BE::Not(inner) => Expr::Not(Box::new(self.to_df_boolean_expr(inner))), - BE::Exists(prop) => Expr::IsNotNull(Box::new( - self.to_df_value_expr(&crate::ast::ValueExpression::Property(prop.clone())), - )), + BE::Not(inner) => Expr::Not(Box::new(Self::to_df_boolean_expr(inner))), + BE::Exists(prop) => Expr::IsNotNull(Box::new(Self::to_df_value_expr( + &crate::ast::ValueExpression::Property(prop.clone()), + ))), _ => lit(true), } } - fn to_df_value_expr(&self, expr: &crate::ast::ValueExpression) -> Expr { + /// Convert ValueExpression to DataFusion Expr + fn to_df_value_expr(expr: &crate::ast::ValueExpression) -> Expr { use crate::ast::{PropertyValue as PV, ValueExpression as VE}; match expr { VE::Property(prop) => { @@ -1455,7 +1518,56 @@ impl DataFusionPlanner { let qualified_name = format!("{}__{}", prop.variable, prop.property); col(&qualified_name) } - VE::Function { .. } | VE::Arithmetic { .. } => lit(0), + VE::Function { name, args } => { + // Handle aggregation functions + match name.to_lowercase().as_str() { + "count" => { + if args.len() == 1 { + // Check for COUNT(*) + let arg_expr = if let VE::Variable(v) = &args[0] { + if v == "*" { + lit(1) + } else { + Self::to_df_value_expr(&args[0]) + } + } else { + Self::to_df_value_expr(&args[0]) + }; + + // Use DataFusion's count helper function + count(arg_expr) + } else { + // Invalid argument count - return placeholder + lit(0) + } + } + _ => { + // Unsupported function - return placeholder for now + lit(0) + } + } + } + VE::Arithmetic { .. } => lit(0), + } + } + + /// Check if a ValueExpression contains an aggregate function + fn contains_aggregate(expr: &crate::ast::ValueExpression) -> bool { + use crate::ast::ValueExpression as VE; + match expr { + VE::Function { name, args } => { + // Check if this is an aggregate function + let is_aggregate = matches!( + name.to_lowercase().as_str(), + "count" | "sum" | "avg" | "min" | "max" + ); + // Also check arguments recursively + is_aggregate || args.iter().any(Self::contains_aggregate) + } + VE::Arithmetic { left, right, .. } => { + Self::contains_aggregate(left) || Self::contains_aggregate(right) + } + _ => false, } } @@ -1463,6 +1575,7 @@ impl DataFusionPlanner { /// /// This generates user-friendly column names following Cypher conventions: /// - Property references: `p.name` (variable.property) + /// - Functions: `function_name(arg)` with simplified argument representation /// - Other expressions: Use the expression as-is /// /// This is used when no explicit alias is provided in RETURN clauses. @@ -1478,8 +1591,24 @@ impl DataFusionPlanner { // Handle nested property references format!("{}.{}", prop.variable, prop.property) } + VE::Function { name, args } => { + // Generate descriptive function name: count(*), count(p.name), etc. + if args.len() == 1 { + let arg_repr = match &args[0] { + VE::Variable(v) => v.clone(), + VE::Property(prop) => format!("{}.{}", prop.variable, prop.property), + _ => "expr".to_string(), + }; + format!("{}({})", name.to_lowercase(), arg_repr) + } else if args.is_empty() { + format!("{}()", name.to_lowercase()) + } else { + // Multiple args - just use function name + name.to_lowercase() + } + } _ => { - // For other expressions (literals, functions), use a generic name + // For other expressions (literals, arithmetic), use a generic name // In practice, these should always have explicit aliases "expr".to_string() } @@ -1522,8 +1651,6 @@ mod tests { #[test] fn test_df_boolean_expr_in_list() { - let cfg = crate::config::GraphConfig::builder().build().unwrap(); - let planner = DataFusionPlanner::new(cfg); let expr = BooleanExpression::In { expression: ValueExpression::Property(PropertyRef { variable: "rel".into(), @@ -1535,7 +1662,7 @@ mod tests { ], }; - if let Expr::InList(in_list) = planner.to_df_boolean_expr(&expr) { + if let Expr::InList(in_list) = DataFusionPlanner::to_df_boolean_expr(&expr) { assert!(!in_list.negated); assert_eq!(in_list.list.len(), 2); match *in_list.expr { diff --git a/rust/lance-graph/src/parser.rs b/rust/lance-graph/src/parser.rs index dc90f7b..813b457 100644 --- a/rust/lance-graph/src/parser.rs +++ b/rust/lance-graph/src/parser.rs @@ -360,12 +360,56 @@ fn comparison_operator(input: &str) -> IResult<&str, ComparisonOperator> { // Parse a value expression fn value_expression(input: &str) -> IResult<&str, ValueExpression> { alt(( + function_call, map(property_reference, ValueExpression::Property), map(property_value, ValueExpression::Literal), map(identifier, |id| ValueExpression::Variable(id.to_string())), ))(input) } +// Parse a function call: function_name(args) +fn function_call(input: &str) -> IResult<&str, ValueExpression> { + let (input, name) = identifier(input)?; + let (input, _) = multispace0(input)?; + let (input, _) = char('(')(input)?; + let (input, _) = multispace0(input)?; + + // Handle COUNT(*) special case - only allow * for COUNT function + if let Ok((input_after_star, _)) = char::<_, nom::error::Error<&str>>('*')(input) { + // Validate that this is COUNT function + if name.to_lowercase() == "count" { + let (input, _) = multispace0(input_after_star)?; + let (input, _) = char(')')(input)?; + return Ok(( + input, + ValueExpression::Function { + name: name.to_string(), + args: vec![ValueExpression::Variable("*".to_string())], + }, + )); + } else { + // Not COUNT - fail parsing to try regular argument parsing + // This will naturally fail since * is not a valid value_expression + } + } + + // Parse regular function arguments + let (input, args) = separated_list0( + tuple((multispace0, char(','), multispace0)), + value_expression, + )(input)?; + let (input, _) = multispace0(input)?; + let (input, _) = char(')')(input)?; + + Ok(( + input, + ValueExpression::Function { + name: name.to_string(), + args, + }, + )) +} + fn value_expression_list(input: &str) -> IResult<&str, Vec> { delimited( tuple((char('['), multispace0)), @@ -828,4 +872,80 @@ mod tests { assert_eq!(result.limit, Some(10)); assert!(result.order_by.is_some()); } + + #[test] + fn test_parse_count_star() { + let query = "MATCH (n:Person) RETURN count(*) AS total"; + let result = parse_cypher_query(query).unwrap(); + + assert_eq!(result.return_clause.items.len(), 1); + let item = &result.return_clause.items[0]; + assert_eq!(item.alias, Some("total".to_string())); + + match &item.expression { + ValueExpression::Function { name, args } => { + assert_eq!(name, "count"); + assert_eq!(args.len(), 1); + match &args[0] { + ValueExpression::Variable(v) => assert_eq!(v, "*"), + _ => panic!("Expected Variable(*) in count(*)"), + } + } + _ => panic!("Expected Function expression"), + } + } + + #[test] + fn test_parse_count_property() { + let query = "MATCH (n:Person) RETURN count(n.age)"; + let result = parse_cypher_query(query).unwrap(); + + assert_eq!(result.return_clause.items.len(), 1); + let item = &result.return_clause.items[0]; + + match &item.expression { + ValueExpression::Function { name, args } => { + assert_eq!(name, "count"); + assert_eq!(args.len(), 1); + match &args[0] { + ValueExpression::Property(prop) => { + assert_eq!(prop.variable, "n"); + assert_eq!(prop.property, "age"); + } + _ => panic!("Expected Property in count(n.age)"), + } + } + _ => panic!("Expected Function expression"), + } + } + + #[test] + fn test_parse_non_count_function_rejects_star() { + // FOO(*) should fail to parse since * is only allowed for COUNT + let query = "MATCH (n:Person) RETURN foo(*)"; + let result = parse_cypher_query(query); + assert!(result.is_err(), "foo(*) should not parse successfully"); + } + + #[test] + fn test_parse_count_with_multiple_args() { + // COUNT with multiple arguments parses successfully + // but will be rejected during semantic validation + let query = "MATCH (n:Person) RETURN count(n.age, n.name)"; + let result = parse_cypher_query(query); + assert!( + result.is_ok(), + "Parser should accept multiple args (validation happens in semantic phase)" + ); + + // Verify the AST structure + let ast = result.unwrap(); + match &ast.return_clause.items[0].expression { + ValueExpression::Function { name, args } => { + assert_eq!(name, "count"); + assert_eq!(args.len(), 2); + } + _ => panic!("Expected Function expression"), + } + } } diff --git a/rust/lance-graph/src/semantic.rs b/rust/lance-graph/src/semantic.rs index c7b6fed..1008204 100644 --- a/rust/lance-graph/src/semantic.rs +++ b/rust/lance-graph/src/semantic.rs @@ -268,7 +268,27 @@ impl SemanticAnalyzer { }); } } - ValueExpression::Function { args, .. } => { + ValueExpression::Function { name, args } => { + // Validate function-specific arity and signature rules + match name.to_lowercase().as_str() { + "count" | "sum" | "avg" | "min" | "max" => { + if args.len() != 1 { + return Err(GraphError::PlanError { + message: format!( + "{} requires exactly 1 argument, got {}", + name.to_uppercase(), + args.len() + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } + _ => { + // Other functions - no validation yet + } + } + + // Validate arguments recursively for arg in args { self.analyze_value_expression(arg)?; } @@ -874,6 +894,63 @@ mod tests { .all(|e| !e.contains("Undefined variable: 'n'"))); } + #[test] + fn test_count_with_multiple_args_fails_validation() { + // COUNT(n.age, n.name) should fail semantic validation + let expr = ValueExpression::Function { + name: "count".to_string(), + args: vec![ + ValueExpression::Property(PropertyRef::new("n", "age")), + ValueExpression::Property(PropertyRef::new("n", "name")), + ], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .any(|e| e.contains("COUNT requires exactly 1 argument")), + "Expected error about COUNT arity, got: {:?}", + result.errors + ); + } + + #[test] + fn test_count_with_zero_args_fails_validation() { + // COUNT() with no arguments should fail + let expr = ValueExpression::Function { + name: "count".to_string(), + args: vec![], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .any(|e| e.contains("COUNT requires exactly 1 argument")), + "Expected error about COUNT arity, got: {:?}", + result.errors + ); + } + + #[test] + fn test_count_with_one_arg_passes_validation() { + // COUNT(n.age) should pass validation + let expr = ValueExpression::Function { + name: "count".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .all(|e| !e.contains("COUNT requires exactly 1 argument")), + "COUNT with 1 arg should not produce arity error, got: {:?}", + result.errors + ); + } + #[test] fn test_arithmetic_with_non_numeric_literal_error() { // RETURN "x" + 1 diff --git a/rust/lance-graph/tests/test_datafusion_pipeline.rs b/rust/lance-graph/tests/test_datafusion_pipeline.rs index a9b036a..9acc826 100644 --- a/rust/lance-graph/tests/test_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/test_datafusion_pipeline.rs @@ -2414,3 +2414,199 @@ async fn test_datafusion_varlength_count() { // Alice can reach 4 people within 2 hops assert_eq!(out.num_rows(), 4); } + +// ============================================================================ +// Aggregation Function Tests +// ============================================================================ + +#[tokio::test] +async fn test_count_star_all_nodes() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (a:Person) RETURN count(*) AS total") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let count_col = result + .column_by_name("total") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_col.value(0), 5); +} + +#[tokio::test] +async fn test_count_with_filter() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = + CypherQuery::new("MATCH (a:Person) WHERE a.age > 30 RETURN count(*) AS older_than_30") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let count_col = result + .column_by_name("older_than_30") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // Bob (35) and David (40) are older than 30 + assert_eq!(count_col.value(0), 2); +} + +#[tokio::test] +async fn test_count_property() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN count(p.name) AS person_count") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let count_col = result + .column_by_name("person_count") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_col.value(0), 5); +} + +#[tokio::test] +async fn test_count_with_grouping() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = + CypherQuery::new("MATCH (p:Person) RETURN p.city, count(*) AS count ORDER BY p.city") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + // Should have 4 groups: NULL (David), Chicago, New York, San Francisco, Seattle + assert_eq!(result.num_rows(), 5); + + let city_col = result + .column_by_name("p.city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let count_col = result + .column_by_name("count") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // NULL city: 1 person (David) + assert!(city_col.is_null(0)); + assert_eq!(count_col.value(0), 1); + + // Chicago: 1 person (Charlie) + assert_eq!(city_col.value(1), "Chicago"); + assert_eq!(count_col.value(1), 1); + + // New York: 1 person (Alice) + assert_eq!(city_col.value(2), "New York"); + assert_eq!(count_col.value(2), 1); + + // San Francisco: 1 person (Bob) + assert_eq!(city_col.value(3), "San Francisco"); + assert_eq!(count_col.value(3), 1); + + // Seattle: 1 person (Eve) + assert_eq!(city_col.value(4), "Seattle"); + assert_eq!(count_col.value(4), 1); +} + +#[tokio::test] +async fn test_count_without_alias_has_descriptive_name() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN count(*)") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + // Should have column named "count(*)" not "expr" or "count" + let count_col = result.column_by_name("count(*)"); + assert!( + count_col.is_some(), + "Expected column named 'count(*)' but schema is: {:?}", + result.schema() + ); +} + +#[tokio::test] +async fn test_count_property_without_alias_has_descriptive_name() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN count(p.name)") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + // Should have column named "count(p.name)" not "expr" + let count_col = result.column_by_name("count(p.name)"); + assert!( + count_col.is_some(), + "Expected column named 'count(p.name)' but schema is: {:?}", + result.schema() + ); +}