Skip to content

Commit 9c161e5

Browse files
authored
feat(semantic): validate numeric literals (#14)
* feat(semantic): validate numeric literals * format code
1 parent 9422227 commit 9c161e5

File tree

1 file changed

+114
-89
lines changed

1 file changed

+114
-89
lines changed

rust/lance-graph/src/semantic.rs

Lines changed: 114 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,25 @@ impl SemanticAnalyzer {
274274
}
275275
}
276276
ValueExpression::Arithmetic { left, right, .. } => {
277+
// Validate arithmetic operands recursively
277278
self.analyze_value_expression(left)?;
278279
self.analyze_value_expression(right)?;
280+
281+
// If both sides are literals, ensure they are numeric
282+
let is_numeric_literal = |pv: &PropertyValue| {
283+
matches!(pv, PropertyValue::Integer(_) | PropertyValue::Float(_))
284+
};
285+
286+
if let (ValueExpression::Literal(l1), ValueExpression::Literal(l2)) =
287+
(&**left, &**right)
288+
{
289+
if !(is_numeric_literal(l1) && is_numeric_literal(l2)) {
290+
return Err(GraphError::PlanError {
291+
message: "Arithmetic requires numeric literal operands".to_string(),
292+
location: snafu::Location::new(file!(), line!(), column!()),
293+
});
294+
}
295+
}
279296
}
280297
}
281298
Ok(())
@@ -437,6 +454,53 @@ mod tests {
437454
.unwrap()
438455
}
439456

457+
// Helper: analyze a query that only has a single RETURN expression
458+
fn analyze_return_expr(expr: ValueExpression) -> Result<SemanticResult> {
459+
let query = CypherQuery {
460+
match_clauses: vec![],
461+
where_clause: None,
462+
return_clause: ReturnClause {
463+
distinct: false,
464+
items: vec![ReturnItem {
465+
expression: expr,
466+
alias: None,
467+
}],
468+
},
469+
limit: None,
470+
order_by: None,
471+
skip: None,
472+
};
473+
let mut analyzer = SemanticAnalyzer::new(test_config());
474+
analyzer.analyze(&query)
475+
}
476+
477+
// Helper: analyze a query with a single MATCH (var:label) and a RETURN expression
478+
fn analyze_return_with_match(
479+
var: &str,
480+
label: &str,
481+
expr: ValueExpression,
482+
) -> Result<SemanticResult> {
483+
let node = NodePattern::new(Some(var.to_string())).with_label(label);
484+
let query = CypherQuery {
485+
match_clauses: vec![MatchClause {
486+
patterns: vec![GraphPattern::Node(node)],
487+
}],
488+
where_clause: None,
489+
return_clause: ReturnClause {
490+
distinct: false,
491+
items: vec![ReturnItem {
492+
expression: expr,
493+
alias: None,
494+
}],
495+
},
496+
limit: None,
497+
order_by: None,
498+
skip: None,
499+
};
500+
let mut analyzer = SemanticAnalyzer::new(test_config());
501+
analyzer.analyze(&query)
502+
}
503+
440504
#[test]
441505
fn test_merge_node_variable_metadata() {
442506
// MATCH (n:Person {age: 30}), (n:Employee {dept: "X"})
@@ -757,30 +821,12 @@ mod tests {
757821

758822
#[test]
759823
fn test_function_argument_undefined_variable_in_return() {
760-
// MATCH (n:Person) RETURN toUpper(m.name)
761-
let node = NodePattern::new(Some("n".to_string())).with_label("Person");
762-
let query = CypherQuery {
763-
match_clauses: vec![MatchClause {
764-
patterns: vec![GraphPattern::Node(node)],
765-
}],
766-
where_clause: None,
767-
return_clause: ReturnClause {
768-
distinct: false,
769-
items: vec![ReturnItem {
770-
expression: ValueExpression::Function {
771-
name: "toUpper".to_string(),
772-
args: vec![ValueExpression::Property(PropertyRef::new("m", "name"))],
773-
},
774-
alias: None,
775-
}],
776-
},
777-
limit: None,
778-
order_by: None,
779-
skip: None,
824+
// RETURN toUpper(m.name)
825+
let expr = ValueExpression::Function {
826+
name: "toUpper".to_string(),
827+
args: vec![ValueExpression::Property(PropertyRef::new("m", "name"))],
780828
};
781-
782-
let mut analyzer = SemanticAnalyzer::new(test_config());
783-
let result = analyzer.analyze(&query).unwrap();
829+
let result = analyze_return_expr(expr).unwrap();
784830
assert!(result
785831
.errors
786832
.iter()
@@ -790,56 +836,23 @@ mod tests {
790836
#[test]
791837
fn test_function_argument_valid_variable_ok() {
792838
// MATCH (n:Person) RETURN toUpper(n.name)
793-
let node = NodePattern::new(Some("n".to_string())).with_label("Person");
794-
let query = CypherQuery {
795-
match_clauses: vec![MatchClause {
796-
patterns: vec![GraphPattern::Node(node)],
797-
}],
798-
where_clause: None,
799-
return_clause: ReturnClause {
800-
distinct: false,
801-
items: vec![ReturnItem {
802-
expression: ValueExpression::Function {
803-
name: "toUpper".to_string(),
804-
args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))],
805-
},
806-
alias: None,
807-
}],
808-
},
809-
limit: None,
810-
order_by: None,
811-
skip: None,
839+
let expr = ValueExpression::Function {
840+
name: "toUpper".to_string(),
841+
args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))],
812842
};
813-
814-
let mut analyzer = SemanticAnalyzer::new(test_config());
815-
let result = analyzer.analyze(&query).unwrap();
843+
let result = analyze_return_with_match("n", "Person", expr).unwrap();
816844
assert!(result.errors.is_empty());
817845
}
818846

819847
#[test]
820848
fn test_arithmetic_with_undefined_variable_in_return() {
821849
// RETURN x + 1
822-
let query = CypherQuery {
823-
match_clauses: vec![],
824-
where_clause: None,
825-
return_clause: ReturnClause {
826-
distinct: false,
827-
items: vec![ReturnItem {
828-
expression: ValueExpression::Arithmetic {
829-
left: Box::new(ValueExpression::Variable("x".to_string())),
830-
operator: ArithmeticOperator::Add,
831-
right: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
832-
},
833-
alias: None,
834-
}],
835-
},
836-
limit: None,
837-
order_by: None,
838-
skip: None,
850+
let expr = ValueExpression::Arithmetic {
851+
left: Box::new(ValueExpression::Variable("x".to_string())),
852+
operator: ArithmeticOperator::Add,
853+
right: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
839854
};
840-
841-
let mut analyzer = SemanticAnalyzer::new(test_config());
842-
let result = analyzer.analyze(&query).unwrap();
855+
let result = analyze_return_expr(expr).unwrap();
843856
assert!(result
844857
.errors
845858
.iter()
@@ -848,35 +861,47 @@ mod tests {
848861

849862
#[test]
850863
fn test_arithmetic_with_defined_property_ok() {
851-
// MATCH (n:Person) RETURN 1 + n.age
852-
let node = NodePattern::new(Some("n".to_string())).with_label("Person");
853-
let query = CypherQuery {
854-
match_clauses: vec![MatchClause {
855-
patterns: vec![GraphPattern::Node(node)],
856-
}],
857-
where_clause: None,
858-
return_clause: ReturnClause {
859-
distinct: false,
860-
items: vec![ReturnItem {
861-
expression: ValueExpression::Arithmetic {
862-
left: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
863-
operator: ArithmeticOperator::Add,
864-
right: Box::new(ValueExpression::Property(PropertyRef::new("n", "age"))),
865-
},
866-
alias: None,
867-
}],
868-
},
869-
limit: None,
870-
order_by: None,
871-
skip: None,
864+
let expr = ValueExpression::Arithmetic {
865+
left: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
866+
operator: ArithmeticOperator::Add,
867+
right: Box::new(ValueExpression::Property(PropertyRef::new("n", "age"))),
872868
};
873-
874-
let mut analyzer = SemanticAnalyzer::new(test_config());
875-
let result = analyzer.analyze(&query).unwrap();
869+
let result = analyze_return_with_match("n", "Person", expr).unwrap();
876870
// Should not report undefined variable 'n'
877871
assert!(result
878872
.errors
879873
.iter()
880874
.all(|e| !e.contains("Undefined variable: 'n'")));
881875
}
876+
877+
#[test]
878+
fn test_arithmetic_with_non_numeric_literal_error() {
879+
// RETURN "x" + 1
880+
let expr = ValueExpression::Arithmetic {
881+
left: Box::new(ValueExpression::Literal(PropertyValue::String(
882+
"x".to_string(),
883+
))),
884+
operator: ArithmeticOperator::Add,
885+
right: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
886+
};
887+
let result = analyze_return_expr(expr).unwrap();
888+
// The semantic analyzer returns Ok with errors collected in the result
889+
assert!(result
890+
.errors
891+
.iter()
892+
.any(|e| e.contains("Arithmetic requires numeric literal operands")));
893+
}
894+
895+
#[test]
896+
fn test_arithmetic_with_numeric_literals_ok() {
897+
// RETURN 1 + 2.0
898+
let expr = ValueExpression::Arithmetic {
899+
left: Box::new(ValueExpression::Literal(PropertyValue::Integer(1))),
900+
operator: ArithmeticOperator::Add,
901+
right: Box::new(ValueExpression::Literal(PropertyValue::Float(2.0))),
902+
};
903+
let result = analyze_return_expr(expr);
904+
assert!(result.is_ok(), "Expected Ok but got {:?}", result);
905+
assert!(result.unwrap().errors.is_empty());
906+
}
882907
}

0 commit comments

Comments
 (0)