Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/duckdb/extension/parquet/parquet_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "duckdb/common/serializer/write_stream.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/function/table_function.hpp"
#include "duckdb/main/extension_helper.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/main/connection.hpp"
#include "duckdb/parser/parsed_data/create_copy_function_info.hpp"
Expand Down Expand Up @@ -374,6 +375,12 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file

if (encryption_config) {
auto &config = DBConfig::GetConfig(context);

// To ensure we can write, we need to autoload httpfs
if (!config.encryption_util || !config.encryption_util->SupportsEncryption()) {
ExtensionHelper::TryAutoLoadExtension(context, "httpfs");
}

if (config.encryption_util && debug_use_openssl) {
// Use OpenSSL
encryption_util = config.encryption_util;
Expand Down
19 changes: 10 additions & 9 deletions src/duckdb/src/common/encryption_key_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ EncryptionKey::EncryptionKey(data_ptr_t encryption_key_p) {
D_ASSERT(memcmp(key, encryption_key_p, MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH) == 0);

// zero out the encryption key in memory
memset(encryption_key_p, 0, MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(encryption_key_p,
MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH);
LockEncryptionKey(key);
}

Expand All @@ -37,7 +38,7 @@ void EncryptionKey::LockEncryptionKey(data_ptr_t key, idx_t key_len) {
}

void EncryptionKey::UnlockEncryptionKey(data_ptr_t key, idx_t key_len) {
memset(key, 0, key_len);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(key, key_len);
#if defined(_WIN32)
VirtualUnlock(key, key_len);
#else
Expand All @@ -64,15 +65,16 @@ EncryptionKeyManager &EncryptionKeyManager::Get(DatabaseInstance &db) {

string EncryptionKeyManager::GenerateRandomKeyID() {
uint8_t key_id[KEY_ID_BYTES];
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(key_id, KEY_ID_BYTES);
RandomEngine engine;
engine.RandomData(key_id, KEY_ID_BYTES);
string key_id_str(reinterpret_cast<const char *>(key_id), KEY_ID_BYTES);
return key_id_str;
}

void EncryptionKeyManager::AddKey(const string &key_name, data_ptr_t key) {
derived_keys.emplace(key_name, EncryptionKey(key));
// Zero-out the encryption key
std::memset(key, 0, DERIVED_KEY_LENGTH);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(key, DERIVED_KEY_LENGTH);
}

bool EncryptionKeyManager::HasKey(const string &key_name) const {
Expand Down Expand Up @@ -107,7 +109,7 @@ string EncryptionKeyManager::Base64Decode(const string &key) {
auto output = duckdb::unique_ptr<unsigned char[]>(new unsigned char[result_size]);
Blob::FromBase64(key, output.get(), result_size);
string decoded_key(reinterpret_cast<const char *>(output.get()), result_size);
memset(output.get(), 0, result_size);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(output.get(), result_size);
return decoded_key;
}

Expand All @@ -124,10 +126,9 @@ void EncryptionKeyManager::DeriveKey(string &user_key, data_ptr_t salt, data_ptr

KeyDerivationFunctionSHA256(reinterpret_cast<const_data_ptr_t>(decoded_key.data()), decoded_key.size(), salt,
derived_key);

// wipe the original and decoded key
std::fill(user_key.begin(), user_key.end(), 0);
std::fill(decoded_key.begin(), decoded_key.end(), 0);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&user_key[0]), user_key.size());
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&decoded_key[0]),
decoded_key.size());
user_key.clear();
decoded_key.clear();
}
Expand Down
10 changes: 10 additions & 0 deletions src/duckdb/src/common/random_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,14 @@ void RandomEngine::SetSeed(uint64_t seed) {
random_state->pcg.seed(seed);
}

void RandomEngine::RandomData(duckdb::data_ptr_t data, duckdb::idx_t len) {
while (len) {
const auto random_integer = NextRandomInteger();
const auto next = duckdb::MinValue<duckdb::idx_t>(len, sizeof(random_integer));
memcpy(data, duckdb::const_data_ptr_cast(&random_integer), next);
data += next;
len -= next;
}
}

} // namespace duckdb
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const {
return false;
}
auto &other = other_p.Cast<PhysicalTableScan>();
if (function.function != other.function.function) {
if (function != other.function) {
return false;
}
if (column_ids != other.column_ids) {
Expand Down
6 changes: 3 additions & 3 deletions src/duckdb/src/function/table/version/pragma_version.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef DUCKDB_PATCH_VERSION
#define DUCKDB_PATCH_VERSION "2-dev305"
#define DUCKDB_PATCH_VERSION "2-dev332"
#endif
#ifndef DUCKDB_MINOR_VERSION
#define DUCKDB_MINOR_VERSION 4
Expand All @@ -8,10 +8,10 @@
#define DUCKDB_MAJOR_VERSION 1
#endif
#ifndef DUCKDB_VERSION
#define DUCKDB_VERSION "v1.4.2-dev305"
#define DUCKDB_VERSION "v1.4.2-dev332"
#endif
#ifndef DUCKDB_SOURCE_ID
#define DUCKDB_SOURCE_ID "8090b8d52e"
#define DUCKDB_SOURCE_ID "0efe5ccb5b"
#endif
#include "duckdb/function/table/system_functions.hpp"
#include "duckdb/main/database.hpp"
Expand Down
24 changes: 24 additions & 0 deletions src/duckdb/src/function/table_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,30 @@ TableFunction::TableFunction(const vector<LogicalType> &arguments, table_functio
TableFunction::TableFunction() : TableFunction("", {}, nullptr, nullptr, nullptr, nullptr) {
}

bool TableFunction::operator==(const TableFunction &rhs) const {
return name == rhs.name && arguments == rhs.arguments && varargs == rhs.varargs && bind == rhs.bind &&
bind_replace == rhs.bind_replace && bind_operator == rhs.bind_operator && init_global == rhs.init_global &&
init_local == rhs.init_local && function == rhs.function && in_out_function == rhs.in_out_function &&
in_out_function_final == rhs.in_out_function_final && statistics == rhs.statistics &&
dependency == rhs.dependency && cardinality == rhs.cardinality &&
pushdown_complex_filter == rhs.pushdown_complex_filter && pushdown_expression == rhs.pushdown_expression &&
to_string == rhs.to_string && dynamic_to_string == rhs.dynamic_to_string &&
table_scan_progress == rhs.table_scan_progress && get_partition_data == rhs.get_partition_data &&
get_bind_info == rhs.get_bind_info && type_pushdown == rhs.type_pushdown &&
get_multi_file_reader == rhs.get_multi_file_reader && supports_pushdown_type == rhs.supports_pushdown_type &&
get_partition_info == rhs.get_partition_info && get_partition_stats == rhs.get_partition_stats &&
get_virtual_columns == rhs.get_virtual_columns && get_row_id_columns == rhs.get_row_id_columns &&
serialize == rhs.serialize && deserialize == rhs.deserialize &&
verify_serialization == rhs.verify_serialization && projection_pushdown == rhs.projection_pushdown &&
filter_pushdown == rhs.filter_pushdown && filter_prune == rhs.filter_prune &&
sampling_pushdown == rhs.sampling_pushdown && late_materialization == rhs.late_materialization &&
global_initialization == rhs.global_initialization;
}

bool TableFunction::operator!=(const TableFunction &rhs) const {
return !(*this == rhs);
}

bool TableFunction::Equal(const TableFunction &rhs) const {
// number of types
if (this->arguments.size() != rhs.arguments.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class EncryptionKeyManager : public ObjectCacheEntry {
static void KeyDerivationFunctionSHA256(data_ptr_t user_key, idx_t user_key_size, data_ptr_t salt,
data_ptr_t derived_key);
static string Base64Decode(const string &key);

//! Generate a (non-cryptographically secure) random key ID
static string GenerateRandomKeyID();

public:
Expand Down
5 changes: 5 additions & 0 deletions src/duckdb/src/include/duckdb/common/encryption_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class EncryptionUtil {

virtual ~EncryptionUtil() {
}

//! Whether the EncryptionUtil supports encryption (some may only support decryption)
DUCKDB_API virtual bool SupportsEncryption() {
return true;
}
};

} // namespace duckdb
1 change: 1 addition & 0 deletions src/duckdb/src/include/duckdb/common/http_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ struct PostRequestInfo : public BaseRequest {
class HTTPClient {
public:
virtual ~HTTPClient() = default;
virtual void Initialize(HTTPParams &http_params) = 0;

virtual unique_ptr<HTTPResponse> Get(GetRequestInfo &info) = 0;
virtual unique_ptr<HTTPResponse> Put(PutRequestInfo &info) = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/duckdb/src/include/duckdb/common/random_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class RandomEngine {

void SetSeed(uint64_t seed);

void RandomData(duckdb::data_ptr_t data, duckdb::idx_t len);

static RandomEngine &Get(ClientContext &context);

mutex lock;
Expand Down
2 changes: 2 additions & 0 deletions src/duckdb/src/include/duckdb/function/table_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou
TableFunctionInitialization global_initialization = TableFunctionInitialization::INITIALIZE_ON_EXECUTE;

DUCKDB_API bool Equal(const TableFunction &rhs) const;
DUCKDB_API bool operator==(const TableFunction &rhs) const;
DUCKDB_API bool operator!=(const TableFunction &rhs) const;
};

} // namespace duckdb
2 changes: 1 addition & 1 deletion src/duckdb/src/include/duckdb/main/database.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DatabaseInstance : public enable_shared_from_this<DatabaseInstance> {

DUCKDB_API SettingLookupResult TryGetCurrentSetting(const string &key, Value &result) const;

DUCKDB_API shared_ptr<EncryptionUtil> GetEncryptionUtil() const;
DUCKDB_API shared_ptr<EncryptionUtil> GetEncryptionUtil();

shared_ptr<AttachedDatabase> CreateAttachedDatabase(ClientContext &context, AttachInfo &info,
AttachOptions &options);
Expand Down
7 changes: 0 additions & 7 deletions src/duckdb/src/include/duckdb/storage/storage_options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,4 @@ struct StorageOptions {
void Initialize(const unordered_map<string, Value> &options);
};

inline void ClearUserKey(shared_ptr<string> const &encryption_key) {
if (encryption_key && !encryption_key->empty()) {
memset(&(*encryption_key)[0], 0, encryption_key->size());
encryption_key->clear();
}
}

} // namespace duckdb
2 changes: 1 addition & 1 deletion src/duckdb/src/logging/log_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ void LogManager::SetEnableStructuredLoggers(vector<string> &enabled_logger_types
throw InvalidInputException("Unknown log type: '%s'", enabled_logger_type);
}

new_config.enabled_log_types.insert(enabled_logger_type);
new_config.enabled_log_types.insert(lookup->name);

min_log_level = MinValue(min_log_level, lookup->level);
}
Expand Down
10 changes: 8 additions & 2 deletions src/duckdb/src/main/database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,18 @@ SettingLookupResult DatabaseInstance::TryGetCurrentSetting(const string &key, Va
return db_config.TryGetCurrentSetting(key, result);
}

shared_ptr<EncryptionUtil> DatabaseInstance::GetEncryptionUtil() const {
shared_ptr<EncryptionUtil> DatabaseInstance::GetEncryptionUtil() {
if (!config.encryption_util || !config.encryption_util->SupportsEncryption()) {
ExtensionHelper::TryAutoLoadExtension(*this, "httpfs");
}

if (config.encryption_util) {
return config.encryption_util;
}

return make_shared_ptr<duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLSFactory>();
auto result = make_shared_ptr<duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLSFactory>();

return std::move(result);
}

ValidChecker &DatabaseInstance::GetValidChecker() {
Expand Down
5 changes: 4 additions & 1 deletion src/duckdb/src/main/http/http_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ BaseRequest::BaseRequest(RequestType type, const string &url, const HTTPHeaders
class HTTPLibClient : public HTTPClient {
public:
HTTPLibClient(HTTPParams &http_params, const string &proto_host_port) {
client = make_uniq<duckdb_httplib::Client>(proto_host_port);
Initialize(http_params);
}
void Initialize(HTTPParams &http_params) override {
auto sec = static_cast<time_t>(http_params.timeout);
auto usec = static_cast<time_t>(http_params.timeout_usec);
client = make_uniq<duckdb_httplib::Client>(proto_host_port);
client->set_follow_location(http_params.follow_location);
client->set_keep_alive(http_params.keep_alive);
client->set_write_timeout(sec, usec);
Expand Down
36 changes: 33 additions & 3 deletions src/duckdb/src/storage/single_file_block_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ void DeserializeEncryptionData(ReadStream &stream, data_t *dest, idx_t size) {

void GenerateDBIdentifier(uint8_t *db_identifier) {
memset(db_identifier, 0, MainHeader::DB_IDENTIFIER_LEN);
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(db_identifier,
MainHeader::DB_IDENTIFIER_LEN);
RandomEngine engine;
engine.RandomData(db_identifier, MainHeader::DB_IDENTIFIER_LEN);
}

void EncryptCanary(MainHeader &main_header, const shared_ptr<EncryptionState> &encryption_state,
Expand Down Expand Up @@ -362,6 +362,15 @@ void SingleFileBlockManager::CheckAndAddEncryptionKey(MainHeader &main_header) {
void SingleFileBlockManager::CreateNewDatabase(QueryContext context) {
auto flags = GetFileFlags(true);

auto encryption_enabled = options.encryption_options.encryption_enabled;
if (encryption_enabled) {
if (!db.GetDatabase().GetEncryptionUtil()->SupportsEncryption() && !options.read_only) {
throw InvalidConfigurationException(
"The database was opened with encryption enabled, but DuckDB currently has a read-only crypto module "
"loaded. Please re-open using READONLY, or ensure httpfs is loaded using `LOAD httpfs`.");
}
}

// open the RDBMS handle
auto &fs = FileSystem::Get(db);
handle = fs.OpenFile(path, flags);
Expand All @@ -376,7 +385,6 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) {
// Derive the encryption key and add it to the cache.
// Not used for plain databases.
data_t derived_key[MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH];
auto encryption_enabled = options.encryption_options.encryption_enabled;

// We need the unique database identifier, if the storage version is new enough.
// If encryption is enabled, we also use it as the salt.
Expand Down Expand Up @@ -487,6 +495,15 @@ void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) {
if (main_header.IsEncrypted()) {
if (options.encryption_options.encryption_enabled) {
//! Encryption is set

//! Check if our encryption module can write, if not, we should throw here
if (!db.GetDatabase().GetEncryptionUtil()->SupportsEncryption() && !options.read_only) {
throw InvalidConfigurationException(
"The database is encrypted, but DuckDB currently has a read-only crypto module loaded. Either "
"re-open the database using `ATTACH '..' (READONLY)`, or ensure httpfs is loaded using `LOAD "
"httpfs`.");
}

//! Check if the given key upon attach is correct
// Derive the encryption key and add it to cache
CheckAndAddEncryptionKey(main_header);
Expand All @@ -506,6 +523,19 @@ void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) {
path, EncryptionTypes::CipherToString(config_cipher),
EncryptionTypes::CipherToString(stored_cipher));
}

// This avoids the cipher from being downgrades by an attacker FIXME: we likely want to have a propervalidation
// of the cipher used instead of this trick to avoid downgrades
if (stored_cipher != EncryptionTypes::GCM) {
if (config_cipher == EncryptionTypes::INVALID) {
throw CatalogException(
"Cannot open encrypted database \"%s\" without explicitly specifying the "
"encryption cipher for security reasons. Please make sure you understand the security implications "
"and re-attach the database specifying the desired cipher.",
path);
}
}

// this is ugly, but the storage manager does not know the cipher type before
db.GetStorageManager().SetCipher(stored_cipher);
}
Expand Down
8 changes: 8 additions & 0 deletions src/duckdb/src/storage/storage_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ bool StorageManager::InMemory() const {
return path == IN_MEMORY_PATH;
}

inline void ClearUserKey(shared_ptr<string> const &encryption_key) {
if (encryption_key && !encryption_key->empty()) {
duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&(*encryption_key)[0]),
encryption_key->size());
encryption_key->clear();
}
}

void StorageManager::Initialize(QueryContext context) {
bool in_memory = InMemory();
if (in_memory && read_only) {
Expand Down
5 changes: 5 additions & 0 deletions src/duckdb/third_party/mbedtls/include/mbedtls_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class AESStateMBEDTLS : public duckdb::EncryptionState {
DUCKDB_API void GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) override;
DUCKDB_API void FinalizeGCM(duckdb::data_ptr_t tag, duckdb::idx_t tag_len);
DUCKDB_API const mbedtls_cipher_info_t *GetCipher(size_t key_len);
DUCKDB_API static void SecureClearData(duckdb::data_ptr_t data, duckdb::idx_t len);

private:
DUCKDB_API void InitializeInternal(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len);
Expand All @@ -98,6 +99,10 @@ class AESStateMBEDTLS : public duckdb::EncryptionState {
}

~AESStateMBEDTLSFactory() override {} //

DUCKDB_API bool SupportsEncryption() override {
return false;
}
};
};

Expand Down
Loading