diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 34a1b2bca..baade094d 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -40,9 +40,9 @@ class WindowHashGroup { public: using HashGroupPtr = unique_ptr; using OrderMasks = HashedSortGroup::OrderMasks; - using ExecutorGlobalStatePtr = unique_ptr; + using ExecutorGlobalStatePtr = unique_ptr; using ExecutorGlobalStates = vector; - using ExecutorLocalStatePtr = unique_ptr; + using ExecutorLocalStatePtr = unique_ptr; using ExecutorLocalStates = vector; using ThreadLocalStates = vector; using Task = WindowSourceTask; @@ -765,7 +765,8 @@ void WindowLocalSourceState::Sink(ExecutionContext &context, InterruptState &int } for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Sink(context, sink_chunk, coll_chunk, input_idx, *gestates[w], *local_states[w], interrupt); + OperatorSinkInput sink {*gestates[w], *local_states[w], interrupt}; + executors[w]->Sink(context, sink_chunk, coll_chunk, input_idx, sink); } window_hash_group->sunk += input_chunk.size(); @@ -790,7 +791,8 @@ void WindowLocalSourceState::Finalize(ExecutionContext &context, InterruptState auto &gestates = window_hash_group->gestates; auto &local_states = window_hash_group->thread_states.at(task->thread_idx); for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Finalize(context, *gestates[w], *local_states[w], window_hash_group->collection, interrupt); + OperatorSinkInput sink {*gestates[w], *local_states[w], interrupt}; + executors[w]->Finalize(context, window_hash_group->collection, sink); } // Mark this range as done @@ -944,8 +946,6 @@ void WindowLocalSourceState::GetData(ExecutionContext &context, DataChunk &resul output_chunk.Reset(); for (idx_t expr_idx = 0; expr_idx < executors.size(); ++expr_idx) { auto &executor = *executors[expr_idx]; - auto &gstate = *gestates[expr_idx]; - auto &lstate = *local_states[expr_idx]; auto &result = output_chunk.data[expr_idx]; if (eval_chunk.data.empty()) { eval_chunk.SetCardinality(input_chunk); @@ -953,7 +953,8 @@ void WindowLocalSourceState::GetData(ExecutionContext &context, DataChunk &resul eval_chunk.Reset(); eval_exec.Execute(input_chunk, eval_chunk); } - executor.Evaluate(context, position, eval_chunk, result, lstate, gstate, interrupt); + OperatorSinkInput sink {*gestates[expr_idx], *local_states[expr_idx], interrupt}; + executor.Evaluate(context, position, eval_chunk, result, sink); } output_chunk.SetCardinality(input_chunk); output_chunk.Verify(); diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index cc14d55a6..c3609018b 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev3047" +#define DUCKDB_PATCH_VERSION "0-dev3109" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 4 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.4.0-dev3047" +#define DUCKDB_VERSION "v1.4.0-dev3109" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "129b1fe55e" +#define DUCKDB_SOURCE_ID "d229d97f40" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp index b8b7f4aa0..c325a6eb3 100644 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ b/src/duckdb/src/function/window/window_aggregate_function.cpp @@ -21,7 +21,7 @@ class WindowAggregateExecutorGlobalState : public WindowExecutorGlobalState { const ValidityMask &order_mask); // aggregate global state - unique_ptr gsink; + unique_ptr gsink; // the filter reference expression. const Expression *filter_ref; @@ -91,21 +91,21 @@ WindowAggregateExecutorGlobalState::WindowAggregateExecutorGlobalState(ClientCon gsink = executor.aggregator->GetGlobalState(client, group_count, partition_mask); } -unique_ptr WindowAggregateExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowAggregateExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } class WindowAggregateExecutorLocalState : public WindowExecutorBoundsLocalState { public: - WindowAggregateExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate, + WindowAggregateExecutorLocalState(ExecutionContext &context, const GlobalSinkState &gstate, const WindowAggregator &aggregator) - : WindowExecutorBoundsLocalState(context, gstate), filter_executor(gstate.client) { + : WindowExecutorBoundsLocalState(context, gstate.Cast()), + filter_executor(context.client) { auto &gastate = gstate.Cast(); - aggregator_state = aggregator.GetLocalState(*gastate.gsink); + aggregator_state = aggregator.GetLocalState(context, *gastate.gsink); // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse auto filter_ref = gastate.filter_ref; @@ -117,23 +117,22 @@ class WindowAggregateExecutorLocalState : public WindowExecutorBoundsLocalState public: // state of aggregator - unique_ptr aggregator_state; + unique_ptr aggregator_state; //! Executor for any filter clause ExpressionExecutor filter_executor; //! Result of filtering SelectionVector filter_sel; }; -unique_ptr -WindowAggregateExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowAggregateExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { return make_uniq(context, gstate, *aggregator); } void WindowAggregateExecutor::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, - const idx_t input_idx, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, InterruptState &interrupt) const { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); + const idx_t input_idx, OperatorSinkInput &sink) const { + auto &gastate = sink.global_state.Cast(); + auto &lastate = sink.local_state.Cast(); auto &filter_sel = lastate.filter_sel; auto &filter_executor = lastate.filter_executor; @@ -145,11 +144,10 @@ void WindowAggregateExecutor::Sink(ExecutionContext &context, DataChunk &sink_ch } D_ASSERT(aggregator); - auto &gestate = *gastate.gsink; - auto &lestate = *lastate.aggregator_state; - aggregator->Sink(context, gestate, lestate, sink_chunk, coll_chunk, input_idx, filtering, filtered, interrupt); + OperatorSinkInput asink {*gastate.gsink, *lastate.aggregator_state, sink.interrupt_state}; + aggregator->Sink(context, sink_chunk, coll_chunk, input_idx, filtering, filtered, asink); - WindowExecutor::Sink(context, sink_chunk, coll_chunk, input_idx, gstate, lstate, interrupt); + WindowExecutor::Sink(context, sink_chunk, coll_chunk, input_idx, sink); } static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, BaseStatistics *base, bool is_start) { @@ -215,12 +213,11 @@ static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, } } -void WindowAggregateExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, CollectionPtr collection, - InterruptState &interrupt) const { - WindowExecutor::Finalize(context, gstate, lstate, collection, interrupt); +void WindowAggregateExecutor::Finalize(ExecutionContext &context, CollectionPtr collection, + OperatorSinkInput &sink) const { + WindowExecutor::Finalize(context, collection, sink); - auto &gastate = gstate.Cast(); + auto &gastate = sink.global_state.Cast(); auto &gsink = gastate.gsink; D_ASSERT(aggregator); @@ -239,21 +236,20 @@ void WindowAggregateExecutor::Finalize(ExecutionContext &context, WindowExecutor base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[1].get(); ApplyWindowStats(wexpr.end, stats[1], base, false); - auto &lastate = lstate.Cast(); - aggregator->Finalize(context, *gsink, *lastate.aggregator_state, collection, stats, interrupt); + auto &lastate = sink.local_state.Cast(); + OperatorSinkInput asink {*gsink, *lastate.aggregator_state, sink.interrupt_state}; + aggregator->Finalize(context, collection, stats, asink); } -void WindowAggregateExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); +void WindowAggregateExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gastate = sink.global_state.Cast(); + auto &lastate = sink.local_state.Cast(); auto &gsink = gastate.gsink; D_ASSERT(aggregator); - auto &agg_state = *lastate.aggregator_state; - - aggregator->Evaluate(context, *gsink, agg_state, lastate.bounds, result, count, row_idx, interrupt); + OperatorSinkInput asink {*gsink, *lastate.aggregator_state, sink.interrupt_state}; + aggregator->Evaluate(context, lastate.bounds, result, count, row_idx, asink); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp index bb2ada29f..3ac9c91c9 100644 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ b/src/duckdb/src/function/window/window_aggregator.cpp @@ -9,9 +9,6 @@ namespace duckdb { //===--------------------------------------------------------------------===// // WindowAggregator //===--------------------------------------------------------------------===// -WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAllocator()) { -} - WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr) : wexpr(wexpr), aggr(wexpr), result_type(wexpr.return_type), state_size(aggr.function.state_size(aggr.function)), exclude_mode(wexpr.exclude_clause) { @@ -31,20 +28,36 @@ WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr, WindowSha WindowAggregator::~WindowAggregator() { } -unique_ptr WindowAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &) const { +WindowAggregatorGlobalState::WindowAggregatorGlobalState(ClientContext &client, const WindowAggregator &aggregator_p, + idx_t group_count) + : client(client), allocator(Allocator::DefaultAllocator()), aggregator(aggregator_p), aggr(aggregator.wexpr), + locals(0), finalized(0) { + + if (aggr.filter) { + // Start with all invalid and set the ones that pass + filter_mask.Initialize(group_count, false); + } else { + filter_mask.InitializeEmpty(group_count); + } +} + +unique_ptr WindowAggregator::GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &) const { return make_uniq(context, *this, group_count); } +WindowAggregatorLocalState::WindowAggregatorLocalState(ExecutionContext &context) + : allocator(Allocator::DefaultAllocator()) { +} + void WindowAggregatorLocalState::Sink(ExecutionContext &context, WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx) { } -void WindowAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, - optional_ptr filter_sel, idx_t filtered, InterruptState &interrupt) { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); +void WindowAggregator::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered, OperatorSinkInput &sink) { + auto &gastate = sink.global_state.Cast(); + auto &lastate = sink.local_state.Cast(); lastate.Sink(context, gastate, sink_chunk, coll_chunk, input_idx); if (filter_sel) { auto &filter_mask = gastate.filter_mask; @@ -79,10 +92,10 @@ void WindowAggregatorLocalState::Finalize(ExecutionContext &context, WindowAggre } } -void WindowAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) { - auto &gasink = gstate.Cast(); - auto &lastate = lstate.Cast(); +void WindowAggregator::Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) { + auto &gasink = sink.global_state.Cast(); + auto &lastate = sink.local_state.Cast(); lastate.Finalize(context, gasink, collection); } diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp index 94c77b37e..1bc0c246a 100644 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ b/src/duckdb/src/function/window/window_constant_aggregator.cpp @@ -77,7 +77,7 @@ WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(ClientC //===--------------------------------------------------------------------===// class WindowConstantAggregatorLocalState : public WindowAggregatorLocalState { public: - explicit WindowConstantAggregatorLocalState(const WindowConstantAggregatorGlobalState &gstate); + WindowConstantAggregatorLocalState(ExecutionContext &context, const WindowConstantAggregatorGlobalState &gstate); ~WindowConstantAggregatorLocalState() override { } @@ -103,8 +103,9 @@ class WindowConstantAggregatorLocalState : public WindowAggregatorLocalState { }; WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( - const WindowConstantAggregatorGlobalState &gstate) - : gstate(gstate), statep(Value::POINTER(0)), statef(gstate.statef.aggr), partition(0) { + ExecutionContext &context, const WindowConstantAggregatorGlobalState &gstate) + : WindowAggregatorLocalState(context), gstate(gstate), statep(Value::POINTER(0)), statef(gstate.statef.aggr), + partition(0) { matches.Initialize(); // Start the aggregates @@ -201,16 +202,15 @@ WindowConstantAggregator::WindowConstantAggregator(BoundWindowExpression &wexpr, } } -unique_ptr WindowConstantAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { +unique_ptr WindowConstantAggregator::GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const { return make_uniq(context, *this, group_count, partition_mask); } -void WindowConstantAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gsink, - WindowAggregatorState &lstate, DataChunk &sink_chunk, DataChunk &coll_chunk, +void WindowConstantAggregator::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered, - InterruptState &interrupt) { - auto &lastate = lstate.Cast(); + OperatorSinkInput &sink) { + auto &lastate = sink.local_state.Cast(); lastate.Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); } @@ -299,11 +299,10 @@ void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChu } } -void WindowConstantAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, - WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats, InterruptState &interrupt) { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); +void WindowConstantAggregator::Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) { + auto &gastate = sink.global_state.Cast(); + auto &lastate = sink.local_state.Cast(); // Single-threaded combine lock_guard finalize_guard(gastate.lock); @@ -315,20 +314,20 @@ void WindowConstantAggregator::Finalize(ExecutionContext &context, WindowAggrega } } -unique_ptr WindowConstantAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(gstate.Cast()); +unique_ptr WindowConstantAggregator::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { + return make_uniq(context, gstate.Cast()); } -void WindowConstantAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, - WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gasink = gsink.Cast(); +void WindowConstantAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { + auto &gasink = sink.global_state.Cast(); const auto &partition_offsets = gasink.partition_offsets; const auto &results = *gasink.results; auto begins = FlatVector::GetData(bounds.data[FRAME_BEGIN]); // Chunk up the constants and copy them one at a time - auto &lcstate = lstate.Cast(); + auto &lcstate = sink.local_state.Cast(); idx_t matched = 0; idx_t target_offset = 0; for (idx_t i = 0; i < count; ++i) { diff --git a/src/duckdb/src/function/window/window_custom_aggregator.cpp b/src/duckdb/src/function/window/window_custom_aggregator.cpp index 3fb3b4b70..a2de32d3d 100644 --- a/src/duckdb/src/function/window/window_custom_aggregator.cpp +++ b/src/duckdb/src/function/window/window_custom_aggregator.cpp @@ -33,7 +33,8 @@ WindowCustomAggregator::~WindowCustomAggregator() { class WindowCustomAggregatorLocalState : public WindowAggregatorLocalState { public: - WindowCustomAggregatorLocalState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode); + WindowCustomAggregatorLocalState(ExecutionContext &context, const AggregateObject &aggr, + const WindowExcludeMode exclude_mode); ~WindowCustomAggregatorLocalState() override; public: @@ -54,13 +55,12 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { WindowCustomAggregatorGlobalState(ClientContext &client, const WindowCustomAggregator &aggregator, idx_t group_count) : WindowAggregatorGlobalState(client, aggregator, group_count) { - gcstate = make_uniq(aggr, aggregator.exclude_mode); } //! Traditional packed filter mask for API ValidityMask filter_packed; //! Data pointer that contains a single local state, used for global custom window execution state - unique_ptr gcstate; + unique_ptr glstate; //! The argument data CollectionPtr collection; //! Column global validity flags @@ -69,9 +69,10 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { FrameStats stats; }; -WindowCustomAggregatorLocalState::WindowCustomAggregatorLocalState(const AggregateObject &aggr, +WindowCustomAggregatorLocalState::WindowCustomAggregatorLocalState(ExecutionContext &context, + const AggregateObject &aggr, const WindowExcludeMode exclude_mode) - : aggr(aggr), state(aggr.function.state_size(aggr.function)), + : WindowAggregatorLocalState(context), aggr(aggr), state(aggr.function.state_size(aggr.function)), statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { // if we have a frame-by-frame method, share the single state aggr.function.initialize(aggr.function, state.data()); @@ -86,22 +87,21 @@ WindowCustomAggregatorLocalState::~WindowCustomAggregatorLocalState() { } } -unique_ptr WindowCustomAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &) const { +unique_ptr WindowCustomAggregator::GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &) const { return make_uniq(context, *this, group_count); } -void WindowCustomAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, - WindowAggregatorState &lstate, CollectionPtr collection, const FrameStats &stats, - InterruptState &interrupt) { +void WindowCustomAggregator::Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) { // Single threaded Finalize for now - auto &gcsink = gstate.Cast(); + auto &gcsink = sink.global_state.Cast(); lock_guard gestate_guard(gcsink.lock); if (gcsink.finalized) { return; } - WindowAggregator::Finalize(context, gstate, lstate, collection, stats, interrupt); + WindowAggregator::Finalize(context, collection, stats, sink); gcsink.collection = collection; auto inputs = collection->inputs.get(); @@ -113,10 +113,12 @@ void WindowCustomAggregator::Finalize(ExecutionContext &context, WindowAggregato auto &filter_mask = gcsink.filter_mask; auto &filter_packed = gcsink.filter_packed; filter_mask.Pack(filter_packed, filter_mask.Capacity()); + gcsink.glstate = GetLocalState(context, gcsink); if (aggr.function.window_init) { - auto &gcstate = *gcsink.gcstate; - WindowPartitionInput partition(context, inputs, count, child_idx, all_valids, filter_packed, stats, interrupt); + auto &gcstate = gcsink.glstate->Cast(); + WindowPartitionInput partition(context, inputs, count, child_idx, all_valids, filter_packed, stats, + sink.interrupt_state); AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); aggr.function.window_init(aggr_input_data, partition, gcstate.state.data()); @@ -125,19 +127,20 @@ void WindowCustomAggregator::Finalize(ExecutionContext &context, WindowAggregato ++gcsink.finalized; } -unique_ptr WindowCustomAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(aggr, exclude_mode); +unique_ptr WindowCustomAggregator::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { + return make_uniq(context, aggr, exclude_mode); } -void WindowCustomAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, - WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &lcstate = lstate.Cast(); +void WindowCustomAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { + auto &lcstate = sink.local_state.Cast(); auto &frames = lcstate.frames; const_data_ptr_t gstate_p = nullptr; - auto &gcsink = gsink.Cast(); - if (gcsink.gcstate) { - gstate_p = gcsink.gcstate->state.data(); + auto &gcsink = sink.global_state.Cast(); + if (gcsink.glstate) { + auto &gcstate = gcsink.glstate->Cast(); + gstate_p = gcstate.state.data(); } auto collection = gcsink.collection; @@ -146,10 +149,10 @@ void WindowCustomAggregator::Evaluate(ExecutionContext &context, const WindowAgg auto &filter_packed = gcsink.filter_packed; auto &stats = gcsink.stats; WindowPartitionInput partition(context, inputs, collection->size(), child_idx, all_valids, filter_packed, stats, - interrupt); + sink.interrupt_state); EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { // Extract the range - AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), lcstate.allocator); aggr.function.window(aggr_input_data, partition, gstate_p, lcstate.state.data(), frames, result, i); }); } diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp index 71814ea0b..21a4e9f1c 100644 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ b/src/duckdb/src/function/window/window_distinct_aggregator.cpp @@ -180,7 +180,8 @@ optional_ptr WindowDistinctAggregatorGlobalState::InitializeLoca class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { public: - explicit WindowDistinctAggregatorLocalState(const WindowDistinctAggregatorGlobalState &aggregator); + WindowDistinctAggregatorLocalState(ExecutionContext &context, + const WindowDistinctAggregatorGlobalState &aggregator); ~WindowDistinctAggregatorLocalState() override { statef.Destroy(); @@ -230,10 +231,10 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { }; WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( - const WindowDistinctAggregatorGlobalState &gdstate) - : tree_allocator(gdstate.CreateTreeAllocator()), update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), - target_v(LogicalType::POINTER), gdstate(gdstate), statef(gdstate.aggr), statep(LogicalType::POINTER), - statel(LogicalType::POINTER), flush_count(0) { + ExecutionContext &context, const WindowDistinctAggregatorGlobalState &gdstate) + : WindowAggregatorLocalState(context), tree_allocator(gdstate.CreateTreeAllocator()), + update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), target_v(LogicalType::POINTER), gdstate(gdstate), + statef(gdstate.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { InitSubFrames(frames, gdstate.aggregator.exclude_mode); sort_chunk.Initialize(Allocator::DefaultAllocator(), gdstate.sort_types); @@ -241,19 +242,18 @@ WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( gdstate.locals++; } -unique_ptr WindowDistinctAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { +unique_ptr WindowDistinctAggregator::GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const { return make_uniq(context, *this, group_count); } -void WindowDistinctAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gsink, - WindowAggregatorState &lstate, DataChunk &sink_chunk, DataChunk &coll_chunk, +void WindowDistinctAggregator::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered, - InterruptState &interrupt) { - WindowAggregator::Sink(context, gsink, lstate, sink_chunk, coll_chunk, input_idx, filter_sel, filtered, interrupt); + OperatorSinkInput &sink) { + WindowAggregator::Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered, sink); - auto &ldstate = lstate.Cast(); - ldstate.Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered, interrupt); + auto &ldstate = sink.local_state.Cast(); + ldstate.Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered, sink.interrupt_state); } void WindowDistinctAggregatorLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, @@ -388,11 +388,10 @@ bool WindowDistinctAggregatorGlobalState::TryPrepareNextStage(WindowDistinctAggr return true; } -void WindowDistinctAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gsink, - WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats, InterruptState &interrupt) { - auto &gdsink = gsink.Cast(); - auto &ldstate = lstate.Cast(); +void WindowDistinctAggregator::Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) { + auto &gdsink = sink.global_state.Cast(); + auto &ldstate = sink.local_state.Cast(); ldstate.Finalize(context, gdsink, collection); // Sort, merge and build the tree in parallel @@ -693,17 +692,17 @@ void WindowDistinctAggregatorLocalState::Evaluate(ExecutionContext &context, statef.Destroy(); } -unique_ptr WindowDistinctAggregator::GetLocalState(const WindowAggregatorState &gstate) const { +unique_ptr WindowDistinctAggregator::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { auto &gdstate = gstate.Cast(); - return make_uniq(gdstate); + return make_uniq(context, gdstate); } -void WindowDistinctAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, - WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { +void WindowDistinctAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { - const auto &gdstate = gsink.Cast(); - auto &ldstate = lstate.Cast(); + const auto &gdstate = sink.global_state.Cast(); + auto &ldstate = sink.local_state.Cast(); ldstate.Evaluate(context, gdstate, bounds, result, count, row_idx); } diff --git a/src/duckdb/src/function/window/window_executor.cpp b/src/duckdb/src/function/window/window_executor.cpp index 2d2e256b4..d0a2f8847 100644 --- a/src/duckdb/src/function/window/window_executor.cpp +++ b/src/duckdb/src/function/window/window_executor.cpp @@ -1,7 +1,5 @@ #include "duckdb/function/window/window_executor.hpp" - #include "duckdb/function/window/window_shared_expressions.hpp" - #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { @@ -14,7 +12,7 @@ WindowExecutorBoundsLocalState::WindowExecutorBoundsLocalState(ExecutionContext : WindowExecutorLocalState(context, gstate), partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), state(gstate.executor.wexpr, gstate.payload_count) { vector bounds_types(8, LogicalType(LogicalTypeId::UBIGINT)); - bounds.Initialize(Allocator::Get(gstate.client), bounds_types); + bounds.Initialize(Allocator::Get(context.client), bounds_types); } void WindowExecutorBoundsLocalState::UpdateBounds(WindowExecutorGlobalState &gstate, idx_t row_idx, @@ -48,13 +46,13 @@ bool WindowExecutor::IgnoreNulls() const { } void WindowExecutor::Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, - WindowExecutorLocalState &lstate, WindowExecutorGlobalState &gstate, - InterruptState &interrupt) const { - auto &lbstate = lstate.Cast(); - lbstate.UpdateBounds(gstate, row_idx, eval_chunk, lstate.range_cursor); + OperatorSinkInput &sink) const { + auto &gbstate = sink.global_state.Cast(); + auto &lbstate = sink.local_state.Cast(); + lbstate.UpdateBounds(gbstate, row_idx, eval_chunk, lbstate.range_cursor); const auto count = eval_chunk.size(); - EvaluateInternal(context, gstate, lstate, eval_chunk, result, count, row_idx, interrupt); + EvaluateInternal(context, eval_chunk, result, count, row_idx, sink); result.Verify(count); } @@ -72,39 +70,38 @@ WindowExecutorGlobalState::WindowExecutorGlobalState(ClientContext &client, cons WindowExecutorLocalState::WindowExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) { } -void WindowExecutorLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) { +void WindowExecutorLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, OperatorSinkInput &sink) { } -void WindowExecutorLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - const auto range_idx = gstate.executor.range_idx; +void WindowExecutorLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + auto &gbstate = sink.global_state.Cast(); + const auto range_idx = gbstate.executor.range_idx; if (range_idx != DConstants::INVALID_INDEX) { range_cursor = make_uniq(*collection, range_idx); } } -unique_ptr WindowExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowExecutor::GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const { - return make_uniq(context, gstate); +unique_ptr WindowExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { + return make_uniq(context, gstate.Cast()); } void WindowExecutor::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, - const idx_t input_idx, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - InterruptState &interrupt) const { - lstate.Sink(context, gstate, sink_chunk, coll_chunk, input_idx, interrupt); + const idx_t input_idx, OperatorSinkInput &sink) const { + auto &lbstate = sink.local_state.Cast(); + lbstate.Sink(context, sink_chunk, coll_chunk, input_idx, sink); } -void WindowExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, CollectionPtr collection, - InterruptState &interrupt) const { - lstate.Finalize(context, gstate, collection, interrupt); +void WindowExecutor::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const { + auto &lbstate = sink.local_state.Cast(); + lbstate.Finalize(context, collection, sink); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_index_tree.cpp b/src/duckdb/src/function/window/window_index_tree.cpp index 78c70c60b..1d8352788 100644 --- a/src/duckdb/src/function/window/window_index_tree.cpp +++ b/src/duckdb/src/function/window/window_index_tree.cpp @@ -1,4 +1,5 @@ #include "duckdb/function/window/window_index_tree.hpp" +#include "duckdb/function/window/window_collection.hpp" #include @@ -14,7 +15,7 @@ WindowIndexTree::WindowIndexTree(ClientContext &context, const BoundOrderModifie : WindowIndexTree(context, order_bys.orders, sort_idx, count) { } -unique_ptr WindowIndexTree::GetLocalState(ExecutionContext &context) { +unique_ptr WindowIndexTree::GetLocalState(ExecutionContext &context) { return make_uniq(context, *this); } diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index 2201f19cf..8bacb1943 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -48,7 +48,7 @@ class WindowNaiveLocalState : public WindowAggregatorLocalState { using RowSet = std::unordered_set; - explicit WindowNaiveLocalState(const WindowNaiveAggregator &gsink); + WindowNaiveLocalState(ExecutionContext &context, const WindowNaiveAggregator &aggregator); void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; @@ -97,9 +97,9 @@ class WindowNaiveLocalState : public WindowAggregatorLocalState { SelectionVector orderby_sel; }; -WindowNaiveLocalState::WindowNaiveLocalState(const WindowNaiveAggregator &aggregator) - : aggregator(aggregator), state(aggregator.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), - statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH) { +WindowNaiveLocalState::WindowNaiveLocalState(ExecutionContext &context, const WindowNaiveAggregator &aggregator) + : WindowAggregatorLocalState(context), aggregator(aggregator), state(aggregator.state_size * STANDARD_VECTOR_SIZE), + statef(LogicalType::POINTER), statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH) { InitSubFrames(frames, aggregator.exclude_mode); update_sel.Initialize(); @@ -360,16 +360,16 @@ void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggr } } -unique_ptr WindowNaiveAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(*this); +unique_ptr WindowNaiveAggregator::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { + return make_uniq(context, *this); } -void WindowNaiveAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, - WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - const auto &gnstate = gsink.Cast(); - auto &lnstate = lstate.Cast(); - lnstate.Evaluate(context, gnstate, bounds, result, count, row_idx, interrupt); +void WindowNaiveAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { + const auto &gnstate = sink.global_state.Cast(); + auto &lnstate = sink.local_state.Cast(); + lnstate.Evaluate(context, gnstate, bounds, result, count, row_idx, sink.interrupt_state); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index 9de1aab12..af70521a0 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -48,11 +48,10 @@ class WindowPeerLocalState : public WindowExecutorBoundsLocalState { } //! Accumulate the secondary sort values - void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + OperatorSinkInput &sink) override; //! Finish the sinking and prepare to scan - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) override; void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); @@ -63,26 +62,25 @@ class WindowPeerLocalState : public WindowExecutorBoundsLocalState { //! The corresponding global peer state const WindowPeerGlobalState &gpstate; //! The optional sorting state for secondary sorts - unique_ptr local_tree; + unique_ptr local_tree; }; -void WindowPeerLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx, interrupt); +void WindowPeerLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Sink(context, sink_chunk, coll_chunk, input_idx, sink); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0, interrupt); + local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0, sink.interrupt_state); } } -void WindowPeerLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Finalize(context, gstate, collection, interrupt); +void WindowPeerLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Finalize(context, collection, sink); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Finalize(context, interrupt); + local_tokens.Finalize(context, sink.interrupt_state); local_tokens.window_tree.Build(); } } @@ -111,15 +109,14 @@ WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, WindowShare } } -unique_ptr WindowPeerExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowPeerExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowPeerExecutor::GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowPeerExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { return make_uniq(context, gstate.Cast()); } @@ -130,11 +127,10 @@ WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, WindowShare : WindowPeerExecutor(wexpr, shared) { } -void WindowRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); +void WindowRankExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { + auto &gpeer = sink.global_state.Cast(); + auto &lpeer = sink.local_state.Cast(); auto rdata = FlatVector::GetData(result); if (gpeer.use_framing) { @@ -174,12 +170,12 @@ WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, W : WindowPeerExecutor(wexpr, shared) { } -void WindowDenseRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &lpeer = lstate.Cast(); +void WindowDenseRankExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gpeer = sink.global_state.Cast(); + auto &lpeer = sink.local_state.Cast(); - auto &order_mask = gstate.order_mask; + auto &order_mask = gpeer.order_mask; auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); auto rdata = FlatVector::GetData(result); @@ -241,12 +237,10 @@ static inline double PercentRank(const idx_t begin, const idx_t end, const uint6 return denom > 0 ? ((double)rank - 1) / denom : 0; } -void WindowPercentRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); +void WindowPercentRankExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gpeer = sink.global_state.Cast(); + auto &lpeer = sink.local_state.Cast(); auto rdata = FlatVector::GetData(result); if (gpeer.use_framing) { @@ -295,11 +289,10 @@ static inline double CumeDist(const idx_t begin, const idx_t end, const idx_t pe return denom > 0 ? (num / denom) : 0; } -void WindowCumeDistExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); +void WindowCumeDistExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gpeer = sink.global_state.Cast(); + auto &lpeer = sink.local_state.Cast(); auto rdata = FlatVector::GetData(result); if (gpeer.use_framing) { diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index 347b1a7dc..f0929d642 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -55,36 +55,33 @@ class WindowRowNumberLocalState : public WindowExecutorBoundsLocalState { } //! Accumulate the secondary sort values - void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + OperatorSinkInput &sink) override; //! Finish the sinking and prepare to scan - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) override; //! The corresponding global peer state const WindowRowNumberGlobalState &grstate; //! The optional sorting state for secondary sorts - unique_ptr local_tree; + unique_ptr local_tree; }; -void WindowRowNumberLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, - DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, - InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx, interrupt); +void WindowRowNumberLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Sink(context, sink_chunk, coll_chunk, input_idx, sink); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0, interrupt); + local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0, sink.interrupt_state); } } -void WindowRowNumberLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Finalize(context, gstate, collection, interrupt); +void WindowRowNumberLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Finalize(context, collection, sink); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Finalize(context, interrupt); + local_tokens.Finalize(context, sink.interrupt_state); local_tokens.window_tree.Build(); } } @@ -100,23 +97,21 @@ WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, W } } -unique_ptr WindowRowNumberExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowRowNumberExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr -WindowRowNumberExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowRowNumberExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { return make_uniq(context, gstate.Cast()); } -void WindowRowNumberExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &grstate = gstate.Cast(); - auto &lrstate = lstate.Cast(); +void WindowRowNumberExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &grstate = sink.global_state.Cast(); + auto &lrstate = sink.local_state.Cast(); auto rdata = FlatVector::GetData(result); if (grstate.use_framing) { @@ -151,11 +146,10 @@ WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSha ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); } -void WindowNtileExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &grstate = gstate.Cast(); - auto &lrstate = lstate.Cast(); +void WindowNtileExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &grstate = sink.global_state.Cast(); + auto &lrstate = sink.local_state.Cast(); auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(lrstate.bounds.data[PARTITION_END]); if (grstate.use_framing) { diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index 38fb851e1..e96c41c23 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -133,7 +133,7 @@ class WindowSegmentTreePart { class WindowSegmentTreeLocalState : public WindowAggregatorLocalState { public: - WindowSegmentTreeLocalState() { + explicit WindowSegmentTreeLocalState(ExecutionContext &context) : WindowAggregatorLocalState(context) { } void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; @@ -145,11 +145,11 @@ class WindowSegmentTreeLocalState : public WindowAggregatorLocalState { unique_ptr right_part; }; -void WindowSegmentTree::Finalize(ExecutionContext &context, WindowAggregatorState &gsink, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) { - WindowAggregator::Finalize(context, gsink, lstate, collection, stats, interrupt); +void WindowSegmentTree::Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) { + WindowAggregator::Finalize(context, collection, stats, sink); - auto &gasink = gsink.Cast(); + auto &gasink = sink.global_state.Cast(); ++gasink.finalized; } @@ -182,13 +182,14 @@ WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const Ag WindowSegmentTreePart::~WindowSegmentTreePart() { } -unique_ptr WindowSegmentTree::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { +unique_ptr WindowSegmentTree::GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const { return make_uniq(context, *this, group_count); } -unique_ptr WindowSegmentTree::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(); +unique_ptr WindowSegmentTree::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { + return make_uniq(context); } void WindowSegmentTreePart::FlushStates(bool combining) { @@ -391,11 +392,10 @@ void WindowSegmentTreeLocalState::Finalize(ExecutionContext &context, WindowAggr } } -void WindowSegmentTree::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, - WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const { - const auto >state = gsink.Cast(); - auto <state = lstate.Cast(); +void WindowSegmentTree::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { + const auto >state = sink.global_state.Cast(); + auto <state = sink.local_state.Cast(); ltstate.Evaluate(context, gtstate, bounds, result, count, row_idx); } diff --git a/src/duckdb/src/function/window/window_token_tree.cpp b/src/duckdb/src/function/window/window_token_tree.cpp index bcad348cf..bfea697db 100644 --- a/src/duckdb/src/function/window/window_token_tree.cpp +++ b/src/duckdb/src/function/window/window_token_tree.cpp @@ -1,4 +1,5 @@ #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/function/window/window_collection.hpp" namespace duckdb { @@ -83,7 +84,7 @@ static void BuildTokens(WindowTokenTree &token_tree, vector &tokens) { } } -unique_ptr WindowTokenTree::GetLocalState(ExecutionContext &context) { +unique_ptr WindowTokenTree::GetLocalState(ExecutionContext &context) { return make_uniq(context, *this); } diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index a4bb0695d..0258b7d6b 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -69,16 +69,15 @@ class WindowValueLocalState : public WindowExecutorBoundsLocalState { } //! Accumulate the secondary sort values - void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + OperatorSinkInput &sink) override; //! Finish the sinking and prepare to scan - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) override; //! The corresponding global value state const WindowValueGlobalState &gvstate; //! The optional sorting state for secondary sorts - unique_ptr local_value; + unique_ptr local_value; //! Reusable selection vector for NULLs SelectionVector sort_nulls; //! The frame boundaries, used for EXCLUDE @@ -88,11 +87,12 @@ class WindowValueLocalState : public WindowExecutorBoundsLocalState { unique_ptr cursor; }; -void WindowValueLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx, interrupt); +void WindowValueLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Sink(context, sink_chunk, coll_chunk, input_idx, sink); if (local_value) { + const auto &gvstate = sink.global_state.Cast(); idx_t filtered = 0; optional_ptr filter_sel; @@ -103,7 +103,7 @@ void WindowValueLocalState::Sink(ExecutionContext &context, WindowExecutorGlobal UnifiedVectorFormat child_data; child.ToUnifiedFormat(coll_count, child_data); const auto &validity = child_data.validity; - if (gstate.executor.IgnoreNulls() && !validity.AllValid()) { + if (gvstate.executor.IgnoreNulls() && !validity.AllValid()) { const auto &sel = *child_data.sel; for (sel_t i = 0; i < coll_count; ++i) { const auto idx = sel.get_index(i); @@ -115,17 +115,16 @@ void WindowValueLocalState::Sink(ExecutionContext &context, WindowExecutorGlobal } auto &value_state = local_value->Cast(); - value_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered, interrupt); + value_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered, sink.interrupt_state); } } -void WindowValueLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - WindowExecutorBoundsLocalState::Finalize(context, gstate, collection, interrupt); +void WindowValueLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + WindowExecutorBoundsLocalState::Finalize(context, collection, sink); if (local_value) { auto &value_state = local_value->Cast(); - value_state.Finalize(context, interrupt); + value_state.Finalize(context, sink.interrupt_state); value_state.index_tree.Build(); } @@ -158,24 +157,21 @@ WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, WindowSha default_idx = shared.RegisterEvaluate(wexpr.default_expr); } -unique_ptr WindowValueExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowValueExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -void WindowValueExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, CollectionPtr collection, - InterruptState &interrupt) const { - auto &gvstate = gstate.Cast(); +void WindowValueExecutor::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const { + auto &gvstate = sink.global_state.Cast(); gvstate.Finalize(collection); - WindowExecutor::Finalize(context, gstate, lstate, collection, interrupt); + WindowExecutor::Finalize(context, collection, sink); } -unique_ptr WindowValueExecutor::GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowValueExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { const auto &gvstate = gstate.Cast(); return make_uniq(context, gvstate); } @@ -246,36 +242,34 @@ class WindowLeadLagLocalState : public WindowValueLocalState { } //! Accumulate the secondary sort values - void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + OperatorSinkInput &sink) override; //! Finish the sinking and prepare to scan - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) override; //! The optional sorting state for the secondary sort row mapping - unique_ptr local_row; + unique_ptr local_row; }; -void WindowLeadLagLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt) { - WindowValueLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx, interrupt); +void WindowLeadLagLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, OperatorSinkInput &sink) { + WindowValueLocalState::Sink(context, sink_chunk, coll_chunk, input_idx, sink); if (local_row) { idx_t filtered = 0; optional_ptr filter_sel; auto &row_state = local_row->Cast(); - row_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered, interrupt); + row_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered, sink.interrupt_state); } } -void WindowLeadLagLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - WindowValueLocalState::Finalize(context, gstate, collection, interrupt); +void WindowLeadLagLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + WindowValueLocalState::Finalize(context, collection, sink); if (local_row) { auto &row_state = local_row->Cast(); - row_state.Finalize(context, interrupt); + row_state.Finalize(context, sink.interrupt_state); row_state.window_tree.Build(); } } @@ -287,24 +281,22 @@ WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, Windo : WindowValueExecutor(wexpr, shared) { } -unique_ptr WindowLeadLagExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowLeadLagExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr -WindowLeadLagExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowLeadLagExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { const auto &glstate = gstate.Cast(); return make_uniq(context, glstate); } -void WindowLeadLagExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &glstate = gstate.Cast(); - auto &llstate = lstate.Cast(); +void WindowLeadLagExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &glstate = sink.global_state.Cast(); + auto &llstate = sink.local_state.Cast(); auto &cursor = *llstate.cursor; WindowInputExpression leadlag_offset(eval_chunk, offset_idx); @@ -457,11 +449,10 @@ WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, : WindowValueExecutor(wexpr, shared) { } -void WindowFirstValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); +void WindowFirstValueExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gvstate = sink.global_state.Cast(); + auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; auto &bounds = lvstate.bounds; auto &frames = lvstate.frames; @@ -507,11 +498,10 @@ WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, W : WindowValueExecutor(wexpr, shared) { } -void WindowLastValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); +void WindowLastValueExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gvstate = sink.global_state.Cast(); + auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; auto &bounds = lvstate.bounds; auto &frames = lvstate.frames; @@ -567,11 +557,10 @@ WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, Win : WindowValueExecutor(wexpr, shared) { } -void WindowNthValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, - idx_t count, idx_t row_idx, InterruptState &interrupt) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); +void WindowNthValueExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { + auto &gvstate = sink.global_state.Cast(); + auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; auto &bounds = lvstate.bounds; auto &frames = lvstate.frames; @@ -899,16 +888,14 @@ class WindowFillLocalState : public WindowLeadLagLocalState { } //! Finish the sinking and prepare to scan - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) override; //! Cursor for the secondary sort values unique_ptr order_cursor; }; -void WindowFillLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - CollectionPtr collection, InterruptState &interrupt) { - WindowLeadLagLocalState::Finalize(context, gstate, collection, interrupt); +void WindowFillLocalState::Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) { + WindowLeadLagLocalState::Finalize(context, collection, sink); // Prepare to scan auto &gfstate = gvstate.Cast(); @@ -917,24 +904,22 @@ void WindowFillLocalState::Finalize(ExecutionContext &context, WindowExecutorGlo } } -unique_ptr WindowFillExecutor::GetGlobalState(ClientContext &client, - const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { +unique_ptr WindowFillExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowFillExecutor::GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowFillExecutor::GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const { const auto &gfstate = gstate.Cast(); return make_uniq(context, gfstate); } -void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const { +void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const { - auto &lfstate = lstate.Cast(); + auto &lfstate = sink.local_state.Cast(); auto &cursor = *lfstate.cursor; // Assume the best and just batch copy all the values @@ -948,7 +933,7 @@ void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, WindowExecu } // Missing values - linear interpolation - auto &gfstate = gstate.Cast(); + auto &gfstate = sink.global_state.Cast(); auto partition_begin = FlatVector::GetData(lfstate.bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(lfstate.bounds.data[PARTITION_END]); diff --git a/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp index 766dc61ef..29c05f97d 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp @@ -20,16 +20,13 @@ class WindowAggregateExecutor : public WindowExecutor { WindowAggregationMode mode); void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - InterruptState &interrupt) const override; - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection, InterruptState &interrupt) const override; + OperatorSinkInput &sink) const override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const override; - unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; const WindowAggregationMode mode; @@ -40,9 +37,8 @@ class WindowAggregateExecutor : public WindowExecutor { unique_ptr filter_ref; protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp index 2edb8745e..631c4596d 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp @@ -17,27 +17,6 @@ class WindowCollection; class WindowCursor; struct WindowSharedExpressions; -class WindowAggregatorState { -public: - WindowAggregatorState(); - virtual ~WindowAggregatorState() { - } - - template - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - - //! Allocator for aggregates - ArenaAllocator allocator; -}; - class WindowAggregator { public: using CollectionPtr = optional_ptr; @@ -110,21 +89,20 @@ class WindowAggregator { virtual ~WindowAggregator(); // Threading states - virtual unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, - const ValidityMask &partition_mask) const; - virtual unique_ptr GetLocalState(const WindowAggregatorState &gstate) const = 0; + virtual unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, + const ValidityMask &partition_mask) const; + virtual unique_ptr GetLocalState(ExecutionContext &context, + const GlobalSinkState &gstate) const = 0; // Build - virtual void Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, - optional_ptr filter_sel, idx_t filtered, InterruptState &interrupt); - virtual void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt); + virtual void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered, OperatorSinkInput &sink); + virtual void Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink); // Probe - virtual void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const = 0; + virtual void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const = 0; //! The window function const BoundWindowExpression &wexpr; @@ -142,22 +120,16 @@ class WindowAggregator { vector child_idx; }; -class WindowAggregatorGlobalState : public WindowAggregatorState { +class WindowAggregatorGlobalState : public GlobalSinkState { public: - WindowAggregatorGlobalState(ClientContext &client, const WindowAggregator &aggregator_p, idx_t group_count) - : client(client), aggregator(aggregator_p), aggr(aggregator.wexpr), locals(0), finalized(0) { - - if (aggr.filter) { - // Start with all invalid and set the ones that pass - filter_mask.Initialize(group_count, false); - } else { - filter_mask.InitializeEmpty(group_count); - } - } + WindowAggregatorGlobalState(ClientContext &client, const WindowAggregator &aggregator_p, idx_t group_count); //! The client we are in ClientContext &client; + //! Global allocator + ArenaAllocator allocator; + //! The aggregator data const WindowAggregator &aggregator; @@ -177,19 +149,21 @@ class WindowAggregatorGlobalState : public WindowAggregatorState { std::atomic finalized; }; -class WindowAggregatorLocalState : public WindowAggregatorState { +class WindowAggregatorLocalState : public LocalSinkState { public: using CollectionPtr = optional_ptr; static void InitSubFrames(SubFrames &frames, const WindowExcludeMode exclude_mode); - WindowAggregatorLocalState() { - } + explicit WindowAggregatorLocalState(ExecutionContext &context); void Sink(ExecutionContext &context, WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t row_idx); virtual void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection); + //! Global allocator + ArenaAllocator allocator; + //! The state used for reading the collection unique_ptr cursor; }; diff --git a/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp index a81075b62..2e5cc2cc3 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp @@ -22,18 +22,16 @@ class WindowConstantAggregator : public WindowAggregator { ~WindowConstantAggregator() override { } - unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const override; - void Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered, InterruptState &interrupt) override; - void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) override; - - unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const override; + unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered, OperatorSinkInput &sink) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) override; + + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; + void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp index a86e1fe94..6ba86c75c 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp @@ -19,15 +19,14 @@ class WindowCustomAggregator : public WindowAggregator { WindowCustomAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared); ~WindowCustomAggregator() override; - unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const override; - void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) override; - - unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const override; + unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const override; + void Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) override; + + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; + void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp index a2a19929a..0904bbfdd 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp @@ -20,19 +20,17 @@ class WindowDistinctAggregator : public WindowAggregator { ClientContext &client); // Build - unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, - const ValidityMask &partition_mask) const override; - void Sink(ExecutionContext &context, WindowAggregatorState &gsink, WindowAggregatorState &lstate, - DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered, InterruptState &interrupt) override; - void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) override; + unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, + const ValidityMask &partition_mask) const override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered, OperatorSinkInput &sink) override; + void Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) override; // Evaluate - unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; + void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; //! Context for sorting ClientContext &context; diff --git a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp index 1459f2d19..ae1d46670 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp @@ -8,37 +8,19 @@ #pragma once +#include "duckdb/execution/physical_operator_states.hpp" #include "duckdb/function/window/window_boundaries_state.hpp" #include "duckdb/function/window/window_collection.hpp" namespace duckdb { class WindowCollection; -class InterruptState; struct WindowSharedExpressions; -class WindowExecutorState { -public: - WindowExecutorState() {}; - virtual ~WindowExecutorState() { - } - - template - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } -}; - class WindowExecutor; -class WindowExecutorGlobalState : public WindowExecutorState { +class WindowExecutorGlobalState : public GlobalSinkState { public: using CollectionPtr = optional_ptr; @@ -54,16 +36,15 @@ class WindowExecutorGlobalState : public WindowExecutorState { vector arg_types; }; -class WindowExecutorLocalState : public WindowExecutorState { +class WindowExecutorLocalState : public LocalSinkState { public: using CollectionPtr = optional_ptr; WindowExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate); - virtual void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, InterruptState &interrupt); - virtual void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection, - InterruptState &interrupt); + virtual void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + OperatorSinkInput &sink); + virtual void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink); //! The state used for reading the range collection unique_ptr range_cursor; @@ -95,21 +76,18 @@ class WindowExecutor { virtual bool IgnoreNulls() const; - virtual unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const; - virtual unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const; + virtual unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const; + virtual unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const; virtual void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - InterruptState &interrupt) const; + OperatorSinkInput &sink) const; - virtual void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, CollectionPtr collection, InterruptState &interrupt) const; + virtual void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const; void Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, - WindowExecutorLocalState &lstate, WindowExecutorGlobalState &gstate, InterruptState &interrupt) const; + OperatorSinkInput &sink) const; // The function const BoundWindowExpression &wexpr; @@ -123,9 +101,8 @@ class WindowExecutor { column_t range_idx = DConstants::INVALID_INDEX; protected: - virtual void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const = 0; + virtual void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx, OperatorSinkInput &sink) const = 0; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp index e131bba1b..058b9929f 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp @@ -33,7 +33,7 @@ class WindowIndexTree : public WindowMergeSortTree { const idx_t count); ~WindowIndexTree() override = default; - unique_ptr GetLocalState(ExecutionContext &context) override; + unique_ptr GetLocalState(ExecutionContext &context) override; //! Find the Nth index in the set of subframes //! Returns {nth index, 0} or {nth offset, overflow} diff --git a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp index a79ce50a9..b560f2fc2 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp @@ -11,7 +11,6 @@ #include "duckdb/execution/merge_sort_tree.hpp" #include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/function/window/window_aggregator.hpp" #include "duckdb/common/sorting/sort.hpp" namespace duckdb { @@ -20,7 +19,7 @@ enum class WindowMergeSortStage : uint8_t { INIT, COMBINE, FINALIZE, SORTED, FIN class WindowMergeSortTree; -class WindowMergeSortTreeLocalState : public WindowAggregatorState { +class WindowMergeSortTreeLocalState : public LocalSinkState { public: WindowMergeSortTreeLocalState(ExecutionContext &context, WindowMergeSortTree &index_tree); @@ -56,7 +55,7 @@ class WindowMergeSortTree { const vector &order_idx, const idx_t count, bool unique = false); virtual ~WindowMergeSortTree() = default; - virtual unique_ptr GetLocalState(ExecutionContext &context) = 0; + virtual unique_ptr GetLocalState(ExecutionContext &context) = 0; //! Make a local sort for a thread optional_ptr InitializeLocalSort(ExecutionContext &context) const; diff --git a/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp index 26dc3a2fd..1801908f8 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp @@ -20,10 +20,9 @@ class WindowNaiveAggregator : public WindowAggregator { WindowNaiveAggregator(const WindowAggregateExecutor &executor, WindowSharedExpressions &shared); ~WindowNaiveAggregator() override; - unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; + void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; //! The parent executor const WindowAggregateExecutor &executor; diff --git a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp index ebed62dc0..46678c4ad 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp @@ -16,11 +16,10 @@ class WindowPeerExecutor : public WindowExecutor { public: WindowPeerExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; //! The column indices of any ORDER BY argument expressions vector arg_order_idx; @@ -31,9 +30,8 @@ class WindowRankExecutor : public WindowPeerExecutor { WindowRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowDenseRankExecutor : public WindowPeerExecutor { @@ -41,9 +39,8 @@ class WindowDenseRankExecutor : public WindowPeerExecutor { WindowDenseRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowPercentRankExecutor : public WindowPeerExecutor { @@ -51,9 +48,8 @@ class WindowPercentRankExecutor : public WindowPeerExecutor { WindowPercentRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowCumeDistExecutor : public WindowPeerExecutor { @@ -61,9 +57,8 @@ class WindowCumeDistExecutor : public WindowPeerExecutor { WindowCumeDistExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp index f548a6e30..dc6ce6ea9 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp @@ -16,11 +16,10 @@ class WindowRowNumberExecutor : public WindowExecutor { public: WindowRowNumberExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; //! The evaluation index of the NTILE column column_t ntile_idx = DConstants::INVALID_INDEX; @@ -28,9 +27,8 @@ class WindowRowNumberExecutor : public WindowExecutor { vector arg_order_idx; protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; // NTILE is just scaled ROW_NUMBER @@ -39,9 +37,8 @@ class WindowNtileExecutor : public WindowRowNumberExecutor { WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp index bad3aebb2..40fcb4155 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp @@ -18,15 +18,14 @@ class WindowSegmentTree : public WindowAggregator { WindowSegmentTree(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const override; - unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats, InterruptState &interrupt) override; - - void Evaluate(ExecutionContext &context, const WindowAggregatorState &gstate, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, - InterruptState &interrupt) const override; + unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, + const ValidityMask &partition_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; + void Finalize(ExecutionContext &context, CollectionPtr collection, const FrameStats &stats, + OperatorSinkInput &sink) override; + + void Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp index e6cd41869..c2790d965 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp @@ -24,7 +24,7 @@ class WindowTokenTree : public WindowMergeSortTree { : WindowTokenTree(context, order_bys.orders, order_idx, count, unique) { } - unique_ptr GetLocalState(ExecutionContext &context) override; + unique_ptr GetLocalState(ExecutionContext &context) override; //! Thread-safe post-sort cleanup void Finished() override; diff --git a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp index de581a915..4504cfc30 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp @@ -17,14 +17,12 @@ class WindowValueExecutor : public WindowExecutor { public: WindowValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection, InterruptState &interrupt) const override; + void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const override; - unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; //! The column index of the value column column_t child_idx = DConstants::INVALID_INDEX; @@ -42,16 +40,14 @@ class WindowLeadLagExecutor : public WindowValueExecutor { public: WindowLeadLagExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowFirstValueExecutor : public WindowValueExecutor { @@ -59,9 +55,8 @@ class WindowFirstValueExecutor : public WindowValueExecutor { WindowFirstValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowLastValueExecutor : public WindowValueExecutor { @@ -69,9 +64,8 @@ class WindowLastValueExecutor : public WindowValueExecutor { WindowLastValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowNthValueExecutor : public WindowValueExecutor { @@ -79,9 +73,8 @@ class WindowNthValueExecutor : public WindowValueExecutor { WindowNthValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; class WindowFillExecutor : public WindowValueExecutor { @@ -93,19 +86,17 @@ class WindowFillExecutor : public WindowValueExecutor { return false; } - unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(ExecutionContext &context, - const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const override; + unique_ptr GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const override; //! Secondary order collection index idx_t order_idx = DConstants::INVALID_INDEX; protected: - void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx, InterruptState &interrupt) const override; + void EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, + OperatorSinkInput &sink) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index 15beb4699..1b48b9330 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -224,8 +224,7 @@ class ClientContext : public enable_shared_from_this { private: //! Parse statements and resolve pragmas from a query - bool ParseStatements(ClientContextLock &lock, const string &query, vector> &result, - ErrorData &error); + vector> ParseStatements(ClientContextLock &lock, const string &query); //! Issues a query to the database and returns a Pending Query Result unique_ptr PendingQueryInternal(ClientContextLock &lock, unique_ptr statement, const PendingQueryParameters ¶meters, bool verify = true); diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/date_trunc_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/date_trunc_simplification.hpp new file mode 100644 index 000000000..b9e03a295 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/date_trunc_simplification.hpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/date_trunc_simplification.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// DateTruncSimplificationRule rewrites an expression of the form +// +// date_trunc(part, column) const_rhs +// +// such that the date_trunc is instead applied to the constant RHS and then simplified further. +// The rules applied are as follows: +// +// date_trunc(part, column) >= const_rhs --> column >= date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// - but if date_trunc(const_rhs) = const_rhs, then we can do column >= const_rhs +// +// date_trunc(part, column) <= const_rhs --> column <= date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// +// date_trunc(part, column) > const_rhs --> column >= date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// - note the change from > to >=! +// +// date_trunc(part, column) < const_rhs --> column < date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// - but if date_trunc(const_rhs) = const_rhs, then we can do column < const_rhs +// +// date_trunc(part, column) == const_rhs --> column >= date_trunc(part, const_rhs) AND +// column < date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// - but if date_trunc(const_rhs) != const_rhs, then this one is unsatisfiable +// +// date_trunc(part, column) <> const_rhs --> column < date_trunc(part, const_rhs) OR +// column >= date_trunc(part, date_add(const_rhs, INTERVAL 1 part)) +// - but if date_trunc(const_rhs) != const_rhs, then this is always satisfied +// +// date_trunc(part, column) IS NOT DISTINCT FROM const_rhs --> (column >= date_trunc(part, const_rhs) AND +// column < date_trunc(part, +// date_add(const_rhs, INTERVAL 1 part)) AND +// column IS NOT NULL) +// - but if const_rhs is NULL, then this is just 'column IS NULL' +// +// date_trunc(part, column) IS DISTINCT FROM const_rhs --> (column < date_trunc(part, const_rhs) OR +// column >= date_trunc(part, +// date_add(const_rhs, INTERVAL 1 part)) OR +// column IS NULL) +// - but if const_rhs is NULL, then this is just 'column IS NOT NULL' +// +class DateTruncSimplificationRule : public Rule { +public: + explicit DateTruncSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; + + static string DatePartToFunc(const DatePartSpecifier &date_part); + + unique_ptr CreateTrunc(const BoundConstantExpression &date_part, const BoundConstantExpression &rhs, + const LogicalType &return_type); + unique_ptr CreateTruncAdd(const BoundConstantExpression &date_part, const BoundConstantExpression &rhs, + const LogicalType &return_type); + + bool DateIsTruncated(const BoundConstantExpression &date_part, const BoundConstantExpression &rhs); + + unique_ptr CastAndEvaluate(unique_ptr rhs, const LogicalType &return_type); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp index 361de546d..4b0099cd8 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp @@ -4,6 +4,7 @@ #include "duckdb/optimizer/rule/conjunction_simplification.hpp" #include "duckdb/optimizer/rule/constant_folding.hpp" #include "duckdb/optimizer/rule/date_part_simplification.hpp" +#include "duckdb/optimizer/rule/date_trunc_simplification.hpp" #include "duckdb/optimizer/rule/distributivity.hpp" #include "duckdb/optimizer/rule/empty_needle_removal.hpp" #include "duckdb/optimizer/rule/like_optimizations.hpp" diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index c4d98192c..606d86420 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -642,13 +642,19 @@ vector> ClientContext::ParseStatements(const string &qu } vector> ClientContext::ParseStatementsInternal(ClientContextLock &lock, const string &query) { - Parser parser(GetParserOptions()); - parser.ParseQuery(query); + try { + Parser parser(GetParserOptions()); + parser.ParseQuery(query); - PragmaHandler handler(*this); - handler.HandlePragmaStatements(lock, parser.statements); + PragmaHandler handler(*this); + handler.HandlePragmaStatements(lock, parser.statements); - return std::move(parser.statements); + return std::move(parser.statements); + } catch (std::exception &ex) { + auto error = ErrorData(ex); + ProcessError(error, query); + error.Throw(); + } } void ClientContext::HandlePragmaStatements(vector> &statements) { @@ -956,10 +962,11 @@ unique_ptr ClientContext::Query(unique_ptr statement, unique_ptr ClientContext::Query(const string &query, bool allow_stream_result) { auto lock = LockContext(); - ErrorData error; vector> statements; - if (!ParseStatements(*lock, query, statements, error)) { - return ErrorResult(std::move(error), query); + try { + statements = ParseStatements(*lock, query); + } catch (const std::exception &ex) { + return ErrorResult(ErrorData(ex), query); } if (statements.empty()) { // no statements, return empty successful result @@ -1012,17 +1019,10 @@ unique_ptr ClientContext::Query(const string &query, bool allow_str return result; } -bool ClientContext::ParseStatements(ClientContextLock &lock, const string &query, - vector> &result, ErrorData &error) { - try { - InitialCleanup(lock); - // parse the query and transform it into a set of statements - result = ParseStatementsInternal(lock, query); - return true; - } catch (std::exception &ex) { - error = ErrorData(ex); - return false; - } +vector> ClientContext::ParseStatements(ClientContextLock &lock, const string &query) { + InitialCleanup(lock); + // parse the query and transform it into a set of statements + return ParseStatementsInternal(lock, query); } unique_ptr ClientContext::PendingQuery(const string &query, bool allow_stream_result) { diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index f5938010e..64bd5d7ed 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -617,7 +617,7 @@ idx_t DBConfig::ParseMemoryLimit(const string &arg) { } else if (unit == "tib") { multiplier = 1024LL * 1024LL * 1024LL * 1024LL; } else { - throw ParserException("Unknown unit for memory_limit: '%s' (expected: KB, MB, GB, TB for 1000^i units or KiB, " + throw ParserException("Unknown unit for memory: '%s' (expected: KB, MB, GB, TB for 1000^i units or KiB, " "MiB, GiB, TiB for 1024^i units)", unit); } diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index 34d2f44bb..28ee7cbeb 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -46,6 +46,7 @@ Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); diff --git a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp new file mode 100644 index 000000000..2004e97fa --- /dev/null +++ b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp @@ -0,0 +1,449 @@ +#include "duckdb/optimizer/rule/date_trunc_simplification.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/function/cast/default_casts.hpp" + +namespace duckdb { + +DateTruncSimplificationRule::DateTruncSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + + auto lhs = make_uniq(); + lhs->function = make_uniq(unordered_set {"date_trunc", "datetrunc"}); + lhs->matchers.push_back(make_uniq()); + lhs->matchers.push_back(make_uniq()); + lhs->policy = SetMatcher::Policy::ORDERED; + + auto rhs = make_uniq(); + + op->matchers.push_back(std::move(lhs)); + op->matchers.push_back(std::move(rhs)); + op->policy = SetMatcher::Policy::UNORDERED; + + root = std::move(op); +} + +unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + auto comparison_type = expr.GetExpressionType(); + + auto &date_part = bindings[2].get().Cast(); + // We must have only a column on the LHS. + if (bindings[3].get().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return nullptr; + } + + auto &column_part = bindings[3].get().Cast(); + auto &rhs = bindings[4].get().Cast(); + + // Determine whether or not the column name is on the lhs or rhs. + const bool col_is_lhs = (expr.left->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION); + + // We want to treat rhs >= col equivalently to col <= rhs. + // So, get the expression type if it was ordered such that the constant was actually on the right hand side. + ExpressionType rhs_comparison_type = comparison_type; + if (!col_is_lhs) { + rhs_comparison_type = FlipComparisonExpression(comparison_type); + } + + // Check whether trunc(date_part, constant_rhs) = constant_rhs. + const bool is_truncated = DateIsTruncated(date_part, rhs); + + switch (rhs_comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + // We handle two very similar optimizations here: + // + // date_trunc(part, column) = constant_rhs --> column >= date_trunc(part, constant_rhs) AND + // column < date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + // or, if date_trunc(part, constant_rhs) <> constant_rhs, this is unsatisfiable + // + // ---- + // + // date_trunc(part, column) IS NOT DISTINCT FROM constant_rhs + // + // Here we have two cases: when constant_rhs is NULL, this simplifies to: + // + // column IS NULL + // + // Otherwise, the expression becomes: + // + // (column >= date_trunc(part, constant_rhs) AND + // column < date_trunc(part, date_add(constant_rhs, INTERVAL 1 part)) AND + // column IS NOT NULL) + // + { + // First check if we can just return `column IS NULL`. + if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM && rhs.value.IsNull()) { + auto op = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); + op->children.push_back(column_part.Copy()); + return std::move(op); + } else { + if (!is_truncated) { + return make_uniq(Value::BOOLEAN(false)); + } + + auto trunc = CreateTrunc(date_part, rhs, column_part.return_type); + if (!trunc) { + return nullptr; + } + + auto trunc_add = CreateTruncAdd(date_part, rhs, column_part.return_type); + if (!trunc_add) { + return nullptr; + } + + auto gteq = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, + column_part.Copy(), std::move(trunc)); + auto lt = make_uniq(ExpressionType::COMPARE_LESSTHAN, column_part.Copy(), + std::move(trunc_add)); + + // For IS NOT DISTINCT FROM, we also have to add the extra NULL term. + if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + auto comp = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(gteq), + std::move(lt)); + + auto isnotnull = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + isnotnull->children.push_back(column_part.Copy()); + + return make_uniq(ExpressionType::CONJUNCTION_AND, std::move(comp), + std::move(isnotnull)); + } else { + return make_uniq(ExpressionType::CONJUNCTION_AND, std::move(gteq), + std::move(lt)); + } + } + } + + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_DISTINCT_FROM: + // We handle two very similar optimizations here: + // + // date_trunc(part, column) <> constant_rhs --> column < date_trunc(part, constant_rhs) OR + // column >= date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + // or, if date_trunc(part, constant_rhs) <> constant_rhs, this is always true + // + // ---- + // + // date_trunc(part, column) IS DISTINCT FROM constant_rhs + // + // Here we have two cases: when constant_rhs is NULL, this simplifies to: + // + // column IS NOT NULL + // + // Otherwise, the expression becomes: + // + // (column < date_trunc(part, constant_rhs) OR + // column >= date_trunc(part, date_add(constant_rhs, INTERVAL 1 part)) OR + // column IS NULL) + // + { + if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM && rhs.value.IsNull()) { + // Return 'column IS NOT NULL'. + auto op = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + op->children.push_back(column_part.Copy()); + return std::move(op); + } else { + if (!is_truncated) { + return make_uniq(Value::BOOLEAN(true)); + } + + auto trunc = CreateTrunc(date_part, rhs, column_part.return_type); + if (!trunc) { + return nullptr; + } + + auto trunc_add = CreateTruncAdd(date_part, rhs, column_part.return_type); + if (!trunc_add) { + return nullptr; + } + + auto lt = make_uniq(ExpressionType::COMPARE_LESSTHAN, column_part.Copy(), + std::move(trunc)); + auto gteq = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, + column_part.Copy(), std::move(trunc_add)); + + // If this is a DISTINCT FROM, we need to add the 'column IS NULL' term. + if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM) { + auto comp = make_uniq(ExpressionType::CONJUNCTION_OR, std::move(gteq), + std::move(lt)); + + auto isnull = + make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); + isnull->children.push_back(column_part.Copy()); + + return make_uniq(ExpressionType::CONJUNCTION_OR, std::move(comp), + std::move(isnull)); + } else { + return make_uniq(ExpressionType::CONJUNCTION_OR, std::move(gteq), + std::move(lt)); + } + } + } + return nullptr; + + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // date_trunc(part, column) < constant_rhs --> column < date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + // date_trunc(part, column) >= constant_rhs --> column >= date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + { + // The optimization for < and >= is a little tricky: if trunc(rhs) = rhs, then we need to just + // use the rhs as-is, instead of using trunc(rhs + 1 date_part). + if (!is_truncated) { + // Create date_trunc(part, date_add(rhs, INTERVAL 1 part)) and fold the constant. + auto trunc = CreateTruncAdd(date_part, rhs, column_part.return_type); + if (!trunc) { + return nullptr; // Something went wrong---don't do the optimization. + } + + if (col_is_lhs) { + expr.left = column_part.Copy(); + expr.right = std::move(trunc); + } else { + expr.right = column_part.Copy(); + expr.left = std::move(trunc); + } + } else { + // If the RHS is already truncated (i.e. date_trunc(part, rhs) = rhs), then we can use + // it as-is. + if (col_is_lhs) { + expr.left = column_part.Copy(); + // Determine whether the RHS needs to be casted. + if (rhs.return_type.id() != expr.left->return_type.id()) { + expr.right = CastAndEvaluate(std::move(expr.right), expr.left->return_type); + } + } else { + expr.right = column_part.Copy(); + // Determine whether the RHS needs to be casted. + if (rhs.return_type.id() != expr.right->return_type.id()) { + expr.left = CastAndEvaluate(std::move(expr.left), expr.right->return_type); + } + } + } + + changes_made = true; + return nullptr; + } + + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + // date_trunc(part, column) <= constant_rhs --> column <= date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + // date_trunc(part, column) > constant_rhs --> column >= date_trunc(part, date_add(constant_rhs, + // INTERVAL 1 part)) + { + // Create date_trunc(part, date_add(rhs, INTERVAL 1 part)) and fold the constant. + auto trunc = CreateTruncAdd(date_part, rhs, column_part.return_type); + if (!trunc) { + return nullptr; // Something went wrong---don't do the optimization. + } + + if (col_is_lhs) { + expr.left = column_part.Copy(); + expr.right = std::move(trunc); + } else { + expr.right = column_part.Copy(); + expr.left = std::move(trunc); + } + + // If this is a >, we need to change it to >= for correctness. + if (rhs_comparison_type == ExpressionType::COMPARE_GREATERTHAN) { + if (col_is_lhs) { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHANOREQUALTO); + } else { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHANOREQUALTO); + } + } + + changes_made = true; + return nullptr; + } + + default: + return nullptr; + } +} + +string DateTruncSimplificationRule::DatePartToFunc(const DatePartSpecifier &date_part) { + switch (date_part) { + // These specifiers can be used as intervals. + case DatePartSpecifier::YEAR: + return "to_years"; + case DatePartSpecifier::MONTH: + return "to_months"; + case DatePartSpecifier::DAY: + return "to_days"; + case DatePartSpecifier::DECADE: + return "to_decades"; + case DatePartSpecifier::CENTURY: + return "to_centuries"; + case DatePartSpecifier::MILLENNIUM: + return "to_millennia"; + case DatePartSpecifier::MICROSECONDS: + return "to_microseconds"; + case DatePartSpecifier::MILLISECONDS: + return "to_milliseconds"; + case DatePartSpecifier::SECOND: + return "to_seconds"; + case DatePartSpecifier::MINUTE: + return "to_minutes"; + case DatePartSpecifier::HOUR: + return "to_hours"; + case DatePartSpecifier::WEEK: + return "to_weeks"; + case DatePartSpecifier::QUARTER: + return "to_quarters"; + + // These specifiers cannot be used as intervals and can only be used as + // date parts. + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::ISOYEAR: + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::ERA: + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + default: + return ""; + } +} + +unique_ptr DateTruncSimplificationRule::CreateTrunc(const BoundConstantExpression &date_part, + const BoundConstantExpression &rhs, + const LogicalType &return_type) { + FunctionBinder binder(rewriter.context); + ErrorData error; + + vector> args; + args.emplace_back(std::move(date_part.Copy())); + args.emplace_back(std::move(rhs.Copy())); + auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args), error); + + // Ensure that the RHS type matches the column type. + if (trunc->return_type.id() != return_type.id()) { + trunc = BoundCastExpression::AddDefaultCastToType(std::move(trunc), return_type, true); + } + + if (trunc->IsFoldable()) { + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, result)) { + return std::move(trunc); + } + + return make_uniq(result); + } + + return std::move(trunc); +} + +unique_ptr DateTruncSimplificationRule::CreateTruncAdd(const BoundConstantExpression &date_part, + const BoundConstantExpression &rhs, + const LogicalType &return_type) { + DatePartSpecifier part = GetDatePartSpecifier(StringValue::Get(date_part.value)); + const string interval_func_name = DatePartToFunc(part); + + // If the date part cannot be represented as an interval, then we cannot + // perform the optimization. + if (interval_func_name.empty()) { + return nullptr; + } + + FunctionBinder binder(rewriter.context); + ErrorData error; + + vector> args1; + auto constant_param = make_uniq(Value::INTEGER(1)); + args1.emplace_back(std::move(constant_param)); + auto interval = binder.BindScalarFunction(DEFAULT_SCHEMA, interval_func_name, std::move(args1), error); + if (!interval) { + return nullptr; // Something wrong---just don't do the optimization. + } + + vector> args2; + args2.emplace_back(std::move(rhs.Copy())); + args2.emplace_back(std::move(interval)); + auto add = binder.BindScalarFunction(DEFAULT_SCHEMA, "+", std::move(args2), error); + + vector> args3; + args3.emplace_back(std::move(date_part.Copy())); + args3.emplace_back(std::move(add)); + auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args3), error); + + // Ensure that the RHS type matches the column type. + if (trunc->return_type.id() != return_type.id()) { + trunc = BoundCastExpression::AddDefaultCastToType(std::move(trunc), return_type, true); + } + + if (trunc->IsFoldable()) { + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, result)) { + return std::move(trunc); + } + + return make_uniq(result); + } + + return std::move(trunc); +} + +bool DateTruncSimplificationRule::DateIsTruncated(const BoundConstantExpression &date_part, + const BoundConstantExpression &rhs) { + // If the rhs is null, then the date is "truncated" in the sense that date_trunc(..., NULL) is also NULL. + if (rhs.value.IsNull()) { + return true; + } + + // Create the node date_trunc(date_part, rhs). + auto trunc = CreateTrunc(date_part, rhs, rhs.return_type); + + Value trunc_result, result; + if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *trunc, trunc_result)) { + return false; + } + if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, rhs, result)) { + return false; + } + + return (result == trunc_result); +} + +unique_ptr DateTruncSimplificationRule::CastAndEvaluate(unique_ptr rhs, + const LogicalType &return_type) { + auto cast = BoundCastExpression::AddDefaultCastToType(std::move(rhs), return_type, true); + if (cast->IsFoldable()) { + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(rewriter.context, *cast, result)) { + return cast; + } + + return make_uniq(result); + } + + return cast; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp index fd30ef195..fa1ffaeba 100644 --- a/src/duckdb/src/storage/compression/bitpacking.cpp +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -257,7 +257,10 @@ struct BitpackingState { BitpackingPrimitives::MinimumBitWidth(static_cast(min_max_delta_diff)); auto regular_required_bitwidth = BitpackingPrimitives::MinimumBitWidth(min_max_diff); - if (delta_required_bitwidth < regular_required_bitwidth && mode != BitpackingMode::FOR) { + //! `min_max_diff` is uninitialized if `can_do_for` isn't true + bool prefer_for = can_do_for && delta_required_bitwidth >= regular_required_bitwidth; + + if (!prefer_for && mode != BitpackingMode::FOR) { SubtractFrameOfReference(delta_buffer, minimum_delta); OP::WriteDeltaFor(reinterpret_cast(delta_buffer), compression_buffer_validity, diff --git a/src/duckdb/src/storage/compression/zstd.cpp b/src/duckdb/src/storage/compression/zstd.cpp index ac0d5ef24..408855284 100644 --- a/src/duckdb/src/storage/compression/zstd.cpp +++ b/src/duckdb/src/storage/compression/zstd.cpp @@ -900,16 +900,15 @@ struct ZSTDScanState : public SegmentScanState { for (idx_t i = 0; i < count; i++) { uncompressed_length += string_lengths[i]; } - auto empty_string = StringVector::EmptyString(result, uncompressed_length); - auto uncompressed_data = empty_string.GetDataWriteable(); + auto &buffer = StringVector::GetStringBuffer(result); + auto uncompressed_data = buffer.AllocateShrinkableBuffer(uncompressed_length); auto string_data = FlatVector::GetData(result); - DecompressString(scan_state, reinterpret_cast(uncompressed_data), uncompressed_length); + DecompressString(scan_state, uncompressed_data, uncompressed_length); idx_t offset = 0; - auto uncompressed_data_const = empty_string.GetData(); for (idx_t i = 0; i < count; i++) { - string_data[result_offset + i] = string_t(uncompressed_data_const + offset, string_lengths[i]); + string_data[result_offset + i] = string_t(char_ptr_cast(uncompressed_data + offset), string_lengths[i]); offset += string_lengths[i]; } scan_state.scanned_count += count; diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index a56d60c21..81e4f54c1 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -72,7 +72,10 @@ void RowGroup::MoveToCollection(RowGroupCollection &collection_p, idx_t new_star if (is_loaded && !is_loaded[c]) { // we only need to set the column start position if it is already loaded // if it is not loaded - we will set the correct start position upon loading - continue; + lock_guard l(row_group_lock); + if (!is_loaded[c]) { + continue; + } } columns[c]->SetStart(new_start); } diff --git a/src/duckdb/ub_src_optimizer_rule.cpp b/src/duckdb/ub_src_optimizer_rule.cpp index ee39a8864..3fa057ede 100644 --- a/src/duckdb/ub_src_optimizer_rule.cpp +++ b/src/duckdb/ub_src_optimizer_rule.cpp @@ -10,6 +10,8 @@ #include "src/optimizer/rule/date_part_simplification.cpp" +#include "src/optimizer/rule/date_trunc_simplification.cpp" + #include "src/optimizer/rule/distinct_aggregate_optimizer.cpp" #include "src/optimizer/rule/distributivity.cpp"