Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/duckdb/extension/core_functions/function_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ static const StaticFunctionDefinition core_functions[] = {
DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun),
DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun),
DUCKDB_SCALAR_FUNCTION(StructInsertFun),
DUCKDB_SCALAR_FUNCTION(StructUpdateFun),
DUCKDB_AGGREGATE_FUNCTION_SET(SumFun),
DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun),
DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,14 @@ struct StructInsertFun {
static ScalarFunction GetFunction();
};

struct StructUpdateFun {
static constexpr const char *Name = "struct_update";
static constexpr const char *Parameters = "struct,any";
static constexpr const char *Description = "Changes field(s)/value(s) to an existing STRUCT with the argument values. The entry name(s) will be the bound variable name(s)";
static constexpr const char *Example = "struct_update({'a': 1}, a := 2)";
static constexpr const char *Categories = "";

static ScalarFunction GetFunction();
};

} // namespace duckdb
161 changes: 161 additions & 0 deletions src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#include "core_functions/scalar/struct_functions.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/parser/expression/bound_expression.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/storage/statistics/struct_stats.hpp"
#include "duckdb/planner/expression_binder.hpp"

namespace duckdb {

static void StructUpdateFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &starting_vec = args.data[0];
starting_vec.Verify(args.size());

auto &starting_child_entries = StructVector::GetEntries(starting_vec);
auto &result_child_entries = StructVector::GetEntries(result);

auto &starting_types = StructType::GetChildTypes(starting_vec.GetType());

auto &func_args = state.expr.Cast<BoundFunctionExpression>().children;
auto new_entries = case_insensitive_tree_t<idx_t>();
auto is_new_field = vector<bool>(args.ColumnCount(), true);

for (idx_t arg_idx = 1; arg_idx < func_args.size(); arg_idx++) {
auto &new_child = func_args[arg_idx];
new_entries.emplace(new_child->alias, arg_idx);
}

// Assign the original child entries to the STRUCT.
for (idx_t field_idx = 0; field_idx < starting_child_entries.size(); field_idx++) {
auto &starting_child = starting_child_entries[field_idx];
auto update = new_entries.find(starting_types[field_idx].first.c_str());

if (update == new_entries.end()) {
// No update present, copy from source
result_child_entries[field_idx]->Reference(*starting_child);
} else {
// We found a replacement of the same name to update
auto arg_idx = update->second;
result_child_entries[field_idx]->Reference(args.data[arg_idx]);
is_new_field[arg_idx] = false;
}
}

// Assign the new (not updated) children to the end of the result vector.
for (idx_t arg_idx = 1, field_idx = starting_child_entries.size(); arg_idx < args.ColumnCount(); arg_idx++) {
if (is_new_field[arg_idx]) {
result_child_entries[field_idx++]->Reference(args.data[arg_idx]);
}
}

result.Verify(args.size());
if (args.AllConstant()) {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
}
}

static unique_ptr<FunctionData> StructUpdateBind(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments.empty()) {
throw InvalidInputException("Missing required arguments for struct_update function.");
}
if (LogicalTypeId::STRUCT != arguments[0]->return_type.id()) {
throw InvalidInputException("The first argument to struct_update must be a STRUCT");
}
if (arguments.size() < 2) {
throw InvalidInputException("Can't update nothing into a STRUCT");
}

child_list_t<LogicalType> new_children;
auto &existing_children = StructType::GetChildTypes(arguments[0]->return_type);

auto incomming_children = case_insensitive_tree_t<idx_t>();
auto is_new_field = vector<bool>(arguments.size(), true);

// Validate incomming arguments and record names
for (idx_t arg_idx = 1; arg_idx < arguments.size(); arg_idx++) {
auto &child = arguments[arg_idx];
if (child->alias.empty()) {
throw BinderException("Need named argument for struct update, e.g., a := b");
} else if (incomming_children.find(child->alias) != incomming_children.end()) {
throw InvalidInputException("Duplicate named argument provided for %s", child->alias.c_str());
}
incomming_children.emplace(child->alias, arg_idx);
}

for (idx_t field_idx = 0; field_idx < existing_children.size(); field_idx++) {
auto &existing_child = existing_children[field_idx];
auto update = incomming_children.find(existing_child.first);
if (update == incomming_children.end()) {
// No update provided for the named value
new_children.push_back(make_pair(existing_child.first, existing_child.second));
} else {
// Update the struct with the new data of the same name
auto arg_idx = update->second;
auto &new_child = arguments[arg_idx];
new_children.push_back(make_pair(new_child->alias, new_child->return_type));
is_new_field[arg_idx] = false;
}
}

// Loop through the additional arguments (name/value pairs)
for (idx_t arg_idx = 1; arg_idx < arguments.size(); arg_idx++) {
if (is_new_field[arg_idx]) {
auto &child = arguments[arg_idx];
new_children.push_back(make_pair(child->alias, child->return_type));
}
}

bound_function.return_type = LogicalType::STRUCT(new_children);
return make_uniq<VariableReturnBindData>(bound_function.return_type);
}

unique_ptr<BaseStatistics> StructUpdateStats(ClientContext &context, FunctionStatisticsInput &input) {
auto &child_stats = input.child_stats;
auto &expr = input.expr;

auto incomming_children = case_insensitive_tree_t<idx_t>();
auto is_new_field = vector<bool>(expr.children.size(), true);
auto new_stats = StructStats::CreateUnknown(expr.return_type);

for (idx_t arg_idx = 1; arg_idx < expr.children.size(); arg_idx++) {
auto &new_child = expr.children[arg_idx];
incomming_children.emplace(new_child->alias, arg_idx);
}

auto existing_type = child_stats[0].GetType();
auto existing_count = StructType::GetChildCount(existing_type);
auto existing_stats = StructStats::GetChildStats(child_stats[0]);
for (idx_t field_idx = 0; field_idx < existing_count; field_idx++) {
auto &existing_child = existing_stats[field_idx];
auto update = incomming_children.find(StructType::GetChildName(existing_type, field_idx));
if (update == incomming_children.end()) {
StructStats::SetChildStats(new_stats, field_idx, existing_child);
} else {
auto arg_idx = update->second;
StructStats::SetChildStats(new_stats, field_idx, child_stats[arg_idx]);
is_new_field[arg_idx] = false;
}
}

for (idx_t arg_idx = 1, field_idx = existing_count; arg_idx < expr.children.size(); arg_idx++) {
if (is_new_field[arg_idx]) {
StructStats::SetChildStats(new_stats, field_idx++, child_stats[arg_idx]);
}
}

return new_stats.ToUnique();
}

ScalarFunction StructUpdateFun::GetFunction() {
ScalarFunction fun({}, LogicalTypeId::STRUCT, StructUpdateFunction, StructUpdateBind, nullptr, StructUpdateStats);
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
fun.varargs = LogicalType::ANY;
fun.serialize = VariableReturnBindData::Serialize;
fun.deserialize = VariableReturnBindData::Deserialize;
return fun;
}

} // namespace duckdb
3 changes: 2 additions & 1 deletion src/duckdb/extension/parquet/include/column_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ColumnWriter {
throw NotImplementedException("Writer does not need analysis");
}

virtual void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) = 0;
virtual void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) = 0;

virtual void BeginWrite(ColumnWriterState &state) = 0;
virtual void Write(ColumnWriterState &state, Vector &vector, idx_t count) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ArrayColumnWriter : public ListColumnWriter {

public:
void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) override;
void Write(ColumnWriterState &state, Vector &vector, idx_t count) override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class ListColumnWriter : public ColumnWriter {
bool HasAnalyze() override;
void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void FinalizeAnalyze(ColumnWriterState &state) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) override;

void BeginWrite(ColumnWriterState &state) override;
void Write(ColumnWriterState &state, Vector &vector, idx_t count) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class PrimitiveColumnWriter : public ColumnWriter {

public:
unique_ptr<ColumnWriterState> InitializeWriteState(duckdb_parquet::RowGroup &row_group) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) override;
void BeginWrite(ColumnWriterState &state) override;
void Write(ColumnWriterState &state, Vector &vector, idx_t count) override;
void FinalizeWrite(ColumnWriterState &state) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class StructColumnWriter : public ColumnWriter {
bool HasAnalyze() override;
void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void FinalizeAnalyze(ColumnWriterState &state) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override;
void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) override;

void BeginWrite(ColumnWriterState &state) override;
void Write(ColumnWriterState &state, Vector &vector, idx_t count) override;
Expand Down
2 changes: 1 addition & 1 deletion src/duckdb/extension/parquet/include/zstd_file_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace duckdb {

class ZStdFileSystem : public CompressedFileSystem {
public:
unique_ptr<FileHandle> OpenCompressedFile(unique_ptr<FileHandle> handle, bool write) override;
unique_ptr<FileHandle> OpenCompressedFile(QueryContext context, unique_ptr<FileHandle> handle, bool write) override;

std::string GetName() const override {
return "ZStdFileSystem";
Expand Down
2 changes: 1 addition & 1 deletion src/duckdb/extension/parquet/parquet_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro

for (auto &chunk : buffer.Chunks({column_ids})) {
for (idx_t i = 0; i < next; i++) {
col_writers[i].get().Prepare(*write_states[i], nullptr, chunk.data[i], chunk.size());
col_writers[i].get().Prepare(*write_states[i], nullptr, chunk.data[i], chunk.size(), true);
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/duckdb/extension/parquet/writer/array_column_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ void ArrayColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *p
child_writer->Analyze(*state.child_state, &state_p, array_child, array_size * count);
}

void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) {
void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) {
auto &state = state_p.Cast<ListColumnWriterState>();

auto array_size = ArrayType::GetSize(vector.GetType());
Expand Down Expand Up @@ -66,7 +67,9 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p
state.parent_index += vcount;

auto &array_child = ArrayVector::GetEntry(vector);
child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size);
// The elements of a single array should not span multiple Parquet pages
// So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false"
child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size, false);
}

void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) {
Expand Down
7 changes: 5 additions & 2 deletions src/duckdb/extension/parquet/writer/list_column_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset,
return total_length;
}

void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) {
void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) {
auto &state = state_p.Cast<ListColumnWriterState>();

auto list_data = FlatVector::GetData<list_entry_t>(vector);
Expand Down Expand Up @@ -111,7 +112,9 @@ void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *pa
auto &list_child = ListVector::GetEntry(vector);
Vector child_list(list_child);
auto child_length = GetConsecutiveChildList(vector, child_list, 0, count);
child_writer->Prepare(*state.child_state, &state_p, child_list, child_length);
// The elements of a single list should not span multiple Parquet pages
// So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false"
child_writer->Prepare(*state.child_state, &state_p, child_list, child_length, false);
}

void ListColumnWriter::BeginWrite(ColumnWriterState &state_p) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ unique_ptr<ColumnWriterPageState> PrimitiveColumnWriter::InitializePageState(Pri
void PrimitiveColumnWriter::FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state) {
}

void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector,
idx_t count) {
void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) {
auto &state = state_p.Cast<PrimitiveColumnWriterState>();
auto &col_chunk = state.row_group.columns[state.col_idx];

Expand Down Expand Up @@ -70,6 +70,10 @@ void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterStat
if (validity.RowIsValid(vector_index)) {
page_info.estimated_page_size += GetRowSize(vector, vector_index, state);
if (page_info.estimated_page_size >= MAX_UNCOMPRESSED_PAGE_SIZE) {
if (!vector_can_span_multiple_pages && i != 0) {
// Vector is not allowed to span multiple pages, and we already started writing it
continue;
}
PageInformation new_info;
new_info.offset = page_info.offset + page_info.row_count;
state.page_info.push_back(new_info);
Expand Down
6 changes: 4 additions & 2 deletions src/duckdb/extension/parquet/writer/struct_column_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ void StructColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) {
}
}

void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) {
void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count,
bool vector_can_span_multiple_pages) {
auto &state = state_p.Cast<StructColumnWriterState>();

auto &validity = FlatVector::Validity(vector);
Expand All @@ -69,7 +70,8 @@ void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *
HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1);
auto &child_vectors = StructVector::GetEntries(vector);
for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) {
child_writers[child_idx]->Prepare(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], count);
child_writers[child_idx]->Prepare(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], count,
vector_can_span_multiple_pages);
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/duckdb/extension/parquet/zstd_file_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct ZstdStreamWrapper : public StreamWrapper {
bool writing = false;

public:
void Initialize(CompressedFile &file, bool write) override;
void Initialize(QueryContext context, CompressedFile &file, bool write) override;
bool Read(StreamData &stream_data) override;
void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override;

Expand All @@ -32,7 +32,7 @@ ZstdStreamWrapper::~ZstdStreamWrapper() {
}
}

void ZstdStreamWrapper::Initialize(CompressedFile &file, bool write) {
void ZstdStreamWrapper::Initialize(QueryContext context, CompressedFile &file, bool write) {
Close();
this->file = &file;
this->writing = write;
Expand Down Expand Up @@ -156,9 +156,9 @@ void ZstdStreamWrapper::Close() {

class ZStdFile : public CompressedFile {
public:
ZStdFile(unique_ptr<FileHandle> child_handle_p, const string &path, bool write)
ZStdFile(QueryContext context, unique_ptr<FileHandle> child_handle_p, const string &path, bool write)
: CompressedFile(zstd_fs, std::move(child_handle_p), path) {
Initialize(write);
Initialize(context, write);
}

FileCompressionType GetFileCompressionType() override {
Expand All @@ -168,9 +168,10 @@ class ZStdFile : public CompressedFile {
ZStdFileSystem zstd_fs;
};

unique_ptr<FileHandle> ZStdFileSystem::OpenCompressedFile(unique_ptr<FileHandle> handle, bool write) {
unique_ptr<FileHandle> ZStdFileSystem::OpenCompressedFile(QueryContext context, unique_ptr<FileHandle> handle,
bool write) {
auto path = handle->path;
return make_uniq<ZStdFile>(std::move(handle), path, write);
return make_uniq<ZStdFile>(context, std::move(handle), path, write);
}

unique_ptr<StreamWrapper> ZStdFileSystem::CreateStream() {
Expand Down
Loading
Loading