diff --git a/MODULE.bazel b/MODULE.bazel index fbe9b41fc..02404b645 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -100,3 +100,8 @@ http_jar( sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], ) + +bazel_dep( + name = "yaml-cpp", + version = "0.9.0", +) diff --git a/env/BUILD b/env/BUILD new file mode 100644 index 000000000..d1f11b17c --- /dev/null +++ b/env/BUILD @@ -0,0 +1,293 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + "//common:constant", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + deps = [ + ":config", + "//checker:type_checker_builder", + "//common:constant", + "//common:decl", + "//common:type", + "//common:type_kind", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//env/internal:ext_registry", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_runtime", + srcs = ["env_runtime.cc"], + hdrs = ["env_runtime.h"], + deps = [ + ":config", + "//env/internal:runtime_ext_registry", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_std_extensions", + srcs = ["env_std_extensions.cc"], + hdrs = ["env_std_extensions.h"], + deps = [ + ":env", + "//checker:optional", + "//compiler:optional", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:proto_ext", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + ], +) + +cc_library( + name = "env_yaml", + srcs = ["env_yaml.cc"], + hdrs = ["env_yaml.h"], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + "//common:constant", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@yaml-cpp", + ], +) + +cc_library( + name = "runtime_std_extensions", + srcs = ["runtime_std_extensions.cc"], + hdrs = ["runtime_std_extensions.h"], + deps = [ + ":env_runtime", + "//checker:optional", + "//env/internal:runtime_ext_registry", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:sets_functions", + "//extensions:strings", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "config_test", + srcs = ["config_test.cc"], + deps = [ + ":config", + "//common:constant", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "env_test", + srcs = ["env_test.cc"], + deps = [ + ":config", + ":env", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:constant", + "//common:decl", + "//common:expr", + "//common:type", + "//common:value", + "//compiler", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_runtime_test", + srcs = ["env_runtime_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":env_yaml", + ":runtime_std_extensions", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//common:value", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_std_extensions_test", + srcs = ["env_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_std_extensions", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "env_yaml_test", + srcs = ["env_yaml_test.cc"], + deps = [ + ":config", + ":env_yaml", + "//common:constant", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "runtime_std_extensions_test", + srcs = ["runtime_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":runtime_std_extensions", + "//checker:optional", + "//checker:validation_result", + "//common:ast", + "//common:value", + "//compiler", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:strings", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/config.cc b/env/config.cc new file mode 100644 index 000000000..ccb4de34c --- /dev/null +++ b/env/config.cc @@ -0,0 +1,190 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +const char* ConstantKindToTypeName(const ConstantKind& kind) { + return std::visit(absl::Overload{ + [](const std::monostate& arg) { return "dyn"; }, + [](const std::nullptr_t& arg) { return "null"; }, + [](bool arg) { return "bool"; }, + [](int64_t arg) { return "int"; }, + [](uint64_t arg) { return "uint"; }, + [](double arg) { return "double"; }, + [](const BytesConstant& arg) { return "bytes"; }, + [](const StringConstant& arg) { return "string"; }, + [](absl::Duration arg) { return "duration"; }, + [](absl::Time arg) { return "timestamp"; }, + }, + kind); +} +} // namespace + +absl::Status Config::AddExtensionConfig(std::string name, int version) { + for (const ExtensionConfig& extension_config : extension_configs_) { + if (extension_config.name == name) { + if (extension_config.version == version) { + return absl::OkStatus(); + } + return absl::AlreadyExistsError(absl::StrCat( + "Extension '", name, "' version ", extension_config.version, + " is already included. Cannot also include version ", version)); + } + } + extension_configs_.push_back( + ExtensionConfig{.name = std::move(name), .version = version}); + return absl::OkStatus(); +} + +absl::Status Config::SetStandardLibraryConfig( + const Config::StandardLibraryConfig& standard_library_config) { + if (!standard_library_config.included_macros.empty() && + !standard_library_config.excluded_macros.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded macros."); + } + + if (!standard_library_config.included_functions.empty() && + !standard_library_config.excluded_functions.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded functions."); + } + + absl::flat_hash_set included_function_names; + for (const auto& function : standard_library_config.included_functions) { + if (function.second.empty()) { + included_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.included_functions) { + if (included_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot include function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + absl::flat_hash_set excluded_function_names; + for (const auto& function : standard_library_config.excluded_functions) { + if (function.second.empty()) { + excluded_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.excluded_functions) { + if (excluded_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot exclude function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + standard_library_config_ = standard_library_config; + return absl::OkStatus(); +} + +absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { + for (const VariableConfig& existing_variable_config : variable_configs_) { + if (existing_variable_config.name == variable_config.name) { + return absl::AlreadyExistsError(absl::StrCat( + "Variable '", variable_config.name, "' is already included.")); + } + } + if (variable_config.value.has_value()) { + absl::string_view constant_type_name = + ConstantKindToTypeName(variable_config.value.kind()); + if (constant_type_name != variable_config.type_info.name) { + return absl::InvalidArgumentError( + absl::StrCat("Variable '", variable_config.name, "' has type ", + variable_config.type_info.name, + " but is assigned a constant value of type ", + constant_type_name, ".")); + } + } + variable_configs_.push_back(variable_config); + return absl::OkStatus(); +} + +absl::Status Config::ValidateFunctionConfig( + const FunctionConfig& function_config) { + for (const auto& overload : function_config.overload_configs) { + if (overload.is_member_function && overload.parameters.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Function '", function_config.name, "' overload '", + overload.overload_id, + "' is marked as a member function but has no parameters. Member " + "functions must have at least one parameter (target).")); + } + } + return absl::OkStatus(); +} + +absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { + CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); + function_configs_.push_back(function_config); + return absl::OkStatus(); +} + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config) { + os << "StandardLibraryConfig("; + if (!config.included_macros.empty()) { + os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); + } + if (!config.excluded_macros.empty()) { + os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); + } + if (!config.included_functions.empty()) { + os << "\n included_functions=" + << absl::StrJoin(config.included_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + if (!config.excluded_functions.empty()) { + os << "\n excluded_functions=" + << absl::StrJoin(config.excluded_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + os << "\n)"; + return os; +} + +} // namespace cel diff --git a/env/config.h b/env/config.h new file mode 100644 index 000000000..10b23d030 --- /dev/null +++ b/env/config.h @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ +#define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel { + +class Config { + public: + void SetName(std::string name) { name_ = std::move(name); } + std::string GetName() const { return name_; } + + struct ContainerConfig { + std::string name; + // TODO(uncreated-issue/87): add support for aliases and abbreviations. + + bool IsEmpty() const { return name.empty(); } + }; + + void SetContainerConfig(ContainerConfig container_config) { + container_config_ = std::move(container_config); + } + + const ContainerConfig& GetContainerConfig() const { + return container_config_; + } + + struct ExtensionConfig { + static constexpr int kLatest = std::numeric_limits::max(); + + std::string name; + int version = kLatest; + }; + + absl::Status AddExtensionConfig(std::string name, + int version = ExtensionConfig::kLatest); + + const std::vector& GetExtensionConfigs() const { + return extension_configs_; + } + + struct StandardLibraryConfig { + // Exclude the entire standard library. + bool disable = false; + + // Exclude all standard library macros. + bool disable_macros = false; + + // Either included or excluded macros can be set, not both. If neither are + // set, all standard library macros are included. + absl::flat_hash_set included_macros; + absl::flat_hash_set excluded_macros; + + // Sets of pairs of function name and overload id to include or exclude. + // Either included or excluded functions can be set, not both. If neither + // are set, all standard library functions are included. + // If an overload is specified, only that overload is included or excluded. + // If no overload is specified (empty second element of pair), all overloads + // are included or excluded. + absl::flat_hash_set> included_functions; + absl::flat_hash_set> excluded_functions; + + bool IsEmpty() const { + return !disable && !disable_macros && included_macros.empty() && + excluded_macros.empty() && included_functions.empty() && + excluded_functions.empty(); + } + }; + + absl::Status SetStandardLibraryConfig( + const StandardLibraryConfig& standard_library_config); + + const StandardLibraryConfig& GetStandardLibraryConfig() const { + return standard_library_config_; + } + + struct TypeInfo { + std::string name; + std::vector params; + bool is_type_param = false; + }; + + struct VariableConfig { + std::string name; + std::string description; + TypeInfo type_info; + Constant value; + }; + + // Adds a variable config to the environment. The variable name and type + // are used by the CEL type checker to validate expressions. The variable + // value is used as an input value at runtime. + // + // Returns an error if a variable with the same name already exists, or if the + // type of the constant value does not match the specified type. + absl::Status AddVariableConfig(const VariableConfig& variable_config); + + const std::vector& GetVariableConfigs() const { + return variable_configs_; + } + + struct FunctionOverloadConfig { + std::string overload_id; + std::vector examples; + bool is_member_function = false; + std::vector parameters; + TypeInfo return_type; + }; + + struct FunctionConfig { + std::string name; + std::string description; + std::vector overload_configs; + }; + + absl::Status AddFunctionConfig(const FunctionConfig& function_config); + + const std::vector& GetFunctionConfigs() const { + return function_configs_; + } + + private: + std::string name_; + ContainerConfig container_config_; + std::vector extension_configs_; + StandardLibraryConfig standard_library_config_; + std::vector variable_configs_; + std::vector function_configs_; + + absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); +}; + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ diff --git a/env/config_test.cc b/env/config_test.cc new file mode 100644 index 000000000..df0d6f875 --- /dev/null +++ b/env/config_test.cc @@ -0,0 +1,222 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(EnvConfigTest, ExtensionConfigs) { + Config config; + ASSERT_THAT( + config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), + IsOk()); + ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvConfigTest, ExtensionConfigConflict) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 3), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::string expected_error; // Empty if no error is expected. +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + + Config config; + absl::Status status = + config.SetStandardLibraryConfig(param.standard_library_config); + if (param.expected_error.empty()) { + EXPECT_THAT(status, IsOk()); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + ::testing::Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_macros = {"all", "exists"}, + .excluded_macros = {"map", "filter"}, + }, + .expected_error = "Cannot set both included and excluded macros.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add_list'", + })); + +TEST(VariableConfigTest, VariableConfig) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = + { + .name = "mytype", + .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, + }, + }; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + + ASSERT_EQ(config.GetVariableConfigs().size(), 1); + const auto& added_config = config.GetVariableConfigs()[0]; + EXPECT_EQ(added_config.type_info.name, "mytype"); + ASSERT_THAT(added_config.type_info.params.size(), 2); + EXPECT_EQ(added_config.type_info.params[0].name, "int"); + EXPECT_FALSE(added_config.type_info.params[0].is_type_param); + EXPECT_EQ(added_config.type_info.params[1].name, "A"); + EXPECT_TRUE(added_config.type_info.params[1].is_type_param); +} + +TEST(VariableConfigTest, VariableConfigConflict) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + .value = Constant(StringConstant("hello")), + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Variable 'test' has type int but is assigned " + "a constant value of type string."))); +} + +TEST(FunctionConfigTest, FunctionConfig) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.description = "Ultimate test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_with_pill", + .examples = {"oracle.isTheOne('Neo', RED)"}, + .is_member_function = true, + .parameters = {{.name = "string"}, {.name = "Choice"}}, + .return_type = {.name = "bool"}, + }); + ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); + ASSERT_EQ(config.GetFunctionConfigs().size(), 1); + const auto& added_config = config.GetFunctionConfigs()[0]; + EXPECT_EQ(added_config.name, "test"); + EXPECT_EQ(added_config.description, "Ultimate test"); + EXPECT_EQ(added_config.overload_configs.size(), 1); + + const auto& overload_config = added_config.overload_configs[0]; + EXPECT_EQ(overload_config.overload_id, "test_with_pill"); + EXPECT_THAT(overload_config.examples, + ElementsAre("oracle.isTheOne('Neo', RED)")); + EXPECT_TRUE(overload_config.is_member_function); + EXPECT_THAT( + overload_config.parameters, + ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), + Field(&Config::TypeInfo::is_type_param, false)), + AllOf(Field(&Config::TypeInfo::name, "Choice"), + Field(&Config::TypeInfo::is_type_param, false)))); + EXPECT_THAT(overload_config.return_type, + Field(&Config::TypeInfo::name, "bool")); +} + +TEST(FunctionConfigTest, FunctionConfigInvalidMember) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_member_no_params", + .is_member_function = true, + .parameters = {}, + }); + EXPECT_THAT(config.AddFunctionConfig(function_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("is marked as a member function but has no " + "parameters"))); +} + +} // namespace +} // namespace cel diff --git a/env/env.cc b/env/env.cc new file mode 100644 index 000000000..2c2555f14 --- /dev/null +++ b/env/env.cc @@ -0,0 +1,318 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, + absl::string_view macro) { + if (config.disable_macros) { + return false; + } + if (config.excluded_macros.contains(macro)) { + return false; + } + if (!config.included_macros.empty() && + !config.included_macros.contains(macro)) { + return false; + } + return true; +} + +bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, + absl::string_view function, + absl::string_view overload_id) { + if (config.excluded_functions.contains( + std::make_pair(std::string(function), std::string(overload_id))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + if (!config.included_functions.empty() && + !config.included_functions.contains( + std::make_pair(std::string(function), "")) && + !config.included_functions.contains( + std::make_pair(std::string(function), std::string(overload_id)))) { + return false; + } + return true; +} + +absl::StatusOr MakeStdlibSubset( + const Config::StandardLibraryConfig& standard_library_config) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + // Capturing by reference is safe. The returned CompilerLibrarySubset's + // callbacks are only used during CompilerBuilder::Build() to configure + // contributed functions and macros. They are not retained by the constructed + // Compiler instance. The referenced config outlives the Build() call. + subset.should_include_macro = [&standard_library_config](const Macro& macro) { + return ShouldIncludeMacro(standard_library_config, macro.function()); + }; + subset.should_include_overload = [&standard_library_config]( + absl::string_view function, + absl::string_view overload_id) { + return ShouldIncludeFunction(standard_library_config, function, + overload_id); + }; + return subset; +} + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return MessageType(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, arena, descriptor_pool)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], arena, descriptor_pool)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + arena, descriptor_pool)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], arena, descriptor_pool)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], arena, + descriptor_pool)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} + +absl::StatusOr FunctionConfigToFunctionDecl( + const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + FunctionDecl function_decl; + function_decl.set_name(function_config.name); + for (const Config::FunctionOverloadConfig& overload_config : + function_config.overload_configs) { + OverloadDecl overload_decl; + overload_decl.set_id(overload_config.overload_id); + overload_decl.set_member(overload_config.is_member_function); + for (const Config::TypeInfo& parameter : overload_config.parameters) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(parameter, arena, descriptor_pool)); + overload_decl.mutable_args().push_back(parameter_type); + } + CEL_ASSIGN_OR_RETURN( + Type return_type, + TypeInfoToType(overload_config.return_type, arena, descriptor_pool)); + overload_decl.set_result(return_type); + CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); + } + return function_decl; +} + +} // namespace + +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); + cel::TypeCheckerBuilder& checker_builder = + compiler_builder->GetCheckerBuilder(); + + checker_builder.set_container(config_.GetContainerConfig().name); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(StandardCompilerLibrary())); + CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, + MakeStdlibSubset(config_.GetStandardLibraryConfig())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); + } + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_ASSIGN_OR_RETURN(CompilerLibrary library, + extension_registry_.GetCompilerLibrary( + extension_config.name, extension_config.version)); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); + } + + google::protobuf::Arena* arena = checker_builder.arena(); + for (const Config::VariableConfig& variable_config : + config_.GetVariableConfigs()) { + VariableDecl variable_decl; + variable_decl.set_name(variable_config.name); + CEL_ASSIGN_OR_RETURN(Type type, + TypeInfoToType(variable_config.type_info, arena, + descriptor_pool_.get())); + variable_decl.set_type(type); + if (variable_config.value.has_value()) { + variable_decl.set_value(variable_config.value); + } + CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); + } + + for (const Config::FunctionConfig& function_config : + config_.GetFunctionConfigs()) { + CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, + FunctionConfigToFunctionDecl(function_config, arena, + descriptor_pool_.get())); + CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); + } + + return compiler_builder->Build(); +} + +} // namespace cel diff --git a/env/env.h b/env/env.h new file mode 100644 index 000000000..f46e5947c --- /dev/null +++ b/env/env.h @@ -0,0 +1,71 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ + +#include + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/internal/ext_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Env class establishes the environment for compiling CEL expressions. +// +// It is used to configure compiler options, extension functions, and other +// customizable CEL features. +class Env { + public: + // Registers a `CompilerLibrary` with the environment. Note that the library + // does not automatically get added to a `Compiler`. `NewCompiler` relies + // on `Config` to determine which libraries to load. + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + extension_registry_.RegisterCompilerLibrary(name, alias, version, + std::move(library_factory)); + } + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + const google::protobuf::DescriptorPool* GetDescriptorPool() const { + return descriptor_pool_.get(); + } + + void SetConfig(const Config& config) { config_ = config; } + + absl::StatusOr> NewCompiler(); + + private: + cel::env_internal::ExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + CompilerOptions compiler_options_; + Config config_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ diff --git a/env/env_runtime.cc b/env/env_runtime.cc new file mode 100644 index 000000000..7bd071570 --- /dev/null +++ b/env/env_runtime.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" + +namespace cel { + +absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + if (extension_config.name == "optional") { + runtime_options_.enable_qualified_type_identifiers = true; + break; + } + } + + CEL_ASSIGN_OR_RETURN( + RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR(RegisterStandardFunctions( + runtime_builder.function_registry(), runtime_options_)); + } + + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, extension_config.name, + extension_config.version)); + } + return runtime_builder; +} + +absl::StatusOr> EnvRuntime::NewRuntime() { + CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); + return std::move(runtime_builder).Build(); +} + +} // namespace cel diff --git a/env/env_runtime.h b/env/env_runtime.h new file mode 100644 index 000000000..26760fa1b --- /dev/null +++ b/env/env_runtime.h @@ -0,0 +1,75 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "env/internal/runtime_ext_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// EnvRuntime class establishes the environment for creating CEL runtimes. +// +// It is used to configure runtime options, extension functions, and other +// customizable CEL runtime features. +// +// The main reason that EnvRuntime is separate from Env is that compilation and +// evaluation are often done by different executables and we want to avoid +// linking both the compiler and runtime extensions into each binary +// unnecessarily. +// +// Even though EnvRuntime is separate from Env, the Config and DescriptorPool +// passed to EnvRuntime are expected to be the same as those passed to Env for +// compilation. This ensures consistency between compilation and runtime. +class EnvRuntime { + public: + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + void SetConfig(const Config& config) { config_ = config; } + + RuntimeOptions& mutable_runtime_options() { return runtime_options_; } + + absl::StatusOr CreateRuntimeBuilder(); + + // Shortcut for CreateRuntimeBuilder() followed by Build(). + absl::StatusOr> NewRuntime(); + + private: + cel::env_internal::RuntimeExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + Config config_; + RuntimeOptions runtime_options_; + + cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { + return extension_registry_; + } + + friend void RegisterStandardExtensions(EnvRuntime& env_runtime); +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc new file mode 100644 index 000000000..9b6b13554 --- /dev/null +++ b/env/env_runtime_test.cc @@ -0,0 +1,159 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; + +struct TestCase { + std::string config_yaml; + std::string expr; + bool expected_to_fail = false; +}; + +class EnvRuntimeTest : public testing::TestWithParam {}; + +TEST_P(EnvRuntimeTest, EndToEnd) { + const TestCase& param = GetParam(); + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + std::unique_ptr ast; + if (!param.expected_to_fail) { + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); + } else { + // Bypass type checking to allow compilation to succeed since we expect the + // runtime to fail. + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource(param.expr, "")); + ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); + } + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + + absl::StatusOr> program_or = + runtime->CreateProgram(std::move(ast)); + if (param.expected_to_fail) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr; + return; + } + + ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; + + std::unique_ptr program = std::move(program_or.value()); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; +} + +std::vector GetEnvRuntimeTestCases() { + return { + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + - name: "optional" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + )yaml", + .expr = "1 + 2 == 3", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "1 + 2 == 3", + .expected_to_fail = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, + ::testing::ValuesIn(GetEnvRuntimeTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_std_extensions.cc b/env/env_std_extensions.cc new file mode 100644 index 000000000..f2041b979 --- /dev/null +++ b/env/env_std_extensions.cc @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include "checker/optional.h" +#include "compiler/optional.h" +#include "env/env.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/proto_ext.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" + +namespace cel { + +void RegisterStandardExtensions(Env& env) { + env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { + return extensions::BindingsCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { + return extensions::EncodersCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.lists", "lists", version, + [version]() { return extensions::ListsCompilerLibrary(version); }); + } + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.math", "math", version, + [version]() { return extensions::MathCompilerLibrary(version); }); + } + for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { + env.RegisterCompilerLibrary("optional", "", version, [version]() { + return OptionalCompilerLibrary(version); + }); + } + env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { + return extensions::ProtoExtCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { + return extensions::SetsCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.strings", "strings", version, + [version]() { return extensions::StringsCompilerLibrary(version); }); + } + env.RegisterCompilerLibrary( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + []() { return extensions::ComprehensionsV2CompilerLibrary(); }); + env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { + return extensions::RegexExtCompilerLibrary(); + }); +} + +} // namespace cel diff --git a/env/env_std_extensions.h b/env/env_std_extensions.h new file mode 100644 index 000000000..79cf37dbf --- /dev/null +++ b/env/env_std_extensions.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ + +#include "env/env.h" + +namespace cel { + +// Registers the standard CEL extensions with the given environment. This makes +// them available, but does not enable them. See Env::Config for how to enable +// extensions. +// +// Extensions are registered under the following names: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// - cel.lib.ext.regex (alias: "regex") +void RegisterStandardExtensions(Env& env); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ diff --git a/env/env_std_extensions_test.cc b/env/env_std_extensions_test.cc new file mode 100644 index 000000000..7d9572cc0 --- /dev/null +++ b/env/env_std_extensions_test.cc @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::TestWithParam; + +struct TestCase { + std::string extension; + std::string expr; +}; + +class EnvStdExtensions : public testing::TestWithParam {}; + +TEST_P(EnvStdExtensions, RegistrationTest) { + const TestCase& param = GetParam(); + + Env env; + RegisterStandardExtensions(env); + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); + ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr + << " but got: " << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + RegistrationTest, EnvStdExtensions, + ::testing::Values( + TestCase{ + .extension = "cel.lib.ext.bindings", // official name + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "bindings", // alias + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "encoders", + .expr = "base64.encode(b'hello')", + }, + TestCase{ + .extension = "lists", + .expr = "[1, 2, 3].sort()", + }, + TestCase{ + .extension = "lists", + .expr = "['a'].sortBy(e, e)", + }, + TestCase{ + .extension = "math", + .expr = "math.sqrt(-1)", + }, + TestCase{ + .extension = "optional", + .expr = "[1, 2].first()", + }, + TestCase{ + .extension = "optional", + .expr = "[0][?1]", // optional syntax auto-enabled + }, + TestCase{ + .extension = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension = "strings", + .expr = "'foo'.reverse()", + }, + TestCase{ + .extension = "two-var-comprehensions", + .expr = "[1, 2, 3, 4].all(i, v, i < v)", + }, + TestCase{ + .extension = "regex", + .expr = "regex.replace('abc', '$', '_end')", + })); + +} // namespace +} // namespace cel diff --git a/env/env_test.cc b/env/env_test.cc new file mode 100644 index 000000000..dcd2d97fa --- /dev/null +++ b/env/env_test.cc @@ -0,0 +1,713 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::internal::test::EqualsProto; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::testing::Values; +using ::testing::ValuesIn; + +Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { + return factory.NewStringConst("Hello"); +} + +class TestLibrary : public CompilerLibrary { + public: + explicit TestLibrary(int version) + : CompilerLibrary( + "testlib", + [version](ParserBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto macro1, + cel::Macro::Global("testMacro1", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto macro2, + cel::Macro::Global("testMacro2", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro2)); + } + return status; + }, + [version](TypeCheckerBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto func1, cel::MakeFunctionDecl( + "testFunc1", MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto func2, + cel::MakeFunctionDecl("testFunc2", + MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func2)); + } + return status; + }) {}; +}; + +absl::StatusOr CompileAndEvalExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); + if (compiler == nullptr) { + return absl::InternalError("Failed to create compiler"); + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); + if (!result.GetIssues().empty()) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::RuntimeOptions opts; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + rt_builder, cel::ReferenceResolverEnabled::kAlways)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + if (runtime == nullptr) { + return absl::InternalError("Failed to create runtime"); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); + if (ast == nullptr) { + return absl::InternalError("Failed to create AST"); + } + google::protobuf::Arena arena; + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + if (program == nullptr) { + return absl::InternalError("Failed to create program"); + } + CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); + return value; +} + +absl::StatusOr CompileAndEvalBooleanExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); + return value.GetBool(); +} + +class LibraryConfigTest : public testing::Test { + protected: + void SetUp() override { + env_.RegisterCompilerLibrary("testlib", "ml", 1, + []() { return TestLibrary(1); }); + env_.RegisterCompilerLibrary("testlib", "ml", 2, + []() { return TestLibrary(2); }); + env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + } + + Env env_; +}; + +TEST_F(LibraryConfigTest, DefaultVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), IsEmpty()); + EXPECT_THAT(result4.GetIssues(), IsEmpty()); +} + +TEST_F(LibraryConfigTest, SpecificVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testMacro2'")))); + EXPECT_THAT(result4.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testFunc2'")))); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::vector expected_valid_expressions; + std::vector expected_invalid_expressions; +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), + IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + for (const std::string& expr : param.expected_valid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), IsEmpty()) + << "With config: " << param.standard_library_config + << ", expected no issues for expr: " << expr + << " but got: " << result1.FormatError(); + } + for (const std::string& expr : param.expected_invalid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) + << "With config: " << param.standard_library_config + << ", expected compilation error for expr: " << expr << " but got: \'" + << result1.FormatError() << "\'"; + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + .expected_valid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable = true}, + .expected_invalid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable_macros = true}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, + .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "add_bytes"}, + {"_+_", "add_list"}, + {"_+_", "add_string"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_functions = {{"_+_", ""}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + })); + +TEST(ContainerConfigTest, ContainerConfig) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +struct TypeInfoTestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TypeInfoTestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + Config::VariableConfig variable_config; + variable_config.name = "test"; + variable_config.type_info = param.type_info; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_THAT(compiler, NotNull()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("test")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << " error: " << result.FormatError(); + + // Obtain the inferred return type of the expression `test`. + const Ast* ast = result.GetAst(); + ASSERT_THAT(ast, NotNull()); + cel::expr::CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*ast, &checked_expr), IsOk()); + auto it = checked_expr.type_map().find(checked_expr.expr().id()); + ASSERT_NE(it, checked_expr.type_map().end()); + + cel::expr::Type actual_type_pb = it->second; + EXPECT_THAT(actual_type_pb, EqualsProto(expected_type_pb)); +} + +std::vector GetTypeInfoTestCases() { + return { + TypeInfoTestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TypeInfoTestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TypeInfoTestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { dyn {} } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TypeInfoTestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TypeInfoTestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TypeInfoTestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TypeInfoTestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { dyn {} } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, TypeInfoTest, + ValuesIn(GetTypeInfoTestCases())); + +struct VariableConfigWithValueTestCase { + Config::VariableConfig variable_config; + std::string validate_type_expr; + std::string validate_value_expr; +}; + +class VariableConfigWithValueTest + : public testing::TestWithParam {}; + +TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { + const VariableConfigWithValueTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN( + bool type_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_type_expr)); + ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; + if (!param.validate_value_expr.empty()) { + ASSERT_OK_AND_ASSIGN( + bool value_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_value_expr)); + ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; + } +} + +Config::VariableConfig MakeConstant( + absl::string_view variable_name, absl::string_view type_name, + absl::AnyInvocable setter) { + Config::VariableConfig variable_config; + variable_config.name = variable_name; + Constant c; + setter(c); + variable_config.type_info.name = type_name; + variable_config.value = c; + return variable_config; +} + +std::vector +GetVariableConfigWithValueTestCases() { + return { + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "null", [](auto& c) { c.set_null_value(nullptr); }), + .validate_type_expr = "type(x) == type(null)", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "bool", [](auto& c) { c.set_bool_value(true); }), + .validate_type_expr = "type(x) == bool", + .validate_value_expr = "x == true", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "int", [](Constant& c) { c.set_int_value(42); }), + .validate_type_expr = "type(x) == int", + .validate_value_expr = "x == 42", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "uint", [](Constant& c) { c.set_uint_value(777); }), + .validate_type_expr = "type(x) == uint", + .validate_value_expr = "x == 777u", + }, + VariableConfigWithValueTestCase{ + .variable_config = + MakeConstant("x", "double", + [](Constant& c) { c.set_double_value(1.0 / 3.0); }), + .validate_type_expr = "type(x) == double", + .validate_value_expr = "x > 0.333 && x < 0.334", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant("x", "bytes", + [](Constant& c) { + c.set_bytes_value(absl::string_view( + "\xff\x00\x01", 3)); + }), + .validate_type_expr = "type(x) == bytes", + .validate_value_expr = "x == b'\\xff\\x00\\x01'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "string", [](Constant& c) { c.set_string_value("hello"); }), + .validate_type_expr = "type(x) == string", + .validate_value_expr = "x == 'hello'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "timestamp", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); + }), + .validate_type_expr = + "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", + .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "duration", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3)); + }), + .validate_type_expr = "type(x) == type(duration('1h2m3s'))", + .validate_value_expr = "x == duration('1h2m3s')", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, + ValuesIn(GetVariableConfigWithValueTestCases())); + +struct FunctionConfigTestCase { + Config::FunctionConfig function_config; + std::vector variable_configs; + std::string expr; + std::string expected_error; +}; + +class FunctionConfigTest + : public testing::TestWithParam {}; + +TEST_P(FunctionConfigTest, FunctionConfig) { + const FunctionConfigTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + for (const Config::VariableConfig& variable_config : param.variable_configs) { + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + } + ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + if (param.expected_error.empty()) { + EXPECT_TRUE(result.GetIssues().empty()) + << " expr: " << param.expr << " error: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + UnorderedElementsAre(Property(&TypeCheckIssue::message, + HasSubstr(param.expected_error)))) + << " expr: " << param.expr << " error: " << result.FormatError(); + } +} + +std::vector GetFunctionConfigTestCases() { + return {{ + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(int,int)", + .examples = {"add(1, 2) -> 3"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "add(1, 2)", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 3"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "1.add(2) == 3", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(string,string)", + .examples = + {"add('hello', 'world') -> 'hello world'"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "add('hello', 'world')", + .expected_error = "found no matching overload for 'add' applied to " + "'(string, string)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 'three'"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "1.add(2) == 3", + .expected_error = "found no matching overload for '_==_' applied to " + "'(string, int)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "double"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Matching opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "int"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Mismatched opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + .expected_error = "found no matching overload for 'sum' applied to " + "'(collection(double))'", + }, + }}; +} + +INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, + ::testing::ValuesIn(GetFunctionConfigTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_yaml.cc b/env/env_yaml.cc new file mode 100644 index 000000000..0035709e9 --- /dev/null +++ b/env/env_yaml.cc @@ -0,0 +1,1026 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "yaml-cpp/emitter.h" +#include "yaml-cpp/emittermanip.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/mark.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +namespace { + +std::string FormatYamlErrorMessage(absl::string_view yaml, + absl::string_view error, + const YAML::Mark& mark) { + if (mark.is_null()) { + return std::string(error); + } + std::string message; + absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, + "\n|"); + size_t start = mark.pos - mark.column; + size_t end = yaml.find('\n', mark.pos); + if (end == std::string::npos) { + end = yaml.size(); + } + + absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", + std::string(mark.column, ' '), "^"); + + return message; +} + +absl::StatusOr LoadYaml(const std::string& yaml) { + try { + return YAML::Load(yaml); + } catch (YAML::ParserException& e) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, e.msg, e.mark)); + } +} + +absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, + absl::string_view error) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, error, node.Mark())); +} + +std::string GetString(absl::string_view yaml, const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return ""; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return ""; + } +} + +bool IsBinary(const YAML::Node& node) { + return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; +} + +absl::StatusOr GetBinary(absl::string_view yaml, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { + return ""; + } + std::string binary; + // Instead of using the YAML::Binary type, we use absl::Base64Unescape + // because YAML::Binary is lenient to Base64 decoding errors. + if (absl::Base64Unescape(GetString(yaml, node), &binary)) { + return binary; + } else { + return YamlError(yaml, node, + "Node '" + GetString(yaml, node) + + "' is not a valid Base64 encoded binary"); + } +} + +absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return false; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + return YamlError(yaml, node, + "Node '" + std::string(key) + "' is not a boolean"); + } +} + +absl::Status ParseName(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node name = root["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' is not a string"); + } + config.SetName(GetString(yaml, name)); + } + return absl::OkStatus(); +} + +absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node container = root["container"]; + if (container.IsDefined()) { + if (!container.IsScalar()) { + return YamlError(yaml, container, "Node 'container' is not a string"); + } + config.SetContainerConfig({.name = GetString(yaml, container)}); + } + return absl::OkStatus(); +} + +absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node extensions = root["extensions"]; + if (!extensions.IsDefined()) { + return absl::OkStatus(); + } + if (!extensions.IsSequence()) { + return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); + } + + for (const YAML::Node& extension : extensions) { + if (!extension || !extension.IsMap()) { + return YamlError(yaml, extension, "Extension is not a map"); + } + const YAML::Node name = extension["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Extension name is not a string"); + } + std::string name_str = GetString(yaml, name); + + const YAML::Node version = extension["version"]; + std::string version_str = GetString(yaml, version); + int extension_version; + if (version.IsDefined()) { + bool is_valid_version = false; + if (version.IsScalar()) { + if (version_str == "latest") { + extension_version = Config::ExtensionConfig::kLatest; + is_valid_version = true; + } else { + if (absl::SimpleAtoi(version_str, &extension_version) && + extension_version >= 0) { + is_valid_version = true; + } + } + } + if (!is_valid_version) { + return YamlError( + yaml, version, + absl::StrCat("Extension '", name_str, + "' version is not a valid number or 'latest'")); + } + } else { + extension_version = Config::ExtensionConfig::kLatest; + } + absl::Status add_status = + config.AddExtensionConfig(name_str, extension_version); + if (!add_status.ok()) { + return YamlError(yaml, extension, add_status.message()); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> ParseMacroList( + absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set macro_set; + const YAML::Node macros = standard_library[std::string(key)]; + if (!macros.IsDefined()) { + return macro_set; + } + if (!macros.IsSequence()) { + return YamlError(yaml, macros, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& macro : macros) { + if (!macro.IsScalar()) { + return YamlError(yaml, macro, + absl::StrCat("Entry in '", key, "' is not a string")); + } + macro_set.insert(GetString(yaml, macro)); + } + return macro_set; +} + +absl::StatusOr>> +ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set> function_set; + const YAML::Node functions = standard_library[std::string(key)]; + if (!functions.IsDefined()) { + return function_set; + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& function : functions) { + if (!function.IsMap()) { + return YamlError(yaml, function, + absl::StrCat("Entry in '", key, "' is not a map")); + } + const YAML::Node name = function["name"]; + if (!name.IsDefined()) { + return YamlError( + yaml, function, + absl::StrCat("Function name in not specified in '", key, "'")); + } + if (!name.IsScalar()) { + return YamlError( + yaml, name, + absl::StrCat("Function name in '", key, "' entry is not a string")); + } + std::string name_str = GetString(yaml, name); + const YAML::Node overloads = function["overloads"]; + if (!overloads.IsDefined()) { + function_set.insert(std::make_pair(name_str, "")); + } else { + if (!overloads.IsSequence()) { + return YamlError( + yaml, overloads, + absl::StrCat("Overloads in '", key, "' entry is not a sequence")); + } + for (const YAML::Node& overload : overloads) { + if (!overload.IsMap()) { + return YamlError( + yaml, overload, + absl::StrCat("Overload in '", key, "' entry is not a map")); + } + const YAML::Node id = overload["id"]; + if (!id || !id.IsScalar()) { + return YamlError( + yaml, id, + absl::StrCat("Overload id in '", key, "' entry is not a string")); + } + function_set.insert(std::make_pair(name_str, GetString(yaml, id))); + } + } + } + return function_set; +} + +absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node standard_library = root["stdlib"]; + if (!standard_library.IsDefined()) { + return absl::OkStatus(); + } + + if (!standard_library.IsMap()) { + return YamlError(yaml, standard_library, + "Standard library config ('stdlib') is not a map"); + } + + Config::StandardLibraryConfig standard_library_config; + + const YAML::Node disable = standard_library["disable"]; + if (disable.IsDefined()) { + if (!disable.IsScalar()) { + return YamlError(yaml, disable, "Node 'disable' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable, + GetBool(yaml, "disable", disable)); + } + + const YAML::Node disable_macros = standard_library["disable_macros"]; + if (disable_macros.IsDefined()) { + if (!disable_macros.IsScalar()) { + return YamlError(yaml, disable_macros, + "Node 'disable_macros' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, + GetBool(yaml, "disable_macros", disable_macros)); + } + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_macros, + ParseMacroList(yaml, standard_library, "include_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_macros, + ParseMacroList(yaml, standard_library, "exclude_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_functions, + ParseFunctionList(yaml, standard_library, "include_functions")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_functions, + ParseFunctionList(yaml, standard_library, "exclude_functions")); + + return config.SetStandardLibraryConfig(standard_library_config); +} + +absl::StatusOr ParseTypeInfo(const YAML::Node& node, + absl::string_view yaml) { + Config::TypeInfo type_config; + const YAML::Node type_name = node["type_name"]; + if (!type_name.IsDefined()) { + return type_config; + } + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + type_config.name = GetString(yaml, type_name); + + const YAML::Node is_type_param = node["is_type_param"]; + if (is_type_param.IsDefined()) { + if (!is_type_param.IsScalar()) { + return YamlError(yaml, is_type_param, + "Node 'is_type_param' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(type_config.is_type_param, + GetBool(yaml, "is_type_param", is_type_param)); + } + + const YAML::Node params = node["params"]; + if (!params.IsDefined()) { + return type_config; + } + if (!params.IsSequence()) { + return YamlError(yaml, params, "Node 'params' is not a sequence"); + } + for (const YAML::Node& param : params) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, + ParseTypeInfo(param, yaml)); + type_config.params.push_back(param_config); + } + + return type_config; +} + +bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { + if (a.name != b.name) { + return a.name < b.name; + } + if (a.params.size() != b.params.size()) { + return a.params.size() < b.params.size(); + } + for (size_t i = 0; i < a.params.size(); ++i) { + if (CompareTypeInfo(a.params[i], b.params[i])) { + return true; + } + if (CompareTypeInfo(b.params[i], a.params[i])) { + return false; + } + } + return false; // They are equal +} + +ConstantKindCase GetConstantKindCase(absl::string_view type_name) { + static const auto kTypeNameToConstantKindCase = + absl::NoDestructor>({ + {"null", ConstantKindCase::kNull}, + {"bool", ConstantKindCase::kBool}, + {"int", ConstantKindCase::kInt}, + {"uint", ConstantKindCase::kUint}, + {"double", ConstantKindCase::kDouble}, + {"string", ConstantKindCase::kString}, + {"bytes", ConstantKindCase::kBytes}, + {"duration", ConstantKindCase::kDuration}, + {"timestamp", ConstantKindCase::kTimestamp}, + }); + if (auto it = kTypeNameToConstantKindCase->find(type_name); + it != kTypeNameToConstantKindCase->end()) { + return it->second; + } + return ConstantKindCase::kUnspecified; +} + +absl::StatusOr ParseConstantValue(absl::string_view yaml, + const YAML::Node& node, + ConstantKindCase constant_kind_case, + absl::string_view value) { + switch (constant_kind_case) { + case ConstantKindCase::kNull: + if (!value.empty()) { + return YamlError(yaml, node, "Failed to parse null constant"); + } + return Constant(nullptr); + case ConstantKindCase::kBool: + if (absl::EqualsIgnoreCase(value, "true")) { + return Constant(true); + } else if (absl::EqualsIgnoreCase(value, "false")) { + return Constant(false); + } else { + return YamlError(yaml, node, "Failed to parse bool constant"); + } + case ConstantKindCase::kInt: + int64_t int_value; + if (!absl::SimpleAtoi(value, &int_value)) { + return YamlError(yaml, node, "Failed to parse int constant"); + } + return Constant(int_value); + case ConstantKindCase::kUint: + uint64_t uint_value; + if (absl::EndsWith(value, "u")) { + value = value.substr(0, value.size() - 1); + } + if (!absl::SimpleAtoi(value, &uint_value)) { + return YamlError(yaml, node, "Failed to parse uint constant"); + } + return Constant(uint_value); + case ConstantKindCase::kDouble: + double double_value; + if (!absl::SimpleAtod(value, &double_value)) { + return YamlError(yaml, node, "Failed to parse double constant"); + } + return Constant(double_value); + case ConstantKindCase::kBytes: { + if (!IsBinary(node)) { + absl::StatusOr bytes_literal = + internal::ParseBytesLiteral(value); + if (bytes_literal.ok()) { + return Constant(BytesConstant(*bytes_literal)); + } + } + return Constant(BytesConstant(value)); + } + case ConstantKindCase::kString: + return Constant(StringConstant(value)); + case ConstantKindCase::kDuration: { + // Duration is deprecated as a builtin type, but still supported for + // compatibility. + absl::Duration duration_value; + if (!absl::ParseDuration(value, &duration_value)) { + return YamlError(yaml, node, "Failed to parse duration constant"); + } + return Constant(duration_value); + } + case ConstantKindCase::kTimestamp: { + // Timestamp is deprecated as a builtin type, but still supported for + // compatibility. + absl::Time timestamp_value; + std::string error; + // Format: YYYY-MM-DDThh:mm:ssZ + if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, + &error)) { + return YamlError( + yaml, node, + absl::StrCat("Failed to parse timestamp constant: ", error, + " supported format: YYYY-MM-DDThh:mm:ssZ")); + } + return Constant(timestamp_value); + } + default: + // This should never happen. + return YamlError(yaml, node, "Constant type is not supported"); + } +} + +absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node variables = root["variables"]; + if (!variables.IsDefined()) { + return absl::OkStatus(); + } + if (!variables.IsSequence()) { + return YamlError(yaml, variables, "Node 'variables' is not a sequence"); + } + + for (const YAML::Node& variable : variables) { + Config::VariableConfig variable_config; + if (!variable || !variable.IsMap()) { + return YamlError(yaml, variable, "Variable is not a map"); + } + const YAML::Node name = variable["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Variable name is not a string"); + } + variable_config.name = GetString(yaml, name); + const YAML::Node description = variable["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Variable description is not a string"); + } + variable_config.description = GetString(yaml, description); + } + + CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); + ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); + std::string value_str; + YAML::Node value = variable["value"]; + if (value.IsDefined()) { + if (constant_kind_case == ConstantKindCase::kUnspecified) { + return YamlError(yaml, value, + absl::StrCat("Constant type '", type_info.name, + "' is not supported")); + } + if (!value.IsScalar()) { + return YamlError(yaml, value, "Variable value is not a scalar"); + } + if (IsBinary(value)) { + CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); + } else { + value_str = GetString(yaml, value); + } + } + + variable_config.type_info = type_info; + + if (constant_kind_case != ConstantKindCase::kUnspecified) { + CEL_ASSIGN_OR_RETURN( + variable_config.value, + ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } + + CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseFunctionOverloadConfig( + absl::string_view yaml, const YAML::Node& overload) { + Config::FunctionOverloadConfig overload_config; + if (!overload || !overload.IsMap()) { + return YamlError(yaml, overload, "Function overload is not a map"); + } + const YAML::Node id = overload["id"]; + if (id.IsDefined()) { + if (!id.IsScalar()) { + return YamlError(yaml, id, "Function overload id is not a string"); + } + overload_config.overload_id = GetString(yaml, id); + } + const YAML::Node examples = overload["examples"]; + if (examples.IsDefined()) { + if (!examples.IsSequence()) { + return YamlError(yaml, examples, + "Function overload examples is not a sequence"); + } + for (const YAML::Node& example : examples) { + if (!example.IsScalar()) { + return YamlError(yaml, example, + "Function overload example is not a string"); + } + overload_config.examples.push_back(GetString(yaml, example)); + } + } + + const YAML::Node target = overload["target"]; + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; + overload_config.parameters.push_back(type_info); + } + + const YAML::Node args = overload["args"]; + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + + const YAML::Node return_type = overload["return"]; + if (return_type.IsDefined()) { + if (!return_type.IsMap()) { + return YamlError(yaml, return_type, + "Function overload return type is not a map"); + } + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } + return overload_config; +} + +absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node functions = root["functions"]; + if (!functions.IsDefined()) { + return absl::OkStatus(); + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, "Node 'functions' is not a sequence"); + } + + for (const YAML::Node& function : functions) { + Config::FunctionConfig function_config; + if (!function || !function.IsMap()) { + return YamlError(yaml, function, "Function is not a map"); + } + const YAML::Node name = function["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Function name is not a string"); + } + function_config.name = GetString(yaml, name); + const YAML::Node description = function["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Function description is not a string"); + } + function_config.description = GetString(yaml, description); + } + const YAML::Node overloads = function["overloads"]; + if (overloads.IsDefined()) { + if (!overloads.IsSequence()) { + return YamlError(yaml, overloads, + "Function 'overloads' item is not a sequence"); + } + + for (const YAML::Node& overload : overloads) { + CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload)); + function_config.overload_configs.push_back(std::move(overload_config)); + } + } + + CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); + } + return absl::OkStatus(); +} + +void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { + const auto& container_config = env_config.GetContainerConfig(); + if (container_config.IsEmpty()) { + return; + } + + out << YAML::Key << "container"; + out << YAML::Value << YAML::DoubleQuoted << container_config.name; +} + +void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { + if (env_config.GetExtensionConfigs().empty()) { + return; + } + + // Sort the extensions to make the output deterministic. + std::vector sorted_extensions = + env_config.GetExtensionConfigs(); + absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, + const Config::ExtensionConfig& b) { + return a.name < b.name; + }); + out << YAML::Key << "extensions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::ExtensionConfig& extension_config : sorted_extensions) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << extension_config.name; + if (extension_config.version != Config::ExtensionConfig::kLatest) { + out << YAML::Key << "version"; + out << YAML::Value << extension_config.version; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitMacroList(YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set& macros) { + if (macros.empty()) { + return; + } + out << YAML::Key << std::string(key); + out << YAML::Value << YAML::BeginSeq; + std::vector sorted_macros(macros.begin(), macros.end()); + absl::c_sort(sorted_macros); + for (const std::string& macro : sorted_macros) { + out << YAML::Value << YAML::DoubleQuoted << macro; + } + out << YAML::EndSeq; +} + +void EmitFunctionList( + YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set>& functions) { + if (functions.empty()) { + return; + } + + // Build a map from function name to a vector of overload ids. + // Using std::map ensures function names are sorted. + std::map> function_overloads; + for (const auto& pair : functions) { + function_overloads[pair.first].push_back(pair.second); + } + + out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; + for (auto const& [name, overloads] : function_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << name; + + // If the only overload is the empty string, it signifies that all overloads + // of the function are included/excluded. In this case, we don't emit the + // "overloads" key. Otherwise, emit the specific overloads. + if (!(overloads.size() == 1 && overloads[0].empty())) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = overloads; + absl::c_sort(sorted_overloads); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const std::string& overload : sorted_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { + const Config::StandardLibraryConfig& standard_library_config = + env_config.GetStandardLibraryConfig(); + if (standard_library_config.IsEmpty()) { + return; + } + + out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; + if (standard_library_config.disable) { + out << YAML::Key << "disable" << YAML::Value << true; + } + if (standard_library_config.disable_macros) { + out << YAML::Key << "disable_macros" << YAML::Value << true; + } + EmitMacroList(out, "include_macros", standard_library_config.included_macros); + EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); + EmitFunctionList(out, "include_functions", + standard_library_config.included_functions); + EmitFunctionList(out, "exclude_functions", + standard_library_config.excluded_functions); + out << YAML::EndMap; +} + +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { + // Note: the map is already started when this is called, so we don't emit + // BeginMap here or EndMap at the end. + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } +} + +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { + const auto& variable_configs = env_config.GetVariableConfigs(); + if (variable_configs.empty()) { + return; + } + + // Sort variable_configs by name to ensure deterministic output. + std::vector sorted_variable_configs = + variable_configs; + absl::c_sort(sorted_variable_configs, + [](const Config::VariableConfig& a, + const Config::VariableConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "variables"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::VariableConfig& variable_config : + sorted_variable_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.name; + if (!variable_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.description; + } + EmitTypeInfo(variable_config.type_info, out); + if (variable_config.value.has_value()) { + const Constant& constant = variable_config.value; + switch (constant.kind_case()) { + case ConstantKindCase::kUnspecified: + case ConstantKindCase::kNull: + break; + case ConstantKindCase::kBool: + out << YAML::Key << "value" << YAML::Value << constant.bool_value(); + break; + case ConstantKindCase::kInt: + out << YAML::Key << "value" << YAML::Value << constant.int_value(); + break; + case ConstantKindCase::kUint: + out << YAML::Key << "value" << YAML::Value << constant.uint_value(); + break; + case ConstantKindCase::kDouble: + out << YAML::Key << "value" << YAML::Value << constant.double_value(); + break; + case ConstantKindCase::kBytes: { + out << YAML::Key << "value"; + const std::string& bytes_value = constant.bytes_value(); + std::string hex_escaped = "b\""; + for (unsigned char byte : bytes_value) { + absl::StrAppend(&hex_escaped, "\\x"); + absl::StrAppendFormat(&hex_escaped, "%02x", byte); + } + absl::StrAppend(&hex_escaped, "\""); + out << YAML::Value << hex_escaped; + break; + } + case ConstantKindCase::kString: + out << YAML::Key << "value"; + out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); + break; + case ConstantKindCase::kDuration: + out << YAML::Key << "value" << YAML::Value; + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + out << absl::FormatDuration(constant.duration_value()); + break; + case ConstantKindCase::kTimestamp: + out << YAML::Key << "value" << YAML::Value; + out << absl::FormatTime( + "%4Y-%2m-%2d%ET%2H:%2M:%E*SZ", + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + constant.timestamp_value(), absl::UTCTimeZone()); + break; + } + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitFunctionOverloadConfig( + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out); + } else { + EmitTypeInfo(overload_config.parameters[0], out); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out); + out << YAML::EndMap; + + out << YAML::EndMap; +} + +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { + const std::vector& function_configs = + env_config.GetFunctionConfigs(); + if (function_configs.empty()) { + return; + } + + // Sort function_configs by name to ensure deterministic output. + std::vector sorted_function_configs = + function_configs; + absl::c_sort(sorted_function_configs, + [](const Config::FunctionConfig& a, + const Config::FunctionConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "functions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionConfig& function_config : + sorted_function_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << function_config.name; + if (!function_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << function_config.description; + } + if (!function_config.overload_configs.empty()) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = + function_config.overload_configs; + absl::c_sort(sorted_overloads, + [](const Config::FunctionOverloadConfig& a, + const Config::FunctionOverloadConfig& b) { + for (size_t i = 0; i < a.parameters.size(); ++i) { + // Order like this: foo(a), foo(a, b) + if (i >= b.parameters.size()) { + return false; + } + if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { + return true; + } + if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { + return false; + } + } + return false; + }); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionOverloadConfig& overload_config : + sorted_overloads) { + EmitFunctionOverloadConfig(overload_config, out); + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} +} // namespace + +absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { + Config config; + CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); + return config; +} + +void EnvConfigToYaml(const Config& env_config, std::ostream& os) { + YAML::Emitter out(os); + out.SetIndent(2); + out << YAML::BeginMap; + if (!env_config.GetName().empty()) { + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); + } + EmitContainerConfig(env_config, out); + EmitExtensionConfigs(env_config, out); + EmitStandardLibraryConfig(env_config, out); + EmitVariableConfigs(env_config, out); + EmitFunctionConfigs(env_config, out); + out << YAML::EndMap; +} + +} // namespace cel diff --git a/env/env_yaml.h b/env/env_yaml.h new file mode 100644 index 000000000..c96b45933 --- /dev/null +++ b/env/env_yaml.h @@ -0,0 +1,39 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" + +namespace cel { + +// EnvConfigFromYaml creates an environment configuration from a YAML string. +// +// To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz +// coverage, but its security model is unclear. Additionally, callers should be +// aware that improper CEL configuration can lead to unsafe or unpredictably +// expensive expressions. +absl::StatusOr EnvConfigFromYaml(const std::string& yaml); + +// EnvConfigToYaml serializes an environment configuration as a YAML string. +void EnvConfigToYaml(const Config& env_config, std::ostream& os); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc new file mode 100644 index 000000000..6fe348db6 --- /dev/null +++ b/env/env_yaml_test.cc @@ -0,0 +1,1466 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(EnvYamlTest, ParseContainerConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: "test.container" + )yaml")); + + EXPECT_THAT(config.GetContainerConfig(), + Field(&Config::ContainerConfig::name, "test.container")); +} + +TEST(EnvYamlTest, ParseExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + extensions: + - name: "math" + version: latest + - name: "optional" + version: 2 + - name: "strings" + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvYamlTest, DefaultExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); +} + +TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + disable: true + disable_macros: true + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_TRUE(stdlib_config.disable); + EXPECT_TRUE(stdlib_config.disable_macros); + EXPECT_THAT(stdlib_config.excluded_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); + EXPECT_THAT( + stdlib_config.excluded_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + include_macros: + - map + - filter + include_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_THAT(stdlib_config.included_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT( + stdlib_config.included_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseVariableConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "msg"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); + EXPECT_EQ(variable_config.description, + "msg represents all possible type permutation which CEL " + "understands from a proto perspective"); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type_name: "map" + params: + - type_name: "string" + - type_name: "A" + is_type_param: true + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +struct ParseConstantTestCase { + std::string type_name; + std::string value; + std::string expected_error; // Empty if no error. + Constant expected_constant; +}; + +class EnvYamlParseConstantTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { + const ParseConstantTestCase& param = GetParam(); + const std::string yaml = absl::StrFormat( + R"yaml( + variables: + - name: "const" + type_name: "%s" + value: %s + )yaml", + param.type_name, param.value); + absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); + if (!param.expected_error.empty()) { + EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + return; + } + ASSERT_OK_AND_ASSIGN(Config config, status_or_config); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "const"); + EXPECT_EQ(variable_config.type_info.name, param.type_name); + EXPECT_EQ(variable_config.value, param.expected_constant); +} + +std::vector GetParseConstantTestCases() { + return { + ParseConstantTestCase{ + .type_name = "null", + .value = "\"\"", + .expected_constant = Constant(nullptr), + }, + ParseConstantTestCase{ + .type_name = "null", + .value = "anything", + .expected_error = "Failed to parse null constant", + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "TRUE", + .expected_constant = Constant(true), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "false", + .expected_constant = Constant(false), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "yes", + .expected_error = "Failed to parse bool constant", + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "42", + .expected_constant = Constant(42), + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "41.999", + .expected_error = "Failed to parse int constant", + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42", + .expected_constant = Constant(42LLU), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42u", + .expected_constant = Constant(42LLU), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "-1", + .expected_error = "Failed to parse uint constant", + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "42.42", + .expected_constant = Constant(42.42), + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "abc", + .expected_error = "Failed to parse double constant", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "b\"\\xFF\\x00\\x01\"", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary /wAB", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary YWJj=", + .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "abc", + .expected_constant = Constant(StringConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "\"\\\"abc\\\"\"", + .expected_constant = Constant(StringConstant("\"abc\"")), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "1s", + .expected_constant = Constant(absl::Seconds(1)), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "abc", + .expected_error = "Failed to parse duration constant", + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "2023-01-01T00:00:00Z", + .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "abc", + .expected_error = "Failed to parse timestamp constant", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, + ::testing::ValuesIn(GetParseConstantTestCases())); + +struct ParseFunctionTestCase { + std::string yaml; + Config::FunctionConfig expected_function_config; +}; + +class EnvYamlParseFunctionTest + : public testing::TestWithParam {}; + +void ExpectTypeInfoEqual(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_THAT(actual.params, SizeIs(expected.params.size())); + for (size_t i = 0; i < expected.params.size(); ++i) { + ExpectTypeInfoEqual(actual.params[i], expected.params[i]); + } +} + +TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { + const ParseFunctionTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); + + ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); + const Config::FunctionConfig& function_config = + config.GetFunctionConfigs()[0]; + const Config::FunctionConfig& expected = param.expected_function_config; + + EXPECT_EQ(function_config.name, expected.name); + EXPECT_EQ(function_config.description, expected.description); + + ASSERT_THAT(function_config.overload_configs, + SizeIs(expected.overload_configs.size())); + + for (size_t i = 0; i < expected.overload_configs.size(); ++i) { + const auto& actual_overload = function_config.overload_configs[i]; + const auto& expected_overload = expected.overload_configs[i]; + + EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); + EXPECT_THAT(actual_overload.examples, + ElementsAreArray(expected_overload.examples)); + EXPECT_EQ(actual_overload.is_member_function, + expected_overload.is_member_function); + + ASSERT_THAT(actual_overload.parameters, + SizeIs(expected_overload.parameters.size())); + for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { + ExpectTypeInfoEqual(actual_overload.parameters[j], + expected_overload.parameters[j]); + } + + ExpectTypeInfoEqual(actual_overload.return_type, + expected_overload.return_type); + } +} + +std::vector GetParseFunctionTestCases() { + return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "wrapper_string_isEmpty", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = + {{.name = "google.protobuf.StringValue"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list_isEmpty", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - id: "global_contains" + examples: + - "contains([1, 2, 3], 2) // true" + args: + - type_name: "list" + params: + - type_name: "T" + is_type_param: true + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "global_contains", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, + ::testing::ValuesIn(GetParseFunctionTestCases())); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +class EnvYamlParseTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { + const ParseTestCase& param = GetParam(); + absl::StatusOr config = EnvConfigFromYaml(param.yaml); + EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlParseTest, EnvYamlParseTest, + ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( + name: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'name' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'container' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + -name: "optional" + - name: "other" + )yaml", + .expected_error = "5:21: end of map not found\n" + "| - name: \"other\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: "bar" + )yaml", + .expected_error = "2:27: Node 'extensions' is not a sequence\n" + "| extensions: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: + - something: "bar" + )yaml", + .expected_error = "4:19: Extension name is not a string\n" + "| - something: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: last + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: last\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: -15 + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: -15\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: 1 + - name: "math" + version: 2 + )yaml", + .expected_error = "5:19: Extension 'math' version 1 is already " + "included. Cannot also include version 2\n" + "| - name: \"math\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: "error" + )yaml", + .expected_error = "2:23: Standard library config ('stdlib') " + "is not a map\n" + "| stdlib: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable: "error" + )yaml", + .expected_error = "3:26: Node 'disable' is not a boolean\n" + "| disable: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable_macros: "error" + )yaml", + .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" + "| disable_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: "error" + )yaml", + .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" + "| exclude_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: + - foo: "error" + )yaml", + .expected_error = "4:19: Entry in 'exclude_macros' " + "is not a string\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: "error" + )yaml", + .expected_error = "3:36: Node 'include_functions' " + "is not a sequence\n" + "| include_functions: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - "error" + )yaml", + .expected_error = "4:19: Entry in 'include_functions' " + "is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - foo: "error" + )yaml", + .expected_error = "4:19: Function name in not specified in " + "'include_functions'\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "5:30: Overloads in 'include_functions' entry " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - foo_string + )yaml", + .expected_error = "6:21: Overload in 'include_functions' entry " + "is not a map\n" + "| - foo_string\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - id: + - foo_int64 + )yaml", + .expected_error = "7:21: Overload id in 'include_functions' entry " + "is not a string\n" + "| - foo_int64\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: + - type_name: "opaque" + )yaml", + .expected_error = "4:19: Variable name is not a string\n" + "| - type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: + - params: + )yaml", + .expected_error = "5:21: Node 'type_name' is not a string\n" + "| - params:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + params: + - type_name: "int" + - type_name: "A" + is_type_param: maybe + )yaml", + .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" + "| is_type_param: maybe\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: -1 + )yaml", + .expected_error = "5:26: Failed to parse uint constant\n" + "| value: -1\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: many + )yaml", + .expected_error = "2:26: Node 'functions' is not a sequence\n" + "| functions: many\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: + - overloads: + )yaml", + .expected_error = "4:19: Function name is not a string\n" + "| - overloads:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "4:30: Function 'overloads' item " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: + - "error" + )yaml", + .expected_error = "6:25: Function overload id is not a string\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + - "error" + )yaml", + .expected_error = "7:25: Function overload target is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + type_name: "Foo" + params: + - type_name: + - is_type_param: true + )yaml", + .expected_error = "10:31: Node 'type_name' is not a string\n" + "| " + "- is_type_param: true\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + args: "a bunch" + )yaml", + .expected_error = "6:29: Function overload args is not a sequence\n" + "| args: \"a bunch\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + return: "to sender" + )yaml", + .expected_error = "6:31: Function overload return type" + " is not a map\n" + "| return: \"to sender\"\n" + "| ^", + })); + +std::string Unindent(std::string_view yaml) { + std::vector lines = absl::StrSplit(yaml, '\n'); + int indent = -1; + std::vector unindented_lines; + for (auto& line : lines) { + std::size_t pos = line.find_first_not_of(" \t"); + if (pos == std::string::npos) { + // Skip blank lines. + continue; + } + if (indent == -1) { + indent = pos; + } + if (pos >= indent) { + unindented_lines.push_back(line.substr(indent)); + } else { + unindented_lines.push_back(line); + } + } + return absl::StrJoin(unindented_lines, "\n"); +} + +struct ExportTestCase { + absl::StatusOr config; + std::string expected_yaml; +}; + +class EnvYamlExportTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlExportTest, EnvYamlExport) { + const ExportTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, param.config); + std::stringstream ss; + EnvConfigToYaml(config, ss); + std::string yaml_output = Unindent(ss.str()); + std::string expected_yaml = Unindent(param.expected_yaml); + EXPECT_EQ(yaml_output, expected_yaml); +} + +std::vector GetExportTestCases() { + return { + ExportTestCase{ + .config = + []() { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); + return config; + }(), + .expected_yaml = R"yaml( + extensions: + - name: "bindings" + - name: "math" + - name: "optional" + version: 2 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable_macros = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable_macros: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "null"}, + .value = Constant(nullptr)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "null" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "bool"}, + .value = Constant(true)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bool" + value: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "int"}, + .value = Constant(42)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "int" + value: 42 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "uint"}, + .value = Constant(777ULL)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: 777 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "double"}, + .value = Constant(0.75)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "double" + value: 0.75 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "bytes"}, + .value = Constant( + BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bytes" + value: b"\xff\x00\x01" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + Constant c; + c.set_string_value("'single' \"double\""); + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "string"}, + .value = Constant(StringConstant("'single' \"double\""))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "string" + value: "'single' \"double\"" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "duration"}, + .value = Constant(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "duration" + value: 1h2m3s + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "timestamp"}, + .value = Constant(absl::FromUnixSeconds(1767323045))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = + "google.expr.proto3.test.TestAllTypes"}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "google.expr.proto3.test.TestAllTypes" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = { + .name = "A", + .params = {{.name = "int"}, + {.name = "B", .is_type_param = true}}}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "A" + params: + - type_name: "int" + - type_name: "B" + is_type_param: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_id", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_a", + .parameters = {{.name = "timestamp"}}, + .return_type = {.name = "list", + .params = {{.name = "int"}}}}, + {.overload_id = "foo_overload_b", + .parameters = {{.name = "double"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "string"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_b" + args: + - type_name: "double" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "string" + - id: "foo_overload_a" + args: + - type_name: "timestamp" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }, + }; +}; + +INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, + ::testing::ValuesIn(GetExportTestCases())); + +class EnvYamlRoundTripTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetRoundTripTestCases() { + return { + R"yaml( + stdlib: + disable: true + disable_macros: true + )yaml", + R"yaml( + name: "test.env" + container: "common.proto.prefix" + extensions: + - name: "math" + version: 0 + - name: "optional" + version: 2 + stdlib: + include_macros: + - "filter" + - "map" + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + extensions: + - name: "bindings" + - name: "math" + stdlib: + exclude_macros: + - "filter" + - "map" + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + variables: + - name: "a" + type_name: "null" + - name: "b" + type_name: "bool" + value: true + - name: "c" + type_name: "int" + value: 42 + - name: "d" + type_name: "uint" + value: 777 + - name: "e" + type_name: "double" + value: 0.75 + - name: "f" + type_name: "bytes" + value: b"\xff\x00\x01" + - name: "g" + type_name: "string" + value: "plain 'single' \"double\"" + - name: "h" + type_name: "duration" + value: 1h2m3s + - name: "i" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + args: + - type_name: "timestamp" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, + ::testing::ValuesIn(GetRoundTripTestCases())); + +} // namespace +} // namespace cel diff --git a/env/internal/BUILD b/env/internal/BUILD new file mode 100644 index 000000000..ec4a0b15c --- /dev/null +++ b/env/internal/BUILD @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ext_registry", + srcs = ["ext_registry.cc"], + hdrs = ["ext_registry.h"], + deps = [ + "//compiler", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_ext_registry", + srcs = ["runtime_ext_registry.cc"], + hdrs = ["runtime_ext_registry.h"], + deps = [ + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ext_registry_test", + srcs = ["ext_registry_test.cc"], + deps = [ + ":ext_registry", + "//checker:type_checker_builder", + "//compiler", + "//internal:testing", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "runtime_ext_registry_test", + srcs = ["runtime_ext_registry_test.cc"], + deps = [ + ":runtime_ext_registry", + "//common:ast", + "//common:source", + "//common:value", + "//common:value_testing", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/internal/ext_registry.cc b/env/internal/ext_registry.cc new file mode 100644 index 000000000..b32239ac3 --- /dev/null +++ b/env/internal/ext_registry.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +void ExtensionRegistry::RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + library_registry_.push_back( + LibraryRegistration(name, alias, version, std::move(library_factory))); +} + +absl::StatusOr ExtensionRegistry::GetCompilerLibrary( + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name)); + } + version = max_version; + } + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.GetLibrary(); + } + } + + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/ext_registry.h b/env/internal/ext_registry.h new file mode 100644 index 000000000..ab5b67a24 --- /dev/null +++ b/env/internal/ext_registry.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +// A registry for CEL compiler extension libraries. +// +// Used to register and retrieve CompilerLibraries by name (or alias) and +// version. +class ExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory); + + absl::StatusOr GetCompilerLibrary(absl::string_view name, + int version) const; + + private: + class LibraryRegistration final { + public: + LibraryRegistration( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + factory_(std::move(library_factory)) {} + + CompilerLibrary GetLibrary() const { return factory_(); } + + private: + std::string name_; + std::string alias_; + int version_; + absl::AnyInvocable factory_; + + friend class ExtensionRegistry; + }; + + std::vector library_registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ diff --git a/env/internal/ext_registry_test.cc b/env/internal/ext_registry_test.cc new file mode 100644 index 000000000..9e345c781 --- /dev/null +++ b/env/internal/ext_registry_test.cc @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "parser/parser_interface.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Field; +using ::testing::HasSubstr; + +TEST(ExtensionRegistryTest, GetCompilerLibrary) { + ExtensionRegistry registry; + registry.RegisterCompilerLibrary("foo1", "f", 1, []() { + return CompilerLibrary("foo1_1", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo1", "f", 2, []() { + return CompilerLibrary("foo1_2", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo2", "", 1, []() { + return CompilerLibrary("foo2_1", nullptr, nullptr); + }); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo1#3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/internal/runtime_ext_registry.cc b/env/internal/runtime_ext_registry.cc new file mode 100644 index 000000000..dc78a38e3 --- /dev/null +++ b/env/internal/runtime_ext_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +void RuntimeExtensionRegistry::AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) { + registry_.push_back(Registration(name, alias, version, + std::move(function_registration_callback))); +} + +absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); + } + version = max_version; + } + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.RegisterExtensionFunctions(runtime_builder, + runtime_options); + } + } + + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/runtime_ext_registry.h b/env/internal/runtime_ext_registry.h new file mode 100644 index 000000000..67838519f --- /dev/null +++ b/env/internal/runtime_ext_registry.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +using FunctionRegistrationCallback = absl::AnyInvocable; + +// A registry for CEL runtime extension functions. +// +// Used to register runtime functions for extensions by name (or alias) and +// version. +class RuntimeExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback); + + absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options, + absl::string_view name, + int version) const; + + private: + class Registration final { + public: + Registration(absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + function_registration_callback_( + std::move(function_registration_callback)) {} + + absl::Status RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) const { + return function_registration_callback_(runtime_builder, runtime_options); + } + + private: + std::string name_; + std::string alias_; + int version_; + FunctionRegistrationCallback function_registration_callback_; + + friend class RuntimeExtensionRegistry; + }; + + std::vector registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ diff --git a/env/internal/runtime_ext_registry_test.cc b/env/internal/runtime_ext_registry_test.cc new file mode 100644 index 000000000..fb8fcdd0a --- /dev/null +++ b/env/internal/runtime_ext_registry_test.cc @@ -0,0 +1,124 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::env_internal { +namespace { + +using ::cel::test::StringValueIs; + +Value Hello1(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, old " + input.ToString() + "!", + context.arena()); +} + +Value Hello2(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, new " + input.ToString() + "!", + context.arena()); +} + +class RuntimeExtensionRegistryTest : public testing::Test { + protected: + absl::StatusOr Run(std::string_view extension_name, int version, + std::string_view expr) { + RuntimeExtensionRegistry registry; + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 1, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("hello", &Hello1, + runtime_builder.function_registry()); + }); + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 2, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterMemberOverload("hello", &Hello2, + runtime_builder.function_registry()); + }); + + ParserOptions parser_options; + CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, + NewParserBuilder(parser_options)->Build()); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); + + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + cel::RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); + + CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( + runtime_builder, runtime_options, extension_name, version)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + protected: + google::protobuf::Arena arena_; +}; + +TEST_F(RuntimeExtensionRegistryTest, RegisterFunctions_SpecificVersion) { + ASSERT_OK_AND_ASSIGN(Value value, + Run("hello_extension", 1, "hello('world')")); + EXPECT_THAT(value, StringValueIs("Hello, old world!")); +} + +TEST_F(RuntimeExtensionRegistryTest, RegisterFunctions_LatestVersion) { + ASSERT_OK_AND_ASSIGN( + Value value, Run("hello_extension_alias", + RuntimeExtensionRegistry::kLatest, "'world'.hello()")); + EXPECT_THAT(value, StringValueIs("Hello, new world!")); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc new file mode 100644 index 000000000..cff001979 --- /dev/null +++ b/env/runtime_std_extensions.cc @@ -0,0 +1,129 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/optional.h" +#include "env/env_runtime.h" +#include "env/internal/runtime_ext_registry.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" +#include "runtime/optional_types.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { + +void RegisterStandardExtensions(EnvRuntime& env_runtime) { + env_internal::RuntimeExtensionRegistry& registry = + env_runtime.GetRuntimeExtensionRegistry(); + registry.AddFunctionRegistration( + "cel.lib.ext.bindings", "bindings", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.encoders", "encoders", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterEncodersFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.lists", "lists", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterListsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.math", "math", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.optional", "optional", version, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::EnableOptionalTypes(runtime_builder); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.protos", "protos", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.sets", "sets", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterSetsFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.strings", "strings", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterStringsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.regex", "regex", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // Not registering runtime functions yet. Include manually if needed. + return absl::OkStatus(); + }); +} + +} // namespace cel diff --git a/env/runtime_std_extensions.h b/env/runtime_std_extensions.h new file mode 100644 index 000000000..d7f714226 --- /dev/null +++ b/env/runtime_std_extensions.h @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ + +#include "env/env_runtime.h" + +namespace cel { + +// Registers the standard CEL extension functions with the given environment +// runtime. This makes them available, but does not enable them. See Env::Config +// for how to enable extensions. +// +// Included in the standard runtime environment: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// +// NOTE: Not included in the standard runtime environment yet - include manually +// if needed: +// - cel.lib.ext.regex (alias: "regex") +// +void RegisterStandardExtensions(EnvRuntime& env_runtime); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ diff --git a/env/runtime_std_extensions_test.cc b/env/runtime_std_extensions_test.cc new file mode 100644 index 000000000..0b497d9ae --- /dev/null +++ b/env/runtime_std_extensions_test.cc @@ -0,0 +1,226 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/optional.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/strings.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; + +struct TestCase { + std::string extension_name; + std::vector extension_versions = {0}; + int latest_extension_version = 0; + std::string expr; +}; + +using RuntimeStdExtensionTest = testing::TestWithParam; + +TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { + const TestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env); + + Config compiler_config; + // For the compilation step, assume latest version of the extension to ensure + // a successful compilation. Later, we will test the runtime with different + // extension versions. + ASSERT_THAT(compiler_config.AddExtensionConfig( + param.extension_name, Config::ExtensionConfig::kLatest), + IsOk()); + env.SetConfig(compiler_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + for (int version = 0; version <= param.latest_extension_version; ++version) { + // Request a specific version of the extension to be configured in the + // runtime. + Config runtime_config; + ASSERT_THAT( + runtime_config.AddExtensionConfig(param.extension_name, version), + IsOk()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool( + cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(runtime_config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + absl::StatusOr> program_or = + runtime->CreateProgram(std::make_unique(*ast)); + + // If the function is not supported in this extension version, check that + // the program creation returned an error. + if (!absl::c_contains(param.extension_versions, version)) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr << " version: " << version; + continue; + } + + ASSERT_THAT(program_or, IsOk()) + << " expr: " << param.expr << " version: " << version; + std::unique_ptr program = std::move(*program_or); + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) + << " expr: " << param.expr << " version: " << version; + } +} + +std::vector GetRuntimeStdExtensionTestCases() { + return { + TestCase{ + // The "bindings" extension does include have any runtime functions. + .extension_name = "bindings", + .expr = "cel.bind(t, 42, t + 1) == 43", + }, + TestCase{ + .extension_name = "encoders", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].slice(0, 1) == [3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].sort() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.least([1, -2, 3]) == -2", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.floor(42.9) == 42.0", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.sqrt(4) == 2.0", + }, + TestCase{ + .extension_name = "optional", + .extension_versions = {0, 1, 2}, + .latest_extension_version = kOptionalExtensionLatestVersion, + .expr = "optional.of(1).hasValue()", + }, + TestCase{ + // No runtime functions. + .extension_name = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension_name = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {0, 1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "strings.quote('hello') == '\"hello\"'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "['hello', 'world'].join(', ') == 'hello, world'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'stressed'.reverse() == 'desserts'", + }, + TestCase{ + // No runtime functions. + .extension_name = "cel.lib.ext.comprev2", + .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", + }, + TestCase{ + .extension_name = "cel.lib.ext.regex", + // Runtime functions are not registered. Add manually if needed. + .extension_versions = {}, + .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P( + RuntimeStdExtensionTest, RuntimeStdExtensionTest, + ::testing::ValuesIn(GetRuntimeStdExtensionTestCases())); + +} // namespace +} // namespace cel diff --git a/extensions/BUILD b/extensions/BUILD index 1e6e9204a..fe97af46a 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -75,6 +75,7 @@ cc_library( srcs = ["math_ext.cc"], hdrs = ["math_ext.h"], deps = [ + ":math_ext_decls", "//common:casting", "//common:value", "//eval/public:cel_function_registry", diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 7b3655de3..4d133d90c 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -308,7 +308,8 @@ Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { } // namespace absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); @@ -360,6 +361,9 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, registry))); + if (version == 0) { + return absl::OkStatus(); + } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( @@ -370,15 +374,6 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.round", RoundDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtInt, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.trunc", TruncDouble, registry))); @@ -453,6 +448,20 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightUint, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtUint, registry))); + return absl::OkStatus(); } diff --git a/extensions/math_ext.h b/extensions/math_ext.h index 63d9e964b..fe000e476 100644 --- a/extensions/math_ext.h +++ b/extensions/math_ext.h @@ -18,6 +18,7 @@ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "extensions/math_ext_decls.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" @@ -25,8 +26,9 @@ namespace cel::extensions { // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. -absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterMathExtensionFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kMathExtensionLatestVersion); absl::Status RegisterMathExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, diff --git a/extensions/strings.cc b/extensions/strings.cc index 652c72572..ed6f27319 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -306,17 +306,8 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "join", /*receiver_style=*/true), - UnaryFunctionAdapter, ListValue>::WrapFunction( - Join1))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, ListValue, StringValue>:: - CreateDescriptor("join", /*receiver_style=*/true), - BinaryFunctionAdapter, ListValue, - StringValue>::WrapFunction(Join2))); + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -350,7 +341,6 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, @@ -388,9 +378,32 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "trim", &Trim, registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "join", /*receiver_style=*/true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + Join1))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, StringValue>:: + CreateDescriptor("join", /*receiver_style=*/true), + BinaryFunctionAdapter, ListValue, + StringValue>::WrapFunction(Join2))); + if (version == 2) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "reverse", &Reverse, registry))); diff --git a/extensions/strings.h b/extensions/strings.h index 5dab33c5d..3cbc9f19f 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -28,8 +28,9 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; // Register extension functions for strings. -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kStringsExtensionLatestVersion); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 50b6e28ea..76e95751b 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -69,7 +69,7 @@ Value FormatDouble(double v, const Function::InvokeContext& context) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( "double format error: ", std::make_error_code(result.ec).message()))); } - absl::string_view out(buf, result.ptr); + absl::string_view out(buf, result.ptr - buf); return StringValue::From(out, arena); #endif }