diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..481072ee8 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -524,12 +524,10 @@ cc_library( "//eval/eval:evaluator_core", "//eval/eval:regex_match_step", "//internal:casts", - "//internal:re2_options", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_googlesource_code_re2//:re2", diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index b94cae383..dcc7edd2b 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -24,7 +24,6 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/builtins.h" @@ -39,7 +38,6 @@ #include "eval/eval/evaluator_core.h" #include "eval/eval/regex_match_step.h" #include "internal/casts.h" -#include "internal/re2_options.h" #include "internal/status_macros.h" #include "re2/re2.h" @@ -106,9 +104,14 @@ class RegexProgramBuilder final { } programs_.erase(existing); } - auto program = - std::make_shared(pattern, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(*program, max_program_size_)); + auto program = std::make_shared(pattern); + if (max_program_size_ > 0 && program->ProgramSize() > max_program_size_) { + return absl::InvalidArgumentError("exceeded RE2 max program size"); + } + if (!program->ok()) { + return absl::InvalidArgumentError( + "invalid_argument unsupported RE2 pattern for matches"); + } programs_.insert({std::move(pattern), program}); return program; } diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc index 8d54a0188..96d0e7a4a 100644 --- a/eval/eval/regex_match_step_test.cc +++ b/eval/eval/regex_match_step_test.cc @@ -76,7 +76,7 @@ TEST(RegexMatchStep, PrecompiledInvalidRegex) { ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("invalid regular expression"))); + HasSubstr("invalid_argument"))); } TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { @@ -94,7 +94,7 @@ TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, - Eq("regular expressions exceeds max allowed size"))); + Eq("exceeded RE2 max program size"))); } } // namespace diff --git a/extensions/BUILD b/extensions/BUILD index 1e6e9204a..4753215cb 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -192,7 +192,6 @@ cc_library( "//common:value", "//eval/public:cel_function_registry", "//eval/public:cel_options", - "//internal:re2_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", @@ -200,7 +199,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -765,7 +763,6 @@ cc_library( "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:casts", - "//internal:re2_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", @@ -774,7 +771,6 @@ cc_library( "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index c3d7cae53..99d1a9c4b 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -21,7 +21,6 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" -#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -35,7 +34,6 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/casts.h" -#include "internal/re2_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" @@ -52,8 +50,7 @@ namespace { using ::cel::checker_internal::BuiltinsArena; -Value Extract(int regex_max_program_size, const StringValue& target, - const StringValue& regex, +Value Extract(const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -61,9 +58,11 @@ Value Extract(int regex_max_program_size, const StringValue& target, std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } const int group_count = re2.NumberOfCapturingGroups(); if (group_count > 1) { return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( @@ -84,8 +83,7 @@ Value Extract(int regex_max_program_size, const StringValue& target, return OptionalValue::None(); } -Value ExtractAll(int regex_max_program_size, const StringValue& target, - const StringValue& regex, +Value ExtractAll(const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -93,9 +91,11 @@ Value ExtractAll(int regex_max_program_size, const StringValue& target, std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } const int group_count = re2.NumberOfCapturingGroups(); if (group_count > 1) { return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( @@ -142,8 +142,8 @@ Value ExtractAll(int regex_max_program_size, const StringValue& target, return std::move(*builder).Build(); } -Value ReplaceAll(int regex_max_program_size, const StringValue& target, - const StringValue& regex, const StringValue& replacement, +Value ReplaceAll(const StringValue& target, const StringValue& regex, + const StringValue& replacement, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -154,9 +154,12 @@ Value ReplaceAll(int regex_max_program_size, const StringValue& target, absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view replacement_view = replacement.ToStringView(&replacement_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } + std::string error_string; if (!re2.CheckRewriteString(replacement_view, &error_string)) { return ErrorValue(absl::InvalidArgumentError( @@ -169,9 +172,8 @@ Value ReplaceAll(int regex_max_program_size, const StringValue& target, return StringValue::From(std::move(output), arena); } -Value ReplaceN(int regex_max_program_size, const StringValue& target, - const StringValue& regex, const StringValue& replacement, - int64_t count, +Value ReplaceN(const StringValue& target, const StringValue& regex, + const StringValue& replacement, int64_t count, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -179,8 +181,8 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target, return target; } if (count < 0) { - return ReplaceAll(regex_max_program_size, target, regex, replacement, - descriptor_pool, message_factory, arena); + return ReplaceAll(target, regex, replacement, descriptor_pool, + message_factory, arena); } std::string target_scratch; @@ -190,9 +192,11 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target, absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view replacement_view = replacement.ToStringView(&replacement_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("given regex is invalid: %s", re2.error()))); + } std::string error_string; if (!re2.CheckRewriteString(replacement_view, &error_string)) { return ErrorValue(absl::InvalidArgumentError( @@ -229,35 +233,25 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target, } absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, - bool disable_extract, - int regex_max_program_size) { + bool disable_extract) { if (!disable_extract) { CEL_RETURN_IF_ERROR(( BinaryFunctionAdapter, StringValue, StringValue>:: - RegisterGlobalOverload( - "regex.extract", - absl::bind_front(&Extract, regex_max_program_size), registry))); + RegisterGlobalOverload("regex.extract", &Extract, registry))); } CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: - RegisterGlobalOverload( - "regex.extractAll", - absl::bind_front(&ExtractAll, regex_max_program_size), - registry))); + RegisterGlobalOverload("regex.extractAll", &ExtractAll, registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, - StringValue>::RegisterGlobalOverload("regex.replace", - absl::bind_front( - &ReplaceAll, - regex_max_program_size), + StringValue>::RegisterGlobalOverload("regex.replace", &ReplaceAll, registry))); CEL_RETURN_IF_ERROR( - (QuaternaryFunctionAdapter, StringValue, - StringValue, StringValue, int64_t>:: - RegisterGlobalOverload( - "regex.replace", - absl::bind_front(&ReplaceN, regex_max_program_size), registry))); + (QuaternaryFunctionAdapter< + absl::StatusOr, StringValue, StringValue, StringValue, + int64_t>::RegisterGlobalOverload("regex.replace", &ReplaceN, + registry))); return absl::OkStatus(); } @@ -314,10 +308,9 @@ absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) { "regex extensions requires the optional types to be enabled"); } if (runtime.expr_builder().options().enable_regex) { - CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions( - builder.function_registry(), - /*disable_extract=*/false, - runtime.expr_builder().options().regex_max_program_size)); + CEL_RETURN_IF_ERROR( + RegisterRegexExtensionFunctions(builder.function_registry(), + /*disable_extract=*/false)); } return absl::OkStatus(); } @@ -327,8 +320,7 @@ absl::Status RegisterRegexExtensionFunctions( const google::api::expr::runtime::InterpreterOptions& options) { if (options.enable_regex) { return RegisterRegexExtensionFunctions(registry->InternalGetRegistry(), - /*disable_extract=*/true, - options.regex_max_program_size); + /*disable_extract=*/true); } return absl::OkStatus(); } diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index e69f7cce1..b2e452ff2 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -378,15 +378,15 @@ std::vector regexTestCases() { // Runtime Errors {EvaluationType::kRuntimeError, R"(regex.extract('foo', 'fo(o+)(abc'))", - "invalid regular expression: missing ): fo(o+)(abc"}, + "given regex is invalid: missing ): fo(o+)(abc"}, {EvaluationType::kRuntimeError, R"(regex.extractAll('foo bar', '[a-z'))", - "invalid regular expression: missing ]: [a-z"}, + "given regex is invalid: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('foo bar', '[a-z', 'a'))", - "invalid regular expression: missing ]: [a-z"}, + "given regex is invalid: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('foo bar', '[a-z', 'a', 1))", - "invalid regular expression: missing ]: [a-z"}, + "given regex is invalid: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \values'))", R"(invalid replacement string: Rewrite schema error: '\' must be followed by a digit or '\'.)"}, diff --git a/extensions/regex_functions.cc b/extensions/regex_functions.cc index 005987ae4..3b3c80a00 100644 --- a/extensions/regex_functions.cc +++ b/extensions/regex_functions.cc @@ -21,7 +21,6 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" -#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -32,7 +31,6 @@ #include "common/value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" -#include "internal/re2_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" @@ -51,8 +49,8 @@ using ::google::api::expr::runtime::InterpreterOptions; // Extract matched group values from the given target string and rewrite the // string -Value ExtractString(int regex_max_program_size, const StringValue& target, - const StringValue& regex, const StringValue& rewrite, +Value ExtractString(const StringValue& target, const StringValue& regex, + const StringValue& rewrite, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -63,9 +61,10 @@ Value ExtractString(int regex_max_program_size, const StringValue& target, absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view rewrite_view = rewrite.ToStringView(&rewrite_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError("Given Regex is Invalid")); + } std::string output; bool result = RE2::Extract(target_view, re2, rewrite_view, &output); if (!result) { @@ -77,8 +76,7 @@ Value ExtractString(int regex_max_program_size, const StringValue& target, // Captures the first unnamed/named group value // NOTE: For capturing all the groups, use CaptureStringN instead -Value CaptureString(int regex_max_program_size, const StringValue& target, - const StringValue& regex, +Value CaptureString(const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -86,9 +84,10 @@ Value CaptureString(int regex_max_program_size, const StringValue& target, std::string target_scratch; absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view target_view = target.ToStringView(&target_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError("Given Regex is Invalid")); + } std::string output; bool result = RE2::FullMatch(target_view, re2, &output); if (!result) { @@ -104,8 +103,7 @@ Value CaptureString(int regex_max_program_size, const StringValue& target, // a. For a named group - // b. For an unnamed group - absl::StatusOr CaptureStringN( - int regex_max_program_size, const StringValue& target, - const StringValue& regex, + const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -113,9 +111,10 @@ absl::StatusOr CaptureStringN( std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); - RE2 re2(regex_view, cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) - .With(ErrorValueReturn()); + RE2 re2(regex_view); + if (!re2.ok()) { + return ErrorValue(absl::InvalidArgumentError("Given Regex is Invalid")); + } const int capturing_groups_count = re2.NumberOfCapturingGroups(); const auto& named_capturing_groups_map = re2.CapturingGroupNames(); if (capturing_groups_count <= 0) { @@ -149,33 +148,25 @@ absl::StatusOr CaptureStringN( return std::move(*builder).Build(); } -absl::Status RegisterRegexFunctions(FunctionRegistry& registry, - int max_regex_program_size) { +absl::Status RegisterRegexFunctions(FunctionRegistry& registry) { // Register Regex Extract Function CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, - StringValue>::RegisterGlobalOverload(kRegexExtract, - absl::bind_front( - &ExtractString, - max_regex_program_size), + StringValue>::RegisterGlobalOverload(kRegexExtract, &ExtractString, registry))); // Register Regex Captures Function - CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter, StringValue, StringValue>:: - RegisterGlobalOverload( - kRegexCapture, - absl::bind_front(&CaptureString, max_regex_program_size), - registry))); + CEL_RETURN_IF_ERROR(( + BinaryFunctionAdapter, StringValue, + StringValue>::RegisterGlobalOverload(kRegexCapture, + &CaptureString, + registry))); // Register Regex CaptureN Function CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: - RegisterGlobalOverload( - kRegexCaptureN, - absl::bind_front(&CaptureStringN, max_regex_program_size), - registry))); + RegisterGlobalOverload(kRegexCaptureN, &CaptureStringN, registry))); return absl::OkStatus(); } @@ -216,8 +207,7 @@ absl::Status RegisterRegexDecls(TypeCheckerBuilder& builder) { absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_regex) { - CEL_RETURN_IF_ERROR( - RegisterRegexFunctions(registry, options.regex_max_program_size)); + CEL_RETURN_IF_ERROR(RegisterRegexFunctions(registry)); } return absl::OkStatus(); } diff --git a/extensions/regex_functions_test.cc b/extensions/regex_functions_test.cc index 92a4da6bb..32416b7bd 100644 --- a/extensions/regex_functions_test.cc +++ b/extensions/regex_functions_test.cc @@ -185,9 +185,8 @@ std::vector createParams() { {// Extract String: Fails when rewritten string has too many placeholders (R"(re.extract('foo', 'f(o+)', '\\1\\2'))"), "Unable to extract string for the given regex"}, - {// Extract String: Fails when invalid regular expression - (R"(re.extract('foo', 'f(o+)(abc', '\\1\\2'))"), - "invalid regular expression"}, + {// Extract String: Fails when regex is invalid + (R"(re.extract('foo', 'f(o+)(abc', '\\1\\2'))"), "Regex is Invalid"}, {// Capture String: Empty regex (R"(re.capture('foo', ''))"), "Unable to capture groups for the given regex"}, @@ -200,8 +199,8 @@ std::vector createParams() { {// Capture String: Mismatched groups (R"(re.capture('foo', 'fo(o+)(s)'))"), "Unable to capture groups for the given regex"}, - {// Capture String: invalid regular expression - (R"(re.capture('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, + {// Capture String: Regex is Invalid + (R"(re.capture('foo', 'fo(o+)(abc'))"), "Regex is Invalid"}, {// Capture String N: Empty regex (R"(re.captureN('foo', ''))"), "Capturing groups were not found in the given regex."}, @@ -214,8 +213,8 @@ std::vector createParams() { {// Capture String N: Mismatched groups (R"(re.captureN('foo', 'fo(o+)(s)'))"), "Unable to capture groups for the given regex"}, - {// Capture String N: invalid regular expression - (R"(re.captureN('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, + {// Capture String N: Regex is Invalid + (R"(re.captureN('foo', 'fo(o+)(abc'))"), "Regex is Invalid"}, }; } diff --git a/internal/BUILD b/internal/BUILD index 59f68df9b..f7e5586db 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -76,16 +76,6 @@ cc_library( hdrs = ["casts.h"], ) -cc_library( - name = "re2_options", - hdrs = ["re2_options.h"], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_googlesource_code_re2//:re2", - ], -) - cc_library( name = "status_builder", hdrs = ["status_builder.h"], diff --git a/internal/re2_options.h b/internal/re2_options.h deleted file mode 100644 index 9c20ceb63..000000000 --- a/internal/re2_options.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ -#define THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "re2/re2.h" - -namespace cel::internal { - -inline RE2::Options MakeRE2Options() { - RE2::Options options; - options.set_log_errors(false); - return options; -} - -inline absl::Status CheckRE2(const RE2& re, int max_program_size) { - if (!re.ok()) { - switch (re.error_code()) { - case RE2::ErrorInternal: - return absl::InternalError( - absl::StrCat("internal RE2 error: ", re.error())); - case RE2::ErrorPatternTooLarge: - return absl::InvalidArgumentError( - absl::StrCat("regular expression too large: ", re.error())); - default: - return absl::InvalidArgumentError( - absl::StrCat("invalid regular expression: ", re.error())); - } - } - int program_size = re.ProgramSize(); - if (max_program_size > 0 && program_size > 0 && - program_size > max_program_size) { - return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); - } - int reverse_program_size = re.ReverseProgramSize(); - if (max_program_size > 0 && reverse_program_size > 0 && - reverse_program_size > max_program_size) { - return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); - } - return absl::OkStatus(); -} - -} // namespace cel::internal - -#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc index 85b47ef45..308c70be0 100644 --- a/runtime/regex_precompilation_test.cc +++ b/runtime/regex_precompilation_test.cc @@ -176,7 +176,7 @@ INSTANTIATE_TEST_SUITE_P( {"matches_global_false", R"(matches(string_var, r'string_var\d+'))", IsBoolValue(false)}, {"matches_bad_re2_expression", "matches('123', r'(? Value { - RE2 re2(regex.ToString(), cel::internal::MakeRE2Options()); - CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, max_size)) - .With(ErrorValueReturn()); + RE2 re2(regex.ToString()); + if (max_size > 0 && re2.ProgramSize() > max_size) { + return ErrorValue( + absl::InvalidArgumentError("exceeded RE2 max program size")); + } + if (!re2.ok()) { + return ErrorValue( + absl::InvalidArgumentError("invalid regex for match")); + } return BoolValue(RE2::PartialMatch(target.ToString(), re2)); };