Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions rust/lance-graph/src/datafusion_planner/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::ast::{BooleanExpression, PropertyValue, ValueExpression};
use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator};
use datafusion_functions_aggregate::average::avg;
use datafusion_functions_aggregate::count::count;
use datafusion_functions_aggregate::min_max::max;
use datafusion_functions_aggregate::min_max::min;
use datafusion_functions_aggregate::sum::sum;

/// Convert BooleanExpression to DataFusion Expr
Expand Down Expand Up @@ -131,6 +133,22 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
lit(0)
}
}
"min" => {
if args.len() == 1 {
let arg_expr = to_df_value_expr(&args[0]);
min(arg_expr)
} else {
lit(0)
}
}
"max" => {
if args.len() == 1 {
let arg_expr = to_df_value_expr(&args[0]);
max(arg_expr)
} else {
lit(0)
}
}
_ => {
// Unsupported function - return placeholder for now
lit(0)
Expand Down Expand Up @@ -555,6 +573,44 @@ mod tests {
assert!(s.contains("p__amount"), "Should contain column reference");
}

#[test]
fn test_value_expr_function_min() {
let expr = ValueExpression::Function {
name: "MIN".into(),
args: vec![ValueExpression::Property(PropertyRef {
variable: "p".into(),
property: "amount".into(),
})],
};

let df_expr = to_df_value_expr(&expr);
let s = format!("{:?}", df_expr);
assert!(
s.contains("min") || s.contains("Min"),
"Should be MIN function"
);
assert!(s.contains("p__amount"), "Should contain column reference");
}

#[test]
fn test_value_expr_function_max() {
let expr = ValueExpression::Function {
name: "MAX".into(),
args: vec![ValueExpression::Property(PropertyRef {
variable: "p".into(),
property: "amount".into(),
})],
};

let df_expr = to_df_value_expr(&expr);
let s = format!("{:?}", df_expr);
assert!(
s.contains("max") || s.contains("Max"),
"Should be MAX function"
);
assert!(s.contains("p__amount"), "Should contain column reference");
}

// ========================================================================
// Unit tests for contains_aggregate()
// ========================================================================
Expand Down Expand Up @@ -588,6 +644,38 @@ mod tests {
);
}

#[test]
fn test_contains_aggregate_min() {
let expr = ValueExpression::Function {
name: "MIN".into(),
args: vec![ValueExpression::Property(PropertyRef {
variable: "p".into(),
property: "value".into(),
})],
};

assert!(
contains_aggregate(&expr),
"MIN should be detected as aggregate"
);
}

#[test]
fn test_contains_aggregate_max() {
let expr = ValueExpression::Function {
name: "MAX".into(),
args: vec![ValueExpression::Property(PropertyRef {
variable: "p".into(),
property: "value".into(),
})],
};

assert!(
contains_aggregate(&expr),
"MAX should be detected as aggregate"
);
}

#[test]
fn test_contains_aggregate_property() {
let expr = ValueExpression::Property(PropertyRef {
Expand Down
156 changes: 156 additions & 0 deletions rust/lance-graph/tests/test_datafusion_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,162 @@ async fn test_avg_without_alias_has_descriptive_name() {
);
}

#[tokio::test]
async fn test_min_property() {
let person_batch = create_person_dataset();
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();

let query = CypherQuery::new("MATCH (p:Person) RETURN min(p.age) AS min_age")
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await
.unwrap();

assert_eq!(result.num_rows(), 1);

let min_col = result
.column_by_name("min_age")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

// Ages: 25, 35, 30, 40, 28 => min = 25
assert_eq!(min_col.value(0), 25);
}

#[tokio::test]
async fn test_max_property() {
let person_batch = create_person_dataset();
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();

let query = CypherQuery::new("MATCH (p:Person) RETURN max(p.age) AS max_age")
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await
.unwrap();

assert_eq!(result.num_rows(), 1);

let max_col = result
.column_by_name("max_age")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

// Ages: 25, 35, 30, 40, 28 => max = 40
assert_eq!(max_col.value(0), 40);
}

#[tokio::test]
async fn test_min_max_with_grouping() {
let person_batch = create_person_dataset();
let config = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();

// One person per city in this dataset (including NULL), so min(age) == that person's age
let query_min =
CypherQuery::new("MATCH (p:Person) RETURN p.city, min(p.age) AS min_age ORDER BY p.city")
.unwrap()
.with_config(config.clone());

let query_max =
CypherQuery::new("MATCH (p:Person) RETURN p.city, max(p.age) AS max_age ORDER BY p.city")
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);

let result_min = query_min
.execute(datasets.clone(), Some(ExecutionStrategy::DataFusion))
.await
.unwrap();

let result_max = query_max
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await
.unwrap();

assert_eq!(result_min.num_rows(), 5);
assert_eq!(result_max.num_rows(), 5);

let city_col_min = result_min
.column_by_name("p.city")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

let min_col_min = result_min
.column_by_name("min_age")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

let city_col_max = result_max
.column_by_name("p.city")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

let min_col_max = result_max
.column_by_name("max_age")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

// ORDER BY p.city, NULL comes first per your other tests
assert!(city_col_min.is_null(0)); // David city NULL
assert!(city_col_max.is_null(0));
assert_eq!(min_col_min.value(0), 40);
assert_eq!(min_col_max.value(0), 40);

assert_eq!(city_col_min.value(1), "Chicago"); // Charlie
assert_eq!(city_col_max.value(1), "Chicago");
assert_eq!(min_col_min.value(1), 30);
assert_eq!(min_col_max.value(1), 30);

assert_eq!(city_col_min.value(2), "New York"); // Alice
assert_eq!(city_col_max.value(2), "New York");
assert_eq!(min_col_min.value(2), 25);
assert_eq!(min_col_max.value(2), 25);

assert_eq!(city_col_min.value(3), "San Francisco"); // Bob
assert_eq!(city_col_max.value(3), "San Francisco");
assert_eq!(min_col_min.value(3), 35);
assert_eq!(min_col_max.value(3), 35);

assert_eq!(city_col_min.value(4), "Seattle"); // Eve
assert_eq!(city_col_max.value(4), "Seattle");
assert_eq!(min_col_min.value(4), 28);
assert_eq!(min_col_max.value(4), 28);
}

// ============================================================================
// Disconnected Pattern (Join) Tests
// ============================================================================
Expand Down
Loading