Skip to content

Commit dbd08cc

Browse files
committed
Make the knowledge graph support query with where-in clause
1 parent 53218ac commit dbd08cc

File tree

3 files changed

+136
-6
lines changed

3 files changed

+136
-6
lines changed

python/python/knowledge_graph/main.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,16 @@ def _build_query_prompt(
397397
"Use the schema summary to craft queries that directly answer the "
398398
"question."
399399
),
400+
(
401+
" • Use the schema summary and allowed relationship_type values to "
402+
"identify candidate relationship directions and types."
403+
),
404+
(
405+
" • When the schema lists relationship_type values and the question "
406+
"does not narrow them down, treat the list as exhaustive and include "
407+
"every value in your filter using OR clauses or "
408+
"WHERE rel.relationship_type IN [...]."
409+
),
400410
(
401411
"Always specify node labels and relationship types in MATCH patterns "
402412
"that introduce aliases."
@@ -405,12 +415,25 @@ def _build_query_prompt(
405415
(" • MATCH (e:Entity) to scan entity rows (name, name_lower, entity_id)."),
406416
(
407417
" • MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) to traverse "
408-
"relationships (relationship_type column)."
418+
"relationships (relationship_type column); `src` aligns with "
419+
"`source_entity_id` and `dst` with `target_entity_id`."
420+
),
421+
(
422+
" • Decide which node should be `src` versus `dst` based on the "
423+
"relationship meaning in the question and schema hints."
424+
),
425+
(
426+
" • Map natural language roles (team, person, product, etc.) to the "
427+
"`entity_type` column so queries filter to the expected entities."
409428
),
410429
" • Use WHERE e.column = 'value' for node-level filters.",
411430
(
412431
" • Filter relationships with WHERE rel.relationship_type = 'VALUE' "
413-
"or by comparing rel.source_entity_id / rel.target_entity_id."
432+
"or by comparing rel.source_entity_id / rel.target_entity_id; when the "
433+
"question does not name a specific relationship type, include every "
434+
"relevant value from the schema summary using OR clauses or "
435+
"WHERE rel.relationship_type IN [...], explicitly note which values "
436+
"you considered, and avoid emitting only a single guessed type."
414437
),
415438
(
416439
" • Select columns using the aliases you define, such as e.name or "
@@ -421,8 +444,19 @@ def _build_query_prompt(
421444
"filter rel.relationship_type instead of [:TYPE]."
422445
),
423446
(
424-
"Example: MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) "
425-
f"WHERE rel.relationship_type = '{example_rel_type}' RETURN rel."
447+
"Example: MATCH (part:Entity)-[rel:RELATIONSHIP]->(whole:Entity) "
448+
f"WHERE rel.relationship_type = '{example_rel_type}' "
449+
"RETURN part.name, whole.name."
450+
),
451+
(
452+
"Example: MATCH (a:Entity)-[rel:RELATIONSHIP]->(b:Entity) WHERE "
453+
"rel.relationship_type = 'TYPE_A' OR rel.relationship_type = 'TYPE_B' "
454+
"RETURN a.name, b.name."
455+
),
456+
(
457+
"Example: MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) WHERE "
458+
"rel.relationship_type IN ['TYPE_A', 'TYPE_B', 'TYPE_C'] "
459+
"RETURN src.name, dst.name."
426460
),
427461
(
428462
"Example: MATCH (dst:Entity) WHERE dst.name_lower = 'acme corp' "

rust/lance-graph/src/datafusion_planner.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ impl DataFusionPlanner {
264264
right: Box::new(r),
265265
})
266266
}
267+
BE::In { expression, list } => {
268+
use datafusion::logical_expr::expr::InList as DFInList;
269+
let expr = self.to_df_value_expr(expression);
270+
let list_exprs = list
271+
.iter()
272+
.map(|item| self.to_df_value_expr(item))
273+
.collect::<Vec<_>>();
274+
Expr::InList(DFInList::new(Box::new(expr), list_exprs, false))
275+
}
267276
BE::And(l, r) => Expr::BinaryExpr(BinaryExpr {
268277
left: Box::new(self.to_df_boolean_expr(l)),
269278
op: Operator::And,
@@ -334,6 +343,35 @@ mod tests {
334343
)
335344
}
336345

346+
#[test]
347+
fn test_df_boolean_expr_in_list() {
348+
let cfg = crate::config::GraphConfig::builder().build().unwrap();
349+
let planner = DataFusionPlanner::new(cfg);
350+
let expr = BooleanExpression::In {
351+
expression: ValueExpression::Property(PropertyRef {
352+
variable: "rel".into(),
353+
property: "relationship_type".into(),
354+
}),
355+
list: vec![
356+
ValueExpression::Literal(PropertyValue::String("WORKS_FOR".into())),
357+
ValueExpression::Literal(PropertyValue::String("PART_OF".into())),
358+
],
359+
};
360+
361+
if let Expr::InList(in_list) = planner.to_df_boolean_expr(&expr) {
362+
assert!(!in_list.negated);
363+
assert_eq!(in_list.list.len(), 2);
364+
match *in_list.expr {
365+
Expr::Column(ref col_expr) => {
366+
assert_eq!(col_expr.name(), "relationship_type");
367+
}
368+
other => panic!("Expected column expression, got {:?}", other),
369+
}
370+
} else {
371+
panic!("Expected InList expression");
372+
}
373+
}
374+
337375
#[test]
338376
fn test_df_planner_scan_filter_project() {
339377
let scan = LogicalOperator::ScanByLabel {

rust/lance-graph/src/parser.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use nom::{
1313
bytes::complete::{tag, tag_no_case, take_while1},
1414
character::complete::{char, multispace0, multispace1},
1515
combinator::{map, opt, recognize},
16-
multi::{many0, separated_list0},
16+
multi::{many0, separated_list0, separated_list1},
1717
sequence::{delimited, pair, preceded, tuple},
1818
IResult,
1919
};
@@ -316,14 +316,28 @@ fn comparison_expression(input: &str) -> IResult<&str, BooleanExpression> {
316316
let (input, _) = multispace0(input)?;
317317
let (input, left) = value_expression(input)?;
318318
let (input, _) = multispace0(input)?;
319+
let left_clone = left.clone();
320+
321+
if let Ok((input_after_in, (_, _, list))) =
322+
tuple((tag_no_case("IN"), multispace0, value_expression_list))(input)
323+
{
324+
return Ok((
325+
input_after_in,
326+
BooleanExpression::In {
327+
expression: left,
328+
list,
329+
},
330+
));
331+
}
332+
319333
let (input, operator) = comparison_operator(input)?;
320334
let (input, _) = multispace0(input)?;
321335
let (input, right) = value_expression(input)?;
322336

323337
Ok((
324338
input,
325339
BooleanExpression::Comparison {
326-
left,
340+
left: left_clone,
327341
operator,
328342
right,
329343
},
@@ -352,6 +366,17 @@ fn value_expression(input: &str) -> IResult<&str, ValueExpression> {
352366
))(input)
353367
}
354368

369+
fn value_expression_list(input: &str) -> IResult<&str, Vec<ValueExpression>> {
370+
delimited(
371+
tuple((char('['), multispace0)),
372+
separated_list1(
373+
tuple((multispace0, char(','), multispace0)),
374+
value_expression,
375+
),
376+
tuple((multispace0, char(']'))),
377+
)(input)
378+
}
379+
355380
// Parse a property reference: variable.property
356381
fn property_reference(input: &str) -> IResult<&str, PropertyRef> {
357382
let (input, variable) = identifier(input)?;
@@ -726,6 +751,39 @@ mod tests {
726751
}
727752
}
728753

754+
#[test]
755+
fn test_parse_query_with_in_clause() {
756+
let query = "MATCH (src:Entity)-[rel:RELATIONSHIP]->(dst:Entity) WHERE rel.relationship_type IN ['WORKS_FOR', 'PART_OF'] RETURN src.name";
757+
let result = parse_cypher_query(query).unwrap();
758+
759+
let where_clause = result.where_clause.expect("Expected WHERE clause");
760+
match where_clause.expression {
761+
BooleanExpression::In { expression, list } => {
762+
match expression {
763+
ValueExpression::Property(prop_ref) => {
764+
assert_eq!(prop_ref.variable, "rel");
765+
assert_eq!(prop_ref.property, "relationship_type");
766+
}
767+
_ => panic!("Expected property reference in IN expression"),
768+
}
769+
assert_eq!(list.len(), 2);
770+
match &list[0] {
771+
ValueExpression::Literal(PropertyValue::String(val)) => {
772+
assert_eq!(val, "WORKS_FOR");
773+
}
774+
_ => panic!("Expected first list item to be a string literal"),
775+
}
776+
match &list[1] {
777+
ValueExpression::Literal(PropertyValue::String(val)) => {
778+
assert_eq!(val, "PART_OF");
779+
}
780+
_ => panic!("Expected second list item to be a string literal"),
781+
}
782+
}
783+
other => panic!("Expected IN expression, got {:?}", other),
784+
}
785+
}
786+
729787
#[test]
730788
fn test_parse_query_with_limit() {
731789
let query = "MATCH (n:Person) RETURN n.name LIMIT 10";

0 commit comments

Comments
 (0)