Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,10 @@ cc_library(
"//eval/eval:evaluator_core",
"//eval/eval:regex_match_step",
"//internal:casts",
"//internal:re2_options",
"//internal:status_macros",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_googlesource_code_re2//:re2",
Expand Down
13 changes: 8 additions & 5 deletions eval/compiler/regex_precompilation_optimization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "base/builtins.h"
Expand All @@ -39,7 +38,6 @@
#include "eval/eval/evaluator_core.h"
#include "eval/eval/regex_match_step.h"
#include "internal/casts.h"
#include "internal/re2_options.h"
#include "internal/status_macros.h"
#include "re2/re2.h"

Expand Down Expand Up @@ -106,9 +104,14 @@ class RegexProgramBuilder final {
}
programs_.erase(existing);
}
auto program =
std::make_shared<RE2>(pattern, cel::internal::MakeRE2Options());
CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(*program, max_program_size_));
auto program = std::make_shared<RE2>(pattern);
if (max_program_size_ > 0 && program->ProgramSize() > max_program_size_) {
return absl::InvalidArgumentError("exceeded RE2 max program size");
}
if (!program->ok()) {
return absl::InvalidArgumentError(
"invalid_argument unsupported RE2 pattern for matches");
}
programs_.insert({std::move(pattern), program});
return program;
}
Expand Down
4 changes: 2 additions & 2 deletions eval/eval/regex_match_step_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ TEST(RegexMatchStep, PrecompiledInvalidRegex) {
ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options));
EXPECT_THAT(expr_builder->CreateExpression(&checked_expr),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("invalid regular expression")));
HasSubstr("invalid_argument")));
}

TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) {
Expand All @@ -94,7 +94,7 @@ TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) {
ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options));
EXPECT_THAT(expr_builder->CreateExpression(&checked_expr),
StatusIs(absl::StatusCode::kInvalidArgument,
Eq("regular expressions exceeds max allowed size")));
Eq("exceeded RE2 max program size")));
}

} // namespace
Expand Down
4 changes: 0 additions & 4 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,13 @@ cc_library(
"//common:value",
"//eval/public:cel_function_registry",
"//eval/public:cel_options",
"//internal:re2_options",
"//internal:status_macros",
"//runtime:function_adapter",
"//runtime:function_registry",
"//runtime:runtime_options",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
Expand Down Expand Up @@ -765,7 +763,6 @@ cc_library(
"//eval/public:cel_function_registry",
"//eval/public:cel_options",
"//internal:casts",
"//internal:re2_options",
"//internal:status_macros",
"//runtime:function_adapter",
"//runtime:function_registry",
Expand All @@ -774,7 +771,6 @@ cc_library(
"//runtime/internal:runtime_impl",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand Down
90 changes: 41 additions & 49 deletions extensions/regex_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/functional/bind_front.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
Expand All @@ -35,7 +34,6 @@
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_options.h"
#include "internal/casts.h"
#include "internal/re2_options.h"
#include "internal/status_macros.h"
#include "runtime/function_adapter.h"
#include "runtime/function_registry.h"
Expand All @@ -52,18 +50,19 @@ namespace {

using ::cel::checker_internal::BuiltinsArena;

Value Extract(int regex_max_program_size, const StringValue& target,
const StringValue& regex,
Value Extract(const StringValue& target, const StringValue& regex,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
std::string target_scratch;
std::string regex_scratch;
absl::string_view target_view = target.ToStringView(&target_scratch);
absl::string_view regex_view = regex.ToStringView(&regex_scratch);
RE2 re2(regex_view, cel::internal::MakeRE2Options());
CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size))
.With(ErrorValueReturn());
RE2 re2(regex_view);
if (!re2.ok()) {
return ErrorValue(absl::InvalidArgumentError(
absl::StrFormat("given regex is invalid: %s", re2.error())));
}
const int group_count = re2.NumberOfCapturingGroups();
if (group_count > 1) {
return ErrorValue(absl::InvalidArgumentError(absl::StrFormat(
Expand All @@ -84,18 +83,19 @@ Value Extract(int regex_max_program_size, const StringValue& target,
return OptionalValue::None();
}

Value ExtractAll(int regex_max_program_size, const StringValue& target,
const StringValue& regex,
Value ExtractAll(const StringValue& target, const StringValue& regex,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
std::string target_scratch;
std::string regex_scratch;
absl::string_view target_view = target.ToStringView(&target_scratch);
absl::string_view regex_view = regex.ToStringView(&regex_scratch);
RE2 re2(regex_view, cel::internal::MakeRE2Options());
CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size))
.With(ErrorValueReturn());
RE2 re2(regex_view);
if (!re2.ok()) {
return ErrorValue(absl::InvalidArgumentError(
absl::StrFormat("given regex is invalid: %s", re2.error())));
}
const int group_count = re2.NumberOfCapturingGroups();
if (group_count > 1) {
return ErrorValue(absl::InvalidArgumentError(absl::StrFormat(
Expand Down Expand Up @@ -142,8 +142,8 @@ Value ExtractAll(int regex_max_program_size, const StringValue& target,
return std::move(*builder).Build();
}

Value ReplaceAll(int regex_max_program_size, const StringValue& target,
const StringValue& regex, const StringValue& replacement,
Value ReplaceAll(const StringValue& target, const StringValue& regex,
const StringValue& replacement,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
Expand All @@ -154,9 +154,12 @@ Value ReplaceAll(int regex_max_program_size, const StringValue& target,
absl::string_view regex_view = regex.ToStringView(&regex_scratch);
absl::string_view replacement_view =
replacement.ToStringView(&replacement_scratch);
RE2 re2(regex_view, cel::internal::MakeRE2Options());
CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size))
.With(ErrorValueReturn());
RE2 re2(regex_view);
if (!re2.ok()) {
return ErrorValue(absl::InvalidArgumentError(
absl::StrFormat("given regex is invalid: %s", re2.error())));
}

std::string error_string;
if (!re2.CheckRewriteString(replacement_view, &error_string)) {
return ErrorValue(absl::InvalidArgumentError(
Expand All @@ -169,18 +172,17 @@ Value ReplaceAll(int regex_max_program_size, const StringValue& target,
return StringValue::From(std::move(output), arena);
}

Value ReplaceN(int regex_max_program_size, const StringValue& target,
const StringValue& regex, const StringValue& replacement,
int64_t count,
Value ReplaceN(const StringValue& target, const StringValue& regex,
const StringValue& replacement, int64_t count,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
if (count == 0) {
return target;
}
if (count < 0) {
return ReplaceAll(regex_max_program_size, target, regex, replacement,
descriptor_pool, message_factory, arena);
return ReplaceAll(target, regex, replacement, descriptor_pool,
message_factory, arena);
}

std::string target_scratch;
Expand All @@ -190,9 +192,11 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target,
absl::string_view regex_view = regex.ToStringView(&regex_scratch);
absl::string_view replacement_view =
replacement.ToStringView(&replacement_scratch);
RE2 re2(regex_view, cel::internal::MakeRE2Options());
CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size))
.With(ErrorValueReturn());
RE2 re2(regex_view);
if (!re2.ok()) {
return ErrorValue(absl::InvalidArgumentError(
absl::StrFormat("given regex is invalid: %s", re2.error())));
}
std::string error_string;
if (!re2.CheckRewriteString(replacement_view, &error_string)) {
return ErrorValue(absl::InvalidArgumentError(
Expand Down Expand Up @@ -229,35 +233,25 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target,
}

absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
bool disable_extract,
int regex_max_program_size) {
bool disable_extract) {
if (!disable_extract) {
CEL_RETURN_IF_ERROR((
BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
RegisterGlobalOverload(
"regex.extract",
absl::bind_front(&Extract, regex_max_program_size), registry)));
RegisterGlobalOverload("regex.extract", &Extract, registry)));
}
CEL_RETURN_IF_ERROR(
(BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
RegisterGlobalOverload(
"regex.extractAll",
absl::bind_front(&ExtractAll, regex_max_program_size),
registry)));
RegisterGlobalOverload("regex.extractAll", &ExtractAll, registry)));
CEL_RETURN_IF_ERROR(
(TernaryFunctionAdapter<
absl::StatusOr<Value>, StringValue, StringValue,
StringValue>::RegisterGlobalOverload("regex.replace",
absl::bind_front(
&ReplaceAll,
regex_max_program_size),
StringValue>::RegisterGlobalOverload("regex.replace", &ReplaceAll,
registry)));
CEL_RETURN_IF_ERROR(
(QuaternaryFunctionAdapter<absl::StatusOr<Value>, StringValue,
StringValue, StringValue, int64_t>::
RegisterGlobalOverload(
"regex.replace",
absl::bind_front(&ReplaceN, regex_max_program_size), registry)));
(QuaternaryFunctionAdapter<
absl::StatusOr<Value>, StringValue, StringValue, StringValue,
int64_t>::RegisterGlobalOverload("regex.replace", &ReplaceN,
registry)));
return absl::OkStatus();
}

Expand Down Expand Up @@ -314,10 +308,9 @@ absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) {
"regex extensions requires the optional types to be enabled");
}
if (runtime.expr_builder().options().enable_regex) {
CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(
builder.function_registry(),
/*disable_extract=*/false,
runtime.expr_builder().options().regex_max_program_size));
CEL_RETURN_IF_ERROR(
RegisterRegexExtensionFunctions(builder.function_registry(),
/*disable_extract=*/false));
}
return absl::OkStatus();
}
Expand All @@ -327,8 +320,7 @@ absl::Status RegisterRegexExtensionFunctions(
const google::api::expr::runtime::InterpreterOptions& options) {
if (options.enable_regex) {
return RegisterRegexExtensionFunctions(registry->InternalGetRegistry(),
/*disable_extract=*/true,
options.regex_max_program_size);
/*disable_extract=*/true);
}
return absl::OkStatus();
}
Expand Down
8 changes: 4 additions & 4 deletions extensions/regex_ext_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,15 @@ std::vector<RegexExtTestCase> regexTestCases() {

// Runtime Errors
{EvaluationType::kRuntimeError, R"(regex.extract('foo', 'fo(o+)(abc'))",
"invalid regular expression: missing ): fo(o+)(abc"},
"given regex is invalid: missing ): fo(o+)(abc"},
{EvaluationType::kRuntimeError, R"(regex.extractAll('foo bar', '[a-z'))",
"invalid regular expression: missing ]: [a-z"},
"given regex is invalid: missing ]: [a-z"},
{EvaluationType::kRuntimeError,
R"(regex.replace('foo bar', '[a-z', 'a'))",
"invalid regular expression: missing ]: [a-z"},
"given regex is invalid: missing ]: [a-z"},
{EvaluationType::kRuntimeError,
R"(regex.replace('foo bar', '[a-z', 'a', 1))",
"invalid regular expression: missing ]: [a-z"},
"given regex is invalid: missing ]: [a-z"},
{EvaluationType::kRuntimeError,
R"(regex.replace('id=123', r'id=(?P<value>\d+)', r'value: \values'))",
R"(invalid replacement string: Rewrite schema error: '\' must be followed by a digit or '\'.)"},
Expand Down
Loading