Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery {

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
let original_schema = cur_input.schema().columns();
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
Expand All @@ -115,6 +116,13 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}

if cur_input.schema().fields().len() != original_schema.len() {
cur_input = LogicalPlanBuilder::from(cur_input)
.project(original_schema.into_iter().map(Expr::from))?
.build()?;
}

Ok(Transformed::yes(cur_input))
}

Expand Down Expand Up @@ -1736,13 +1744,14 @@ mod tests {
plan,
@r"
Projection: customer.c_custkey [c_custkey:Int64]
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
Projection: orders.o_custkey [o_custkey:Int64]
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
Projection: orders.o_custkey [o_custkey:Int64]
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
"
)
}
Expand Down
86 changes: 79 additions & 7 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,9 @@ fn optimize_projections(
}
LogicalPlan::Join(join) => {
let left_len = join.left.schema().fields().len();
let right_len = join.right.schema().fields().len();
let (left_req_indices, right_req_indices) =
split_join_requirements(left_len, indices, &join.join_type);
split_join_requirements(left_len, right_len, indices, &join.join_type);
let left_indices =
left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
let right_indices =
Expand Down Expand Up @@ -746,6 +747,7 @@ fn outer_columns_helper_multi<'a, 'b>(
/// # Parameters
///
/// * `left_len` - The length of the left child.
/// * `right_len` - The length of the right child.
/// * `indices` - A slice of requirement indices.
/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
///
Expand All @@ -757,21 +759,29 @@ fn outer_columns_helper_multi<'a, 'b>(
/// adjusted based on the join type.
fn split_join_requirements(
left_len: usize,
right_len: usize,
indices: RequiredIndices,
join_type: &JoinType,
) -> (RequiredIndices, RequiredIndices) {
match join_type {
// In these cases requirements are split between left/right children:
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftMark
| JoinType::RightMark => {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
// Decrease right side indices by `left_len` so that they point to valid
// positions within the right child:
indices.split_off(left_len)
}
JoinType::LeftMark => {
// LeftMark output: [left_cols(0..left_len), mark]
// The mark column is synthetic (produced by the join itself),
// so discard it and route only to the left child.
let (left_indices, _mark) = indices.split_off(left_len);
(left_indices, RequiredIndices::new())
}
JoinType::RightMark => {
// Same as LeftMark, but for the right child.
let (right_indices, _mark) = indices.split_off(right_len);
(RequiredIndices::new(), right_indices)
}
// All requirements can be re-routed to left child directly.
JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
// All requirements can be re-routed to right side directly.
Expand Down Expand Up @@ -2311,6 +2321,68 @@ mod tests {
)
}

// Regression test for https://github.com/apache/datafusion/issues/20083
// Optimizer must not fail when LeftMark joins from EXISTS OR EXISTS
// feed into a Left join.
#[test]
fn optimize_projections_exists_or_exists_with_outer_join() -> Result<()> {
use datafusion_expr::utils::disjunction;
use datafusion_expr::{exists, out_ref_col};

let table_a = test_table_scan_with_name("a")?;
let table_b = test_table_scan_with_name("b")?;

let sq_a = Arc::new(
LogicalPlanBuilder::from(test_table_scan_with_name("sq_a")?)
.filter(col("sq_a.a").eq(out_ref_col(DataType::UInt32, "a.a")))?
.project(vec![lit(1)])?
.build()?,
);

let sq_b = Arc::new(
LogicalPlanBuilder::from(test_table_scan_with_name("sq_b")?)
.filter(col("sq_b.b").eq(out_ref_col(DataType::UInt32, "a.b")))?
.project(vec![lit(1)])?
.build()?,
);

let plan = LogicalPlanBuilder::from(table_a)
.filter(disjunction(vec![exists(sq_a), exists(sq_b)]).unwrap())?
.join(table_b, JoinType::Left, (vec!["a"], vec!["a"]), None)?
.build()?;

let optimizer = Optimizer::new();
let config = OptimizerContext::new();
optimizer.optimize(plan, &config, observe)?;

Ok(())
}

#[test]
fn optimize_projections_left_mark_join_with_projection() -> Result<()> {
let table_a = test_table_scan_with_name("a")?;
let table_b = test_table_scan_with_name("b")?;
let table_c = test_table_scan_with_name("c")?;

let plan = LogicalPlanBuilder::from(table_a)
.join(table_b, JoinType::LeftMark, (vec!["a"], vec!["a"]), None)?
.project(vec![col("a.a"), col("a.b"), col("a.c")])?
.join(table_c, JoinType::Left, (vec!["a"], vec!["a"]), None)?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Left Join: a.a = c.a
Projection: a.a, a.b, a.c
LeftMark Join: a.a = b.a
TableScan: a projection=[a, b, c]
TableScan: b projection=[a]
TableScan: c projection=[a, b, c]
"
)
}

fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}

fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
Expand Down
Loading