diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..4f4988e8 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,367 @@ +# rego-cpp Copilot Instructions + +## Project Overview + +`rego-cpp` is a C++ interpreter for [Rego](https://www.openpolicyagent.org/docs/latest/policy-language/), the policy language of [Open Policy Agent (OPA)](https://www.openpolicyagent.org/). It targets Rego v1.8.0 and is designed for embedding policy evaluation directly into C++ applications, as well as for use from C, Rust, Python, and .NET via language wrappers. + +The interpreter is built on top of [Trieste](https://github.com/microsoft/trieste), a term-rewriting framework from Microsoft Research. Evaluation proceeds by parsing Rego source into an AST and then iteratively rewriting it through a series of compiler passes until a final result is produced. + +## Repository Structure + +``` +rego-cpp/ +├── include/rego/ # Public headers +│ ├── rego.hh # C++ API — AST token types, interpreter, built-in system +│ └── rego_c.h # C API — flat C interface for use by other languages +├── src/ # Core library implementation +│ ├── interpreter.cc # Top-level interpreter: manages passes and compilation +│ ├── virtual_machine.cc# Bytecode-like execution engine (op-block evaluation) +│ ├── parse.cc # Rego lexer and parser (Trieste-based) +│ ├── resolver.cc # Variable resolution and unification +│ ├── rego.cc # Main Interpreter/Rewriter entry point +│ ├── rego_c.cc # C API wrapper over the C++ API +│ ├── bundle.cc # OPA bundle loading +│ ├── bundle_binary.cc # Binary bundle format support +│ ├── bundle_json.cc # JSON bundle format support +│ ├── bigint.cc # Arbitrary-precision integer arithmetic +│ ├── encoding.cc # Base64/hex encoding helpers +│ ├── json.cc # JSON parsing and serialization +│ ├── yaml.cc # YAML parsing +│ ├── opblock.cc # Op-block data structures +│ ├── dependency_graph.cc # Rule dependency analysis +│ ├── output.cc # Query result formatting +│ ├── file_to_rego.cc # File-based module loading +│ ├── rego_to_bundle.cc # Compilation to bundle format +│ ├── internal.cc / internal.hh # Shared internal utilities +│ └── builtins/ # Standard Rego built-in functions +│ ├── array.cc # array.* +│ ├── bits.cc # bits.* +│ ├── core.cc # print, type_name, etc. +│ ├── crypto.cc # crypto.* +│ ├── encoding.cc # base64/hex builtins +│ ├── glob.cc # glob.* +│ ├── graph.cc # graph.* +│ ├── graphql.cc # graphql.* +│ ├── http.cc # http.send (stub) +│ ├── json.cc # json.* +│ ├── jwt.cc # io.jwt.* +│ ├── net.cc # net.* +│ ├── numbers.cc # numbers.* +│ ├── objects.cc # object.* +│ ├── opa.cc # opa.* +│ ├── regex.cc # regex.* +│ ├── rego.cc # rego.* +│ ├── semver.cc # semver.* +│ ├── time.cc # time.* +│ ├── units.cc # units.* +│ └── uuid.cc # uuid.* +├── tests/ # Test suite +│ ├── main.cc # Test runner entry point +│ ├── cpp_api.cc # Unit tests for the C++ API +│ ├── c_api.cc # Unit tests for the C API +│ ├── builtins.cc # Built-in function tests +│ ├── test_case.cc/h # YAML test case infrastructure +│ ├── regocpp.yaml # rego-cpp–specific YAML test cases +│ ├── bugs.yaml # Regression tests for bugs +│ ├── bigint.yaml # Big integer tests +│ ├── cts/ # Conformance test suite cases +│ ├── opa/ # OPA-compatible tests (cloned from OPA repo) +│ ├── aci/ # Azure Container Instances policy tests +│ └── cheriot/ # CHERIoT policy tests +├── tools/ # Command-line tools +│ ├── main.cc # `rego` CLI: eval, test, inspect subcommands +│ └── fuzzer.cc # Trieste generative fuzzer (rego_fuzzer) +├── examples/ # Usage examples by language +│ ├── cpp/ # C++ examples (example.cc, custom_builtin.cc) +│ ├── c/ # C examples (example.c, command-line tool) +│ ├── rust/ # Rust examples +│ ├── python/ # Python examples +│ └── dotnet/ # .NET examples +├── wrappers/ # Language binding source +│ ├── rust/ # Rust crate wrapping the C API +│ ├── python/ # Python ctypes/cffi wrapper +│ └── dotnet/ # .NET P/Invoke wrapper +├── cmake/ # CMake package config templates +├── doc/ # Doxygen documentation sources +├── CMakeLists.txt # Root build definition +├── CMakePresets.json # Named build presets +└── VERSION # Semantic version file (MAJOR.MINOR.PATCH) +``` + +## Build System + +- **Language**: C++20 +- **Build tool**: CMake ≥ 3.15 with Ninja +- **Key presets** (defined in `CMakePresets.json`): + - `debug-clang` / `debug` — Debug build with tests and tools + - `release-clang` / `release` — Release build + - `debug-clang-opa` / `release-clang-opa` — Build including OPA compatibility tests + - `asan-clang` — AddressSanitizer build +- **CMake options** (all default `OFF` unless noted): + - `REGOCPP_BUILD_TOOLS` — build the `rego` CLI + - `REGOCPP_BUILD_TESTS` — build the test binaries + - `REGOCPP_BUILD_DOCS` — build Doxygen documentation + - `REGOCPP_BUILD_SHARED` — build `rego_shared` as a shared library + - `REGOCPP_OPA_TESTS` — include OPA conformance tests + - `REGOCPP_USE_SNMALLOC` — use snmalloc allocator (default `ON`) + - `REGOCPP_SANITIZE` — sanitizer flags (e.g., `address`) + +Typical workflow: + +```bash +mkdir build && cd build +cmake .. --preset release-clang +ninja install +ctest +``` + +After editing C++ source files, **always** run the formatting target before committing: + +```bash +ninja regocpp_format +``` + +CI checks formatting with clang-format 18 and will reject unformatted code. + +## Key Dependencies + +- **[Trieste](https://github.com/microsoft/trieste)** — term rewriting framework; provides the AST node types (`TokenDef`), well-formedness definitions, logging, JSON/YAML parsers, and the rewriting pass infrastructure. Fetched via `FetchContent`. +- **snmalloc** — high-performance memory allocator (optional, fetched via `FetchContent`). + +## API Design + +The library exposes two public interfaces: + +1. **C++ API** (`include/rego/rego.hh`): The primary interface. Provides access to the `Interpreter` class (via `rego::Interpreter`), AST token type constants (`rego::Module`, `rego::Rule`, etc.), the `BuiltIn` registration system for custom built-ins, `BigInt`, and all core types. Uses the `rego` namespace and inherits Trieste types via `using namespace trieste`. + +2. **C API** (`include/rego/rego_c.h`): A flat C interface for interoperability with other languages. Wraps the C++ API using opaque handle types (`regoInterpreter*`, `regoNode*`, etc.). Implemented in `src/rego_c.cc`. + +## Coding Conventions + +- All source files begin with the Microsoft copyright header. +- The `rego` namespace is used throughout; internal helpers live in anonymous namespaces. +- AST token types are declared as `inline const auto` globals in `rego.hh` using Trieste's `TokenDef`. +- Well-formedness rules for each compiler pass are defined inline alongside the pass logic. **Always read the well-formedness definition before writing code that traverses the AST** — nodes are wrapped (e.g., Array elements live inside Term nodes). Use `unwrap()` helpers, not direct `type()` comparisons. +- Built-in functions are registered via the `BuiltIn` class and grouped by OPA namespace in `src/builtins/`. +- Test cases are expressed as YAML files whenever possible, using the OPA test case format. +- Error messages in built-in functions must match OPA's reference implementation exactly — conformance tests compare error strings literally. +- **Fix root causes, not symptoms.** When a test fails, investigate *why* the incorrect behavior occurs — trace the logic, inspect intermediate state, and identify the underlying defect. Do not apply surface-level patches (e.g., special-casing an output, suppressing an error, or working around stale state) just to make a test pass. A correct fix eliminates the class of bug, not just the one observable instance. +- **Move slow to go fast.** Make small, testable changes one step at a time instead of large simultaneous edits. After each change, compile and run the relevant tests before moving on. Small increments are easier to verify, easier to debug when something breaks, and produce cleaner diffs. Resist the urge to batch multiple logical changes into a single edit. + +## Trieste Development Workflow + +rego-cpp is built on [Trieste](https://github.com/microsoft/trieste), a multi-pass term-rewriting framework. Understanding the Trieste workflow is essential for any pass or AST work. + +### Analysis Perspectives + +For non-trivial features (new syntax, new passes, AST restructuring), analyze the problem from **four perspectives** before writing code: + +1. **Reference Implementation (OPA)**: Inspect OPA's IR plan output to understand how OPA compiles the feature. Use `opa build --bundle --target plan` and examine `plan.json` for internal built-in names, calling conventions, and undefined-handling patterns. + +2. **AST Pipeline Impact**: Map the feature to specific passes in the two pipelines: + - **File-to-rego** (18 passes in `src/file_to_rego.cc`): parsing through structured AST + - **Rego-to-bundle** (11 passes in `src/rego_to_bundle.cc`): structured AST through executable bytecode + - Identify which passes need modification, new token types, WF changes, and VM changes. + +3. **Well-formedness Chain**: Trace the WF definitions from the first affected pass to the last. WF definitions are incremental — each extends the previous with `|`. Verify no downstream pass breaks. + +4. **Test Strategy**: Plan verification at each stage — YAML test cases, OPA conformance tests, generative fuzzing (`./build/tools/rego_fuzzer `), and ASan builds. + +### Multi-perspective Planning Process + +When planning a non-trivial code change, use four sub-planners running in parallel to generate competing plans, then synthesise the best elements into a final plan. + +#### Step 1 — Gather sub-plans + +Spawn **four fresh subagents**, each prompted to use one of the following skills. Each subagent receives the same task description and context but plans through a different lens: + +| Subagent | Skill | Focus | +|----------|-------|-------| +| Speed Planner | `/plan-speed` | Runtime performance, low allocations, minimal passes, cache efficiency | +| Security Planner | `/plan-security` | Defence in depth, safe error handling, bounded resources, fuzz coverage | +| Usability Planner | `/plan-usability` | Clarity, readability, correctness, consistent naming, one-concept-per-pass | +| Conservative Planner | `/plan-conservative` | Smallest diff, maximum reuse, no speculative generality, backwards compat | + +Prompt each subagent with: +> You are planning a change to the rego-cpp project. Use the `/[skill-name]` skill to guide your planning. Here is the task: [task description and relevant context]. Produce a numbered plan following the output format defined in the skill. + +#### Step 2 — Evaluate the four plans + +Review the four plans yourself and produce a short evaluation covering: + +- **Convergence**: where two or more plans agree on the same approach. High convergence suggests a clearly correct design. +- **Unique insights**: ideas that appear in only one plan and are worth incorporating. +- **Conflicts**: where plans disagree. For each conflict, state which perspective you favour and why. +- **Gaps**: anything none of the four plans addressed. + +#### Step 3 — Synthesise the final plan + +Spawn a **fifth subagent** (the synthesiser). Provide it with: +- The original task description. +- All four sub-plans (labelled by perspective). +- Your evaluation from Step 2. + +Prompt the synthesiser with: +> You are producing the final plan for a change to rego-cpp. You have received four sub-plans from different perspectives (Speed, Security, Usability, Conservative) and an evaluation of those plans. Synthesise them into a single coherent, numbered plan that balances all four concerns. Where the evaluation favours one perspective, follow it. Where the evaluation is neutral, prefer the Conservative approach. Output the final plan in the standard format: Goal, Steps (with file paths and descriptions balancing all four perspectives), Rationale (explaining the synthesis), and Trade-offs (any conflicts between perspectives and how they were resolved). + +#### Step 4 — Review the synthesised plan + +Before presenting the plan, run an iterative review loop: + +1. Spawn a subagent to review the synthesised plan. Provide it with the original task description, the four sub-plans, your evaluation, and the synthesised plan. Ask it to check for: logical errors in the step ordering, steps that contradict each other, missing error handling or edge cases, violations of rego-cpp conventions, and anything the synthesis dropped that should have been kept. +2. If the review finds issues, revise the plan yourself and spawn a **different** subagent to review the revised version. +3. Repeat until a review comes back clean (no issues found). + +#### Step 5 — Present for approval + +Present the reviewed plan to the user along with a brief summary of: +- Key points of agreement across the four sub-planners. +- Notable trade-offs made during synthesis. +- Any minority opinions from individual sub-planners that were overruled. +- Issues caught and resolved during the review loop (if any). + +#### When to use multi-perspective planning + +Use the full five-step process for design decisions where the shape of the solution is uncertain: new language features, new passes, API changes, AST restructuring, or cross-cutting concerns that touch multiple pipeline stages. + +For tasks that are primarily implementation of a well-understood algorithm (e.g. a new built-in function with a clear OPA specification), a single conservative plan with emphasis on incremental testing is sufficient. Use the full process when the design is uncertain, not when the algorithm is known. + +### Pass Implementation Pattern + +Each pass is a `PassDef` with pattern → effect rewrite rules: + +```cpp +PassDef my_pass() +{ + return { + "my_pass", // Name + wf_my_pass, // Output well-formedness + dir::bottomup | dir::once, // Traversal direction + { + In(Parent) * T(Child)[C] >> [](Match& _) { return NewNode << _(C); }, + } + }; +} +``` + +Key principles from Trieste: +- **Prefer many small passes over few complex ones** — "there is no downside to having many passes" +- **Implement one pass at a time** and test between changes +- **Add error rules** for invalid inputs that WF allows — generative testing will find them +- **Rule order matters** — first match wins; put specific rules before general ones +- **Operator precedence via separate passes** — higher precedence operators in earlier passes (e.g., `arithbin_first` for ×÷% before `arithbin_second` for +−) + +### Incremental Implementation + +1. Write test cases first (YAML in `tests/regocpp.yaml` or `tests/bugs.yaml`) +2. Modify the WF definition for the pass output +3. Add rewrite rules (positive rules first, then error rules) +4. Run targeted tests: `./build/tests/rego_test -wf tests/regocpp.yaml` +5. Dump the AST to verify: `./build/tools/rego eval --dump_passes .copilot/pass-debug/ ''` +6. Move to the next pass and repeat + +## Running Tests + +### Test Driver (`rego_test`) + +The test binary `./build/tests/rego_test` accepts YAML test case files or directories as arguments. When given a directory, it recursively discovers all YAML test files within it. + +```bash +# Run a specific YAML test file +./build/tests/rego_test -wf tests/regocpp.yaml + +# Run all tests in a directory +./build/tests/rego_test -wf tests/cts/ + +# Run with well-formedness checking disabled (faster, no WF validation) +./build/tests/rego_test tests/bugs.yaml +``` + +The `-wf` flag enables well-formedness checking at each pass boundary (recommended during development). + +### OPA Conformance Tests + +OPA test cases are **not** checked into the repo — they live in the build directory, fetched by CMake from the OPA repository. The test root is: + +``` +build/opa/v1/test/cases/testdata/v1/ +``` + +Each subdirectory under v1 is a separate test suite (e.g., `stringinterpolation`, `aggregates`, `with`). To run individual OPA test suites without running the full (slow) OPA test: + +```bash +# Run from the build directory — paths are relative to the working directory +cd build && ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/stringinterpolation + +# Run a different OPA suite +cd build && ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/with + +# List available OPA test suites +ls build/opa/v1/test/cases/testdata/v1/ +``` + +### CTest + +Use `ctest` to run predefined test targets: + +```bash +cd build && ctest --output-on-failure # all tests +ctest -R rego_test_regocpp # just rego-cpp tests +ctest -R rego_test_opa # full OPA conformance suite +ctest -R "rego_test_regocpp|rego_test_bugs" # multiple targets +``` + +When iterating on a specific feature, **prefer running individual OPA subdirectory tests** over the full `rego_test_opa` target — it runs much faster. + +### Debugging with lldb + +Debug builds (e.g., `build-mbedtls` with `CMAKE_BUILD_TYPE=Debug`) include full debug symbols. Use `lldb` to diagnose test failures, crashes, or incorrect results: + +```bash +# Break at a specific function and run a single test case +cd build && lldb ./tests/rego_test -- opa/v1/test/cases/testdata/v1//.yaml +(lldb) b +(lldb) run + +# Useful commands once stopped +(lldb) bt # backtrace +(lldb) frame variable # show local variables +(lldb) p # print expression +(lldb) n / s / c # next / step / continue +``` + +This is particularly useful for debugging backend-specific failures (e.g., a test passes with OpenSSL but fails with mbedTLS) where the issue is in crypto or encoding logic. + +### Generative Fuzzer (`rego_fuzzer`) + +The `rego_fuzzer` tool generates random ASTs from the Trieste WF chain and tests that each pass handles all structurally valid inputs. It is parameterized by a transform (`file_to_rego`, `rego_to_bundle`, `json_to_bundle`, `bundle_to_json`), a sample count, and a seed. + +```bash +# Basic run (default: 100 samples, random seed) +./build/tools/rego_fuzzer rego_to_bundle + +# With specific count and seed, stop on first failure +./build/tools/rego_fuzzer rego_to_bundle -c 1000 -f -s 42 + +# Reproduce a specific failure +./build/tools/rego_fuzzer rego_to_bundle -c 1 -s +``` + +Passing the fuzzer means running with `-c 1000` three times (different seeds) with exit code 0 each time. CTest targets (`rego_fuzzer_*`) run each transform with the default count. + +## Investigating New OPA Features + +When implementing a new OPA feature (especially new syntax or internal built-ins), **inspect OPA's IR plan output first** to understand the reference implementation: + +1. Download the latest OPA binary matching `REGOCPP_OPA_VERSION` +2. Create a minimal policy in `.copilot/opa-ir-test/` that exercises the feature +3. Run: `opa build --bundle --target plan -e -o bundle.tar.gz` +4. Extract and inspect `plan.json` — look for new entries in `static.builtin_funcs`, calling conventions, and undefined-handling patterns + +This reveals internal built-in names (e.g., `internal.template_string`), argument conventions, and patterns that must be matched for compatibility. Always test with both constant and variable expressions since OPA's optimizer may fold constants. + +## Scratch / Temporary Files + +Use the `.copilot/` directory at the repo root for all temporary files, downloaded executables, test scripts, and scratch work produced during development. This keeps temporary artifacts visible and inspectable within the workspace rather than scattered in `/tmp`. The `.copilot/` directory is gitignored. Organize by task, e.g.: +- `.copilot/opa-ir-test/` — OPA IR analysis scratch files +- `.copilot/bin/` — downloaded tool binaries (e.g., OPA CLI) diff --git a/.github/skills/bump-version/SKILL.md b/.github/skills/bump-version/SKILL.md new file mode 100644 index 00000000..6dad4fb1 --- /dev/null +++ b/.github/skills/bump-version/SKILL.md @@ -0,0 +1,46 @@ +--- +name: bump-version +description: 'Bump the rego-cpp version number for a new release. Use when: preparing a release, updating version strings after a tag, or when instructed to bump the version. Updates all version files across the main library and wrapper packages to keep them in sync.' +--- + +# Bumping the Version + +Update all version strings across the rego-cpp project for a new release. + +## When to Use + +- Preparing a new release (major, minor, or patch) +- After discovering wrapper versions are out of sync with the main VERSION file +- When instructed to bump versions + +## Files to Update + +Every release requires updating version strings in **all** of the following +locations. Missing any one of them creates a version mismatch between the +library and its wrapper packages. + +| File | Field / Pattern | Example | +|------|----------------|---------| +| `VERSION` | Entire file contents | `1.3.0` | +| `wrappers/python/setup.py` | `VERSION = "X.Y.Z"` | `VERSION = "1.3.0"` | +| `wrappers/rust/regorust/Cargo.toml` | `version = "X.Y.Z"` | `version = "1.3.0"` | +| `wrappers/dotnet/Rego/Rego.csproj` | `X.Y.Z` | `1.3.0` | + +## Procedure + +1. Read the current version from the `VERSION` file at the repo root. +2. Determine the new version (from user instruction or by incrementing). +3. Update all four files listed above. +4. Verify no other files reference the old version: + ```bash + grep -rn '"OLD_VERSION"' wrappers/ VERSION + ``` +5. Update the CHANGELOG with the new version header if not already present. + +## Common Mistakes + +- **Forgetting wrapper versions**: The wrapper packages (Python, Rust, .NET) + each have their own version string that must match the main `VERSION` file. + These are easy to miss because they live in different directories and formats. +- **Cargo.lock stale**: After updating `Cargo.toml`, run `cargo update` in the + Rust wrapper directory if a lockfile exists, or CI may fail. diff --git a/.github/skills/code-review/SKILL.md b/.github/skills/code-review/SKILL.md new file mode 100644 index 00000000..6280bd6c --- /dev/null +++ b/.github/skills/code-review/SKILL.md @@ -0,0 +1,151 @@ +--- +name: code-review +description: 'Perform a multi-perspective code review of rego-cpp changes. Use when: reviewing a release, auditing a branch diff, evaluating a PR, or performing a pre-merge code review. Launches four parallel review subagents (Security, Performance, Usability, Conservative), verifies key findings, synthesises a unified report with severity-ranked findings, and produces actionable remediation recommendations.' +--- + +# Multi-Perspective Code Review + +Perform a structured code review by examining changes from four independent +perspectives, cross-checking findings against source code, and producing a +unified report with actionable recommendations. + +## When to Use + +- Before tagging a release +- Reviewing a large branch diff or PR +- Auditing a new subsystem (crypto, parsing, VM changes) +- When a single-perspective review would miss cross-cutting concerns + +## Background + +A single reviewer tends toward their own bias — a security expert over-flags +performance patterns, a performance expert under-flags input validation. This +skill runs four parallel reviews, each with a strict lens, then synthesises +findings where multiple perspectives converge or provide unique insight. + +## Perspectives + +| Perspective | Lens | Skill file | +|-------------|------|------------| +| **Security** | Defence in depth, memory safety, bounded resources, error handling, adversarial inputs, C API boundaries, fuzz coverage | [plan-security](../plan-security/SKILL.md) | +| **Performance** | Allocation minimisation, cache-friendly access, pass count, hot-path awareness, algorithmic complexity | [plan-speed](../plan-speed/SKILL.md) | +| **Usability** | Correctness, clarity, naming, WF precision, error message quality, one-concept-per-pass, API ergonomics | [plan-usability](../plan-usability/SKILL.md) | +| **Conservative** | Smallest diff, backwards compatibility, API stability, reuse, no speculative generality, blast radius | [plan-conservative](../plan-conservative/SKILL.md) | + +## Procedure + +### Step 1: Identify the Diff + +Determine the commit range or branch diff to review. + +```bash +# Example: changes since last release tag +git diff --stat v1.2.0..HEAD +``` + +Group changed files by subsystem (parser, builtins, VM, C API, build system, +wrappers) to assign review focus areas. + +### Step 2: Launch Four Review Subagents + +Spawn four Explore subagents **in parallel**, one per perspective. Each +subagent receives: + +1. The same list of changed files and feature summary +2. The perspective-specific review lens (from the table above) +3. Specific files to examine based on the subsystem grouping +4. Instructions to classify findings by severity and provide file/line references + +**Prompt template for each subagent:** + +> You are performing a {PERSPECTIVE}-focused code review of rego-cpp. +> The changes add: {FEATURE_SUMMARY}. +> +> Your review lens: **{LENS_DESCRIPTION}** +> +> THOROUGHNESS: thorough +> +> Please examine these files and report findings: +> {FILE_LIST_WITH_SPECIFIC_QUESTIONS} +> +> For each finding, classify severity as {SEVERITY_SCALE} and provide the +> file path and approximate line numbers. Return a structured report. + +Severity scales per perspective: +- **Security**: CRITICAL / HIGH / MEDIUM / LOW / INFO +- **Performance**: HIGH / MEDIUM / LOW impact +- **Usability**: CONCERN / SUGGESTION / POSITIVE +- **Conservative**: BREAKING / HIGH-RISK / MEDIUM-RISK / LOW-RISK / OK + +### Step 3: Verify Key Findings + +After collecting all four reports, identify the highest-severity findings and +**spot-check them against source code**. Launch a verification subagent: + +> For each claim below, read the relevant code and report whether the claim +> is CONFIRMED, PARTIALLY CONFIRMED, or REFUTED. Provide exact code evidence. +> {LIST_OF_CLAIMS_TO_VERIFY} + +This step prevents false positives from propagating into the final report. +Mark any unverifiable claims as such. + +### Step 4: Synthesise the Report + +Produce a unified report with these sections: + +#### Convergence +Where two or more perspectives agree on the same finding. High convergence +indicates high confidence. + +#### Findings by Severity +A single table combining all verified findings, normalised to a unified +severity scale: + +| Unified Severity | Mapping | +|-----------------|---------| +| CRITICAL / HIGH | Security CRITICAL/HIGH, Performance HIGH, Usability CONCERN (correctness bug), Conservative BREAKING | +| MEDIUM | Security MEDIUM, Performance MEDIUM, Usability CONCERN (non-correctness), Conservative HIGH-RISK | +| LOW | Security LOW, Performance LOW, Usability SUGGESTION, Conservative MEDIUM-RISK | + +Each finding gets: number, description, originating perspective(s), verification +status, file path and line references. + +#### Positive Highlights +Things the code does well, called out by any perspective. This provides +balanced feedback and reinforces good patterns. + +#### Recommendations +Ordered by priority. Split into: +- **Before release**: correctness bugs, UB, security issues +- **After release**: performance optimisation, tech debt, hardening + +#### Trade-offs +Where perspectives conflict (e.g., security wants more validation, performance +wants less overhead), state the conflict and the recommended resolution. + +### Step 5: Calibrate Against Existing Test Coverage + +Before finalising recommendations, check whether existing test suites +(OPA conformance tests, regocpp.yaml, fuzzer) already cover the flagged +scenarios. The OPA test suite is comprehensive — findings about "missing +test coverage" must be verified against: + +```bash +# List OPA test suites +ls build/opa/v1/test/cases/testdata/v1/ + +# Check specific suite coverage +grep 'note:' build/opa/v1/test/cases/testdata/v1/{suite}/*.yaml +``` + +Drop or downgrade recommendations that duplicate existing OPA coverage. + +## Output Format + +The final report should be a structured markdown document (presented in chat, +not saved to a file unless requested) with the sections described in Step 4. + +## Reference + +- [Example remediation plan from v1.3.0 review](./references/v1.3.0-remediation-plan.md) — + a concrete example of findings and the resulting fix plan. diff --git a/.github/skills/opa-compat-check/SKILL.md b/.github/skills/opa-compat-check/SKILL.md new file mode 100644 index 00000000..724a7769 --- /dev/null +++ b/.github/skills/opa-compat-check/SKILL.md @@ -0,0 +1,207 @@ +--- +name: opa-compat-check +description: 'Check OPA Rego version compatibility for rego-cpp. Use when: updating OPA version, checking for new OPA releases, auditing rego-cpp compatibility, planning OPA upgrade work, reviewing OPA release notes for rego-cpp impact. Fetches OPA release notes, compares with current rego-cpp support, and produces an actionable compatibility report.' +argument-hint: 'Optional: specific OPA version to check against (e.g. "1.9.0"). Omit to check latest.' +--- + +# OPA Rego Compatibility Check + +Determine what changes (if any) rego-cpp needs to maintain compatibility with the latest OPA Rego release. + +## When to Use + +- Checking if a new OPA version has been released since rego-cpp's last update +- Planning work to upgrade rego-cpp to a newer OPA version +- Auditing the current compatibility gap between rego-cpp and OPA +- Reviewing what changed in OPA that affects rego-cpp + +## Procedure + +Follow these steps in order. Do not skip steps. + +### Step 1: Determine Current rego-cpp OPA Version + +Read the `REGOCPP_OPA_VERSION` variable from the root `CMakeLists.txt`: + +``` +grep "REGOCPP_OPA_VERSION" CMakeLists.txt +``` + +This is the OPA version rego-cpp currently targets. Record it as `CURRENT_VERSION`. + +Also read the `VERSION` file in the repo root to get the rego-cpp library version. + +### Step 2: Fetch Latest OPA Release Information + +Fetch the OPA releases page to find the latest version and release notes: + +1. Fetch `https://github.com/open-policy-agent/opa/releases` with query "latest release version" +2. Identify the latest stable release version. Record it as `LATEST_VERSION`. +3. If the user specified a target version, use that instead of the latest. + +If `LATEST_VERSION` equals `CURRENT_VERSION`, report that rego-cpp is up to date and stop. + +### Step 3: Fetch Release Notes for Each Intermediate Version + +For each OPA version between `CURRENT_VERSION` (exclusive) and `LATEST_VERSION` (inclusive): + +1. Fetch `https://github.com/open-policy-agent/opa/releases/tag/v{VERSION}` with query "release notes changes new features built-ins deprecations breaking changes Rego language" +2. Also fetch `https://www.openpolicyagent.org/docs/latest/policy-reference/` with query "built-in functions list" to get the current built-in function catalog. + +Collect all changes across versions. + +### Step 4: Categorize Changes by Impact Area + +Sort every change into the categories below. Refer to [change-surface.md](./references/change-surface.md) for details on how each category maps to rego-cpp code. + +| Category | What to Look For | rego-cpp Impact | +|----------|-----------------|-----------------| +| **New Built-in Functions** | New entries in OPA's built-in function list | Add implementation in `src/builtins/` | +| **Modified Built-in Semantics** | Changed behavior of existing built-ins | Update existing builtin impl | +| **Deprecated Built-ins** | Functions marked deprecated | Update `is_deprecated()` in `src/builtins.cc` | +| **Removed Built-ins** | Functions removed entirely | Remove from builtin registry | +| **Language Syntax Changes** | New keywords, operators, or grammar rules | Update parser in `src/parse.cc`, tokens in `include/rego/rego.hh` | +| **Evaluation Semantics** | Changes to how policies are evaluated | Update compiler passes or VM in `src/virtual_machine.cc` | +| **Bundle Format Changes** | Changes to OPA bundle structure | Update `src/bundle.cc`, `src/bundle_json.cc`, `src/bundle_binary.cc` | +| **Conformance Test Changes** | New or modified OPA test cases | Automatically picked up when version is bumped | +| **No rego-cpp Impact** | Go runtime changes, CLI changes, server changes, plugin API changes | Document as not applicable | + +### Step 5: Cross-reference with rego-cpp Built-in Coverage + +For any new or modified built-ins, check whether rego-cpp already implements them: + +1. Search `src/builtins/` for the function name +2. Check `src/builtins.cc` lookup dispatch to see if the namespace is routed +3. Check `src/builtins/builtins.hh` for the namespace factory declaration + +Mark each new built-in as one of: +- **Already implemented** — no action needed +- **Namespace exists, function missing** — add to existing file +- **New namespace** — new file needed in `src/builtins/`, plus registration + +See [builtin-pattern.md](./references/builtin-pattern.md) for the implementation pattern. + +### Step 6: Produce the Compatibility Report + +Generate a structured report with the following sections: + +```markdown +# OPA Rego Compatibility Report + +## Version Summary +- **rego-cpp version**: {from VERSION file} +- **Current OPA target**: {CURRENT_VERSION} +- **Latest OPA release**: {LATEST_VERSION} +- **Versions to bridge**: {list of intermediate versions} + +## Required Changes + +### New Built-in Functions +For each new built-in: +- Function signature (name, args, return type) +- Which `src/builtins/` file to modify or create +- Complexity estimate (trivial/moderate/complex) + +### Built-in Semantic Changes +For each changed built-in: +- What changed and how +- Which file to modify + +### Deprecations +For each deprecated item: +- Add to `is_deprecated()` list in `src/builtins.cc` + +### Language Changes +For each syntax/grammar change: +- What changed in the grammar +- Parser modifications needed +- New tokens needed in `include/rego/rego.hh` +- Documentation updates needed in `README.md` grammar section + +### Evaluation Changes +For each evaluation behavior change: +- Description of change +- Affected compiler passes or VM behavior + +### Bundle Format Changes +For each format change: +- What changed +- Files to update + +## Version Bump Checklist +- [ ] Update `REGOCPP_OPA_VERSION` in `CMakeLists.txt` +- [ ] Update OPA version reference in `README.md` +- [ ] Sync `README.md` grammar section with OPA grammar +- [ ] Implement new built-ins (list each) +- [ ] Update `is_deprecated()` for newly deprecated built-ins +- [ ] Apply parser changes (if any) +- [ ] Apply evaluation changes (if any) +- [ ] Run OPA conformance tests: `cmake --preset debug-clang-opa && ninja -C build && ctest --test-dir build -R opa` +- [ ] Update `CHANGELOG` with new version entry +- [ ] Update `VERSION` file + +## No Impact (documented for completeness) +- List of OPA changes that don't affect rego-cpp +``` + +### Step 7: Analyze OPA IR for Complex Features + +For any **language syntax changes** or **new internal built-ins** identified in Step 4, use OPA's plan IR to understand the implementation before coding: + +1. **Ensure latest OPA binary** is available: + ```bash + curl -L -o /tmp/opa https://github.com/open-policy-agent/opa/releases/download/v{LATEST_VERSION}/opa_linux_amd64_static + chmod +x /tmp/opa + /tmp/opa version # verify it matches LATEST_VERSION + ``` + +2. **Create minimal test policies** in `.copilot/opa-ir-test/`, one per feature: + ```bash + mkdir -p .copilot/opa-ir-test + cat > .copilot/opa-ir-test/policy.rego << 'EOF' + package test + p := + EOF + ``` + +3. **Build and inspect the IR**: + ```bash + /tmp/opa build --bundle .copilot/opa-ir-test --target plan -e test/p -o .copilot/opa-ir-test/bundle.tar.gz + mkdir -p .copilot/opa-ir-test/output && cd .copilot/opa-ir-test/output + tar xzf ../bundle.tar.gz && python3 -m json.tool plan.json + ``` + +4. **Look for**: + - New entries in `static.builtin_funcs` — these are internal built-ins rego-cpp must implement + - The exact calling convention (arg types, arity, return type) + - How undefined values are handled (e.g., `BlockStmt` + `Set` wrapping patterns) + - Whether OPA's optimizer folds constant cases differently from variable cases (test both) + +5. **Record findings** in the compatibility report under the relevant section. + +See [change-surface.md § Analyzing OPA's IR](./references/change-surface.md) for detailed examples. + +### Step 8: Validate Findings + +If the workspace has a build directory and the OPA test suite available: + +1. **Remove the stale OPA clone** before reconfiguring: `rm -rf build/opa` +2. **Reconfigure**: `cmake .. --preset debug-clang-opa` (this re-clones at the new version tag) +3. **Run targeted tests** for specific areas during development: + ```bash + ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/ + ``` + Subdirectory names match OPA built-in names without separators (e.g., `regexfind`, `numbersrangestep`, `stringinterpolation`). +4. **Run the full conformance suite** for final validation: + ```bash + ctest -R rego_test_opa --output-on-failure + ``` + +## Important Notes + +- OPA release notes are at `https://github.com/open-policy-agent/opa/releases` +- OPA built-in reference is at `https://www.openpolicyagent.org/docs/latest/policy-reference/` +- The OPA conformance test suite is cloned from `https://github.com/open-policy-agent/opa/` at the tag matching `REGOCPP_OPA_VERSION` +- rego-cpp only implements the Rego **language** and **built-in functions**. Changes to OPA's Go runtime, REST API, plugin system, or CLI tool do not apply. +- Some built-ins may be platform-dependent (e.g., time zone functions require `cpp_lib_chrono >= 201907L`). Flag these in the report. +- `http.send` is a stub in rego-cpp — new HTTP-related functionality is typically marked as a placeholder. diff --git a/.github/skills/opa-compat-check/references/builtin-pattern.md b/.github/skills/opa-compat-check/references/builtin-pattern.md new file mode 100644 index 00000000..3a1b43cc --- /dev/null +++ b/.github/skills/opa-compat-check/references/builtin-pattern.md @@ -0,0 +1,179 @@ +# Built-in Function Implementation Pattern + +When OPA adds new built-in functions, rego-cpp must provide matching implementations. This document shows the exact patterns used. + +## File Structure + +Each OPA namespace has a dedicated file in `src/builtins/`: +``` +src/builtins/ +├── builtins.hh # Factory declarations +├── array.cc # array.concat, array.reverse, array.slice +├── strings.cc # strings.*, concat, contains, etc. +├── ... # one file per namespace +``` + +## Implementation Pattern + +A built-in consists of three parts: behavior function, factory function, and namespace dispatch. + +### 1. Behavior Function (anonymous namespace) + +```cpp +#include "builtins.hh" +#include "rego.hh" + +namespace +{ + using namespace rego; + namespace bi = rego::builtins; + + Node my_func(const Nodes& args) + { + // Unwrap and validate arguments + Node x = unwrap_arg(args, UnwrapOpt(0).func("namespace.my_func").type(Array)); + if (x->type() == Error) + { + return x; + } + + // Implement the built-in logic + // Return an AST node (Array, Object, Set, JSONString, Int, Float, True, False, Null) + return result_node; + } +} +``` + +### 2. Factory Function (anonymous namespace) + +```cpp +namespace +{ + BuiltIn my_func_factory() + { + const Node my_func_decl = bi::Decl + << (bi::ArgSeq + << (bi::Arg + << (bi::Name ^ "x") + << (bi::Description ^ "description of x") + << (bi::Type << bi::Any))) + << (bi::Result + << (bi::Name ^ "result") + << (bi::Description ^ "description of result") + << (bi::Type << bi::Any)); + return BuiltInDef::create({"namespace.my_func"}, my_func_decl, my_func); + } +} +``` + +### 3. Namespace Dispatch Function + +In the same file, inside `namespace rego::builtins`: + +```cpp +namespace rego +{ + namespace builtins + { + BuiltIn my_namespace(const Location& name) + { + assert(name.view().starts_with("my_namespace.")); + std::string_view view = name.view().substr(13); // skip "my_namespace." + if (view == "my_func") + { + return my_func_factory(); + } + // ... more functions in this namespace + return nullptr; // unknown function + } + } +} +``` + +## Type Constants for Declarations + +Available types for `bi::Type`: +- `bi::Any` — any type +- `bi::String`, `bi::Number`, `bi::Boolean`, `bi::Null` — scalars +- `bi::DynamicArray << (bi::Type << bi::Any)` — array of any +- `bi::StaticArray << (bi::Type << bi::String) << ...` — fixed-type array +- `bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)` — object +- `bi::Set << (bi::Type << bi::Any)` — set + +## Argument Unwrapping + +```cpp +// By position and expected type: +Node arg = unwrap_arg(args, UnwrapOpt(0).func("name.func").type(ExpectedType)); + +// Multiple accepted types: +Node arg = unwrap_arg(args, UnwrapOpt(0).func("name.func").types({Array, Set})); +``` + +## Important: AST Node Wrapping + +Array/Object/Set children are wrapped in Term nodes. **Never** compare `child->type()` directly: +```cpp +// WRONG — will not match because child is a Term wrapping an Array +if (child->type() == Array) { ... } + +// CORRECT — unwrap handles the Term wrapper +auto maybe = unwrap(child, Array); +if (maybe.success) { ... use maybe.node ... } +``` +Read the well-formedness definitions in each pass to understand the node structure. + +## Adding to an Existing Namespace + +1. Add behavior function and factory in the namespace's `.cc` file +2. Add a new `if` branch in the namespace dispatch function + +## Adding a New Namespace + +1. Create `src/builtins/.cc` with the pattern above +2. Add declaration to `src/builtins/builtins.hh`: + ```cpp + BuiltIn my_namespace(const Location& name); + ``` +3. Add routing in `src/builtins.cc` `BuiltInsDef::lookup()` dispatch tree +4. Add source file to `src/CMakeLists.txt` + +## Placeholder for Unsupported Built-ins + +For built-ins that intentionally cannot be supported: + +```cpp +BuiltIn http(const Location& name) +{ + // ... + return BuiltInDef::placeholder( + name, decl, "http.send is not supported"); +} +``` + +## Internal / Compiler-Generated Built-ins + +Some OPA built-ins are not user-facing — they are emitted by the compiler during +desugaring (e.g., `internal.template_string` for `$"..."` template strings). +These follow the same registration pattern but have distinct characteristics: + +- **Name prefix**: `internal.` — routed through the `internal` namespace dispatch +- **Not in OPA's public built-in docs**: Discovered by inspecting OPA's IR plan output +- **Called from compiler-generated code only**: The desugaring pass in `src/rego_to_bundle.cc` (or a dedicated pass) emits `ExprCall` nodes that reference these functions +- **Argument conventions may differ**: Internal built-ins may receive pre-processed arguments (e.g., arrays with sentinel values like empty sets for undefined) + +### Investigating Internal Built-ins + +Use OPA's plan IR to discover the exact calling convention: +```bash +/tmp/opa build --bundle --target plan -e -o bundle.tar.gz +tar xzf bundle.tar.gz && python3 -m json.tool plan.json +``` + +Look at `static.builtin_funcs` for the declaration and `funcs.funcs[].blocks` for actual call sites. + +### Example: `internal.template_string` +- **Signature**: `internal.template_string(array[any]) -> string` +- **Array contents**: Interleaved literal string chunks and expression values +- **Undefined encoding**: Potentially-undefined expressions are wrapped in a set at the IR level (empty set = undefined → produces `""` in output; set with one element = defined value) +- **Stringification**: Each non-string element is stringified (JSON-like representation); strings are used raw (not quoted) diff --git a/.github/skills/opa-compat-check/references/change-surface.md b/.github/skills/opa-compat-check/references/change-surface.md new file mode 100644 index 00000000..3389266f --- /dev/null +++ b/.github/skills/opa-compat-check/references/change-surface.md @@ -0,0 +1,156 @@ +# rego-cpp Change Surface for OPA Compatibility + +When OPA releases a new version, the following areas of rego-cpp may need updates. This document maps OPA change types to specific files and patterns in the rego-cpp codebase. + +## 1. Built-in Functions + +### Files +- `src/builtins/*.cc` — Individual built-in implementations, one file per OPA namespace +- `src/builtins/builtins.hh` — Factory function declarations for each namespace +- `src/builtins.cc` — Lookup dispatch tree and `BuiltInsDef` manager (including `is_deprecated()`) + +### How Built-ins Are Resolved +The `BuiltInsDef::lookup()` function in `src/builtins.cc` uses a binary dispatch tree keyed on the namespace prefix (text before the first `.`). When a new namespace is added, this dispatch tree must be extended. The tree is hand-coded, not auto-generated. + +### Adding a New Built-in to an Existing Namespace +1. Add the implementation function in the appropriate `src/builtins/.cc` +2. Register it in the namespace's factory function (the function declared in `builtins.hh`) +3. No changes to the dispatch tree needed + +### Adding a New Namespace +1. Create `src/builtins/.cc` +2. Add factory declaration to `src/builtins/builtins.hh` +3. Add routing entry in the lookup dispatch tree in `src/builtins.cc` +4. Add the source file to `src/CMakeLists.txt` + +### Deprecating a Built-in +Add the function name to the `deprecated` vector in `BuiltInsDef::is_deprecated()` in `src/builtins.cc`. + +### Marking a Built-in as Unavailable +Use `BuiltInDef::placeholder()` to create an entry that returns a descriptive error message without implementing the function. This is used for built-ins that cannot be supported (e.g., `http.send` requires network access). + +## 2. Parser / Language Syntax + +### Files +- `src/parse.cc` — Rego lexer and parser (Trieste-based) +- `include/rego/rego.hh` — AST token type definitions (`TokenDef` globals) + +### What Triggers Parser Changes +- New keywords (e.g., `every`, `contains`, `in` were added historically) +- New operators +- Grammar rule changes (e.g., new expression forms) +- Changes to import syntax + +### Token Definition Pattern +New tokens are added as `inline const auto` globals in `include/rego/rego.hh`: +```cpp +inline const auto NewToken = TokenDef("rego-newtoken", flag::print); +``` + +## 3. Evaluation / Compiler Passes + +### Files +- `src/interpreter.cc` — Pass pipeline management +- `src/virtual_machine.cc` — VM execution engine +- `src/resolver.cc` — Variable resolution and unification +- `src/dependency_graph.cc` — Rule dependency analysis + +### What Triggers Evaluation Changes +- Changes to how partial evaluation works +- Changes to conflict resolution between rules +- New evaluation capabilities (e.g., new comprehension types) +- Changes to the `with` keyword behavior + +## 4. Bundle Format + +### Files +- `src/bundle.cc` — Bundle loading orchestration +- `src/bundle_json.cc` — JSON bundle format +- `src/bundle_binary.cc` — Binary bundle format (rego-cpp specific) +- `src/rego_to_bundle.cc` — Compilation to bundle format + +### What Triggers Bundle Changes +- Changes to OPA's bundle manifest format +- New metadata fields in bundles +- Changes to the wasm/plan format (rego-cpp uses its own VM, but tracks format) + +## 5. Version and Test Infrastructure + +### Files +- `CMakeLists.txt` — `REGOCPP_OPA_VERSION` variable, OPA repo clone +- `README.md` — Version compatibility statement and EBNF grammar +- `CHANGELOG` — Release history +- `VERSION` — rego-cpp semantic version +- `tests/CMakeLists.txt` — Test suite configuration + +### Conformance Tests +OPA tests are automatically cloned from the OPA repo at the tag matching `REGOCPP_OPA_VERSION`. When the version is bumped: +- Must `rm -rf build/opa` first — CMake only clones if the directory doesn't exist +- New tests are picked up automatically +- Tests requiring unimplemented built-ins are skipped via `all_builtins_available()` +- Platform-specific tests (e.g., time zones) are skipped based on compiler capabilities + +### Running Targeted Tests +The OPA test cases live in subdirectories under `build/opa/v1/test/cases/testdata/v1/`. Run a specific category: +```bash +./tests/rego_test -wf opa/v1/test/cases/testdata/v1/ +``` +Subdirectory names match OPA built-in names with no separators (e.g., `regexfind`, `numbersrangestep`, `stringinterpolation`). Always run targeted tests first during development, then the full suite for final validation. + +## 6. Analyzing OPA's IR for New Features + +When OPA adds a significant new feature (new syntax, new internal built-in), **inspect the IR plan** OPA produces to understand the implementation pattern. This is critical for ensuring rego-cpp's compilation matches OPA's semantics. + +### Setup +```bash +# Download latest OPA binary (always upgrade before analyzing!) +curl -L -o /tmp/opa https://github.com/open-policy-agent/opa/releases/download/v{VERSION}/opa_linux_amd64_static +chmod +x /tmp/opa + +# Use .copilot/ in the repo for scratch files +mkdir -p .copilot/opa-ir-test +``` + +### Build and Inspect IR +```bash +# Create a minimal policy exercising the feature +cat > .copilot/opa-ir-test/policy.rego << 'EOF' +package test +p := +EOF + +# Build IR plan +/tmp/opa build --bundle .copilot/opa-ir-test --target plan -e test/p -o .copilot/opa-ir-test/bundle.tar.gz + +# Extract and inspect +mkdir -p .copilot/opa-ir-test/output && cd .copilot/opa-ir-test/output +tar xzf ../bundle.tar.gz +python3 -m json.tool plan.json +``` + +### What to Look For +- **New internal built-in names** in `static.builtin_funcs` (e.g., `internal.template_string`) +- **Calling convention**: argument types, arity, return type +- **Undefined handling patterns**: OPA often wraps potentially-undefined expressions in `BlockStmt` + `Set` patterns (empty set = undefined, set{value} = defined) +- **Constant folding**: OPA's optimizer may fold constant expressions into a single value — test with both constant and variable expressions to see the unoptimized IR + +### Example: String Interpolation (`internal.template_string`) +OPA compiles `$"hello {expr} world"` into: +1. `MakeArrayStmt` with capacity = number of chunks + expressions +2. `ArrayAppendStmt` for literal text chunks (as string_index operands) +3. For each `{expr}`: + - If potentially undefined: wrap evaluation in `BlockStmt` + `MakeSetStmt`/`SetAddStmt`, then append the set + - If constant: evaluate and `ArrayAppendStmt` the value directly +4. `CallStmt` to `internal.template_string` with the array as sole argument + +## 7. Typical No-Impact Changes in OPA + +These OPA changes do NOT affect rego-cpp: +- Go runtime performance improvements +- OPA REST API / server changes +- Plugin system / discovery changes +- OPA CLI tool changes (flags, subcommands) +- Logging / telemetry changes +- Wasm compiler changes (rego-cpp has its own VM) +- OPA Docker image changes +- OPA SDK changes (Go-specific) diff --git a/.github/skills/opa-compat-check/reports/2026-03-22-v1.14.1.md b/.github/skills/opa-compat-check/reports/2026-03-22-v1.14.1.md new file mode 100644 index 00000000..f06dd2a9 --- /dev/null +++ b/.github/skills/opa-compat-check/reports/2026-03-22-v1.14.1.md @@ -0,0 +1,91 @@ +# OPA Rego Compatibility Transition Record — v1.8.0 -> v1.14.1 + +Recorded: 2026-03-22 + +## Version Summary +- **rego-cpp version**: 1.2.0 +- **Previous OPA target**: v1.8.0 +- **Updated OPA target**: v1.14.1 +- **Versions bridged**: v1.9.0, v1.10.0, v1.10.1, v1.11.0, v1.11.1, v1.12.0, v1.12.1, v1.12.2, v1.13.0, v1.13.1, v1.13.2, v1.14.0, v1.14.1 + +## Completed Changes + +### 1. New Built-in Function: `array.flatten` (v1.13.0) + +- **Status**: Implemented +- **Implementation**: Added `array.flatten(arr)` in `src/builtins/array.cc` with recursive flattening and registration in the array built-ins table. +- **Result**: Feature is available and documented in `CHANGELOG`. + +### 2. Language Change: String Interpolation (v1.12.0) — **MAJOR** + +- **Status**: Implemented +- **Implementation**: + - Added `TemplateString` token support in `include/rego/rego.hh` and propagated handling through internal AST/type plumbing. + - Extended parsing in `src/parse.cc` for template-string lexing/parsing, including template expression boundaries and escape-sensitive paths. + - Added pipeline support in `src/file_to_rego.cc`, `src/rego_to_bundle.cc`, and related scalar/string handling. + - Added runtime support for template-string evaluation/composition in VM/opblock paths. +- **Validation run**: `./bin/rego_test -wf ../opa/v1/test/cases/testdata/v1/stringinterpolation/` completed successfully (exit code 0). + +### 3. Language Change: Keywords Allowed in References (v1.6.0) + +- **Status**: Addressed during the v1.14.1 migration via parser/pipeline updates. +- **Note**: No migration blocker remained for this item after the parser work. A dedicated before/after failure artifact was not captured in this record. + +### 4. Built-in Semantic Change: `json.match_schema` Arrays (v1.13.0) + +- **Status**: No implementation change required +- **Current state**: `json.match_schema` remains a documented placeholder with "JSON schema is not supported" behavior. + +### 5. Built-in Semantic Change: `strings.render_template` Error (v1.13.0) + +- **Status**: No implementation change required +- **Current state**: `strings.render_template` remains unsupported/placeholder. + +### 6. Built-in Behavioral Fix: `numbers.range_step` (v1.13.0) + +- **Status**: Implemented/fixed +- **Implementation**: Updated `numbers.range_step` behavior in `src/builtins/numbers.cc` to align with current OPA expectations. +- **Result**: Included in `CHANGELOG` as a behavioral fix. + +## Version Bump Execution Log + +- [x] Updated `REGOCPP_OPA_VERSION` to `1.14.1` in `CMakeLists.txt` +- [x] Updated README OPA support statement to v1.14.1 +- [x] Implemented `array.flatten` +- [x] Implemented template-string support across parser/AST/pipeline/runtime +- [x] Updated `numbers.range_step` behavior +- [x] Updated `CHANGELOG` +- [x] Updated `VERSION` +- [x] Ran targeted OPA suite for string interpolation (`rego_test -wf .../stringinterpolation/`) +- [x] Ran full OPA conformance target: `ctest -R rego_test_opa --output-on-failure` (from `build/`) +- [x] Re-ran full OPA conformance target on a fresh build; result: `100% tests passed, 0 tests failed out of 1` (`rego_test_opa`, 96.68s) + +## No Impact (documented for completeness) + +All of the following OPA changes across v1.9.0–v1.14.1 do **not** affect rego-cpp: + +| Version | Change | Why No Impact | +|---------|--------|---------------| +| v1.9.0 | Compile Rego Queries Into SQL Filters | Go Compile API, not language | +| v1.9.0 | Improved rule indexing for naked refs | Go runtime optimization | +| v1.10.0 | Non-static arm64 binaries | Go build change | +| v1.10.0 | `opa test --fail-on-empty` | CLI tool flag | +| v1.10.1 | `split` infinite loop fix | Go-specific bug | +| v1.11.0 | Immutable releases | GitHub release process | +| v1.11.0 | Concurrent Rego parsing in bundle loader | Go performance | +| v1.11.0 | Custom SemVer implementation | Go internal refactor | +| v1.11.1 | Memory exhaustion via gzip header | OPA server security | +| v1.11.1 | Decision logs dropped fix | OPA server bug | +| v1.12.0 | Context cancellation in builtins | Go runtime concurrency | +| v1.12.1 | `regex.replace` anchor revert | Go-specific behavioral revert | +| v1.12.3 | Bundle polling misconfiguration | OPA server plugin | +| v1.13.0 | Decision Logger immediate trigger | OPA server feature | +| v1.14.0 | Improved rule indexing (var assignments, `x in {...}`) | Go runtime optimization | +| v1.14.0 | Custom storage backend API | Go SDK | +| v1.14.0 | `--h2c` with unix domain socket | CLI tool feature | +| v1.14.1 | Rule indexer revert + dep bumps | Go-specific bug fix | +| All | Go dependency bumps, CI changes, docs, website | Infrastructure | + +## Closing Summary + +The migration to OPA v1.14.1 was executed with code and documentation updates in place, including `array.flatten`, template-string support, and `numbers.range_step` alignment. Full OPA conformance was verified on a fresh build (`rego_test_opa` passing). This file now reflects completed work rather than an implementation plan. diff --git a/.github/skills/plan-conservative/SKILL.md b/.github/skills/plan-conservative/SKILL.md new file mode 100644 index 00000000..45cf1f5d --- /dev/null +++ b/.github/skills/plan-conservative/SKILL.md @@ -0,0 +1,114 @@ +--- +name: plan-conservative +description: > + Conservative planning skill for rego-cpp changes. Produces plans with + the smallest possible changeset, fewest new abstractions, minimal disruption + to existing code, strict backwards compatibility, and maximum reuse of + existing tokens, passes, and patterns. Use this skill when planning code + changes and a minimal-change perspective is needed. +user-invocable: false +--- + +# Conservative Planner + +You are a change-averse planner. Every decision you make must be justified +through the lens of **minimal disruption**. Your plans should produce the +smallest diff that correctly implements the requested change, touching the +fewest files and introducing the fewest new concepts. + +## Core Principles + +1. **Smallest diff wins.** Given two correct approaches, always choose the one + that changes fewer lines, fewer files, and fewer existing abstractions. Every + changed line is a potential regression; every new file is maintenance burden. + +2. **Reuse before creating.** Before introducing a new token, pass, function, or + abstraction, exhaustively check whether an existing one can serve the purpose. + Trieste provides many built-in tokens (`Group`, `Seq`, `Lift`, `Error`) — + use them. Existing passes in the file-to-rego or rego-to-bundle pipeline may + already handle a related transformation and can be extended with one or two + additional rules. + +3. **No speculative generality.** Do not add configuration, parameters, or + abstractions "in case they're needed later." Implement exactly what is asked + for, nothing more. A feature that isn't requested is a feature that doesn't + need to exist. + +4. **Backwards compatibility is sacred.** Public APIs — the C++ API in + `include/rego/rego.hh`, the C API in `include/rego/rego_c.h`, and the + language wrappers in `wrappers/` — must not change in ways that break + existing users. If a breaking change is unavoidable, flag it explicitly and + explain why no non-breaking alternative exists. + +5. **Prefer extending over replacing.** WF specs are designed for incremental + extension with `|`. Add new shapes rather than rewriting existing specs. + Add new rewrite rules to existing passes rather than creating new passes. + +6. **Avoid ripple effects.** A change to a WF spec forces every downstream pass + to be consistent. Prefer changes that affect the fewest downstream specs and + passes. If a new token must be introduced, confine its lifetime to as few + passes as possible. + +7. **One concern at a time.** Do not bundle cleanup, refactoring, or + improvements with the requested change. If existing code is messy but + functional, leave it alone. The goal is to implement the request, not to + improve the neighbourhood. + +8. **Preserve existing patterns.** If the surrounding code uses a particular + idiom (e.g. `dir::topdown`, anonymous namespace for pass functions, specific + error message style), follow it exactly — even if you know a "better" way. + Consistency with neighbours beats local optimality. + +9. **Measure the blast radius.** For every proposed step, state how many files + it touches and whether it changes any public interface. If a step touches + more than two files, consider whether it can be split or simplified. + +10. **OPA conformance is a constraint, not a goal.** Only implement what is + needed to pass the relevant OPA conformance tests. Do not add OPA + compatibility features that are not tested or requested. + +## Planning Output Format + +Produce a numbered plan with: + +- **Goal**: one-sentence summary. +- **Blast radius**: total files modified, total files created, any public API + changes (ideally zero). +- **Steps**: numbered list of changes, each with the file path and a description + of what changes. For each step, state the **line count delta** (approximate + lines added / removed). +- **Reuse inventory**: existing tokens, passes, and helpers that are reused + instead of creating new ones, with justification. +- **Rejected alternatives**: approaches that were considered but rejected because + they had a larger changeset or more ripple effects. +- **Compatibility**: confirmation that no existing public API is broken, or an + explicit list of breaking changes with justification. + +## rego-cpp-specific Conservative Guidance + +- Before creating a new token, check `include/rego/rego.hh` for existing tokens + that might already serve the purpose, and check Trieste's built-in tokens + (`Group`, `Seq`, `Lift`, `Error`, `Top`). +- Before creating a new pass, check whether an existing pass in + `src/file_to_rego.cc` or `src/rego_to_bundle.cc` can absorb the new rewrite + rules. Adding three rules to an existing pass is cheaper than adding a new + pass with its own WF spec. +- WF spec changes propagate: changing a token's children in one spec may require + updates to every subsequent spec. Prefer adding optional children (`~Token`) + or extending choice sets over restructuring. +- If the change only affects one stage of the pipeline, only modify that stage's + files. Do not "clean up" adjacent stages. +- Prefer `dir::once` when it suffices — it avoids introducing a fixed-point loop + that might interact unexpectedly with existing rules in the same pass. +- Built-in functions can often be added to an existing namespace file in + `src/builtins/` without touching any other file besides the registration in + `src/builtins.cc`. This is the ideal blast radius for a new built-in. +- When adding OPA conformance, run only the specific subdirectory test rather + than the full OPA suite: + ```bash + cd build && ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/ + ``` +- Avoid modifying `src/virtual_machine.cc` unless new opcodes are strictly + required. VM changes have the widest blast radius in the project. +- The C API wrapper in `src/rego_c.cc` should only change when the C API + header changes. Do not add C API surface area for internal features. diff --git a/.github/skills/plan-security/SKILL.md b/.github/skills/plan-security/SKILL.md new file mode 100644 index 00000000..893c9287 --- /dev/null +++ b/.github/skills/plan-security/SKILL.md @@ -0,0 +1,116 @@ +--- +name: plan-security +description: > + Security-focused planning skill for rego-cpp changes. Produces plans + that prioritise defence in depth, safe memory handling, bounded resource + consumption, robust error representation, thorough fuzz coverage, and + resistance to adversarial inputs. Use this skill when planning code changes + and a security-oriented perspective is needed. +user-invocable: false +--- + +# Security Planner + +You are a security-obsessed planner. Every decision you make must be justified +through the lens of **defensive correctness**. Your plans should produce code +that is resilient to malformed, malicious, and adversarial Rego policies, JSON +data, and bundle inputs, and that fails safely when invariants are violated. + +## Core Principles + +1. **Validate at every boundary.** Any data entering the system — Rego source + text, JSON/YAML data documents, bundle files, AST nodes from a prior pass, + user-supplied options via the C/C++ API — must be validated before use. + Never trust the shape of an AST node without WF confirmation. + +2. **Bound all resource consumption.** Recursive descent, pattern expansion, and + fixed-point iteration can all diverge on crafted inputs. Every loop and + recursion must have an explicit or structural bound. Prefer `dir::once` or + bounded iteration counts when unbounded rewriting is not necessary. + +3. **Fail safely with Error nodes.** When an invariant is violated, emit an + `Error << (ErrorMsg ^ "description") << (ErrorAst << node)` rather than + crashing, asserting, or silently producing a wrong tree. Error nodes are + exempt from WF checks and propagate cleanly. + +4. **Memory safety by construction.** Use Trieste's intrusive reference counting + (`Node`) consistently. Never hold raw pointers to nodes across rewrite + boundaries — the tree may be mutated. Avoid iterator invalidation by not + modifying a child vector while iterating over it. Run AddressSanitizer + (`cmake --preset asan-clang`) as a standard validation step. + +5. **Minimise attack surface.** Expose only the tokens, passes, and APIs that + are necessary. Keep internal passes in anonymous namespaces. Avoid + `flag::lookup` / `flag::lookdown` unless symbol resolution is genuinely + required — each widens the scope of what an adversarial input can reference. + +6. **C API boundary safety.** The C API (`include/rego/rego_c.h`, + `src/rego_c.cc`) is a trust boundary — callers may pass null pointers, + invalid handles, or out-of-bounds indices. Every C API function must + validate its inputs before forwarding to the C++ layer. + +7. **Bundle input validation.** Bundle loading (`src/bundle.cc`, + `src/bundle_binary.cc`, `src/bundle_json.cc`) processes untrusted external + data. Validate bundle structure, file sizes, and nesting depth before + parsing contents. Reject malformed bundles with clear error messages. + +8. **Regex safety.** RE2 is safe by design (no backtracking), but overly broad + patterns can still match unintended input. Anchor patterns where possible and + use word boundaries (`\b`) to prevent partial matches leaking through. The + `regex.*` built-ins in `src/builtins/regex.cc` should reject patterns that + RE2 cannot safely compile. + +9. **Fuzz-test everything.** Every new pass must be covered by WF-driven fuzz + testing (`rego_fuzzer`). If a change alters a WF spec, verify that the fuzzer + still generates meaningful inputs. Run with `-c 1000` across three different + seeds: + ```bash + ./build/tools/rego_fuzzer file_to_rego -c 1000 -f + ./build/tools/rego_fuzzer rego_to_bundle -c 1000 -f + ``` + +10. **Principle of least privilege.** A pass should only read/write the tokens it + declares in its WF spec. If a pass does not need symbol tables, do not mark + tokens with `flag::symtab`. If a pass does not need to see the entire tree, + restrict its `In()` context. + +11. **Audit trail.** When a plan introduces new error paths, document what + triggers them, what the user sees, and how the error can be resolved. Error + messages in built-in functions must match OPA's reference implementation + exactly — conformance tests compare error strings literally. + +## Planning Output Format + +Produce a numbered plan with: + +- **Goal**: one-sentence summary. +- **Threat model**: which classes of bad input or misuse this change must handle. +- **Steps**: numbered list of changes, each with the file path and a description + of what changes and *how it defends against the identified threats*. +- **Error handling**: for each new code path, describe the error node produced + and what triggers it. +- **Fuzz coverage**: which WF specs are new or changed, and confirmation that + the fuzzer will exercise them. +- **Residual risks**: anything that is *not* defended against and why (e.g. + "denial of service via 100 GB input is out of scope"). + +## rego-cpp-specific Security Guidance + +- `flag::defbeforeuse` prevents forward-reference attacks in symbol tables — use + it when definition order matters. +- `flag::shadowing` limits lookup scope — use it to prevent inner scopes from + accidentally resolving to outer definitions. +- WF specs are the primary safety net: a tight WF spec after every pass ensures + that no malformed tree shape survives into later processing stages. +- The `post()` hook on a `PassDef` is an ideal place for global invariant checks + that individual rewrite rules cannot enforce. +- Never silently drop nodes — either rewrite them into valid output or wrap them + in `Error`. Silent drops can hide injection of unexpected structure. +- Built-in functions that process strings (`src/builtins/encoding.cc`, + `src/builtins/regex.cc`, `src/builtins/jwt.cc`) must handle malformed input + gracefully. Use `json::unescape()` when processing `get_string()` values, as + raw token text contains escape sequences intact. +- The `http.send` built-in (`src/builtins/http.cc`) is an SSRF risk — it must + not be enabled without explicit user opt-in and must validate URLs. +- Cryptographic built-ins (`src/builtins/crypto.cc`, `src/builtins/jwt.cc`) + must use well-vetted libraries and never implement custom crypto primitives. diff --git a/.github/skills/plan-speed/SKILL.md b/.github/skills/plan-speed/SKILL.md new file mode 100644 index 00000000..d4deac0d --- /dev/null +++ b/.github/skills/plan-speed/SKILL.md @@ -0,0 +1,94 @@ +--- +name: plan-speed +description: > + Performance-focused planning skill for rego-cpp changes. Produces plans + that prioritise runtime speed, low allocation counts, cache-friendly data access, + minimal pass counts, efficient pattern matching, and fast policy evaluation. + Use this skill when planning code changes and a performance-oriented perspective + is needed. +user-invocable: false +--- + +# Speed Planner + +You are a performance-obsessed planner. Every decision you make must be justified +through the lens of **runtime efficiency**. Your plans should produce code that +evaluates Rego policies as fast as possible on real-world inputs. + +## Core Principles + +1. **Algorithmic complexity first.** Always choose the approach with the best + asymptotic complexity. If two designs are equivalent in big-O, prefer the one + with lower constant factors. + +2. **Minimise allocations.** Heap allocations are expensive. Prefer reusing + existing AST nodes over creating new ones. Favour in-place mutation of the AST + when the semantics allow it. Use `Seq` to splice results rather than building + intermediate containers. + +3. **Cache-friendly traversal.** Prefer `dir::topdown` when children are + accessed immediately after the parent, and `dir::bottomup` when results + bubble up. Choose the direction that keeps working-set locality tight. + +4. **Reduce pass count.** Each pass is a full tree traversal. Merge logically + related rewrites into a single pass whenever doing so does not compromise + correctness. Prefer fewer, broader passes over many narrow ones. The + file-to-rego pipeline has 18 passes and rego-to-bundle has 11 — additions + should justify their traversal cost. + +5. **Pattern matching efficiency.** Keep patterns specific — narrow `In()` + contexts and leading `T()` tokens help the dispatch map skip irrelevant + subtrees quickly. Avoid catch-all patterns (`Any++`) at the head of a rule. + +6. **Compile-time computation.** Push work to compile time where possible: + `constexpr` values, static token definitions, template-based dispatch. + +7. **Avoid redundant work.** If a pass can terminate early (e.g. a `dir::once` + pass), say so. If a `pre()` hook can short-circuit an entire subtree, use it. + +8. **Built-in function efficiency.** Built-in functions in `src/builtins/` are + called frequently during evaluation. Avoid unnecessary AST node creation, + string copies, and repeated `unwrap()` calls in hot paths. Cache intermediate + results when a built-in processes collections. + +9. **VM hot path awareness.** The virtual machine (`src/virtual_machine.cc`) + is the innermost evaluation loop. Changes to opblock evaluation, variable + unification, or rule indexing have outsized performance impact. Profile + before and after any VM changes. + +10. **Benchmark-aware.** When proposing a plan, call out which steps have + measurable performance impact and suggest how to validate the improvement + (e.g. "run the OPA conformance suite before and after and compare wall time", + or "evaluate a large policy bundle and measure throughput"). + +## Planning Output Format + +Produce a numbered plan with: + +- **Goal**: one-sentence summary. +- **Steps**: numbered list of changes, each with the file path and a description + of what changes and *why it is fast*. +- **Performance rationale**: a short paragraph at the end explaining the + expected performance characteristics and any trade-offs made for speed. +- **Risks**: anything that could make this slower than expected (e.g. branch + misprediction under certain input distributions, increased compile time). + +## rego-cpp-specific Performance Guidance + +- Rewrite rules that fire frequently should appear early in the rule list so the + dispatcher finds them first. +- Token flags like `flag::symtab` add overhead to every node of that type; only + request them when symbol lookup is genuinely needed. +- `dir::once` avoids fixed-point iteration — use it when a single sweep suffices. +- WF validation is not free; keep WF specs as tight as possible so the validator + can reject malformed trees early without deep inspection. +- Prefer `T(A, B, C)` over `T(A) / T(B) / T(C)` — the multi-token form uses a + bitset check rather than sequential alternatives. +- In the resolver (`src/resolver.cc`), unification is called on every rule + evaluation. Minimize allocations in the unification path. +- Bundle loading (`src/bundle.cc`, `src/bundle_binary.cc`) is a startup cost. + Prefer lazy parsing of bundle components over eager full-tree construction. +- `BigInt` operations (`src/bigint.cc`) can be expensive for large values. + Short-circuit to native integer arithmetic when values fit in 64 bits. +- The dependency graph (`src/dependency_graph.cc`) is built once per module set. + Prefer efficient graph representations (adjacency lists over matrices). diff --git a/.github/skills/plan-usability/SKILL.md b/.github/skills/plan-usability/SKILL.md new file mode 100644 index 00000000..5f8aae7b --- /dev/null +++ b/.github/skills/plan-usability/SKILL.md @@ -0,0 +1,108 @@ +--- +name: plan-usability +description: > + Usability-focused planning skill for rego-cpp changes. Produces plans + that prioritise clear, readable, self-documenting code, consistent naming, + well-structured pass pipelines, precise WF specs, ergonomic APIs, and + correctness above all else. Use this skill when planning code changes and a + clarity-and-correctness perspective is needed. +user-invocable: false +--- + +# Usability Planner + +You are a usability-obsessed planner. Every decision you make must be justified +through the lens of **clarity, correctness, and developer experience**. Your +plans should produce code that is a pleasure to read, easy to extend, and +obviously correct by inspection. + +## Core Principles + +1. **Correctness is non-negotiable.** A change that is unclear or ambiguous is + a change that will eventually be wrong. Prefer designs where the correct + behaviour is the only possible behaviour — use the type system, WF specs, + and Trieste's structural constraints to make illegal states unrepresentable. + +2. **Readable code is maintainable code.** Every token name, variable, function, + and pass should have a name that communicates its purpose without needing a + comment. Follow existing naming conventions (`snake_case` for functions, + `PascalCase` for tokens). If a name requires explanation, choose a better name. + +3. **One concept per pass.** Each pass should do exactly one conceptual + transformation. If a pass description requires "and" to explain, it should + probably be two passes. The small cost in traversal is repaid many times over + in debuggability and testability. + +4. **WF specs as documentation.** A well-written WF spec is the best + documentation of what the AST looks like at each stage. Invest time in making + WF specs precise, well-formatted, and incrementally defined. Align the `|` + operators for visual scanning. + +5. **Consistent patterns.** Mimic the structure of existing passes in the same + pipeline. If neighbouring passes use `dir::topdown`, a new pass should too + unless there is a compelling reason otherwise. If errors are reported with a + specific message style, follow that style. Error messages in built-in + functions must match OPA's reference implementation exactly. + +6. **Explicit over implicit.** Prefer explicit token types over reusing generic + ones. Prefer named captures (`[Id]`, `[Rhs]`) over positional child access. + Prefer spelled-out WF shapes over shorthand that obscures structure. + +7. **Small, composable pieces.** Favour small rewrite rules that each handle one + case clearly over a single rule with complex conditional logic. The rewriting + DSL is designed for this — lean into it. + +8. **API ergonomics.** rego-cpp has three API surfaces: C++ (`rego.hh`), + C (`rego_c.h`), and language wrappers (Rust, Python, .NET). Changes to the + public API should consider how downstream users will discover and use it. + Function signatures should be self-explanatory. The C API must be usable + without knowledge of the C++ internals. + +9. **Test clarity.** When proposing test changes, each YAML test case should + test one feature and have a descriptive name. Prefer many small test cases + over few large ones. Use `tests/regocpp.yaml` for rego-cpp-specific features + and `tests/bugs.yaml` for regression tests. + +10. **Route through the standard pipeline.** When adding a new compound node + type, route its sub-expressions through the existing `Group → Literal → Expr` + pipeline rather than creating a custom parallel path. The standard pipeline + already handles `with`/`as`, `some`, comprehensions, and other features. + Convert specialised tokens to standard types as early as possible. + +## Planning Output Format + +Produce a numbered plan with: + +- **Goal**: one-sentence summary. +- **Design rationale**: why this structure was chosen for clarity and + correctness, and what alternatives were rejected. +- **Steps**: numbered list of changes, each with the file path and a description + of what changes and *how it improves or maintains code clarity*. +- **Naming decisions**: any new tokens, passes, or functions introduced, with + justification for the chosen names. +- **WF spec changes**: the before/after WF shape for affected passes, formatted + for readability. +- **Consistency check**: confirmation that the change follows existing patterns + in the codebase, or justification for diverging. + +## rego-cpp-specific Usability Guidance + +- Token names in `include/rego/rego.hh` use `PascalCase` and are declared as + `inline const auto` globals using Trieste's `TokenDef`. +- Pass functions return `PassDef` and are named descriptively + (e.g. `build_refs()`, `merge_data()`). +- Error messages in `ErrorMsg` should be actionable — tell the user what went + wrong and, if possible, what to do about it. +- The file-to-rego pipeline should read top-to-bottom as a narrative: parse → + group → structure → resolve → validate. When adding a new pass, explain where + it fits in the narrative and why it belongs there. +- Use Trieste's built-in `Lift` and `Seq` tokens for their intended purposes + rather than inventing ad-hoc equivalents. +- Built-in functions in `src/builtins/` are grouped by OPA namespace. A new + built-in belongs in the file matching its OPA namespace (e.g., + `strings.contains` → `src/builtins/strings.cc` if it existed, or the nearest + match). +- The `unwrap()` helper expresses intent better than manual child traversal. + Prefer `unwrap(node, Type)` over `node->front()->front()`. +- YAML test cases are the preferred way to specify expected behaviour. Each case + should have a `note` field that describes what is being tested. diff --git a/.github/skills/rego-fuzzer/SKILL.md b/.github/skills/rego-fuzzer/SKILL.md new file mode 100644 index 00000000..8385f71b --- /dev/null +++ b/.github/skills/rego-fuzzer/SKILL.md @@ -0,0 +1,192 @@ +--- +name: rego-fuzzer +description: 'Pass the rego-cpp Trieste fuzzer for a given pass collection. Use when: verifying that compiler passes are robust to all valid WF inputs, debugging fuzzer failures, fixing generative testing regressions, or validating pass changes against random inputs. The fuzzer generates random ASTs from the Trieste well-formedness chain and checks that each pass handles all structurally valid inputs without crashing or producing malformed output.' +argument-hint: 'Specify the transform to fuzz (file_to_rego, rego_to_bundle, json_to_bundle, bundle_to_json) or say "all" to run all transforms.' +--- + +# Passing the Rego Fuzzer + +Verify that rego-cpp compiler passes are robust to all valid inputs by running the Trieste generative fuzzer. + +## When to Use + +- After modifying or adding a compiler pass +- After changing a well-formedness (WF) definition +- After adding new AST node types or rewrite rules +- When a CI fuzzer run has failed and you need to reproduce and fix the issue +- As a final validation step before merging pass pipeline changes + +## Background + +The `rego_fuzzer` tool uses Trieste's generative testing framework. For each pass in a transform pipeline, it: + +1. Reads the **input well-formedness definition** for that pass +2. Generates random ASTs that are structurally valid according to that WF +3. Runs the pass on each generated AST +4. Checks that the output conforms to the pass's **output well-formedness definition** + +This catches edge cases that hand-written tests miss — any structurally valid input the WF permits can be generated. + +## Transforms + +The fuzzer is parameterized by a **transform**, which is a named collection of passes: + +| Transform | Description | Passes | +|-----------|-------------|--------| +| `file_to_rego` | Parsing through structured AST | 18 passes in `src/file_to_rego.cc` | +| `rego_to_bundle` | Structured AST to executable bytecode | 11 passes in `src/rego_to_bundle.cc` | +| `json_to_bundle` | JSON bundle to internal bundle format | Passes in bundle pipeline | +| `bundle_to_json` | Internal bundle to JSON bundle format | Passes in bundle pipeline | + +## Procedure + +### Step 1: Build the Fuzzer + +The fuzzer binary is built when `REGOCPP_BUILD_TOOLS` is enabled (it is in all standard presets): + +```bash +cd build && ninja rego_fuzzer +``` + +The binary is located at `./build/tools/rego_fuzzer`. + +### Step 2: Determine Which Transforms to Test + +- If the user specified a transform, use that one. +- If the user said "all", test all four: `file_to_rego`, `rego_to_bundle`, `json_to_bundle`, `bundle_to_json`. +- If the user didn't specify, infer from the files they changed: + - Changes in `src/file_to_rego.cc` or `src/parse.cc` → `file_to_rego` + - Changes in `src/rego_to_bundle.cc` → `rego_to_bundle` + - Changes in `src/bundle_json.cc` or `src/bundle.cc` → `json_to_bundle` and `bundle_to_json` + - Changes in `include/rego/rego.hh` (WF definitions) → all transforms + - Changes in `src/internal.hh` → all transforms + +### Step 3: Run the Fuzzer + +For each transform, run the fuzzer **three times** with count 1000. **Do not provide a seed** — the fuzzer picks a random seed each time, ensuring the three runs cover different inputs. (The fuzzer tests seeds sequentially from the starting seed, so providing consecutive seeds like 1, 2, 3 would result in nearly complete overlap.) Use `--failfast` (`-f`) to stop on the first failure in each run. + +```bash +cd build + +# Run 1 +./tools/rego_fuzzer -c 1000 -f + +# Run 2 +./tools/rego_fuzzer -c 1000 -f + +# Run 3 +./tools/rego_fuzzer -c 1000 -f +``` + +**Passing criterion**: all three runs must produce output containing **no** `Failed pass:` lines. **Do not rely on the exit code alone** — the fuzzer may exit 0 even when a pass fails. Always read the tail of the output (e.g., pipe through `tail -5`) and check for `Failed pass:` or `Failed!` text. + +If a run fails, proceed to Step 4 before running additional transforms. + +### Step 4: Diagnose Failures + +When the fuzzer fails, it produces structured output with the following sections (see [references/example-failure.md](./references/example-failure.md) for a complete annotated example): + +``` +Testing x1, seed: 1452196526 + +: unexpected rego-templatestring, expected a rego-STRING, rego-INT, rego-FLOAT, +rego-true, rego-false or rego-null $85 +~~~ +(rego-templatestring) + + +============ +Pass: index_strings_locals, seed: 1452196526 +------------ +(top ...) <-- full input AST (what was fed into the pass) +------------ +(top ...) <-- full output AST (what the pass produced) +============ +Failed pass: index_strings_locals, seed: 1452196526 +``` + +The output structure is: + +1. **Header**: `Testing xN, seed: S` +2. **WF error message**: Describes the well-formedness violation — which node type was found and what types were expected. Includes a node id (`$NN`), an underline (`~~~`), and the offending node shown as `(rego-X)`. +3. **Pass and seed**: `Pass: , seed: ` — identifies which pass failed. +4. **Input AST**: The full AST that was generated and fed **into** the failing pass (between `---` separators). This is the WF-valid input that triggered the bug. +5. **Output AST**: The full AST the pass **produced** (between `---` and `===` separators). Compare this against the pass's output WF to see exactly what's wrong. +6. **Failure summary**: `Failed pass: , seed: ` — the last line, repeating the identification. + +A successful run produces only the header line and exits with code 0: + +``` +Testing x3, seed: 42 +``` + +To reproduce a failure for debugging, re-run with the exact seed and count 1: + +```bash +./tools/rego_fuzzer -c 1 -s +``` + +Add `-l Info` for additional logging if the AST dump is not sufficient: + +```bash +./tools/rego_fuzzer -c 1 -s -l Info +``` + +#### How to Read the Failure + +1. **Start from the error message** at the top — it tells you the node type that violated the output WF and what was expected instead. +2. **Find the offending node in the input AST** — search for that node type in the input dump. This shows how the fuzzer-generated input contains a structurally valid (per the input WF) combination that the pass doesn't handle. +3. **Check the output AST** — the pass left the offending node unchanged or transformed it incorrectly, violating the output WF. +4. **Read the pass's WF definitions** — the input WF tells you what shapes the pass must be prepared to handle; the output WF tells you what shapes it must produce. + +#### Common Failure Categories + +| Symptom | Likely Cause | Fix | +|---------|-------------|-----| +| WF violation after pass X | A rewrite rule in pass X produces output not matching `wf_X` | Add or fix a rewrite rule to handle the input pattern | +| Unhandled node type | A pattern the pass doesn't match but the input WF allows | Add a rewrite rule or error rule for the pattern | +| Crash / assertion failure | Null dereference or out-of-bounds access in a rewrite rule | Add guards or handle the empty-children case | +| Infinite loop (timeout) | Fixpoint pass rules that don't converge | Add `dir::once` or fix the rules so they make progress | + +### Step 5: Fix and Re-verify + +After fixing a failure: + +1. **Re-run with the specific failing seed** to confirm the fix: + ```bash + ./tools/rego_fuzzer -c 1 -s + ``` + +2. **Re-run the full three-pass verification** (Step 3) to ensure no regressions. + +3. **Run the standard test suite** to check the fix didn't break deterministic tests: + ```bash + cd build && ./tests/rego_test -wf tests/regocpp.yaml + ``` + +### Step 6: Report Results + +Summarize the results for each transform: + +``` +Fuzzer results for : + Run 1 (seed 1, count 1000): PASS + Run 2 (seed 2, count 1000): PASS + Run 3 (seed 3, count 1000): PASS +``` + +If any failures were found and fixed, include: +- The failing seed(s) and pass name(s) +- A brief description of the root cause +- What was changed to fix it + +## Tips + +- **Start with a low count** (e.g., `-c 10`) when iterating on a fix to get fast feedback, then scale up to `-c 1000` for the final verification. +- **The seed is deterministic** — the same seed always produces the same random ASTs, making failures reproducible. +- **Error rules are the primary fix** for fuzzer failures. When the fuzzer finds an input your pass doesn't handle, add an error rule that catches the pattern and produces a meaningful `err()` node. This is preferable to trying to handle every exotic WF-valid combination. +- **Read the WF definition** of the failing pass's input — it tells you exactly what shapes the fuzzer might generate. +- **CTest also runs the fuzzer** with the default count (100). To run fuzzer tests via CTest: + ```bash + ctest --test-dir build -R rego_fuzzer + ``` diff --git a/.github/skills/rego-fuzzer/references/example-failure.md b/.github/skills/rego-fuzzer/references/example-failure.md new file mode 100644 index 00000000..b696fb3a --- /dev/null +++ b/.github/skills/rego-fuzzer/references/example-failure.md @@ -0,0 +1,109 @@ +## Example: Fuzzer Failure Output + +This is a real fuzzer failure output from `rego_to_bundle` with seed `1452196526`. +It demonstrates the structure of failure output for diagnosing issues. + +### Command + +```bash +./tools/rego_fuzzer rego_to_bundle -c 1 -s 1452196526 +``` + +### Output + +``` +Testing x1, seed: 1452196526 + +: unexpected rego-templatestring, expected a rego-STRING, rego-INT, rego-FLOAT, +rego-true, rego-false or rego-null $85 +~~~ +(rego-templatestring) + + +============ +Pass: index_strings_locals, seed: 1452196526 +------------ +(top + {} + (rego-bundle + { + $6 = rego-basedocument + $86 = + rego-virtualdocument + rego-virtualdocument + h = rego-function + tegj = rego-function} + (rego-basedocument + {} + (rego-ident 2:$6) + (rego-baseobject + {... symbol table ...} + (rego-baseobjectitem + (rego-ident 3:$17) + (rego-dataterm + (rego-scalar + (rego-true)))) + ... + (rego-baseobjectitem + (rego-ident 3:$82) + (rego-dataterm + (rego-scalar + (rego-templatestring)))))) <-- the offending node + ... + (rego-modulefileseq))) +------------ +(top + {} + (rego-bundle + { + h = rego-function + tegj = rego-function} + (rego-data + (rego-object + ... + (rego-objectitem + (rego-term + (rego-scalar + (rego-STRING 3:$82))) + (rego-term + (rego-scalar + (rego-templatestring)))))) <-- still present in output + ... + (rego-modulefileseq))) +============ +Failed pass: index_strings_locals, seed: 1452196526 +``` + +### Output Structure + +The failure output has this structure: + +1. **Header**: `Testing xN, seed: S` +2. **WF error message**: Describes the well-formedness violation — which node type was + found and what types were expected. The `$NN` is the node id, `~~~` underlines the + error location, and the indented `(rego-X)` shows the offending node. +3. **Separator**: `============` +4. **Pass identification**: `Pass: , seed: ` +5. **Input AST**: The full AST that was fed **into** the failing pass (between `---` separators) +6. **Output AST**: The full AST the pass **produced** (between `---` and `===` separators) +7. **Failure summary**: `Failed pass: , seed: ` + +### How to Read This Example + +- The **error message** says `rego-templatestring` was unexpected inside a `Scalar` — only + `STRING`, `INT`, `FLOAT`, `true`, `false`, and `null` are valid there according to the + output WF of the `index_strings_locals` pass. +- The **input AST** shows where the `rego-templatestring` entered the pass: nested inside + `(rego-dataterm (rego-scalar (rego-templatestring)))` in the base document data. +- The **output AST** shows that the pass propagated the `rego-templatestring` through to + its output unchanged, violating the output WF. +- The **fix** would be to add a rewrite rule or error rule in the `index_strings_locals` + pass to handle the `TemplateString` node type. + +### Successful Output (for comparison) + +A successful run produces only the header line and exits with code 0: + +``` +Testing x3, seed: 42 +``` diff --git a/.github/skills/regocpp-builtins/SKILL.md b/.github/skills/regocpp-builtins/SKILL.md new file mode 100644 index 00000000..88733d3a --- /dev/null +++ b/.github/skills/regocpp-builtins/SKILL.md @@ -0,0 +1,350 @@ +--- +name: regocpp-builtins +description: 'Add, update, or remove OPA Rego built-in functions in rego-cpp. Use when: implementing a new builtin, replacing a placeholder with a real implementation, adding a new OPA builtin namespace, updating builtin declarations to match a new OPA version, removing deprecated builtins, or debugging builtin dispatch/registration. Covers the full lifecycle: declaration, implementation, dispatch registration, CMake wiring, and OPA conformance testing.' +argument-hint: 'Describe which builtin(s) to add, update, or remove.' +--- + +# rego-cpp Built-in Function Development + +Add, update, and remove OPA Rego built-in functions in rego-cpp. + +## When to Use + +- Implementing a new builtin (replacing a placeholder or adding from scratch) +- Adding a new OPA builtin namespace (new `src/builtins/.cc` file) +- Updating builtin declarations to track a new OPA version +- Replacing placeholder stubs with real implementations +- Removing deprecated builtins +- Debugging builtin dispatch or registration issues + +## Architecture Overview + +Built-in functions follow a three-layer architecture: + +``` +BuiltInsDef::lookup(name) ← Dispatch layer (src/builtins.cc) + → builtins::(name) ← Namespace router (src/builtins/.cc) + → _factory() ← Factory (returns BuiltIn with decl + behavior) + → (args) ← Implementation (unwrap args, compute, return) +``` + +### Key Files + +| File | Purpose | +|------|---------| +| `src/builtins/builtins.hh` | Namespace dispatch function declarations | +| `src/builtins.cc` | `BuiltInsDef::lookup` — hand-coded binary dispatch tree | +| `src/builtins/.cc` | One file per OPA namespace (e.g., `crypto.cc`, `jwt.cc`) | +| `src/CMakeLists.txt` | SOURCES list — must include new `.cc` files | +| `include/rego/rego.hh` | Public API — `BuiltIn`, `BuiltInDef`, `UnwrapOpt`, helpers | + +### The Binary Dispatch Tree + +`BuiltInsDef::lookup` in `src/builtins.cc` uses a **hand-coded binary search tree** (generated by `src/builtins/binary_tree.py`) that routes the namespace prefix of a builtin name to the corresponding namespace function. When adding a new namespace, this tree must be regenerated or manually updated. + +Special case: `"io"` prefix routes to `builtins::jwt(name)` for `io.jwt.*` builtins. + +## Procedure + +### Adding a New Builtin to an Existing Namespace + +1. **Read the existing namespace file** (`src/builtins/.cc`) to understand the patterns in use. + +2. **Write the implementation function:** + ```cpp + Node my_func(const Nodes& args) + { + // Unwrap and validate arguments + Node x = unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("namespace.my_func")); + if (x->type() == Error) + return x; + + // Extract values + std::string val = get_string(x); + + // Compute result + std::string result = do_something(val); + + // Return wrapped result + return JSONString ^ result; + } + ``` + +3. **Write the factory function:** + ```cpp + BuiltIn my_func_factory() + { + const Node my_func_decl = bi::Decl + << (bi::ArgSeq + << (bi::Arg << (bi::Name ^ "x") + << (bi::Description ^ "input string") + << (bi::Type << bi::String))) + << (bi::Result << (bi::Name ^ "y") + << (bi::Description ^ "result description") + << (bi::Type << bi::String)); + return BuiltInDef::create({"namespace.my_func"}, my_func_decl, my_func); + } + ``` + +4. **Register in the namespace router** (the public function at the bottom of the file): + ```cpp + BuiltIn namespace_func(const Location& name) + { + // ... existing dispatches ... + if (view == "my_func") + { + return my_func_factory(); + } + return nullptr; + } + ``` + +5. **Run OPA conformance tests:** + ```bash + cd build && ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/ + ``` + +### Adding a New Namespace + +When adding an entirely new OPA namespace (new `.cc` file): + +1. **Create `src/builtins/.cc`** following the pattern of existing files. Include the anonymous namespace for internal functions and the `rego::builtins` namespace for the public dispatch function. + +2. **Declare the dispatch function** in `src/builtins/builtins.hh`: + ```cpp + namespace rego::builtins { + BuiltIn my_namespace(const Location& name); + } + ``` + +3. **Add to the dispatch tree** in `src/builtins.cc` — find the correct position in `BuiltInsDef::lookup` based on the namespace prefix string and add the routing branch. Alternatively, regenerate the tree using `src/builtins/binary_tree.py`. + +4. **Add the source file to `src/CMakeLists.txt`:** + ```cmake + set( SOURCES + # ... existing sources ... + builtins/.cc + ) + ``` + +5. **Rebuild and test.** + +### Replacing a Placeholder with a Real Implementation + +Many builtins are registered as `BuiltInDef::placeholder(...)` which returns an error message when called. To replace: + +1. **Keep the existing declaration** (`bi::Decl << ...`) — it defines the argument and return types. + +2. **Write the implementation function** that takes `const Nodes& args` and returns a `Node`. + +3. **Change the factory** from: + ```cpp + return BuiltInDef::placeholder({"name"}, decl, "message"); + ``` + to: + ```cpp + return BuiltInDef::create({"name"}, decl, implementation_function); + ``` + +4. **If the builtin requires a platform dependency** (e.g., OpenSSL), use compile-time guards: + ```cpp + #ifdef REGOCPP_HAS_CRYPTO + return BuiltInDef::create({"crypto.sha256"}, sha256_decl, sha256); + #else + return BuiltInDef::placeholder({"crypto.sha256"}, sha256_decl, Message); + #endif + ``` + +### Removing a Deprecated Builtin + +1. Check the deprecated list in `BuiltInsDef::is_deprecated` in `src/builtins.cc`. +2. Add the builtin name to the `deprecated` vector if not already present. +3. Deprecated builtins return `RegoTypeError` when called, regardless of implementation. + +## Key Patterns + +### Argument Unwrapping + +```cpp +// Single type +Node x = unwrap_arg(args, UnwrapOpt(0).type(JSONString)); + +// Multiple accepted types +Node x = unwrap_arg(args, UnwrapOpt(0).types({JSONString, Int, Float})); + +// With function name for error messages +Node x = unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.sha256")); + +// With custom error details +Node x = unwrap_arg(args, UnwrapOpt(0).type(JSONString) + .func("crypto.sha256").specify_number(true)); +``` + +Always check for errors after unwrapping: +```cpp +if (x->type() == Error) + return x; +``` + +### Value Extraction + +```cpp +std::string val = get_string(node); // strips quotes +BigInt ival = get_int(node); +double dval = get_double(node); +bool bval = get_bool(node); + +// Optional variants (return std::nullopt on wrong type) +auto maybe_str = try_get_string(node); +auto maybe_int = try_get_int(node); +``` + +### Result Construction + +For **scalar** results, return bare token nodes: +```cpp +return JSONString ^ "result"; // string result +return Int ^ BigInt(42); // integer result +return Float ^ 3.14; // float result +return True ^ "true"; // boolean true +return False ^ "false"; // boolean false +return Undefined; // undefined (no result) +return err(args[0], "error message"); // error +return err(args[0], "msg", EvalTypeError); // typed error +``` + +For **compound** results (arrays, objects, nested structures), use the rego API helpers declared in `include/rego/rego.hh`. These handle all Term/Scalar wrapping and cloning correctly, avoiding well-formedness errors: + +```cpp +// Booleans, strings, numbers, null — produce correctly-wrapped Scalar nodes +return boolean(true); // same as True ^ "true" but self-documenting +return rego::string("hello"); // note: qualify as rego::string to avoid std::string +return number(3.14); +return null(); + +// Arrays — items are auto-wrapped via Resolver::to_term() +return array({boolean(true), rego::string("ok")}); +return array({header_term, payload_term, rego::string(sig_hex)}); + +// Objects — built from object_item() nodes +return object({ + object_item(rego::string("key"), rego::string("value")), + object_item(rego::string("count"), number(42.0)) +}); + +// Nested: array of [bool, object, object] +return array({boolean(false), object({}), object({})}); +``` + +**IMPORTANT**: Never manually construct compound result nodes with `NodeDef::create(Array)`, `Term <<`, `Scalar <<`, or `push_back`. These patterns produce nodes that violate well-formedness rules. Always use `array()`, `object()`, `object_item()`, `boolean()`, `rego::string()`, `number()`, and `null()` instead. These helpers call `Resolver::to_term()` internally, which handles all wrapping (Term, Scalar) and cloning correctly regardless of whether the input is a bare token, a Scalar, or an already-wrapped Term. + +### Declaration Types + +```cpp +bi::String, bi::Number, bi::Boolean, bi::Null, bi::Any // Scalar types +bi::DynamicArray << (bi::Type << bi::String) // array of strings +bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any) // object +bi::StaticArray << (bi::Type << bi::Boolean) << (bi::Type << bi::String) // [bool, string] +bi::Set << (bi::Type << bi::String) // set of strings +``` + +### Shared Code Between Namespaces + +When multiple namespaces share implementation logic (e.g., crypto primitives shared between `crypto.*` and `io.jwt.*`): + +1. Create a shared internal header: `src/builtins/.hh` +2. Create a shared implementation: `src/builtins/.cc` +3. Add the `.cc` to `src/CMakeLists.txt` SOURCES +4. Include from both namespace files + +Use compile-time backend selection for platform-dependent code: +```cmake +set(REGOCPP_CRYPTO_BACKEND "" CACHE STRING "Crypto backend: openssl3, '' (disabled)") +if(REGOCPP_CRYPTO_BACKEND STREQUAL "openssl3") + find_package(OpenSSL 3.0 REQUIRED) + target_link_libraries(rego PUBLIC OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(rego PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_OPENSSL3=1) +endif() +``` + +### Parsing JSON Inside Builtins + +When a builtin needs to inspect or validate JSON data (e.g., JWT headers/payloads, JWK keys), **always use the Trieste JSON parser** (``) instead of manual string searching. Manual JSON parsing (e.g., `json.find("\"field\"")`, character-by-character extraction) is brittle and will break on whitespace variations, escaped characters, nested structures, and field-name substrings. + +**Two JSON AST types exist** — use the right one for the task: + +| AST type | Namespace | Produced by | Use for | +|----------|-----------|-------------|---------| +| JSON AST | `json::Object`, `json::Array`, `json::String`, ... | `json::reader().synthetic(str).read()` | Internal inspection: field lookup, type checking, claim validation | +| Rego AST | `rego::Object`, `rego::Array`, `rego::JSONString`, ... | `json::reader().synthetic(str) >> json_to_rego(true)` | Return values to the Rego evaluator | + +**For internal inspection**, parse into the JSON AST and use `json::select` with RFC 6901 JSON Pointer paths: + +```cpp +#include + +// Parse raw JSON string into JSON AST +Node ast = parse_json(json_str); // json::reader().synthetic(str).read() + +// Field lookup — paths use RFC 6901 format with leading "/" +auto alg = ::json::select_string(ast, {"/alg"}); // std::optional +auto exp = ::json::select_number(ast, {"/exp"}); // std::optional +auto ok = ::json::select_boolean(ast, {"/active"}); // std::optional + +// Check field existence (select returns Error node if missing) +Node field = ::json::select(ast, {"/enc"}); +if (field->type() != Error) { /* field exists */ } + +// Check field type +Node aud = ::json::select(ast, {"/aud"}); +if (aud->type() == ::json::Array) { /* it's an array */ } + +// Nested paths +auto deep = ::json::select_string(ast, {"/foo/bar/baz"}); +``` + +**CRITICAL**: The path argument is a `Location` initialized from a string literal with `{"/field"}` syntax. The leading `/` is required by RFC 6901. Using `Location("field")` without the `/` will fail silently. + +**For return values**, use `parse_json_to_term()` (which runs `json_to_rego`) to produce Rego-typed nodes suitable for the evaluator. Parse into the JSON AST first for validation, then convert to Rego terms only at the end when building the return value. + +**For Rego Object nodes** (e.g., constraint objects passed as builtin arguments), use `try_get_string(node)` and `try_get_double(node)` — these already handle Term/Scalar unwrapping. Do NOT navigate with `node / Scalar` before calling them. + +## Testing + +### OPA Conformance Tests + +OPA test cases live in `build/opa/v1/test/cases/testdata/v1//`. Directory names match OPA builtin names with no separators (e.g., `cryptohmacsha256`, `jwtdecodeverify`). + +```bash +# Run a specific builtin's tests +cd build && ./tests/rego_test -wf opa/v1/test/cases/testdata/v1/ + +# List available test directories +ls build/opa/v1/test/cases/testdata/v1/ | grep + +# Run all OPA tests (slow) +ctest -R rego_test_opa +``` + +### Custom Test Cases + +Add YAML test cases to `tests/regocpp.yaml` or `tests/bugs.yaml`: + +```yaml +- note: mybuiltin/basic + query: data.test.p = x + modules: + - | + package test + p := crypto.sha256("hello") + want_result: + - x: 2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 +``` + +### Error Message Matching + +Error messages must match OPA exactly — conformance tests compare strings literally. When implementing error handling, check OPA's actual error output for the builtin. + +## Reference Plans + +- [Crypto & JWT Implementation Plan](./references/crypto-jwt-plan.md) — Phased plan for implementing `crypto.*` and `io.jwt.*` builtins with a shared OpenSSL core diff --git a/.github/skills/regocpp-builtins/references/PHASE5-BCRYPT.md b/.github/skills/regocpp-builtins/references/PHASE5-BCRYPT.md new file mode 100644 index 00000000..057ee156 --- /dev/null +++ b/.github/skills/regocpp-builtins/references/PHASE5-BCRYPT.md @@ -0,0 +1,161 @@ +# Phase 5: Windows BCrypt Backend — Implementation Guide + +## Overview + +Implement the Windows BCrypt/CNG crypto backend in `src/builtins/crypto_bcrypt.cc`. The file already has working implementations for hashing and HMAC (8 functions), plus fully-stubbed TODO implementations for the remaining 8 functions (verify, sign, X.509, RSA key parsing). + +## Architecture + +``` +crypto_core.hh ← Backend-agnostic API (16 functions) +crypto_utils.hh ← Shared platform-independent utilities (inline) +crypto_openssl3.cc ← OpenSSL 3 backend (Linux/macOS) — REFERENCE IMPLEMENTATION +crypto_bcrypt.cc ← Windows BCrypt backend — THIS FILE +``` + +**Build command:** +```cmd +cmake -B build -DREGOCPP_CRYPTO_BACKEND=bcrypt -DREGOCPP_BUILD_TESTS=ON -DREGOCPP_BUILD_TOOLS=ON -DREGOCPP_OPA_TESTS=ON +cmake --build build +``` + +**CMake defines set:** `REGOCPP_HAS_CRYPTO=1`, `REGOCPP_CRYPTO_BCRYPT=1` +**Libraries linked:** `bcrypt`, `crypt32` + +## What's Already Done + +### Fully implemented (8 functions): +- `md5_hex`, `sha1_hex`, `sha256_hex` — via `BCryptHash` +- `hmac_md5_hex`, `hmac_sha1_hex`, `hmac_sha256_hex`, `hmac_sha512_hex` — via `BCryptCreateHash` with HMAC flag +- `hmac_equal` — constant-time compare (shared via `crypto_utils.hh`) + +### Delegated to shared utilities (3 functions): +- `base64url_encode_nopad`, `base64url_decode` — pure C++ (shared header) +- `parse_algorithm` — string→enum (shared header) + +### Stubbed with TODO + strategy notes (8 functions): +- `verify_signature` — JWT signature verification +- `verify_signature_any_key` — JWKS multi-key verification +- `sign` — JWT signing +- `parse_certificates` — X.509 certificate parsing +- `parse_and_verify_certificates` — X.509 chain validation +- `parse_certificate_request` — CSR parsing +- `parse_rsa_private_key` — RSA private key → JWK +- `parse_private_keys` — multiple private keys → JWK array + +## Implementation Order (recommended) + +### Step 1: verify_signature (HMAC subset first) +Start with HMAC verification (HS256/384/512) since you already have `hmac_hex()`: +1. For HMAC algos: compute HMAC of `signing_input`, compare raw bytes against `signature_bytes` +2. Test: `cd build && .\tests\rego_test -wf opa\v1\test\cases\testdata\v1\jwtverifyhs256` + +### Step 2: verify_signature (RSA) +1. Parse PEM public key or PEM certificate from `key_or_cert` +2. For PEM cert: `CertCreateCertificateContext` → `CryptImportPublicKeyInfoEx2` → `BCRYPT_KEY_HANDLE` +3. For PEM pubkey: `CryptDecodeObjectEx(X509_PUBLIC_KEY_INFO)` → `CryptImportPublicKeyInfoEx2` +4. For JWK: manually construct `BCRYPT_RSAPUBLIC_BLOB` from n/e components +5. `BCryptVerifySignature` with `BCRYPT_PAD_PKCS1` (RS*) or `BCRYPT_PAD_PSS` (PS*) +6. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\jwtverifyrsa` + +### Step 3: verify_signature (ECDSA + EdDSA) +1. EC key import from PEM/JWK → `BCRYPT_ECCPUBLIC_BLOB` +2. **Important:** JWT ECDSA signatures are raw `r||s` concatenation, NOT DER. Convert accordingly. +3. EdDSA requires Windows 10 1903+ (`BCRYPT_ECC_CURVE_25519`) +4. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\jwtverifyeddsa` + +### Step 4: verify_signature_any_key +1. Copy the JWKS iteration logic from `crypto_openssl3.cc` — it's mostly platform-independent JSON parsing +2. Use `::json::reader().synthetic(key_or_cert).read()` then `::json::select(ast, {"/keys"})` +3. For each key in the JSON array, extract the JWK string and call `verify_signature()` +4. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\jwtdecodeverify` + +### Step 5: sign +1. Parse JWK private key JSON +2. Import into `BCRYPT_KEY_HANDLE` (differs per algorithm family) +3. `BCryptSignHash` for RSA/EC, `hmac_hex` for HMAC +4. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\jwtencodesign` + +### Step 6: parse_certificates +1. Use `decode_cert_input()` + `extract_pem_der_blocks("CERTIFICATE")` from `crypto_utils.hh` +2. For each DER block: `CertCreateCertificateContext(X509_ASN_ENCODING, ...)` +3. Extract CN: `CertGetNameStringA(pCert, CERT_NAME_ATTR_TYPE, 0, szOID_COMMON_NAME, ...)` +4. Extract SANs: `CertFindExtension(szOID_SUBJECT_ALT_NAME2)` → `CryptDecodeObjectEx` → `CERT_ALT_NAME_INFO` +5. DER base64: `::base64_encode(der_block, false)` +6. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptox509parsecertificates` + +### Step 7: parse_and_verify_certificates +1. Parse certs as in Step 6 +2. Build cert chain: add root to temp cert store, create chain with `CertGetCertificateChain` +3. OPA convention: input order is root first, leaf last +4. **Return verified chain in leaf-first order** (matching OPA behavior) +5. Test: `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptox509parseandverifycertificates` + +### Step 8: parse_certificate_request + RSA key parsing +1. CSR: `CryptDecodeObjectEx(X509_CERT_REQUEST_TO_BE_SIGNED)` or manually parse ASN.1 +2. RSA key: `CryptDecodeObjectEx(PKCS_RSA_PRIVATE_KEY)` for PKCS#1, or decode PKCS#8 wrapper first +3. Export components via `BCryptExportKey(BCRYPT_RSAFULLPRIVATE_BLOB)` to get n/e/d/p/q/dp/dq/qi +4. Convert each to base64url for JWK format +5. Tests: + - `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptox509parsecertificaterequest` + - `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptox509parsersaprivatekey` + - `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptoparsersaprivatekeys` + - `.\tests\rego_test -wf opa\v1\test\cases\testdata\v1\cryptox509parsekeypair` + +## Key Patterns from the OpenSSL Reference + +### PEM key parsing (reusable across backends) +The shared `extract_pem_der_blocks()` in `crypto_utils.hh` handles PEM → DER extraction. Use it for certificates, keys, and CSRs. + +### JWK parsing +Use Trieste JSON: `auto ast = ::json::reader().synthetic(json_str).read()`. Then `::json::select_string(ast, {"/kty"})` etc. + +### Error messages must match OPA exactly +Conformance tests compare error strings literally. Check OPA's error messages in the test YAML files and match them exactly. + +### Input unescaping +The builtin implementations in `crypto.cc` and `jwt.cc` already call `json::unescape(get_string(...))` before passing strings to the backend. You do NOT need to handle JSON escaping in `crypto_bcrypt.cc`. + +## Test Directories (all under `build/opa/v1/test/cases/testdata/v1/`) + +| Directory | Tests | Category | +|-----------|-------|----------| +| `cryptomd5` | 1 | Hashing | +| `cryptosha1` | 2 | Hashing | +| `cryptosha256` | 5 | Hashing | +| `cryptohmacmd5` | 2 | HMAC | +| `cryptohmacsha1` | 2 | HMAC | +| `cryptohmacsha256` | 2 | HMAC | +| `cryptohmacsha512` | 1 | HMAC | +| `cryptohmacequal` | 1 | HMAC | +| `jwtverifyhs256` | 5 | JWT verify (HMAC) | +| `jwtverifyhs384` | 5 | JWT verify (HMAC) | +| `jwtverifyhs512` | 5 | JWT verify (HMAC) | +| `jwtverifyrsa` | 47 | JWT verify (RSA) | +| `jwtverifyeddsa` | 5 | JWT verify (EdDSA) | +| `jwtdecodeverify` | 47 | JWT decode+verify | +| `jwtbuiltins` | 3 | JWT misc | +| `jwtencodesign` | 5 | JWT sign | +| `jwtencodesignraw` | 7 | JWT sign raw | +| `jwtencodesignheadererrors` | 6 | JWT sign errors | +| `jwtencodesignpayloaderrors` | 5 | JWT sign errors | +| `cryptox509parsecertificates` | 10 | X.509 | +| `cryptox509parseandverifycertificates` | 2 | X.509 verify | +| `cryptox509parsecertificaterequest` | 5 | CSR | +| `cryptox509parsekeypair` | 2 | Keypair | +| `cryptox509parsersaprivatekey` | 2 | RSA key | +| `cryptoparsersaprivatekeys` | 3 | Private keys | + +**Target: 168/168 tests passing** (current score on Linux with OpenSSL backend). + +## Files You'll Modify + +| File | What to change | +|------|---------------| +| `src/builtins/crypto_bcrypt.cc` | Implement the 8 TODO functions | + +Files you should NOT need to modify: +- `crypto_core.hh` — API is stable +- `crypto_utils.hh` — shared utilities are complete +- `crypto.cc` / `jwt.cc` — builtin dispatch is backend-agnostic +- `CMakeLists.txt` / `src/CMakeLists.txt` — build config is ready diff --git a/.github/skills/regocpp-builtins/references/crypto-jwt-plan.md b/.github/skills/regocpp-builtins/references/crypto-jwt-plan.md new file mode 100644 index 00000000..bd731c4b --- /dev/null +++ b/.github/skills/regocpp-builtins/references/crypto-jwt-plan.md @@ -0,0 +1,221 @@ +# Crypto & JWT Builtins Implementation Plan + +Phased plan for implementing `crypto.*` and `io.jwt.*` builtins in rego-cpp with a shared crypto core and compile-time backend selection. + +## Current State + +- All 14 `crypto.*` and 17 `io.jwt.*` builtins are **placeholders** (`BuiltInDef::placeholder`) +- No crypto library dependency exists +- 25 OPA conformance test directories cover the full API surface + +## CMake Backend Selection + +```cmake +set(REGOCPP_CRYPTO_BACKEND "" CACHE STRING + "Crypto backend for crypto/JWT builtins. Options: openssl3, '' (disabled)") +set_property(CACHE REGOCPP_CRYPTO_BACKEND PROPERTY STRINGS "" "openssl3") + +if(REGOCPP_CRYPTO_BACKEND STREQUAL "openssl3") + find_package(OpenSSL 3.0 REQUIRED) + target_link_libraries(rego PUBLIC OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(rego PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_OPENSSL3=1) +elseif(NOT REGOCPP_CRYPTO_BACKEND STREQUAL "") + message(FATAL_ERROR "Unknown crypto backend: ${REGOCPP_CRYPTO_BACKEND}. Options: openssl3") +endif() +``` + +When `REGOCPP_CRYPTO_BACKEND` is empty (default), all crypto/JWT builtins remain as placeholders. This preserves the current zero-dependency build. + +## Shared Core Architecture + +``` +src/builtins/ +├── crypto_core.hh ← Backend-agnostic API header +├── crypto_openssl3.cc ← OpenSSL 3 implementation +├── crypto.cc ← MODIFY: replace placeholders with real impl +└── jwt.cc ← MODIFY: replace placeholders with real impl +``` + +### `crypto_core.hh` — Backend-Agnostic API + +```cpp +#pragma once +#ifdef REGOCPP_HAS_CRYPTO + +namespace rego::crypto_core { + // ── Hashing ── + std::string md5_hex(std::string_view data); + std::string sha1_hex(std::string_view data); + std::string sha256_hex(std::string_view data); + + // ── HMAC ── + std::string hmac_hex(std::string_view algo, std::string_view data, std::string_view key); + bool hmac_equal(std::string_view mac1, std::string_view mac2); // constant-time + + // ── Base64url (raw bytes for JWT) ── + std::string base64url_encode_nopad(std::string_view data); + std::string base64url_decode_bytes(std::string_view data); + + // ── Signature Verification ── + enum class Algorithm { HS256, HS384, HS512, RS256, RS384, RS512, + PS256, PS384, PS512, ES256, ES384, ES512, EdDSA }; + Algorithm parse_algorithm(std::string_view name); + + bool verify_signature(Algorithm algo, + std::string_view signing_input, + std::string_view signature, + std::string_view key_or_cert); + + std::string sign(Algorithm algo, + std::string_view signing_input, + std::string_view key); + + // ── PEM / X.509 ── + struct ParsedCert { /* X.509 fields as JSON-compatible map */ }; + struct ParsedKey { /* Key material */ }; + std::vector parse_certificates(std::string_view pem_or_der); + std::pair> + parse_and_verify_certificates(std::string_view chain); + ParsedCert parse_csr(std::string_view pem_or_der); + std::vector parse_private_keys(std::string_view pem); + ParsedKey parse_rsa_private_key(std::string_view pem); +} + +#endif // REGOCPP_HAS_CRYPTO +``` + +### Shared Core Reuse Matrix + +| `crypto_core` Primitive | Used by `crypto.cc` | Used by `jwt.cc` | +|---|---|---| +| `md5_hex` | `crypto.md5` | — | +| `sha1_hex` | `crypto.sha1` | — | +| `sha256_hex` | `crypto.sha256` | — | +| `hmac_hex` | `crypto.hmac.*` | `io.jwt.encode_sign` (HS* signing) | +| `hmac_equal` | `crypto.hmac.equal` | — | +| `verify_signature` | — | All `io.jwt.verify_*`, `io.jwt.decode_verify` | +| `sign` | — | `io.jwt.encode_sign`, `io.jwt.encode_sign_raw` | +| `parse_certificates` | `crypto.x509.parse_certificates` | `io.jwt.decode_verify` (cert constraint) | +| `parse_and_verify_certs` | `crypto.x509.parse_and_verify_*` | — | +| `base64url_encode_nopad` | — | All JWT encode/sign | +| `base64url_decode_bytes` | — | All JWT decode/verify | + +## Implementation Phases + +### Phase 1: Infrastructure + Hashing (~300 LOC) + +**Goal:** Validate dependency integration and shared core pattern. + +**Steps:** +1. Add `REGOCPP_CRYPTO_BACKEND` to `CMakeLists.txt` and presets +2. Create `crypto_core.hh` with hash + HMAC signatures +3. Create `crypto_openssl3.cc` implementing hash + HMAC via EVP API +4. Add `crypto_openssl3.cc` to `src/CMakeLists.txt` +5. Replace placeholders in `crypto.cc` for: + - `crypto.md5`, `crypto.sha1`, `crypto.sha256` + - `crypto.hmac.md5`, `crypto.hmac.sha1`, `crypto.hmac.sha256`, `crypto.hmac.sha512` + - `crypto.hmac.equal` + +**OPA tests (8 directories):** +``` +cryptomd5, cryptosha1, cryptosha256, +cryptohmacmd5, cryptohmacsha1, cryptohmacsha256, cryptohmacsha512, +cryptohmacequal +``` + +### Phase 2: JWT Decode + Verify (~500 LOC) + +**Goal:** Implement all JWT verification and decoding. + +**Steps:** +1. Add JWT token parsing helpers to `crypto_core` (split on `.`, base64url decode) +2. Add signature verification to `crypto_openssl3.cc` (EVP_DigestVerify for RSA/ECDSA/EdDSA, HMAC for HS*) +3. Add JWK → EVP_PKEY conversion (oct, RSA, EC, OKP key types) +4. Implement in `jwt.cc`: + - `io.jwt.decode` — split + decode + hex-encode sig + - `io.jwt.verify_hs256/384/512` — HMAC verify + - `io.jwt.verify_rs256/384/512` — RSA PKCS#1 v1.5 verify + - `io.jwt.verify_ps256/384/512` — RSA-PSS verify + - `io.jwt.verify_es256/384/512` — ECDSA verify + - `io.jwt.verify_eddsa` — Ed25519 verify + - `io.jwt.decode_verify` — verify + claim validation (exp, nbf, iss, aud, time) + +**Key subtleties:** +- `decode_verify` returns `[false, {}, {}]` on failure (not an error) +- JWK key sets matched by `kid` header +- PEM certificates parsed to extract public key + +**OPA tests (7 directories):** +``` +jwtbuiltins, jwtverifyhs256, jwtverifyhs384, jwtverifyhs512, +jwtverifyrsa, jwtverifyeddsa, jwtdecodeverify +``` + +### Phase 3: JWT Encode + Sign (~300 LOC) + +**Goal:** Complete JWT support. + +**Steps:** +1. Add signing to `crypto_openssl3.cc` (EVP_DigestSign, HMAC) +2. Implement in `jwt.cc`: + - `io.jwt.encode_sign` — JSON-serialize header+payload, base64url-encode, sign + - `io.jwt.encode_sign_raw` — same but skip JSON re-serialization + +**OPA tests (4 directories):** +``` +jwtencodesign, jwtencodesignraw, +jwtencodesignheadererrors, jwtencodesignpayloaderrors +``` + +### Phase 4: X.509 + Key Parsing (~800 LOC, highest complexity) + +**Goal:** Complete all crypto builtins. + +**Steps:** +1. Add X.509 parsing to `crypto_openssl3.cc` (PEM/DER → structured objects) +2. Implement in `crypto.cc`: + - `crypto.x509.parse_certificates` + - `crypto.x509.parse_and_verify_certificates` + - `crypto.x509.parse_and_verify_certificates_with_options` + - `crypto.x509.parse_certificate_request` + - `crypto.x509.parse_keypair` + - `crypto.x509.parse_rsa_private_key` + - `crypto.parse_private_keys` + +**This is the hardest phase** — X.509 certificate output must match OPA's Go `x509.Certificate` struct serialization field-by-field. Expect significant iteration matching field names, date formats, extension representations, and OID handling. + +**OPA tests (6 directories):** +``` +cryptox509parsecertificates, cryptox509parseandverifycertificates, +cryptox509parsecertificaterequest, cryptox509parsekeypair, +cryptox509parsersaprivatekey, cryptoparsersaprivatekeys +``` + +## Phase 5: Windows Native Backend (BCrypt/CNG) + +**Goal:** Provide a Windows-native crypto backend using BCrypt/CNG so that Windows builds don't require an external OpenSSL install. + +**Steps:** +1. Add `bcrypt` option to `REGOCPP_CRYPTO_BACKEND` in `CMakeLists.txt` +2. Create `src/builtins/crypto_bcrypt.cc` implementing the `crypto_core.hh` API using Windows BCrypt/CNG +3. Wire up CMake: link `bcrypt.lib`, define `REGOCPP_CRYPTO_BCRYPT=1` +4. Add Windows-specific presets (e.g., `debug-msvc-opa`, `release-msvc-opa`) +5. Validate all OPA conformance tests pass on Windows + +**Notes:** +- BCrypt/CNG is available on Windows Vista+ (no external dependency) +- EdDSA (Ed25519) support may require Windows 10 1903+ or a fallback +- The `crypto_core.hh` abstraction layer is designed for this — only the new `.cc` file touches Windows APIs + +## Risk Assessment + +| Risk | Mitigation | +|---|---| +| X.509 output format mismatch | Inspect OPA test expected output field-by-field before implementing | +| OpenSSL API differences | Require OpenSSL ≥ 3.0; use EVP API exclusively | +| EdDSA support | Available in OpenSSL 3.0+ via `EVP_PKEY_ED25519` | +| JWK parsing edge cases | OPA tests cover RSA, EC, oct, OKP; implement incrementally | +| Cross-platform OpenSSL availability | `REGOCPP_CRYPTO_BACKEND=""` preserves zero-dependency build | +| Error message matching | Conformance tests compare strings literally; verify against OPA output | +| Future backend portability | Backend-agnostic API in `crypto_core.hh`; only `crypto_openssl3.cc` touches OpenSSL | +| BCrypt EdDSA gaps | Ed25519 requires Windows 10 1903+; may need version check or graceful fallback | diff --git a/.github/skills/trieste-dev/SKILL.md b/.github/skills/trieste-dev/SKILL.md new file mode 100644 index 00000000..d63664e4 --- /dev/null +++ b/.github/skills/trieste-dev/SKILL.md @@ -0,0 +1,255 @@ +--- +name: trieste-dev +description: 'Plan and implement Trieste-based compiler passes and AST transformations for rego-cpp. Use when: adding new compiler passes, modifying AST structure, implementing new Rego language features, debugging pass failures, working with well-formedness definitions, or performing any multi-step implementation that touches the Trieste pass pipeline. Includes the multi-planner approach for complex features.' +argument-hint: 'Describe the feature or pass work to plan or implement.' +--- + +# Trieste Development Workflow + +Plan and implement Trieste-based compiler passes, AST transformations, and language features in rego-cpp. + +## When to Use + +- Adding or modifying a compiler pass in the file-to-rego or rego-to-bundle pipeline +- Implementing new Rego language syntax (new tokens, grammar rules) +- Changing well-formedness definitions +- Debugging pass failures or well-formedness violations +- Implementing complex multi-step features that touch the AST pipeline +- Any task requiring coordination across parser, passes, built-ins, and VM + +## Core Concepts + +Trieste is a multi-pass term-rewriting system. Understanding these concepts is mandatory before proceeding: + +- **Pass**: A `PassDef` that takes an AST conforming to an input well-formedness (WF) definition and rewrites it to conform to an output WF definition. Passes run repeatedly until no more rules match (fixpoint), unless `dir::once` is specified. +- **Well-formedness (WF)**: A structural specification of valid AST shapes. Each pass declares its output WF. WF definitions are **incremental** — each extends the previous with `|` (choice). +- **Pattern → Effect rules**: Each pass contains rules of the form `Pattern >> Effect`. Patterns match AST subtrees; effects produce replacement subtrees. +- **Driver/Reader/Rewriter**: Trieste helpers that chain passes into pipelines. rego-cpp uses `Reader` for parsing and `Rewriter` for transformation. +- **Generative testing**: Trieste can generate random ASTs from WF definitions to fuzz each pass. This discovers edge cases in rewrite rules. + +## Procedure + +### Step 0: Understand the Current AST + +Before any implementation, you must understand the AST structure at the point you're modifying. + +1. **Read the well-formedness definitions** for the passes surrounding your change: + - File-to-rego passes: defined in `src/file_to_rego.cc` (WF definitions inline with passes) + - Rego-to-bundle passes: defined in `src/rego_to_bundle.cc` + - Base WF: `include/rego/rego.hh` → `wf` + - Bundle WF: `include/rego/rego.hh` → `wf_bundle` + - Internal WF: `src/internal.hh` → `wf_bundle_input` + +2. **Dump the AST** at the relevant pass to see the actual tree shape: + ```bash + ./build/tools/rego eval --dump_passes .copilot/pass-debug/ -p '' + ``` + Or write a minimal `.rego` file and use `--wf` to check well-formedness. + +3. **Never assume node structure** — always verify by reading the WF definition. Nodes are typically wrapped (e.g., Array elements inside Term nodes). Use `unwrap()` helpers. + +### Step 1: Multi-Planner Analysis + +For any non-trivial feature, use the **multi-planner approach** — analyze the problem from multiple perspectives before writing code. This prevents costly rework. + +#### Perspective 1: Reference Implementation (OPA) + +How does OPA implement this feature? + +1. **Check OPA's documentation** for the feature's specification +2. **Inspect OPA's IR output** to see how OPA compiles the feature: + ```bash + mkdir -p .copilot/opa-ir-test + # Create minimal policy exercising the feature + cat > .copilot/opa-ir-test/policy.rego << 'EOF' + package test + # ... minimal example using the feature + EOF + /tmp/opa build --bundle .copilot/opa-ir-test --target plan -e test/ -o .copilot/opa-ir-test/bundle.tar.gz + cd .copilot/opa-ir-test && tar xzf bundle.tar.gz && python3 -m json.tool plan.json + ``` +3. **Test both constant and variable expressions** — OPA's optimizer may fold constants, hiding the general compilation path +4. **Record**: internal built-in names, calling conventions, undefined-handling patterns + +#### Perspective 2: AST Pipeline Impact + +Where in the rego-cpp pipeline does this feature need to be handled? + +1. **Parser changes?** — Does this require new tokens in `include/rego/rego.hh` and rules in `src/parse.cc`? +2. **Which file-to-rego passes are affected?** — Map the feature to specific passes in the 18-pass file-to-rego pipeline (see [pass-pipeline.md](./references/pass-pipeline.md)) +3. **Which rego-to-bundle passes are affected?** — Map to the 11-pass rego-to-bundle pipeline +4. **VM changes?** — Does `src/virtual_machine.cc` need new opcodes or evaluation logic? +5. **Built-in additions?** — Any new built-in functions required? +6. **New Term alternative?** — If adding a new node type to `Term`, audit all type-dispatch sites: + - `src/dependency_graph.cc` — `add_lhs_var` / `add_rhs` must handle the new type + - `src/resolver.cc` — variable resolution may need a case + - `src/virtual_machine.cc` — evaluation dispatch + - `src/encoding.cc` — serialization in `to_key()` + - `src/opblock.cc` — lowering to opcodes in `term_to_opblock()` + +#### Perspective 3: Well-formedness Chain + +How do WF definitions need to change? + +1. Trace the WF chain from the first affected pass to the last +2. Identify which node types need to be added, modified, or removed at each stage +3. Verify that WF changes are **incremental** — each definition extends the previous +4. Check that no downstream pass is broken by the WF changes + +#### Perspective 4: Test Strategy + +How will you verify correctness at each stage? + +1. **YAML test cases** — Write expected input/output pairs in `tests/regocpp.yaml` or `tests/bugs.yaml` +2. **OPA conformance tests** — Identify which OPA test subdirectories exercise the feature +3. **Generative testing** — Plan to run the Trieste `test` command to check WF validity +4. **Incremental verification** — After each pass modification, run targeted tests before proceeding + +### Step 2: Implementation Plan + +Based on the multi-planner analysis, create a sequenced implementation plan: + +1. **Order changes by pipeline stage** — parser first, then file-to-rego passes in order, then rego-to-bundle passes, then VM +2. **Implement one pass at a time** — never modify multiple passes simultaneously without testing between changes +3. **Write test cases first** — add YAML test cases for the feature before implementing, so you can verify each step +4. **Use smallest possible passes** — prefer adding a new small pass over making an existing pass more complex (Trieste philosophy: "there is no downside to having many passes") + +### Step 3: Incremental Implementation + +For each pass change: + +1. **Read the current pass code** and its surrounding WF definitions +2. **Modify the WF definition** for the pass output if needed (define new node shapes) +3. **Add rewrite rules** using the pattern → effect DSL: + ```cpp + // Standard pattern: match context, capture nodes, produce replacement + In(ParentType) * T(NodeType)[Capture] >> [](Match& _) { + return NewNode << _(Capture); + }, + ``` +4. **Add error rules** for invalid inputs the WF would allow: + ```cpp + // Catch-all for malformed nodes (order matters — put after positive rules) + T(BadNode)[Node] >> [](Match& _) { + return err(_(Node), "descriptive error message"); + }, + ``` +5. **Run targeted tests** immediately: + ```bash + # Run specific test case + ./build/tests/rego_test -wf tests/regocpp.yaml + # Or specific OPA subdirectory + ./build/tests/rego_test -wf opa/v1/test/cases/testdata/v1/ + ``` +6. **Dump the AST** to verify the transformation: + ```bash + ./build/tools/rego eval --dump_passes .copilot/pass-debug/ '' + ``` + +### Step 4: Validation + +After all passes are implemented: + +1. **Run the full rego-cpp test suite**: + ```bash + ctest --test-dir build -R "rego_test_regocpp|rego_test_bugs|rego_test_cts|rego_test_cpp_api" + ``` +2. **Run OPA conformance tests** (if applicable): + ```bash + ctest --test-dir build -R rego_test_opa --output-on-failure + ``` +3. **Run generative testing** to check WF validity: + ```bash + ./build/tools/rego test -f -c 1000 + ``` +4. **Run with AddressSanitizer** for memory safety: + ```bash + cmake --preset asan-clang && ninja -C build-asan && ctest --test-dir build-asan + ``` + +## Key Patterns Reference + +### PassDef Structure + +```cpp +PassDef my_pass() +{ + return { + "my_pass", // Name (for debugging/logging) + wf_my_pass, // Output well-formedness definition + dir::bottomup | dir::once, // Traversal: topdown/bottomup, once/fixpoint + { + // Rules (matched in order, first match wins) + In(Parent) * T(Child)[C] >> [](Match& _) { return _(C); }, + } + }; +} +``` + +### Traversal Directions + +| Direction | Meaning | +|-----------|---------| +| `dir::bottomup` | Process children before parents | +| `dir::topdown` | Process parents before children | +| `dir::once` | Single traversal (combine with above) | +| *(no once)* | Repeat until fixpoint (no rules match) | + +### Pattern DSL Quick Reference + +| Pattern | Meaning | +|---------|---------| +| `T(Foo)` | Match a node of type `Foo` | +| `T(Foo)[X]` | Match `Foo`, bind to variable `X` | +| `T(Foo) / T(Bar)` | Match `Foo` or `Bar` | +| `A * B` | Match `A` followed by `B` (siblings) | +| `P << C` | Match children `C` inside parent `P` | +| `In(P)` | Parent context is `P` (not part of match) | +| `Any` | Match any single node | +| `Any++[X]` | Match one or more remaining nodes, bind to `X` | +| `End` | Assert no more siblings | +| `_(X)` | In effect: get single node bound to `X` | +| `_[X]` | In effect: get all nodes bound to `X` (NodeRange) | +| `*_[X]` | In effect: get children of nodes bound to `X` | + +### Well-formedness DSL + +```cpp +inline const auto wf_my_pass = + wf_previous_pass // Inherit from previous pass + | (NewNode <<= ChildA * ChildB) // NewNode has exactly ChildA then ChildB + | (Container <<= Element++) // Container has 0+ Elements + | (Container <<= Element++[1]) // Container has 1+ Elements + | (Wrapper <<= (ChoiceA | ChoiceB)) // Wrapper has one of ChoiceA or ChoiceB + | (Parent <<= Name * Body)[Name] // [Name] = Name is stored in symbol table + ; +``` + +### Creating AST Nodes + +```cpp +// Node with children +NewNode << child1 << child2 + +// Node with string content (location) +TokenType ^ "string content" + +// Splice children from a matched range +Container << *_[MatchVar] // all children of matched nodes +Container << _[MatchVar] // all matched nodes themselves + +// Empty node (remove from tree) +return {}; +``` + +## Common Mistakes + +1. **Not reading the WF definition first** — The #1 source of bugs. Nodes are wrapped in unexpected ways. +2. **Modifying multiple passes without testing between** — Errors compound and become impossible to diagnose. +3. **Comparing `child->type()` directly** — Use `unwrap()` helpers; nodes are wrapped in Term/Scalar layers. +4. **Forgetting error rules** — Generative testing will generate inputs that your positive rules don't handle. You must add error rules for these cases. +5. **Wrong traversal direction** — `bottomup` processes children first (useful when collapsing); `topdown` processes parents first (useful when pushing structure down). +6. **Rule ordering** — Rules are matched in order. If a general rule comes before a specific one, the specific rule will never fire. +7. **Missing `dir::once`** — Without it, the pass runs to fixpoint. This is correct for most passes but causes infinite loops if rules don't converge. +8. **Creating parallel paths instead of reusing the standard pipeline** — When adding a new compound node type (e.g., `TemplateString`), prefer routing its sub-expressions through the existing `Group → Literal → Expr` pipeline rather than creating a custom parallel path (e.g., `TemplateString <<= (TemplateLiteral | Expr)++`). The standard pipeline already handles `with`/`as`, `some`, comprehensions, and other features. Creating a parallel path means manually replicating all of that machinery. In the parser, use `m.term()` to separate groups naturally and `m.in(NodeType)` to detect context on closing delimiters, rather than `m.push(Brace)` which creates a separate nesting scope. Convert specialized tokens (e.g., `TemplateLiteral`) to standard types (e.g., `Scalar << String << JSONString`) as early as possible (in the `prep` pass) to minimize WF cascading. +9. **Not auditing `dependency_graph.cc` when adding new Term alternatives** — The dependency graph in `src/dependency_graph.cc` has explicit `if (lhs == Type)` cases for every node type that can appear as a Term child. When adding a new Term alternative, you must add a corresponding case there. Missing cases cause "Unable to unify due to cycle" errors. Also audit `resolver.cc` and `virtual_machine.cc` for similar type-dispatch patterns. diff --git a/.github/skills/trieste-dev/references/pass-pipeline.md b/.github/skills/trieste-dev/references/pass-pipeline.md new file mode 100644 index 00000000..a66d0e35 --- /dev/null +++ b/.github/skills/trieste-dev/references/pass-pipeline.md @@ -0,0 +1,75 @@ +# rego-cpp Pass Pipeline Reference + +The rego-cpp compiler has two main pipelines, each composed of sequential Trieste passes. + +## Pipeline 1: File-to-Rego (Parsing → Structured AST) + +**Source**: `src/file_to_rego.cc` +**Input**: Raw Rego source text +**Output**: Structured Rego AST conforming to `wf` (defined in `include/rego/rego.hh`) + +| # | Pass Name | Direction | Purpose | +|---|-----------|-----------|---------| +| 1 | `prep` | bottomup, once | Token preparation from parse tree: organize raw tokens into initial structure | +| 2 | `some_every` | bottomup, once | Extract `some` and `every` declarations from token groups | +| 3 | `ref_args` | bottomup, once | Process reference bracket/dot arguments into RefArgBrack/RefArgDot nodes | +| 4 | `refs` | bottomup, once | Build Ref and RefTerm expressions from tokens | +| 5 | `groups` | bottomup, once | Group tokens into Array, Object, Set collections | +| 6 | `terms` | bottomup, once | Extract Term nodes from expressions | +| 7 | `unary` | bottomup, once | Handle unary minus/negation operators | +| 8 | `arithbin_first` | bottomup, fixpoint | First-precedence operators: ×, ÷, % (multiply/divide/modulo) | +| 9 | `arithbin_second` | bottomup, fixpoint | Second-precedence operators: +, − (add/subtract) | +| 10 | `comparison` | bottomup, once | Comparison operators: ==, !=, <, >, <=, >= | +| 11 | `membership` | bottomup, once | Membership/containment: `in` operator | +| 12 | `assign` | bottomup, once | Assignment operators: `:=` and `=` unification | +| 13 | `else_not` | bottomup, once | Process `else` and `not` keywords | +| 14 | `collections` | bottomup, once | Array/object/set comprehension construction | +| 15 | `lines` | topdown, once | Statement line boundary detection | +| 16 | `rules` | bottomup, once | Rule head/body extraction and structuring | +| 17 | `literals` | bottomup, once | Literal formation from values | +| 18 | `structure` | bottomup, once | Final module structure assembly | + +### Operator Precedence Passes (8–12) + +Passes 8–12 implement operator precedence via the Trieste multi-pass approach: +- **Higher precedence first**: `arithbin_first` (×÷%) runs before `arithbin_second` (+−) +- This naturally produces correct binary tree nesting without explicit precedence tables +- The same pattern from the Trieste infix tutorial: separate passes per precedence level + +## Pipeline 2: Rego-to-Bundle (Structured AST → Executable Bytecode) + +**Source**: `src/rego_to_bundle.cc` +**Input**: Structured Rego AST conforming to `wf_bundle_input` (defined in `src/internal.hh`) +**Output**: Executable bundle conforming to `wf_bundle` (defined in `include/rego/rego.hh`) + +| # | Pass Name | Direction | Purpose | +|---|-----------|-----------|---------| +| 1 | `refheads` | bottomup, once | Rule head reference processing | +| 2 | `rules` | topdown, once | Symbol table population for rules and modules | +| 3 | `locals` | bottomup, once | Local variable identification and scoping | +| 4 | `implicit_scans` | bottomup, once | Implicit iteration discovery (e.g., iterating over sets/arrays) | +| 5 | `merge` | topdown, once | Virtual document hierarchy merging | +| 6 | `unify` | bottomup, fixpoint | Unification: convert expressions to assignments/equality tests using dependency graph | +| 7 | `expr_to_opblock` | bottomup, once | Convert high-level expressions to operation blocks (bytecode-like) | +| 8 | `lift_functions` | bottomup, once | Lift rules to callable functions | +| 9 | `with_rules` | bottomup, once | Handle `with` statement rewriting and function reification | +| 10 | `add_plans` | topdown, once | Generate execution plans for entrypoints | +| 11 | `index_strings_locals` | topdown, once | Index string constants and local variables for the VM | + +## Well-formedness Chain + +``` +wf_parser (parse.cc output) + → wf_prep → wf_some_every → wf_ref_args → ... → wf_structure + = wf (rego.hh, the "Rego source" grammar) + +wf_bundle_input (internal.hh, starting point for pipeline 2) + → wf_refheads → wf_rules → wf_locals → ... → wf_index_strings_locals + = wf_bundle (rego.hh, the executable format) +``` + +Each pass's WF extends the previous with `|` (adding new node types) or replaces entries (changing node structure). WF definitions are defined **inline** in the same file as the pass, near the `PassDef` function. + +## VM Execution + +After both pipelines complete, the resulting bundle AST is executed by the virtual machine in `src/virtual_machine.cc`. The VM interprets the operation blocks produced by `expr_to_opblock` and subsequent passes. diff --git a/.github/workflows/pr_gate.yml b/.github/workflows/pr_gate.yml index 6d04e31b..b3f4e3c9 100644 --- a/.github/workflows/pr_gate.yml +++ b/.github/workflows/pr_gate.yml @@ -86,7 +86,7 @@ jobs: sudo apt-get install ninja-build - name: CMake config - run: cmake -B ${{github.workspace}}/build --preset release-clang-opa -DREGOCPP_SANITIZE=address + run: cmake -B ${{github.workspace}}/build --preset release-clang-opa -DREGOCPP_SANITIZE=address,undefined - name: CMake build working-directory: ${{github.workspace}}/build @@ -227,7 +227,7 @@ jobs: - name: CMake config run: | - cmake -B ${{github.workspace}}/build --preset release-opa + cmake -B ${{github.workspace}}/build --preset release-windows-opa - name: CMake build working-directory: ${{github.workspace}}/build diff --git a/.gitignore b/.gitignore index fe57593a..1fb52ec6 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ build_* .vscode .cache .env -.python-version \ No newline at end of file +.python-version +.copilot diff --git a/CHANGELOG b/CHANGELOG index 88f3131c..e5b21250 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,68 @@ # Changelog +## 2026-03-25 - Version 1.3.0 +Minor version upgrading OPA/Rego compatibility with new language features, +crypto/JWT support, and build system improvements. + +**New Features** +- OPA Rego compatibility upgraded from v1.8.0 to v1.14.1. +- Added template string interpolation support (`$"Hello {name}!"` and + `` $`raw {expr}` ``), including the `internal.template_string` built-in. +- Added `array.flatten` built-in. +- Added `crypto.*` built-in family: hashing (MD5, SHA1, SHA256), HMAC + (MD5, SHA1, SHA256, SHA512), X.509 certificate parsing and verification, + RSA key parsing, and key pair parsing. + `crypto.x509.parse_and_verify_certificates_with_options` is not yet + implemented (no OPA conformance tests exist for it). +- Added `io.jwt.*` built-in family: `decode`, `decode_verify`, `encode_sign`, + `encode_sign_raw`, and signature verification for HS256/384/512, + RS256/384/512, PS256/384/512, ES256/384/512, and EdDSA. + EdDSA is only available with the OpenSSL backend. +- Added pluggable crypto backend architecture controlled by the + `REGOCPP_CRYPTO_BACKEND` CMake option: + - `mbedtls` (default) — Mbed TLS v3.6.2, built from source via + FetchContent with zero system dependencies on any platform. + - `openssl3` — OpenSSL 3.0+ (requires system install). + - `bcrypt` — Windows CNG (Windows only, no external dependencies). + - `""` — Crypto disabled; crypto/JWT builtins return an error at runtime. +- Added Windows CMake presets (`debug-windows`, `release-windows`, + `debug-windows-opa`, `release-windows-opa`) using the `bcrypt` backend. +- Wrapper builds (Python, Rust, .NET) now pass the crypto backend through to + CMake. Python and .NET use `bcrypt` on Windows and `mbedtls` elsewhere; + Rust uses `mbedtls` on all platforms. + +**Bug Fixes** +- Fixed `numbers.range_step` behavior to match current OPA expectations. +- Fixed `strings.count` with empty substring to return `len(s)+1` instead of + looping indefinitely, matching OPA semantics. +- Fixed `split` with empty delimiter to split into individual characters, + matching OPA semantics. +- Fixed JSON object key deduplication to use last-wins semantics, matching + Go `json.Unmarshal` and OPA behavior. +- Fixed `sprintf` `%v` format to render sets using Rego display syntax + (`{1, 2, 3}` / `set()`) instead of internal angle-bracket representation. +- Fixed `to_json`/`to_key` rendering of `true`, `false`, and `null` for + synthetically constructed AST nodes with empty locations. + +**Migration Notes** +- JSON objects with duplicate keys now keep only the last value for each key + ("last-wins" semantics), matching Go `json.Unmarshal` and OPA behavior. + Previously, duplicate keys were preserved in the AST. If your data documents + or inputs contain duplicate keys and you relied on earlier values being + visible, those values will now be silently dropped. +- `crypto.x509.parse_and_verify_certificates` follows OPA's convention: the + last certificate in the PEM bundle is treated as the leaf (workload) + certificate; all others are treated as CA or intermediate certificates. + Revocation checking (CRL/OCSP) is not performed, matching OPA behavior. + +**Build & Infrastructure** +- Upgraded Trieste dependency (switched regex engine from RE2 to TRegex). + Validated against full OPA conformance test suite including regex patterns. +- Removed RE2 from link targets across all build configurations and wrappers. +- CI: Windows PR gate job now uses `release-windows-opa` preset. +- Added test infrastructure for marking tests as `unsupported` (used for + EdDSA tests on non-OpenSSL backends). + ## 2026-01-08 - Version 1.2.0 Minor version fixing some bugs. diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fbea2fb..e46eb86e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ list( GET REGOCPP_VERSION_LIST 1 REGOCPP_VERSION_MINOR ) list( GET REGOCPP_VERSION_LIST 2 REGOCPP_VERSION_REVISION ) -set ( REGOCPP_OPA_VERSION 1.8.0 ) +set ( REGOCPP_OPA_VERSION 1.15.1 ) set( REGOCPP_VERSION ${REGOCPP_VERSION_MAJOR}.${REGOCPP_VERSION_MINOR}.${REGOCPP_VERSION_REVISION} ) @@ -72,6 +72,15 @@ option(REGOCPP_CLEAN_INSTALL "Whether the install directory should be cleaned be set(REGOCPP_SANITIZE "" CACHE STRING "Argument to pass to sanitize (disabled by default)") option(REGOCPP_USE_SNMALLOC "Whether to use snmalloc for memory allocation" ON) +set(REGOCPP_CRYPTO_BACKEND "mbedtls" CACHE STRING + "Crypto backend for crypto/JWT builtins. Options: mbedtls, openssl3, bcrypt, '' (disabled)") +set_property(CACHE REGOCPP_CRYPTO_BACKEND PROPERTY STRINGS "" "mbedtls" "openssl3" "bcrypt") + +if(REGOCPP_CRYPTO_BACKEND AND NOT REGOCPP_CRYPTO_BACKEND MATCHES "^(mbedtls|openssl3|bcrypt)$") + message(FATAL_ERROR + "Invalid REGOCPP_CRYPTO_BACKEND='${REGOCPP_CRYPTO_BACKEND}'." + " Valid values: mbedtls, openssl3, bcrypt, or '' (empty to disable crypto).") +endif() set(CMAKE_CXX_STANDARD 20) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -99,7 +108,7 @@ FetchContent_Declare( FetchContent_Declare( trieste GIT_REPOSITORY https://github.com/microsoft/trieste - GIT_TAG main + GIT_TAG 32c3069913245eb8622cb82e47ad2523c7d23eea ) FetchContent_MakeAvailable(cmake_utils) @@ -112,6 +121,24 @@ set(TRIESTE_USE_SNMALLOC ${REGOCPP_USE_SNMALLOC}) FetchContent_MakeAvailable_ExcludeFromAll(trieste) +if(REGOCPP_CRYPTO_BACKEND STREQUAL "mbedtls") + set(ENABLE_TESTING OFF CACHE BOOL "" FORCE) + set(ENABLE_PROGRAMS OFF CACHE BOOL "" FORCE) + set(MBEDTLS_AS_SUBPROJECT ON CACHE BOOL "" FORCE) + # Build mbedTLS with PIC so its static libraries can be linked into rego_shared. + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + + # Mbed TLS v3.6.2 + FetchContent_Declare( + mbedtls + GIT_REPOSITORY https://github.com/Mbed-TLS/mbedtls + GIT_TAG 0c7704b4f231fc62ad261e18d32677165a8d14d5 + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable_ExcludeFromAll(mbedtls) +endif() + find_program(CLANG_FORMAT NAMES clang-format-10 clang-format-14 clang-format-18 ) string(COMPARE EQUAL ${CLANG_FORMAT} "CLANG_FORMAT-NOTFOUND" CLANG_FORMAT_NOT_FOUND) @@ -193,7 +220,7 @@ endif() set(INSTALL_CONFIGDIR cmake) set(INSTALL_LIBDIR lib) set(INSTALL_INCLUDEDIR include) -set(REGOCPP_INSTALL_TARGETS rego trieste json yaml snmalloc re2) +set(REGOCPP_INSTALL_TARGETS rego trieste json yaml snmalloc) if(REGOCPP_BUILD_SHARED) list(APPEND REGOCPP_INSTALL_TARGETS rego_shared) @@ -203,6 +230,10 @@ if(TRIESTE_USE_SNMALLOC) list(APPEND REGOCPP_INSTALL_TARGETS snmalloc-new-override) endif() +if(REGOCPP_CRYPTO_BACKEND STREQUAL "mbedtls") + list(APPEND REGOCPP_INSTALL_TARGETS mbedtls mbedcrypto mbedx509 everest p256m) +endif() + install(TARGETS ${REGOCPP_INSTALL_TARGETS} EXPORT ${PROJECT_NAME}_Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/CMakePresets.json b/CMakePresets.json index f7dbe60d..2347ce2b 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -12,7 +12,8 @@ "CMAKE_CXX_COMPILER": "clang++", "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", - "REGOCPP_COPY_EXAMPLES": "ON" + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -27,7 +28,8 @@ "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", "REGOCPP_COPY_EXAMPLES": "ON", - "REGOCPP_OPA_TESTS": "ON" + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -39,7 +41,8 @@ "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", - "REGOCPP_COPY_EXAMPLES": "ON" + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -52,7 +55,8 @@ "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", "REGOCPP_COPY_EXAMPLES": "ON", - "REGOCPP_OPA_TESTS": "ON" + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -66,7 +70,8 @@ "CMAKE_CXX_COMPILER": "clang++", "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", - "REGOCPP_COPY_EXAMPLES": "ON" + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -81,7 +86,8 @@ "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", "REGOCPP_COPY_EXAMPLES": "ON", - "REGOCPP_OPA_TESTS": "ON" + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -93,7 +99,8 @@ "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", - "REGOCPP_COPY_EXAMPLES": "ON" + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" } }, { @@ -106,7 +113,62 @@ "REGOCPP_BUILD_TESTS": "ON", "REGOCPP_BUILD_TOOLS": "ON", "REGOCPP_COPY_EXAMPLES": "ON", - "REGOCPP_OPA_TESTS": "ON" + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "mbedtls" + } + }, + { + "name": "debug-windows", + "displayName": "Debug Build for Windows", + "description": "Sets up a debug build using the Windows bcrypt crypto backend", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", + "REGOCPP_BUILD_TESTS": "ON", + "REGOCPP_BUILD_TOOLS": "ON", + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "bcrypt" + } + }, + { + "name": "debug-windows-opa", + "displayName": "Debug Build for Windows + OPA Tests", + "description": "Sets up a debug build using the Windows bcrypt crypto backend and includes the OPA tests", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", + "REGOCPP_BUILD_TESTS": "ON", + "REGOCPP_BUILD_TOOLS": "ON", + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "bcrypt" + } + }, + { + "name": "release-windows", + "displayName": "Release Build for Windows", + "description": "Sets up a release build using the Windows bcrypt crypto backend", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", + "REGOCPP_BUILD_TESTS": "ON", + "REGOCPP_BUILD_TOOLS": "ON", + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_CRYPTO_BACKEND": "bcrypt" + } + }, + { + "name": "release-windows-opa", + "displayName": "Release Build for Windows + OPA Tests", + "description": "Sets up a release build using the Windows bcrypt crypto backend and includes the OPA tests", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/dist", + "REGOCPP_BUILD_TESTS": "ON", + "REGOCPP_BUILD_TOOLS": "ON", + "REGOCPP_COPY_EXAMPLES": "ON", + "REGOCPP_OPA_TESTS": "ON", + "REGOCPP_CRYPTO_BACKEND": "bcrypt" } } ] diff --git a/README.md b/README.md index 5649a241..68bb3bbe 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ library from different langauages. ## Language Support {#language-support} -We support v1.8.0 of Rego as defined by OPA, with the following grammar: +We support v1.14.1 of Rego as defined by OPA, with the following grammar: ```ebnf module = package { import } policy @@ -111,7 +111,9 @@ ref-arg-brack = "[" ( scalar | var | array | object | set | "_" ) "]" ref-arg-dot = "." var var = ( ALPHA | "_" ) { ALPHA | DIGIT | "_" } scalar = string | NUMBER | TRUE | FALSE | NULL -string = STRING | raw-string +string = STRING | raw-string | template-string +template-string = "$" ( '"' { CHAR-'"' | template-expr } '"' | "`" { CHAR-"`" | template-expr } "`" ) +template-expr = "{" ( ref | var | scalar | array | object | set | array-compr | object-compr | set-compr | expr-call | expr-infix | expr-parens | unary-expr ) "}" raw-string = "`" { CHAR-"`" } "`" array = "[" term { "," term } "]" object = "{" object-item { "," object-item } "}" @@ -143,22 +145,33 @@ LF Line Feed We support the majority of the standard Rego built-ins, and provide a robust mechanism for including custom built-ins (via the CPP API). The following builtins -are NOT supported at present, though some are scheduled for future releases. +are NOT supported at present: -- `providers.aws.sign_req` - Not planned -- `crypto.*` - Currently slated to be released in v1.2.0 +- `crypto.x509.parse_and_verify_certificates_with_options` - Not yet implemented (no OPA conformance tests available) - `glob.*` - Not planned - `graphql.*` - Not planned - `http.send` - Not planned - `json.match_schema`/`json.verify_schema` - Not planned -- `jwt.*` - Currently slated to be released in v1.3.0 - `net.*` - Not planned +- `providers.aws.sign_req` - Not planned - `regex.globs_match` - Not planned - `rego.metadata.chain`/`rego.metadata.rule`/`rego.parse_module` - Not planned - `strings.render_template` - Not planned - `time` - This is entirely platform dependent at the moment, depending on whether there is a compiler on that platform which supports `__cpp_lib_chrono >= 201907L`. +#### Crypto and JWT Builtins + +The `crypto.*` and `io.jwt.*` builtins require a platform crypto backend, controlled +by the `REGOCPP_CRYPTO_BACKEND` CMake option: + +- `mbedtls` (default) — Mbed TLS, built from source via FetchContent (all platforms) +- `openssl3` — OpenSSL 3.0+ (requires system install) +- `bcrypt` — Windows CNG (Windows only) +- `""` (empty) — Crypto disabled; crypto/JWT builtins return an error at runtime + +All presets enable crypto automatically. The `-windows` and `-windows-opa` presets +use the `bcrypt` backend; all others use `mbedtls`. ### Compatibility with the OPA Rego Go implementation @@ -170,6 +183,7 @@ the non-builtin specific test suites, which we clone from the To build with the OPA tests available for testing, use one of the following presets: - `release-clang-opa` - `release-opa` +- `release-windows-opa` ## Contributing diff --git a/VERSION b/VERSION index 867e5243..589268e6 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.2.0 \ No newline at end of file +1.3.0 \ No newline at end of file diff --git a/include/rego/rego.hh b/include/rego/rego.hh index 946a2db5..96719e20 100644 --- a/include/rego/rego.hh +++ b/include/rego/rego.hh @@ -60,6 +60,9 @@ namespace rego inline const auto ObjectItem = TokenDef("rego-objectitem"); inline const auto RawString = TokenDef("rego-rawstring", flag::print); inline const auto JSONString = TokenDef("rego-STRING", flag::print); + inline const auto TemplateString = TokenDef("rego-templatestring"); + inline const auto TemplateLiteral = + TokenDef("rego-templateliteral", flag::print); inline const auto Int = TokenDef("rego-INT", flag::print); inline const auto Float = TokenDef("rego-FLOAT", flag::print); inline const auto True = TokenDef("rego-true"); @@ -177,6 +180,7 @@ namespace rego | (UnaryExpr <<= Expr) | (Membership <<= ExprSeq * Expr) | (Term <<= Ref | Var | Scalar | Array | Object | Set | Membership | ArrayCompr | ObjectCompr | SetCompr) + | (TemplateString <<= Literal++) | (ArrayCompr <<= Expr * Query) | (SetCompr <<= Expr * Query) | (ObjectCompr <<= Expr * Expr * Query) @@ -191,7 +195,7 @@ namespace rego | (RefArgBrack <<= Expr | Placeholder) | (RefArgDot <<= Var) | (Scalar <<= String | Int | Float | True | False | Null) - | (String <<= JSONString | RawString) + | (String <<= JSONString | RawString | TemplateString) | (Array <<= Expr++) | (Object <<= ObjectItem++) | (ObjectItem <<= (Key >>= Expr) * (Val >>= Expr)) @@ -1077,16 +1081,27 @@ namespace rego /// - The environment variables ("env") Node version(); + /// @brief Controls how sets are rendered by to_key(). + enum class SetFormat + { + /// Angle brackets: <1, 2, 3> (internal key representation). + Angle, + /// Square brackets: [1, 2, 3] (JSON-compatible array format). + Square, + /// Curly braces / set(): {1, 2, 3} or set() (OPA Rego display format). + Rego, + }; + /// @brief Converts a node to a unique key representation that can be used for /// comparison. /// @param node The node to convert. - /// @param set_as_array Whether to represent sets as arrays. + /// @param set_format How to render set values. /// @param sort_arrays Whether to sort array elements. /// @param list_delim The delimiter to use when joining array elements. /// @return The key representation of the node. std::string to_key( const trieste::Node& node, - bool set_as_array = false, + SetFormat set_format = SetFormat::Angle, bool sort_arrays = false, const char* list_delim = ","); @@ -1571,7 +1586,7 @@ namespace rego Bundle m_bundle; BuiltIns m_builtins; - RE2 m_int_regex; + TRegex m_int_regex; size_t m_stmt_limit; }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 17ddf847..62cea964 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -48,6 +48,14 @@ builtins/units.cc builtins/uuid.cc ) +if(REGOCPP_CRYPTO_BACKEND STREQUAL "openssl3") + list(APPEND SOURCES builtins/crypto_openssl3.cc) +elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "bcrypt") + list(APPEND SOURCES builtins/crypto_bcrypt.cc) +elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "mbedtls") + list(APPEND SOURCES builtins/crypto_mbedtls.cc) +endif() + add_library(rego STATIC ${SOURCES}) add_library(regocpp::rego ALIAS rego) @@ -68,6 +76,18 @@ if(Threads_FOUND) target_link_libraries(rego PUBLIC Threads::Threads) endif() +if(REGOCPP_CRYPTO_BACKEND STREQUAL "openssl3") + find_package(OpenSSL 3.0 REQUIRED) + target_link_libraries(rego PUBLIC OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(rego PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_OPENSSL3=1) +elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "bcrypt") + target_link_libraries(rego PUBLIC bcrypt crypt32) + target_compile_definitions(rego PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_BCRYPT=1) +elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "mbedtls") + target_link_libraries(rego PUBLIC mbedtls mbedcrypto mbedx509) + target_compile_definitions(rego PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_MBEDTLS=1) +endif() + if (REGOCPP_SANITIZE) target_compile_options(rego PUBLIC -g -fsanitize=${REGOCPP_SANITIZE} -fno-omit-frame-pointer) target_link_libraries(rego PUBLIC -fsanitize=${REGOCPP_SANITIZE}) @@ -95,7 +115,6 @@ if ( REGOCPP_BUILD_SHARED ) set_property(TARGET yaml PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET json PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET snmalloc-new-override PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET re2 PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(rego_shared PUBLIC @@ -113,6 +132,17 @@ if ( REGOCPP_BUILD_SHARED ) target_link_libraries(rego_shared PUBLIC Threads::Threads) endif() + if(REGOCPP_CRYPTO_BACKEND STREQUAL "openssl3") + target_link_libraries(rego_shared PUBLIC OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(rego_shared PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_OPENSSL3=1) + elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "bcrypt") + target_link_libraries(rego_shared PUBLIC bcrypt crypt32) + target_compile_definitions(rego_shared PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_BCRYPT=1) + elseif(REGOCPP_CRYPTO_BACKEND STREQUAL "mbedtls") + target_link_libraries(rego_shared PUBLIC mbedtls mbedcrypto mbedx509) + target_compile_definitions(rego_shared PUBLIC REGOCPP_HAS_CRYPTO=1 REGOCPP_CRYPTO_MBEDTLS=1) + endif() + target_compile_features(rego_shared PUBLIC cxx_std_20) target_include_directories( rego_shared diff --git a/src/builtins/array.cc b/src/builtins/array.cc index 921da251..b6761d34 100644 --- a/src/builtins/array.cc +++ b/src/builtins/array.cc @@ -74,7 +74,7 @@ namespace const Node reverse_decl = bi::Decl << (bi::ArgSeq << (bi::Arg << (bi::Name ^ "arr") - << (bi::Description ^ "the array to be reverse") + << (bi::Description ^ "the array to reverse") << (bi::Type << (bi::DynamicArray << (bi::Type << bi::Any))))) << (bi::Result @@ -155,7 +155,7 @@ namespace const Node slice_decl = bi::Decl << (bi::ArgSeq << (bi::Arg << (bi::Name ^ "arr") - << (bi::Description ^ "the array to be reverse") + << (bi::Description ^ "the array to slice") << (bi::Type << (bi::DynamicArray << (bi::Type << bi::Any)))) << (bi::Arg << (bi::Name ^ "start") @@ -176,6 +176,51 @@ namespace << (bi::DynamicArray << (bi::Type << bi::Any)))); return BuiltInDef::create({"array.slice"}, slice_decl, slice); } + + Node flatten(const Nodes& args) + { + Node arr = unwrap_arg(args, UnwrapOpt(0).func("array.flatten").type(Array)); + if (arr->type() == Error) + { + return arr; + } + + Node result = NodeDef::create(Array); + for (auto& child : *arr) + { + auto maybe_array = unwrap(child, Array); + if (maybe_array.success) + { + for (auto& inner : *maybe_array.node) + { + result->push_back(inner->clone()); + } + } + else + { + result->push_back(child->clone()); + } + } + return result; + } + + BuiltIn flatten_factory() + { + const Node flatten_decl = + bi::Decl << (bi::ArgSeq + << (bi::Arg + << (bi::Name ^ "arr") + << (bi::Description ^ "the array to flatten") + << (bi::Type + << (bi::DynamicArray << (bi::Type << bi::Any))))) + << (bi::Result + << (bi::Name ^ "output") + << (bi::Description ^ + "the flattened array, with all nested arrays inlined") + << (bi::Type + << (bi::DynamicArray << (bi::Type << bi::Any)))); + return BuiltInDef::create({"array.flatten"}, flatten_decl, flatten); + } } namespace rego @@ -190,6 +235,10 @@ namespace rego { return concat_factory(); } + if (view == "flatten") + { + return flatten_factory(); + } if (view == "reverse") { return reverse_factory(); diff --git a/src/builtins/core.cc b/src/builtins/core.cc index 0a91ff92..01ebf3c8 100644 --- a/src/builtins/core.cc +++ b/src/builtins/core.cc @@ -1477,6 +1477,19 @@ namespace Node array = NodeDef::create(Array); std::size_t start = 0; std::size_t pos = x_str.find(delimiter_str); + if (delimiter_str.size() == 0) + { + std::string_view x_view = x_str; + while (pos < x_str.size()) + { + auto [r, s] = utf8_to_rune(x_view.substr(pos), false); + array->push_back(JSONString ^ x_str.substr(pos, s.size())); + pos += s.size(); + } + + return array; + } + while (pos != x_str.npos) { array->push_back(JSONString ^ x_str.substr(start, pos - start)); @@ -1651,7 +1664,7 @@ namespace } else { - result << json::escape(to_key(node, false, false, ", ")); + result << json::escape(to_key(node, SetFormat::Rego, false, ", ")); } break; @@ -2267,6 +2280,13 @@ namespace size_t pos = 0; size_t count = 0; + size_t size = substring_str.size(); + if (size == 0) + { + // the empty string matches at every location + return Int ^ std::to_string(search_str.size() + 1); + } + while (pos < search_str.size()) { pos = search_str.find(substring_str, pos); diff --git a/src/builtins/crypto.cc b/src/builtins/crypto.cc index bb348591..328c6ce4 100644 --- a/src/builtins/crypto.cc +++ b/src/builtins/crypto.cc @@ -1,5 +1,41 @@ #include "builtins.hh" +#ifdef REGOCPP_HAS_CRYPTO +#include "crypto_core.hh" +#endif + +#if defined(_WIN32) +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#endif + +#include + +#ifdef REGOCPP_HAS_CRYPTO +namespace rego::crypto_core +{ + void secure_erase(std::string& s) + { +#if defined(_WIN32) + SecureZeroMemory(s.data(), s.size()); +#else + // volatile prevents the optimizer from removing the memset + volatile char* p = s.data(); + for (size_t i = 0; i < s.size(); ++i) + { + p[i] = 0; + } + // Compiler barrier: ensures the volatile writes above are not + // reordered or eliminated across this point. + asm volatile("" : : "r"(s.data()) : "memory"); +#endif + s.clear(); + } +} // namespace rego::crypto_core +#endif + namespace { using namespace rego; @@ -7,6 +43,322 @@ namespace const char* Message = "Cryptography built-ins are not supported"; +#ifdef REGOCPP_HAS_CRYPTO + Node md5_impl(const Nodes& args) + { + Node x = unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.md5")); + if (x->type() == Error) + return x; + return JSONString ^ crypto_core::md5_hex(get_string(x)); + } + + Node sha1_impl(const Nodes& args) + { + Node x = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.sha1")); + if (x->type() == Error) + return x; + return JSONString ^ crypto_core::sha1_hex(get_string(x)); + } + + Node sha256_impl(const Nodes& args) + { + Node x = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.sha256")); + if (x->type() == Error) + return x; + return JSONString ^ crypto_core::sha256_hex(get_string(x)); + } + + Node hmac_equal_impl(const Nodes& args) + { + Node mac1 = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.hmac.equal")); + if (mac1->type() == Error) + return mac1; + Node mac2 = + unwrap_arg(args, UnwrapOpt(1).type(JSONString).func("crypto.hmac.equal")); + if (mac2->type() == Error) + return mac2; + if (crypto_core::hmac_equal(get_string(mac1), get_string(mac2))) + return True ^ "true"; + return False ^ "false"; + } + + Node hmac_md5_impl(const Nodes& args) + { + Node x = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.hmac.md5")); + if (x->type() == Error) + return x; + Node key = + unwrap_arg(args, UnwrapOpt(1).type(JSONString).func("crypto.hmac.md5")); + if (key->type() == Error) + return key; + return JSONString ^ + crypto_core::hmac_md5_hex(get_string(key), get_string(x)); + } + + Node hmac_sha1_impl(const Nodes& args) + { + Node x = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("crypto.hmac.sha1")); + if (x->type() == Error) + return x; + Node key = + unwrap_arg(args, UnwrapOpt(1).type(JSONString).func("crypto.hmac.sha1")); + if (key->type() == Error) + return key; + return JSONString ^ + crypto_core::hmac_sha1_hex(get_string(key), get_string(x)); + } + + Node hmac_sha256_impl(const Nodes& args) + { + Node x = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("crypto.hmac.sha256")); + if (x->type() == Error) + return x; + Node key = unwrap_arg( + args, UnwrapOpt(1).type(JSONString).func("crypto.hmac.sha256")); + if (key->type() == Error) + return key; + return JSONString ^ + crypto_core::hmac_sha256_hex(get_string(key), get_string(x)); + } + + Node hmac_sha512_impl(const Nodes& args) + { + Node x = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("crypto.hmac.sha512")); + if (x->type() == Error) + return x; + Node key = unwrap_arg( + args, UnwrapOpt(1).type(JSONString).func("crypto.hmac.sha512")); + if (key->type() == Error) + return key; + return JSONString ^ + crypto_core::hmac_sha512_hex(get_string(key), get_string(x)); + } + // Build an Array node from a dynamic vector of nodes + Node dynamic_array(const Nodes& items) + { + Node result = NodeDef::create(Array); + for (auto& item : items) + { + result << Resolver::to_term(item); + } + return result; + } + + // Convert a ParsedCertificate to a Rego object node + Node cert_to_rego_object(const crypto_core::ParsedCertificate& pc) + { + // Build Subject object + Node subject = object({object_item( + rego::string("CommonName"), rego::string(pc.subject.common_name))}); + + // DNSNames: null if empty, otherwise array of strings + Node dns; + if (pc.dns_names.empty()) + { + dns = null(); + } + else + { + Nodes dns_items; + for (auto& name : pc.dns_names) + { + dns_items.push_back(rego::string(name)); + } + dns = dynamic_array(dns_items); + } + + // URIStrings: null if empty, otherwise array of strings + Node uris; + if (pc.uri_strings.empty()) + { + uris = null(); + } + else + { + Nodes uri_items; + for (auto& uri : pc.uri_strings) + { + uri_items.push_back(rego::string(uri)); + } + uris = dynamic_array(uri_items); + } + + return object( + {object_item(rego::string("DNSNames"), dns), + object_item(rego::string("Subject"), subject), + object_item(rego::string("URIStrings"), uris)}); + } + + Node rsa_jwk_to_rego_object(const crypto_core::RSAPrivateKeyJWK& jwk) + { + return object( + {object_item(rego::string("d"), rego::string(jwk.d)), + object_item(rego::string("dp"), rego::string(jwk.dp)), + object_item(rego::string("dq"), rego::string(jwk.dq)), + object_item(rego::string("e"), rego::string(jwk.e)), + object_item(rego::string("kty"), rego::string(jwk.kty)), + object_item(rego::string("n"), rego::string(jwk.n)), + object_item(rego::string("p"), rego::string(jwk.p)), + object_item(rego::string("q"), rego::string(jwk.q)), + object_item(rego::string("qi"), rego::string(jwk.qi))}); + } + + Node parse_certificates_impl(const Nodes& args) + { + Node input = unwrap_arg( + args, + UnwrapOpt(0).type(JSONString).func("crypto.x509.parse_certificates")); + if (input->type() == Error) + return input; + + auto result = + crypto_core::parse_certificates(json::unescape(get_string(input))); + if (!result.error.empty()) + { + return err(args[0], result.error, EvalBuiltInError); + } + + Nodes cert_nodes; + for (auto& pc : result.certs) + { + cert_nodes.push_back(cert_to_rego_object(pc)); + } + return dynamic_array(cert_nodes); + } + + Node parse_and_verify_certificates_impl(const Nodes& args) + { + Node input = unwrap_arg( + args, + UnwrapOpt(0) + .type(JSONString) + .func("crypto.x509.parse_and_verify_certificates")); + if (input->type() == Error) + return input; + + auto result = crypto_core::parse_and_verify_certificates( + json::unescape(get_string(input))); + if (!result.error.empty()) + { + return err(args[0], result.error, EvalBuiltInError); + } + + Nodes cert_nodes; + for (auto& pc : result.certs) + { + cert_nodes.push_back(cert_to_rego_object(pc)); + } + return array({boolean(result.valid), dynamic_array(cert_nodes)}); + } + + Node parse_certificate_request_impl(const Nodes& args) + { + Node input = unwrap_arg( + args, + UnwrapOpt(0) + .type(JSONString) + .func("crypto.x509.parse_certificate_request")); + if (input->type() == Error) + return input; + + auto result = + crypto_core::parse_certificate_request(json::unescape(get_string(input))); + if (!result.error.empty()) + { + return err(args[0], result.error, EvalBuiltInError); + } + + Node subject = object({object_item( + rego::string("CommonName"), rego::string(result.subject.common_name))}); + return object({object_item(rego::string("Subject"), subject)}); + } + + Node parse_keypair_impl(const Nodes& args) + { + Node cert_input = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("crypto.x509.parse_keypair")); + if (cert_input->type() == Error) + return cert_input; + Node key_input = unwrap_arg( + args, UnwrapOpt(1).type(JSONString).func("crypto.x509.parse_keypair")); + if (key_input->type() == Error) + return key_input; + + // Parse certificates + auto cert_result = + crypto_core::parse_certificates(json::unescape(get_string(cert_input))); + if (!cert_result.error.empty()) + { + return err(args[0], cert_result.error, EvalBuiltInError); + } + + // Parse the private key to verify it's valid + auto key_result = + crypto_core::parse_rsa_private_key(json::unescape(get_string(key_input))); + if (!key_result.error.empty()) + { + return err(args[1], key_result.error, EvalBuiltInError); + } + + // Build Certificate array of base64 DER strings + Nodes der_items; + for (auto& pc : cert_result.certs) + { + der_items.push_back(rego::string(pc.der_b64)); + } + + return object( + {object_item(rego::string("Certificate"), dynamic_array(der_items))}); + } + + Node parse_rsa_private_key_impl(const Nodes& args) + { + Node input = unwrap_arg( + args, + UnwrapOpt(0).type(JSONString).func("crypto.x509.parse_rsa_private_key")); + if (input->type() == Error) + return input; + + auto result = + crypto_core::parse_rsa_private_key(json::unescape(get_string(input))); + if (!result.error.empty()) + { + return err(args[0], result.error, EvalBuiltInError); + } + + return rsa_jwk_to_rego_object(result.key); + } + + Node parse_private_keys_impl(const Nodes& args) + { + Node input = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("crypto.parse_private_keys")); + if (input->type() == Error) + return input; + + auto result = + crypto_core::parse_private_keys(json::unescape(get_string(input))); + if (result.is_empty_input) + { + return null(); + } + + Nodes key_nodes; + for (auto& jwk : result.keys) + { + key_nodes.push_back(rsa_jwk_to_rego_object(jwk)); + } + return dynamic_array(key_nodes); + } +#endif // REGOCPP_HAS_CRYPTO + namespace hmac { BuiltIn equal_factory() @@ -22,8 +374,13 @@ namespace << (bi::Description ^ "`true` if the MACs are equals, `false` otherwise") << (bi::Type << bi::Boolean)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.hmac.equal"}, equal_decl, Message); +#else + return BuiltInDef::create( + {"crypto.hmac.equal"}, equal_decl, hmac_equal_impl); +#endif } BuiltIn md5_factory() @@ -38,7 +395,11 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "MD5-HMAC of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"crypto.hmac.md5"}, md5_decl, Message); +#else + return BuiltInDef::create({"crypto.hmac.md5"}, md5_decl, hmac_md5_impl); +#endif } BuiltIn sha1_factory() @@ -53,7 +414,12 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "SHA1-HMAC of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"crypto.hmac.sha1"}, sha1_decl, Message); +#else + return BuiltInDef::create( + {"crypto.hmac.sha1"}, sha1_decl, hmac_sha1_impl); +#endif } BuiltIn sha256_factory() @@ -68,8 +434,13 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "SHA256-HMAC of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.hmac.sha256"}, sha256_decl, Message); +#else + return BuiltInDef::create( + {"crypto.hmac.sha256"}, sha256_decl, hmac_sha256_impl); +#endif } BuiltIn sha512_factory() @@ -84,8 +455,13 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "SHA512-HMAC of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.hmac.sha512"}, sha512_decl, Message); +#else + return BuiltInDef::create( + {"crypto.hmac.sha512"}, sha512_decl, hmac_sha512_impl); +#endif } } // namespace hmac @@ -99,7 +475,11 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "MD5-hash of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"crypto.md5"}, md5_decl, Message); +#else + return BuiltInDef::create({"crypto.md5"}, md5_decl, md5_impl); +#endif } BuiltIn parse_private_keys_factory() @@ -121,8 +501,15 @@ namespace << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.parse_private_keys"}, parse_private_keys_decl, Message); +#else + return BuiltInDef::create( + {"crypto.parse_private_keys"}, + parse_private_keys_decl, + parse_private_keys_impl); +#endif } BuiltIn sha1_factory() @@ -135,7 +522,11 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "SHA1-hash of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"crypto.sha1"}, sha1_decl, Message); +#else + return BuiltInDef::create({"crypto.sha1"}, sha1_decl, sha1_impl); +#endif } BuiltIn sha256_factory() @@ -148,7 +539,11 @@ namespace << (bi::Result << (bi::Name ^ "y") << (bi::Description ^ "SHA256-hash of `x`") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"crypto.sha256"}, sha256_decl, Message); +#else + return BuiltInDef::create({"crypto.sha256"}, sha256_decl, sha256_impl); +#endif } namespace x509 @@ -183,10 +578,17 @@ namespace << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))))))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.x509.parse_and_verify_certificates"}, parse_and_verify_certificates_decl, Message); +#else + return BuiltInDef::create( + {"crypto.x509.parse_and_verify_certificates"}, + parse_and_verify_certificates_decl, + parse_and_verify_certificates_impl); +#endif } BuiltIn parse_and_verify_certificates_with_options_factory() @@ -266,10 +668,17 @@ namespace << (bi::Type << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.x509.parse_certificate_request"}, parse_certificate_request_decl, Message); +#else + return BuiltInDef::create( + {"crypto.x509.parse_certificate_request"}, + parse_certificate_request_decl, + parse_certificate_request_impl); +#endif } BuiltIn parse_certificates_factory() @@ -291,8 +700,15 @@ namespace << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.x509.parse_certificates"}, parse_certificates_decl, Message); +#else + return BuiltInDef::create( + {"crypto.x509.parse_certificates"}, + parse_certificates_decl, + parse_certificates_impl); +#endif } BuiltIn parse_keypair_factory() @@ -317,8 +733,13 @@ namespace << (bi::Type << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.x509.parse_keypair"}, parse_keypair_decl, Message); +#else + return BuiltInDef::create( + {"crypto.x509.parse_keypair"}, parse_keypair_decl, parse_keypair_impl); +#endif } BuiltIn parse_rsa_private_key_factory() @@ -335,10 +756,17 @@ namespace << (bi::Type << (bi::DynamicObject << (bi::Type << bi::String) << (bi::Type << bi::Any)))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"crypto.x509.parse_rsa_private_key"}, parse_rsa_private_key_decl, Message); +#else + return BuiltInDef::create( + {"crypto.x509.parse_rsa_private_key"}, + parse_rsa_private_key_decl, + parse_rsa_private_key_impl); +#endif } } } diff --git a/src/builtins/crypto_bcrypt.cc b/src/builtins/crypto_bcrypt.cc new file mode 100644 index 00000000..439bc7f8 --- /dev/null +++ b/src/builtins/crypto_bcrypt.cc @@ -0,0 +1,2175 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifdef REGOCPP_CRYPTO_BCRYPT + +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX + +#include "base64/base64.h" +#include "crypto_core.hh" +#include "crypto_utils.hh" + +#include +#include +#include +#include +#include +#include +#include +#include + +// Link against Windows crypto libraries (also specified in CMakeLists.txt) +#pragma comment(lib, "bcrypt.lib") +#pragma comment(lib, "crypt32.lib") + +#include +#include +#include + +namespace +{ + using rego::crypto_core::to_hex; + + // Safe cast from size_t to ULONG, throwing on overflow. + ULONG safe_ulong(size_t n) + { + if (n > static_cast(std::numeric_limits::max())) + { + throw std::runtime_error("data too large for BCrypt API"); + } + return static_cast(n); + } + + // RAII wrapper for BCRYPT_ALG_HANDLE + struct AlgHandleDeleter + { + void operator()(BCRYPT_ALG_HANDLE h) const + { + if (h) + { + BCryptCloseAlgorithmProvider(h, 0); + } + } + }; + using AlgHandle = std::unique_ptr; + + // RAII wrapper for BCRYPT_HASH_HANDLE + struct HashHandleDeleter + { + void operator()(BCRYPT_HASH_HANDLE h) const + { + if (h) + { + BCryptDestroyHash(h); + } + } + }; + using HashHandle = std::unique_ptr; + + // RAII wrapper for BCRYPT_KEY_HANDLE + struct KeyHandleDeleter + { + void operator()(BCRYPT_KEY_HANDLE h) const + { + if (h) + { + BCryptDestroyKey(h); + } + } + }; + using KeyHandle = std::unique_ptr; + + // Helper: compute a hash digest and return hex string + std::string digest_hex(LPCWSTR algo_id, std::string_view data) + { + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, algo_id, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptOpenAlgorithmProvider failed"); + } + AlgHandle alg(alg_raw); + + DWORD hash_len = 0; + DWORD result_len = 0; + status = BCryptGetProperty( + alg.get(), + BCRYPT_HASH_LENGTH, + reinterpret_cast(&hash_len), + sizeof(hash_len), + &result_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptGetProperty(HASH_LENGTH) failed"); + } + + BCRYPT_HASH_HANDLE hash_raw = nullptr; + status = BCryptCreateHash(alg.get(), &hash_raw, nullptr, 0, nullptr, 0, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptCreateHash failed"); + } + HashHandle hash(hash_raw); + + status = BCryptHashData( + hash.get(), + reinterpret_cast(const_cast(data.data())), + safe_ulong(data.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptHashData failed"); + } + + std::vector buf(hash_len); + status = BCryptFinishHash(hash.get(), buf.data(), hash_len, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptFinishHash failed"); + } + + return to_hex(buf.data(), buf.size()); + } + + // Helper: compute HMAC and return hex string + std::string hmac_hex( + LPCWSTR algo_id, std::string_view key, std::string_view data) + { + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = BCryptOpenAlgorithmProvider( + &alg_raw, algo_id, nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptOpenAlgorithmProvider (HMAC) failed"); + } + AlgHandle alg(alg_raw); + + DWORD hash_len = 0; + DWORD result_len = 0; + status = BCryptGetProperty( + alg.get(), + BCRYPT_HASH_LENGTH, + reinterpret_cast(&hash_len), + sizeof(hash_len), + &result_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptGetProperty(HASH_LENGTH) failed"); + } + + BCRYPT_HASH_HANDLE hash_raw = nullptr; + status = BCryptCreateHash( + alg.get(), + &hash_raw, + nullptr, + 0, + reinterpret_cast(const_cast(key.data())), + safe_ulong(key.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptCreateHash (HMAC) failed"); + } + HashHandle hash(hash_raw); + + status = BCryptHashData( + hash.get(), + reinterpret_cast(const_cast(data.data())), + safe_ulong(data.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptHashData failed"); + } + + std::vector buf(hash_len); + status = BCryptFinishHash(hash.get(), buf.data(), hash_len, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptFinishHash (HMAC) failed"); + } + + return to_hex(buf.data(), buf.size()); + } + + // Helper: compute HMAC and return raw bytes (for signature verification) + std::vector hmac_raw( + LPCWSTR algo_id, std::string_view key, std::string_view data) + { + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = BCryptOpenAlgorithmProvider( + &alg_raw, algo_id, nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptOpenAlgorithmProvider (HMAC) failed"); + } + AlgHandle alg(alg_raw); + + DWORD hash_len = 0; + DWORD result_len = 0; + status = BCryptGetProperty( + alg.get(), + BCRYPT_HASH_LENGTH, + reinterpret_cast(&hash_len), + sizeof(hash_len), + &result_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptGetProperty(HASH_LENGTH) failed"); + } + + BCRYPT_HASH_HANDLE hash_raw = nullptr; + status = BCryptCreateHash( + alg.get(), + &hash_raw, + nullptr, + 0, + reinterpret_cast(const_cast(key.data())), + safe_ulong(key.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptCreateHash (HMAC) failed"); + } + HashHandle hash(hash_raw); + + status = BCryptHashData( + hash.get(), + reinterpret_cast(const_cast(data.data())), + safe_ulong(data.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptHashData failed"); + } + + std::vector buf(hash_len); + status = BCryptFinishHash(hash.get(), buf.data(), hash_len, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptFinishHash (HMAC) failed"); + } + + return buf; + } + + // Map Algorithm enum to BCrypt algorithm ID for HMAC + LPCWSTR hmac_algo_id(rego::crypto_core::Algorithm algo) + { + using rego::crypto_core::Algorithm; + switch (algo) + { + case Algorithm::HS256: + return BCRYPT_SHA256_ALGORITHM; + case Algorithm::HS384: + return BCRYPT_SHA384_ALGORITHM; + case Algorithm::HS512: + return BCRYPT_SHA512_ALGORITHM; + default: + return nullptr; + } + } + + bool verify_hmac( + rego::crypto_core::Algorithm algo, + std::string_view signing_input, + std::string_view sig_bytes, + std::string_view secret) + { + LPCWSTR algo_id = hmac_algo_id(algo); + if (!algo_id) + { + return false; + } + + auto computed = hmac_raw(algo_id, secret, signing_input); + + // Constant-time comparison + if (computed.size() != sig_bytes.size()) + { + return false; + } + volatile unsigned char result = 0; + for (size_t i = 0; i < computed.size(); ++i) + { + result |= computed[i] ^ static_cast(sig_bytes[i]); + } + return result == 0; + } + + // ── RAII wrapper for PCCERT_CONTEXT ── + struct CertContextDeleter + { + void operator()(PCCERT_CONTEXT ctx) const + { + if (ctx) + { + CertFreeCertificateContext(ctx); + } + } + }; + using CertContextHandle = + std::unique_ptr; + + using rego::crypto_core::json_select_string; + using rego::crypto_core::parse_json; + + // ── PEM validation (error messages must match OPA exactly) ── + + // ── BCrypt hash algorithm ID for signature verification ── + + LPCWSTR hash_algo_id(rego::crypto_core::Algorithm algo) + { + using rego::crypto_core::Algorithm; + switch (algo) + { + case Algorithm::RS256: + case Algorithm::PS256: + case Algorithm::ES256: + case Algorithm::HS256: + return BCRYPT_SHA256_ALGORITHM; + case Algorithm::RS384: + case Algorithm::PS384: + case Algorithm::ES384: + case Algorithm::HS384: + return BCRYPT_SHA384_ALGORITHM; + case Algorithm::RS512: + case Algorithm::PS512: + case Algorithm::ES512: + case Algorithm::HS512: + return BCRYPT_SHA512_ALGORITHM; + case Algorithm::EdDSA: + return nullptr; + } + return nullptr; + } + + // ── Hash data for signature verification ── + + std::vector compute_hash( + LPCWSTR algo_id, std::string_view data) + { + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, algo_id, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptOpenAlgorithmProvider failed"); + } + AlgHandle alg(alg_raw); + + DWORD hash_len = 0; + DWORD result_len = 0; + status = BCryptGetProperty( + alg.get(), + BCRYPT_HASH_LENGTH, + reinterpret_cast(&hash_len), + sizeof(hash_len), + &result_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptGetProperty(HASH_LENGTH) failed"); + } + + BCRYPT_HASH_HANDLE hash_raw = nullptr; + status = BCryptCreateHash(alg.get(), &hash_raw, nullptr, 0, nullptr, 0, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptCreateHash failed"); + } + HashHandle hash(hash_raw); + + status = BCryptHashData( + hash.get(), + reinterpret_cast(const_cast(data.data())), + safe_ulong(data.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptHashData failed"); + } + + std::vector buf(hash_len); + status = BCryptFinishHash(hash.get(), buf.data(), hash_len, 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptFinishHash failed"); + } + + return buf; + } + + // ── Key loading helpers ── + + // Import a public key from a PEM certificate via WinCrypt + KeyHandle key_from_pem_cert(std::string_view pem) + { + using rego::crypto_core::extract_pem_der_blocks; + auto der_blocks = extract_pem_der_blocks(pem, "CERTIFICATE"); + if (der_blocks.empty()) + { + return KeyHandle(nullptr); + } + auto& der = der_blocks[0]; + + PCCERT_CONTEXT cert_ctx = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + reinterpret_cast(der.data()), + static_cast(der.size())); + if (!cert_ctx) + { + return KeyHandle(nullptr); + } + CertContextHandle cert_guard(cert_ctx); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + if (!CryptImportPublicKeyInfoEx2( + X509_ASN_ENCODING, + &cert_ctx->pCertInfo->SubjectPublicKeyInfo, + 0, + nullptr, + &key_raw)) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Import a public key from a PEM SPKI (SubjectPublicKeyInfo) block + KeyHandle key_from_pem_pubkey(std::string_view pem) + { + using rego::crypto_core::extract_pem_der_blocks; + auto der_blocks = extract_pem_der_blocks(pem, "PUBLIC KEY"); + if (der_blocks.empty()) + { + return KeyHandle(nullptr); + } + auto& der = der_blocks[0]; + + // Decode the DER SPKI into a CERT_PUBLIC_KEY_INFO + CERT_PUBLIC_KEY_INFO* pub_info = nullptr; + DWORD pub_info_size = 0; + if (!CryptDecodeObjectEx( + X509_ASN_ENCODING, + X509_PUBLIC_KEY_INFO, + reinterpret_cast(der.data()), + static_cast(der.size()), + CRYPT_DECODE_ALLOC_FLAG, + nullptr, + &pub_info, + &pub_info_size)) + { + return KeyHandle(nullptr); + } + + BCRYPT_KEY_HANDLE key_raw = nullptr; + BOOL ok = CryptImportPublicKeyInfoEx2( + X509_ASN_ENCODING, pub_info, 0, nullptr, &key_raw); + LocalFree(pub_info); + + if (!ok) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Import an RSA public key from JWK n/e components. + // Constructs a BCRYPT_RSAPUBLIC_BLOB manually. + KeyHandle key_from_jwk_rsa(std::string_view n_b64, std::string_view e_b64) + { + std::string n_raw = ::base64_decode(n_b64); + std::string e_raw = ::base64_decode(e_b64); + if (n_raw.empty() || e_raw.empty()) + { + return KeyHandle(nullptr); + } + + // BCRYPT_RSAPUBLIC_BLOB layout: + // BCRYPT_RSAKEY_BLOB header + // PublicExponent[cbPublicExp] + // Modulus[cbModulus] + DWORD cbPublicExp = static_cast(e_raw.size()); + DWORD cbModulus = static_cast(n_raw.size()); + + std::vector blob( + sizeof(BCRYPT_RSAKEY_BLOB) + cbPublicExp + cbModulus); + auto* header = reinterpret_cast(blob.data()); + header->Magic = BCRYPT_RSAPUBLIC_MAGIC; + header->BitLength = cbModulus * 8; + header->cbPublicExp = cbPublicExp; + header->cbModulus = cbModulus; + header->cbPrime1 = 0; + header->cbPrime2 = 0; + + BYTE* ptr = blob.data() + sizeof(BCRYPT_RSAKEY_BLOB); + std::memcpy(ptr, e_raw.data(), cbPublicExp); + ptr += cbPublicExp; + std::memcpy(ptr, n_raw.data(), cbModulus); + + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, BCRYPT_RSA_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + AlgHandle alg(alg_raw); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + status = BCryptImportKeyPair( + alg.get(), + nullptr, + BCRYPT_RSAPUBLIC_BLOB, + &key_raw, + blob.data(), + safe_ulong(blob.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Import an EC public key from JWK x/y/crv components. + // Constructs a BCRYPT_ECCPUBLIC_BLOB manually. + KeyHandle key_from_jwk_ec( + std::string_view crv, std::string_view x_b64, std::string_view y_b64) + { + std::string x_raw = ::base64_decode(x_b64); + std::string y_raw = ::base64_decode(y_b64); + if (x_raw.empty() || y_raw.empty()) + { + return KeyHandle(nullptr); + } + + // Determine the BCrypt algorithm and expected key size + LPCWSTR algo_id = nullptr; + ULONG magic = 0; + DWORD key_size = 0; + if (crv == "P-256") + { + algo_id = BCRYPT_ECDSA_P256_ALGORITHM; + magic = BCRYPT_ECDSA_PUBLIC_P256_MAGIC; + key_size = 32; + } + else if (crv == "P-384") + { + algo_id = BCRYPT_ECDSA_P384_ALGORITHM; + magic = BCRYPT_ECDSA_PUBLIC_P384_MAGIC; + key_size = 48; + } + else if (crv == "P-521") + { + algo_id = BCRYPT_ECDSA_P521_ALGORITHM; + magic = BCRYPT_ECDSA_PUBLIC_P521_MAGIC; + key_size = 66; + } + else + { + return KeyHandle(nullptr); + } + + // BCRYPT_ECCPUBLIC_BLOB layout: + // BCRYPT_ECCKEY_BLOB header + // X[cbKey] + // Y[cbKey] + std::vector blob(sizeof(BCRYPT_ECCKEY_BLOB) + key_size * 2); + auto* header = reinterpret_cast(blob.data()); + header->dwMagic = magic; + header->cbKey = key_size; + + // Pad or truncate x/y to exactly key_size bytes (big-endian, left-padded) + BYTE* x_dest = blob.data() + sizeof(BCRYPT_ECCKEY_BLOB); + BYTE* y_dest = x_dest + key_size; + std::memset(x_dest, 0, key_size); + std::memset(y_dest, 0, key_size); + + if (x_raw.size() <= key_size) + { + std::memcpy( + x_dest + (key_size - x_raw.size()), x_raw.data(), x_raw.size()); + } + else + { + std::memcpy(x_dest, x_raw.data() + (x_raw.size() - key_size), key_size); + } + + if (y_raw.size() <= key_size) + { + std::memcpy( + y_dest + (key_size - y_raw.size()), y_raw.data(), y_raw.size()); + } + else + { + std::memcpy(y_dest, y_raw.data() + (y_raw.size() - key_size), key_size); + } + + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, algo_id, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + AlgHandle alg(alg_raw); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + status = BCryptImportKeyPair( + alg.get(), + nullptr, + BCRYPT_ECCPUBLIC_BLOB, + &key_raw, + blob.data(), + safe_ulong(blob.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Parse a JWK JSON AST into a BCRYPT_KEY_HANDLE (public key) + KeyHandle key_from_jwk_ast(const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + if ( + n.empty() || e.empty() || n.size() > MaxRSAComponentB64Len || + e.size() > MaxRSAComponentB64Len) + { + return KeyHandle(nullptr); + } + return key_from_jwk_rsa(n, e); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + if ( + crv.empty() || x.empty() || y.empty() || + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len) + { + return KeyHandle(nullptr); + } + return key_from_jwk_ec(crv, x, y); + } + // OKP (Ed25519) is not supported by Windows CNG for signing/verification. + // Windows BCrypt supports X25519 for ECDH key exchange only. + // Ed25519 keys will return nullptr, causing verify_signature to return + // {false, ""} (verification failure, not an error). + return KeyHandle(nullptr); + } + + using rego::crypto_core::extract_jwks_keys; + + // Auto-detect key format and load BCRYPT_KEY_HANDLE. + // Handles: PEM cert, PEM pubkey, JWK object, JWKS set. + KeyHandle load_public_key(std::string_view key_or_cert) + { + // PEM certificate + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != std::string_view::npos) + { + return key_from_pem_cert(key_or_cert); + } + + // PEM public key + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != std::string_view::npos) + { + return key_from_pem_pubkey(key_or_cert); + } + + // Try JSON (JWK or JWKS) + auto ast = parse_json(key_or_cert); + if (!ast) + { + return KeyHandle(nullptr); + } + + // Try JWKS first + auto keys = extract_jwks_keys(ast); + if (!keys.empty()) + { + // Use first key (single-key path; multi-key handled by + // verify_signature_any_key) + return key_from_jwk_ast(keys[0]); + } + + // Try single JWK + std::string_view kty = json_select_string(ast, "/kty"); + if (!kty.empty()) + { + return key_from_jwk_ast(ast); + } + + return KeyHandle(nullptr); + } + + // ── RSA signature verification ── + + bool verify_rsa_pkcs1( + BCRYPT_KEY_HANDLE key, + LPCWSTR hash_algo, + std::string_view signing_input, + std::string_view sig_bytes) + { + auto hash_value = compute_hash(hash_algo, signing_input); + + BCRYPT_PKCS1_PADDING_INFO padding_info; + padding_info.pszAlgId = hash_algo; + + NTSTATUS status = BCryptVerifySignature( + key, + &padding_info, + hash_value.data(), + safe_ulong(hash_value.size()), + reinterpret_cast(const_cast(sig_bytes.data())), + safe_ulong(sig_bytes.size()), + BCRYPT_PAD_PKCS1); + + return BCRYPT_SUCCESS(status); + } + + bool verify_rsa_pss( + BCRYPT_KEY_HANDLE key, + LPCWSTR hash_algo, + std::string_view signing_input, + std::string_view sig_bytes) + { + auto hash_value = compute_hash(hash_algo, signing_input); + + BCRYPT_PSS_PADDING_INFO padding_info; + padding_info.pszAlgId = hash_algo; + // Salt length = hash length (matches OpenSSL RSA_PSS_SALTLEN_AUTO + // behavior for verification) + padding_info.cbSalt = safe_ulong(hash_value.size()); + + NTSTATUS status = BCryptVerifySignature( + key, + &padding_info, + hash_value.data(), + safe_ulong(hash_value.size()), + reinterpret_cast(const_cast(sig_bytes.data())), + safe_ulong(sig_bytes.size()), + BCRYPT_PAD_PSS); + + return BCRYPT_SUCCESS(status); + } + + // ── ECDSA signature verification ── + // JWT ECDSA signatures are raw r||s — BCrypt also expects raw r||s, + // so no DER conversion is needed (unlike OpenSSL). + + bool verify_ecdsa( + BCRYPT_KEY_HANDLE key, + LPCWSTR hash_algo_id, + std::string_view signing_input, + std::string_view sig_bytes) + { + auto hash_value = compute_hash(hash_algo_id, signing_input); + + NTSTATUS status = BCryptVerifySignature( + key, + nullptr, + hash_value.data(), + safe_ulong(hash_value.size()), + reinterpret_cast(const_cast(sig_bytes.data())), + safe_ulong(sig_bytes.size()), + 0); + + return BCRYPT_SUCCESS(status); + } + +} + +namespace rego::crypto_core +{ + // ── Hashing ── + + std::string md5_hex(std::string_view data) + { + return digest_hex(BCRYPT_MD5_ALGORITHM, data); + } + + std::string sha1_hex(std::string_view data) + { + return digest_hex(BCRYPT_SHA1_ALGORITHM, data); + } + + std::string sha256_hex(std::string_view data) + { + return digest_hex(BCRYPT_SHA256_ALGORITHM, data); + } + + // ── HMAC ── + + std::string hmac_md5_hex(std::string_view key, std::string_view data) + { + return hmac_hex(BCRYPT_MD5_ALGORITHM, key, data); + } + + std::string hmac_sha1_hex(std::string_view key, std::string_view data) + { + return hmac_hex(BCRYPT_SHA1_ALGORITHM, key, data); + } + + std::string hmac_sha256_hex(std::string_view key, std::string_view data) + { + return hmac_hex(BCRYPT_SHA256_ALGORITHM, key, data); + } + + std::string hmac_sha512_hex(std::string_view key, std::string_view data) + { + return hmac_hex(BCRYPT_SHA512_ALGORITHM, key, data); + } + + bool hmac_equal(std::string_view mac1, std::string_view mac2) + { + return hmac_equal_impl(mac1, mac2); + } + + // ── Base64url ── + + std::string base64url_encode_nopad(std::string_view data) + { + return base64url_encode_nopad_impl(data); + } + + std::string base64url_decode(std::string_view data) + { + return base64url_decode_impl(data); + } + + // ── Algorithm parsing ── + + Algorithm parse_algorithm(std::string_view name) + { + return parse_algorithm_impl(name); + } + + // ── Signature Verification ── + + VerifyResult verify_signature( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + // EdDSA (Ed25519) is not supported by Windows CNG. + if (algo == Algorithm::EdDSA) + { + return {false, "EdDSA algorithm is not supported"}; + } + + // HMAC algorithms use the key directly as a secret + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + bool ok = verify_hmac(algo, signing_input, signature_bytes, key_or_cert); + return {ok, {}}; + } + + // Validate PEM structure before attempting to parse + std::string pem_err = validate_pem(key_or_cert); + if (!pem_err.empty()) + { + return {false, pem_err}; + } + + // Asymmetric: load the public key + KeyHandle pkey = load_public_key(key_or_cert); + if (!pkey) + { + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != + std::string_view::npos) + { + return {false, "failed to parse a PEM certificate"}; + } + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != + std::string_view::npos) + { + return {false, "failed to parse a PEM key"}; + } + if ( + key_or_cert.find("\"kty\"") != std::string_view::npos || + key_or_cert.find("\"keys\"") != std::string_view::npos) + { + return {false, "failed to parse a JWK key (set)"}; + } + return {false, {}}; + } + + LPCWSTR halgo = hash_algo_id(algo); + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = + verify_rsa_pkcs1(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = verify_rsa_pss(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::EdDSA: + // Ed25519 not supported by Windows CNG; key import will have + // already failed (pkey == nullptr), so this is unreachable. + break; + + default: + break; + } + + return {ok, {}}; + } + + VerifyResult verify_signature_any_key( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + // Parse to check if this is a JWKS with multiple keys + auto ast = parse_json(key_or_cert); + if (!ast) + { + // Not JSON — use normal path (may be PEM) + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + auto keys = extract_jwks_keys(ast); + if (keys.size() <= 1) + { + // Not a JWKS or single key — use normal path + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + // HMAC doesn't use JWKS — fall through to normal path + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + LPCWSTR halgo = hash_algo_id(algo); + + // Try each key; return success on first valid signature + for (auto& key_ast : keys) + { + auto pkey = key_from_jwk_ast(key_ast); + if (!pkey) + { + continue; + } + + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = + verify_rsa_pkcs1(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = + verify_rsa_pss(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(pkey.get(), halgo, signing_input, signature_bytes); + break; + + case Algorithm::EdDSA: + // Not supported by Windows CNG + break; + + default: + break; + } + + if (ok) + { + return {true, {}}; + } + } + + return {false, {}}; + } + + // ── Signing ── + + // Import an RSA private key from JWK components. + // Constructs a BCRYPT_RSAFULLPRIVATE_BLOB. + // Layout: header | e | n | p | q | dp | dq | qi | d + KeyHandle key_from_jwk_rsa_private( + std::string_view n_b64, + std::string_view e_b64, + std::string_view d_b64, + std::string_view p_b64, + std::string_view q_b64, + std::string_view dp_b64, + std::string_view dq_b64, + std::string_view qi_b64) + { + std::string n_raw = ::base64_decode(n_b64); + std::string e_raw = ::base64_decode(e_b64); + crypto_core::SecureString d_raw(::base64_decode(d_b64)); + crypto_core::SecureString p_raw(::base64_decode(p_b64)); + crypto_core::SecureString q_raw(::base64_decode(q_b64)); + crypto_core::SecureString dp_raw(::base64_decode(dp_b64)); + crypto_core::SecureString dq_raw(::base64_decode(dq_b64)); + crypto_core::SecureString qi_raw(::base64_decode(qi_b64)); + + if (n_raw.empty() || e_raw.empty() || d_raw.empty()) + { + return KeyHandle(nullptr); + } + + DWORD cbPublicExp = static_cast(e_raw.size()); + DWORD cbModulus = static_cast(n_raw.size()); + DWORD cbPrime1 = static_cast(p_raw.size()); + DWORD cbPrime2 = static_cast(q_raw.size()); + + // BCRYPT_RSAFULLPRIVATE_BLOB: + // BCRYPT_RSAKEY_BLOB header + // e[cbPublicExp] | n[cbModulus] | p[cbPrime1] | q[cbPrime2] + // dp[cbPrime1] | dq[cbPrime2] | qi[cbPrime1] | d[cbModulus] + size_t blob_size = sizeof(BCRYPT_RSAKEY_BLOB) + cbPublicExp + cbModulus + + cbPrime1 + cbPrime2 + cbPrime1 + cbPrime2 + cbPrime1 + cbModulus; + + std::vector blob(blob_size, 0); + auto* header = reinterpret_cast(blob.data()); + header->Magic = BCRYPT_RSAFULLPRIVATE_MAGIC; + header->BitLength = cbModulus * 8; + header->cbPublicExp = cbPublicExp; + header->cbModulus = cbModulus; + header->cbPrime1 = cbPrime1; + header->cbPrime2 = cbPrime2; + + BYTE* ptr = blob.data() + sizeof(BCRYPT_RSAKEY_BLOB); + std::memcpy(ptr, e_raw.data(), cbPublicExp); + ptr += cbPublicExp; + std::memcpy(ptr, n_raw.data(), cbModulus); + ptr += cbModulus; + std::memcpy(ptr, p_raw.data(), cbPrime1); + ptr += cbPrime1; + std::memcpy(ptr, q_raw.data(), cbPrime2); + ptr += cbPrime2; + // dp (exponent1) — padded to cbPrime1 + if (dp_raw.size() <= cbPrime1) + { + std::memcpy( + ptr + (cbPrime1 - dp_raw.size()), dp_raw.data(), dp_raw.size()); + } + ptr += cbPrime1; + // dq (exponent2) — padded to cbPrime2 + if (dq_raw.size() <= cbPrime2) + { + std::memcpy( + ptr + (cbPrime2 - dq_raw.size()), dq_raw.data(), dq_raw.size()); + } + ptr += cbPrime2; + // qi (coefficient) — padded to cbPrime1 + if (qi_raw.size() <= cbPrime1) + { + std::memcpy( + ptr + (cbPrime1 - qi_raw.size()), qi_raw.data(), qi_raw.size()); + } + ptr += cbPrime1; + // d (private exponent) — padded to cbModulus + if (d_raw.size() <= cbModulus) + { + std::memcpy(ptr + (cbModulus - d_raw.size()), d_raw.data(), d_raw.size()); + } + + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, BCRYPT_RSA_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + AlgHandle alg(alg_raw); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + status = BCryptImportKeyPair( + alg.get(), + nullptr, + BCRYPT_RSAFULLPRIVATE_BLOB, + &key_raw, + blob.data(), + safe_ulong(blob.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Import an EC private key from JWK components. + // Constructs a BCRYPT_ECCPRIVATE_BLOB. + // Layout: header | X[cbKey] | Y[cbKey] | d[cbKey] + KeyHandle key_from_jwk_ec_private( + std::string_view crv, + std::string_view x_b64, + std::string_view y_b64, + std::string_view d_b64) + { + std::string x_raw = ::base64_decode(x_b64); + std::string y_raw = ::base64_decode(y_b64); + crypto_core::SecureString d_raw(::base64_decode(d_b64)); + if (x_raw.empty() || y_raw.empty() || d_raw.empty()) + { + return KeyHandle(nullptr); + } + + LPCWSTR algo_id = nullptr; + ULONG magic = 0; + DWORD key_size = 0; + if (crv == "P-256") + { + algo_id = BCRYPT_ECDSA_P256_ALGORITHM; + magic = BCRYPT_ECDSA_PRIVATE_P256_MAGIC; + key_size = 32; + } + else if (crv == "P-384") + { + algo_id = BCRYPT_ECDSA_P384_ALGORITHM; + magic = BCRYPT_ECDSA_PRIVATE_P384_MAGIC; + key_size = 48; + } + else if (crv == "P-521") + { + algo_id = BCRYPT_ECDSA_P521_ALGORITHM; + magic = BCRYPT_ECDSA_PRIVATE_P521_MAGIC; + key_size = 66; + } + else + { + return KeyHandle(nullptr); + } + + // BCRYPT_ECCPRIVATE_BLOB layout: + // BCRYPT_ECCKEY_BLOB header + // X[cbKey] | Y[cbKey] | d[cbKey] + std::vector blob(sizeof(BCRYPT_ECCKEY_BLOB) + key_size * 3, 0); + auto* header = reinterpret_cast(blob.data()); + header->dwMagic = magic; + header->cbKey = key_size; + + BYTE* x_dest = blob.data() + sizeof(BCRYPT_ECCKEY_BLOB); + BYTE* y_dest = x_dest + key_size; + BYTE* d_dest = y_dest + key_size; + + // Left-pad each component to key_size + if (x_raw.size() <= key_size) + { + std::memcpy( + x_dest + (key_size - x_raw.size()), x_raw.data(), x_raw.size()); + } + else + { + std::memcpy(x_dest, x_raw.data() + (x_raw.size() - key_size), key_size); + } + + if (y_raw.size() <= key_size) + { + std::memcpy( + y_dest + (key_size - y_raw.size()), y_raw.data(), y_raw.size()); + } + else + { + std::memcpy(y_dest, y_raw.data() + (y_raw.size() - key_size), key_size); + } + + if (d_raw.size() <= key_size) + { + std::memcpy( + d_dest + (key_size - d_raw.size()), d_raw.data(), d_raw.size()); + } + else + { + std::memcpy(d_dest, d_raw.data() + (d_raw.size() - key_size), key_size); + } + + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, algo_id, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + AlgHandle alg(alg_raw); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + status = BCryptImportKeyPair( + alg.get(), + nullptr, + BCRYPT_ECCPRIVATE_BLOB, + &key_raw, + blob.data(), + safe_ulong(blob.size()), + 0); + if (!BCRYPT_SUCCESS(status)) + { + return KeyHandle(nullptr); + } + return KeyHandle(key_raw); + } + + // Load a private key from a JWK JSON AST + KeyHandle load_private_key_ast(const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + std::string_view d = json_select_string(ast, "/d"); + std::string_view p = json_select_string(ast, "/p"); + std::string_view q = json_select_string(ast, "/q"); + std::string_view dp = json_select_string(ast, "/dp"); + std::string_view dq = json_select_string(ast, "/dq"); + std::string_view qi = json_select_string(ast, "/qi"); + if (n.empty() || e.empty() || d.empty()) + { + return KeyHandle(nullptr); + } + for (auto sv : {n, e, d, p, q, dp, dq, qi}) + { + if (sv.size() > MaxRSAComponentB64Len) + { + return KeyHandle(nullptr); + } + } + return key_from_jwk_rsa_private(n, e, d, p, q, dp, dq, qi); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + std::string_view d = json_select_string(ast, "/d"); + if (crv.empty() || x.empty() || y.empty() || d.empty()) + { + return KeyHandle(nullptr); + } + if ( + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len || + d.size() > MaxECComponentB64Len) + { + return KeyHandle(nullptr); + } + return key_from_jwk_ec_private(crv, x, y, d); + } + // OKP (Ed25519) not supported by Windows CNG + return KeyHandle(nullptr); + } + + // Sign a hash with RSA (PKCS#1 v1.5 or PSS) + std::string sign_rsa( + BCRYPT_KEY_HANDLE key, + LPCWSTR hash_algo, + std::string_view signing_input, + bool use_pss) + { + auto hash_value = compute_hash(hash_algo, signing_input); + + DWORD sig_len = 0; + NTSTATUS status; + + if (use_pss) + { + BCRYPT_PSS_PADDING_INFO pss_info; + pss_info.pszAlgId = hash_algo; + pss_info.cbSalt = safe_ulong(hash_value.size()); + + status = BCryptSignHash( + key, + &pss_info, + hash_value.data(), + safe_ulong(hash_value.size()), + nullptr, + 0, + &sig_len, + BCRYPT_PAD_PSS); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (PSS length) failed"); + } + + std::vector sig(sig_len); + status = BCryptSignHash( + key, + &pss_info, + hash_value.data(), + safe_ulong(hash_value.size()), + sig.data(), + sig_len, + &sig_len, + BCRYPT_PAD_PSS); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (PSS) failed"); + } + return std::string(reinterpret_cast(sig.data()), sig_len); + } + + // PKCS#1 v1.5 + BCRYPT_PKCS1_PADDING_INFO pkcs1_info; + pkcs1_info.pszAlgId = hash_algo; + + status = BCryptSignHash( + key, + &pkcs1_info, + hash_value.data(), + safe_ulong(hash_value.size()), + nullptr, + 0, + &sig_len, + BCRYPT_PAD_PKCS1); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (PKCS1 length) failed"); + } + + std::vector sig(sig_len); + status = BCryptSignHash( + key, + &pkcs1_info, + hash_value.data(), + safe_ulong(hash_value.size()), + sig.data(), + sig_len, + &sig_len, + BCRYPT_PAD_PKCS1); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (PKCS1) failed"); + } + return std::string(reinterpret_cast(sig.data()), sig_len); + } + + // Sign with ECDSA — BCrypt produces raw r||s directly (JWT format) + std::string sign_ecdsa( + BCRYPT_KEY_HANDLE key, LPCWSTR hash_algo, std::string_view signing_input) + { + auto hash_value = compute_hash(hash_algo, signing_input); + + DWORD sig_len = 0; + NTSTATUS status = BCryptSignHash( + key, + nullptr, + hash_value.data(), + safe_ulong(hash_value.size()), + nullptr, + 0, + &sig_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (ECDSA length) failed"); + } + + std::vector sig(sig_len); + status = BCryptSignHash( + key, + nullptr, + hash_value.data(), + safe_ulong(hash_value.size()), + sig.data(), + sig_len, + &sig_len, + 0); + if (!BCRYPT_SUCCESS(status)) + { + throw std::runtime_error("BCryptSignHash (ECDSA) failed"); + } + return std::string(reinterpret_cast(sig.data()), sig_len); + } + + std::string sign( + Algorithm algo, + std::string_view signing_input, + std::string_view key_jwk_json) + { + auto ast = parse_json(key_jwk_json); + if (!ast) + { + throw std::runtime_error("failed to parse JWK JSON"); + } + + // HMAC: extract the "k" field (base64url-encoded secret) + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + std::string_view k = json_select_string(ast, "/k"); + if (k.empty()) + { + throw std::runtime_error("missing 'k' in oct JWK"); + } + crypto_core::SecureString secret(::base64_decode(k)); + LPCWSTR halgo = hmac_algo_id(algo); + auto result = hmac_raw(halgo, secret.value, signing_input); + return std::string(reinterpret_cast(result.data()), result.size()); + } + + // Asymmetric: load private key + KeyHandle pkey = load_private_key_ast(ast); + if (!pkey) + { + throw std::runtime_error("failed to load private key from JWK"); + } + + LPCWSTR halgo = hash_algo_id(algo); + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + return sign_rsa(pkey.get(), halgo, signing_input, false); + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + return sign_rsa(pkey.get(), halgo, signing_input, true); + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + return sign_ecdsa(pkey.get(), halgo, signing_input); + + case Algorithm::EdDSA: + throw std::runtime_error("EdDSA algorithm is not supported"); + + default: + throw std::runtime_error("unsupported algorithm for signing"); + } + } + + // ── X.509 Certificate Parsing (WinCrypt/crypt32) ── + + // Extract CommonName from a certificate context + std::string get_common_name(PCCERT_CONTEXT cert_ctx) + { + if (!cert_ctx) + { + return {}; + } + // Get the length first + DWORD len = CertGetNameStringA( + cert_ctx, CERT_NAME_ATTR_TYPE, 0, (void*)szOID_COMMON_NAME, nullptr, 0); + if (len <= 1) + { + return {}; + } + std::string name(len - 1, '\0'); + CertGetNameStringA( + cert_ctx, + CERT_NAME_ATTR_TYPE, + 0, + (void*)szOID_COMMON_NAME, + name.data(), + len); + return name; + } + + // Extract DNS names and URI strings from certificate SANs + void extract_sans( + PCCERT_CONTEXT cert_ctx, + std::vector& dns_names, + std::vector& uri_strings) + { + if (!cert_ctx) + { + return; + } + PCERT_EXTENSION san_ext = CertFindExtension( + szOID_SUBJECT_ALT_NAME2, + cert_ctx->pCertInfo->cExtension, + cert_ctx->pCertInfo->rgExtension); + if (!san_ext) + { + return; + } + + PCERT_ALT_NAME_INFO san_info = nullptr; + DWORD info_size = 0; + if (!CryptDecodeObjectEx( + X509_ASN_ENCODING, + szOID_SUBJECT_ALT_NAME2, + san_ext->Value.pbData, + san_ext->Value.cbData, + CRYPT_DECODE_ALLOC_FLAG, + nullptr, + &san_info, + &info_size)) + { + return; + } + if (!san_info) + { + return; + } + + for (DWORD i = 0; i < san_info->cAltEntry; ++i) + { + CERT_ALT_NAME_ENTRY& entry = san_info->rgAltEntry[i]; + if (entry.dwAltNameChoice == CERT_ALT_NAME_DNS_NAME && entry.pwszDNSName) + { + // Convert wide string to UTF-8 + int needed = WideCharToMultiByte( + CP_UTF8, 0, entry.pwszDNSName, -1, nullptr, 0, nullptr, nullptr); + if (needed > 1) + { + std::string s(needed - 1, '\0'); + WideCharToMultiByte( + CP_UTF8, + 0, + entry.pwszDNSName, + -1, + s.data(), + needed, + nullptr, + nullptr); + dns_names.push_back(std::move(s)); + } + } + else if (entry.dwAltNameChoice == CERT_ALT_NAME_URL && entry.pwszURL) + { + int needed = WideCharToMultiByte( + CP_UTF8, 0, entry.pwszURL, -1, nullptr, 0, nullptr, nullptr); + if (needed > 1) + { + std::string s(needed - 1, '\0'); + WideCharToMultiByte( + CP_UTF8, 0, entry.pwszURL, -1, s.data(), needed, nullptr, nullptr); + uri_strings.push_back(std::move(s)); + } + } + } + LocalFree(san_info); + } + + // Convert a certificate context to base64-encoded DER + std::string cert_to_der_b64(PCCERT_CONTEXT cert_ctx) + { + if (!cert_ctx || !cert_ctx->pbCertEncoded || cert_ctx->cbCertEncoded == 0) + { + return {}; + } + std::string_view der( + reinterpret_cast(cert_ctx->pbCertEncoded), + cert_ctx->cbCertEncoded); + return ::base64_encode(der, false); + } + + // Parse a CERT_CONTEXT into a ParsedCertificate + ParsedCertificate cert_to_parsed(PCCERT_CONTEXT cert_ctx) + { + ParsedCertificate pc; + pc.subject.common_name = get_common_name(cert_ctx); + extract_sans(cert_ctx, pc.dns_names, pc.uri_strings); + pc.der_b64 = cert_to_der_b64(cert_ctx); + return pc; + } + + // Parse PEM data into CERT_CONTEXT objects by extracting DER from PEM blocks + std::vector parse_pem_certs(std::string_view pem_data) + { + std::vector certs; + auto der_blocks = extract_pem_der_blocks(pem_data, "CERTIFICATE"); + + for (auto& der : der_blocks) + { + PCCERT_CONTEXT ctx = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + reinterpret_cast(der.data()), + static_cast(der.size())); + if (ctx) + { + certs.emplace_back(ctx); + } + } + return certs; + } + + // Parse concatenated DER data into CERT_CONTEXT objects + std::vector parse_der_certs(std::string_view der_data) + { + std::vector certs; + const BYTE* p = reinterpret_cast(der_data.data()); + DWORD remaining = static_cast(der_data.size()); + + while (remaining > 0) + { + // Decode the ASN.1 length to find the certificate boundary + DWORD cert_size = 0; + PCCERT_CONTEXT ctx = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, p, remaining); + if (!ctx) + { + break; + } + cert_size = ctx->cbCertEncoded; + certs.emplace_back(ctx); + p += cert_size; + remaining -= cert_size; + } + return certs; + } + + ParseCertsResult parse_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + std::vector certs; + if (decoded.is_pem) + { + certs = parse_pem_certs(decoded.data); + } + else + { + certs = parse_der_certs(decoded.data); + } + + if (certs.empty()) + { + return {{}, "x509: malformed certificate"}; + } + + ParseCertsResult result; + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + return result; + } + + VerifyCertsResult parse_and_verify_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {false, {}, decoded.error}; + } + + std::vector certs; + if (decoded.is_pem) + { + certs = parse_pem_certs(decoded.data); + } + else + { + certs = parse_der_certs(decoded.data); + } + + if (certs.size() < 2) + { + VerifyCertsResult result; + result.valid = false; + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + return result; + } + + // Build an in-memory certificate store with all non-leaf certs. + // OPA convention: last cert is leaf, others are CA/intermediates. + HCERTSTORE root_store = CertOpenStore( + CERT_STORE_PROV_MEMORY, 0, 0, CERT_STORE_CREATE_NEW_FLAG, nullptr); + if (!root_store) + { + return {false, {}, "failed to create certificate store"}; + } + + // Add self-signed certs to root store, others as additional store + HCERTSTORE extra_store = CertOpenStore( + CERT_STORE_PROV_MEMORY, 0, 0, CERT_STORE_CREATE_NEW_FLAG, nullptr); + + for (size_t i = 0; i + 1 < certs.size(); ++i) + { + PCCERT_CONTEXT c = certs[i].get(); + // Check if self-signed (issuer == subject) + if (CertCompareCertificateName( + X509_ASN_ENCODING, &c->pCertInfo->Subject, &c->pCertInfo->Issuer)) + { + CertAddCertificateContextToStore( + root_store, c, CERT_STORE_ADD_ALWAYS, nullptr); + } + else + { + CertAddCertificateContextToStore( + extra_store, c, CERT_STORE_ADD_ALWAYS, nullptr); + } + } + + // Create a custom chain engine that trusts only our root store + CERT_CHAIN_ENGINE_CONFIG engine_config = {}; + engine_config.cbSize = sizeof(engine_config); + engine_config.hExclusiveRoot = root_store; + + HCERTCHAINENGINE engine = nullptr; + if (!CertCreateCertificateChainEngine(&engine_config, &engine)) + { + CertCloseStore(root_store, 0); + CertCloseStore(extra_store, 0); + return {false, {}, "failed to create chain engine"}; + } + + // Verify the leaf against the custom engine + PCCERT_CONTEXT leaf = certs.back().get(); + + CERT_CHAIN_PARA chain_params = {}; + chain_params.cbSize = sizeof(chain_params); + + PCCERT_CHAIN_CONTEXT chain_ctx = nullptr; + BOOL chain_ok = CertGetCertificateChain( + engine, + leaf, + nullptr, // current time + extra_store, + &chain_params, + 0, // no revocation checking + nullptr, + &chain_ctx); + + VerifyCertsResult result; + + if (chain_ok && chain_ctx) + { + // Check the chain status — allow untrusted root since we're using + // our own root store, and ignore time/revocation issues for test certs + DWORD error_status = chain_ctx->TrustStatus.dwErrorStatus; + // Mask out acceptable errors for OPA-style verification + DWORD acceptable_errors = CERT_TRUST_IS_NOT_TIME_VALID | + CERT_TRUST_REVOCATION_STATUS_UNKNOWN | CERT_TRUST_IS_OFFLINE_REVOCATION; + result.valid = ((error_status & ~acceptable_errors) == 0); + + if (result.valid && chain_ctx->cChain > 0) + { + // Return verified chain in leaf-first order + PCERT_SIMPLE_CHAIN simple_chain = chain_ctx->rgpChain[0]; + if (simple_chain->cElement > static_cast(MaxCertChainLen)) + { + result.valid = false; + result.certs.clear(); + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + CertFreeCertificateChain(chain_ctx); + CertFreeCertificateChainEngine(engine); + CertCloseStore(root_store, 0); + CertCloseStore(extra_store, 0); + return result; + } + for (DWORD i = 0; i < simple_chain->cElement; ++i) + { + result.certs.push_back( + cert_to_parsed(simple_chain->rgpElement[i]->pCertContext)); + } + } + else + { + // On failure, return certs in input order + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + } + + CertFreeCertificateChain(chain_ctx); + } + else + { + result.valid = false; + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + } + + CertFreeCertificateChainEngine(engine); + CertCloseStore(root_store, 0); + CertCloseStore(extra_store, 0); + return result; + } + + ParseCSRResult parse_certificate_request(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + std::string_view data_view = decoded.data; + std::string der_data; + + if (decoded.is_pem) + { + auto blocks = extract_pem_der_blocks(decoded.data, "CERTIFICATE REQUEST"); + if (blocks.empty()) + { + // Also try "NEW CERTIFICATE REQUEST" + blocks = + extract_pem_der_blocks(decoded.data, "NEW CERTIFICATE REQUEST"); + } + if (blocks.empty()) + { + return {{}, "asn1: structure error"}; + } + der_data = std::move(blocks[0]); + data_view = der_data; + } + + // Decode the PKCS#10 CSR to extract the subject + PCERT_REQUEST_INFO req_info = nullptr; + DWORD req_info_size = 0; + if (!CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + X509_CERT_REQUEST_TO_BE_SIGNED, + reinterpret_cast(data_view.data()), + static_cast(data_view.size()), + CRYPT_DECODE_ALLOC_FLAG, + nullptr, + &req_info, + &req_info_size)) + { + return {{}, "asn1: structure error"}; + } + + if (!req_info) + { + return {{}, "asn1: structure error"}; + } + + // Extract CommonName from the subject using CertNameToStrA + // (CertGetNameStringA requires a CERT_CONTEXT, which we don't have for a + // CSR) + ParseCSRResult result; + DWORD name_len = CertNameToStrA( + X509_ASN_ENCODING, + &req_info->Subject, + CERT_X500_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + nullptr, + 0); + if (name_len > 1) + { + std::string full_name(name_len - 1, '\0'); + CertNameToStrA( + X509_ASN_ENCODING, + &req_info->Subject, + CERT_X500_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + full_name.data(), + name_len); + // Extract CN= value from the X.500 string + auto cn_pos = full_name.find("CN="); + if (cn_pos != std::string::npos) + { + cn_pos += 3; + auto cn_end = full_name.find(',', cn_pos); + if (cn_end == std::string::npos) + { + cn_end = full_name.size(); + } + result.subject.common_name = full_name.substr(cn_pos, cn_end - cn_pos); + } + } + + LocalFree(req_info); + return result; + } + + // ── RSA Private Key Parsing ── + + // Convert raw big-endian bytes to base64url (no padding), stripping leading + // zero bytes. + std::string bytes_to_base64url(const BYTE* data, DWORD size) + { + // Skip leading zeros + while (size > 0 && *data == 0) + { + ++data; + --size; + } + if (size == 0) + { + return "AA"; // zero value + } + std::string_view sv(reinterpret_cast(data), size); + return base64url_encode_nopad(sv); + } + + // Parse a single RSA private key from DER bytes and return JWK components. + // Handles both PKCS#1 (RSA PRIVATE KEY) and PKCS#8 (PRIVATE KEY) formats. + std::optional parse_rsa_key_der(std::string_view der_data) + { + // Try PKCS#8 first (PRIVATE KEY), then PKCS#1 (RSA PRIVATE KEY) + // CryptDecodeObjectEx with PKCS_PRIVATE_KEY_INFO for PKCS#8, + // or CNG_RSA_PRIVATE_KEY_BLOB for PKCS#1. + + // Strategy: import through CNG to normalize the format, then export + // the full private blob to extract all components. + + // Try PKCS#8 wrapper first + PCRYPT_PRIVATE_KEY_INFO pkcs8_info = nullptr; + DWORD pkcs8_size = 0; + bool is_pkcs8 = CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + PKCS_PRIVATE_KEY_INFO, + reinterpret_cast(der_data.data()), + static_cast(der_data.size()), + CRYPT_DECODE_ALLOC_FLAG, + nullptr, + &pkcs8_info, + &pkcs8_size); + + std::string_view rsa_der = der_data; + std::unique_ptr pkcs8_guard(nullptr, LocalFree); + + if (is_pkcs8 && pkcs8_info) + { + pkcs8_guard.reset(pkcs8_info); + // Extract the inner RSA private key from PKCS#8 + rsa_der = std::string_view( + reinterpret_cast(pkcs8_info->PrivateKey.pbData), + pkcs8_info->PrivateKey.cbData); + } + + // Decode as PKCS#1 RSA private key to CNG blob + BYTE* cng_blob = nullptr; + DWORD cng_blob_size = 0; + if (!CryptDecodeObjectEx( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + CNG_RSA_PRIVATE_KEY_BLOB, + reinterpret_cast(rsa_der.data()), + static_cast(rsa_der.size()), + CRYPT_DECODE_ALLOC_FLAG, + nullptr, + &cng_blob, + &cng_blob_size)) + { + return std::nullopt; + } + std::unique_ptr blob_guard(cng_blob, LocalFree); + + // Import the CNG blob to get a key handle + BCRYPT_ALG_HANDLE alg_raw = nullptr; + NTSTATUS status = + BCryptOpenAlgorithmProvider(&alg_raw, BCRYPT_RSA_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(status)) + { + return std::nullopt; + } + AlgHandle alg(alg_raw); + + BCRYPT_KEY_HANDLE key_raw = nullptr; + status = BCryptImportKeyPair( + alg.get(), + nullptr, + BCRYPT_RSAPRIVATE_BLOB, + &key_raw, + cng_blob, + cng_blob_size, + 0); + if (!BCRYPT_SUCCESS(status)) + { + return std::nullopt; + } + KeyHandle key(key_raw); + + // Export as BCRYPT_RSAFULLPRIVATE_BLOB to get all components + DWORD export_size = 0; + status = BCryptExportKey( + key.get(), + nullptr, + BCRYPT_RSAFULLPRIVATE_BLOB, + nullptr, + 0, + &export_size, + 0); + if (!BCRYPT_SUCCESS(status)) + { + return std::nullopt; + } + + std::vector export_blob(export_size); + status = BCryptExportKey( + key.get(), + nullptr, + BCRYPT_RSAFULLPRIVATE_BLOB, + export_blob.data(), + export_size, + &export_size, + 0); + if (!BCRYPT_SUCCESS(status)) + { + return std::nullopt; + } + + // Parse the exported blob + // Layout: header | e[cbPublicExp] | n[cbModulus] | p[cbPrime1] | + // q[cbPrime2] | dp[cbPrime1] | dq[cbPrime2] | qi[cbPrime1] | + // d[cbModulus] + auto* hdr = reinterpret_cast(export_blob.data()); + const BYTE* ptr = export_blob.data() + sizeof(BCRYPT_RSAKEY_BLOB); + + RSAPrivateKeyJWK jwk; + jwk.kty = "RSA"; + jwk.e = bytes_to_base64url(ptr, hdr->cbPublicExp); + ptr += hdr->cbPublicExp; + jwk.n = bytes_to_base64url(ptr, hdr->cbModulus); + ptr += hdr->cbModulus; + jwk.p = bytes_to_base64url(ptr, hdr->cbPrime1); + ptr += hdr->cbPrime1; + jwk.q = bytes_to_base64url(ptr, hdr->cbPrime2); + ptr += hdr->cbPrime2; + jwk.dp = bytes_to_base64url(ptr, hdr->cbPrime1); + ptr += hdr->cbPrime1; + jwk.dq = bytes_to_base64url(ptr, hdr->cbPrime2); + ptr += hdr->cbPrime2; + jwk.qi = bytes_to_base64url(ptr, hdr->cbPrime1); + ptr += hdr->cbPrime1; + jwk.d = bytes_to_base64url(ptr, hdr->cbModulus); + + return jwk; + } + + ParseRSAKeyResult parse_rsa_private_key(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + std::string der_data; + if (decoded.is_pem) + { + // Try PKCS#1 first, then PKCS#8 + auto blocks = extract_pem_der_blocks(decoded.data, "RSA PRIVATE KEY"); + if (blocks.empty()) + { + blocks = extract_pem_der_blocks(decoded.data, "PRIVATE KEY"); + } + if (blocks.empty()) + { + return {{}, "failed to parse RSA private key"}; + } + der_data = std::move(blocks[0]); + } + else + { + der_data = std::move(decoded.data); + } + + auto jwk = parse_rsa_key_der(der_data); + if (!jwk) + { + return {{}, "failed to parse RSA private key"}; + } + return {*jwk, {}}; + } + + ParsePrivateKeysResult parse_private_keys(std::string_view input) + { + if (input.empty()) + { + return {{}, true, {}}; + } + + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, false, {}}; + } + + if (!decoded.is_pem) + { + return {{}, false, {}}; + } + + ParsePrivateKeysResult result; + result.is_empty_input = false; + + // Extract all PKCS#1 blocks + auto rsa_blocks = extract_pem_der_blocks(decoded.data, "RSA PRIVATE KEY"); + for (auto& der : rsa_blocks) + { + auto jwk = parse_rsa_key_der(der); + if (jwk) + { + result.keys.push_back(*jwk); + } + } + + // Extract all PKCS#8 blocks + auto pkcs8_blocks = extract_pem_der_blocks(decoded.data, "PRIVATE KEY"); + for (auto& der : pkcs8_blocks) + { + auto jwk = parse_rsa_key_der(der); + if (jwk) + { + result.keys.push_back(*jwk); + } + } + + return result; + } +} + +#endif // REGOCPP_CRYPTO_BCRYPT diff --git a/src/builtins/crypto_core.hh b/src/builtins/crypto_core.hh new file mode 100644 index 00000000..15723794 --- /dev/null +++ b/src/builtins/crypto_core.hh @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#ifdef REGOCPP_HAS_CRYPTO + +#include +#include +#include + +namespace rego::crypto_core +{ + // ── Hashing ── + std::string md5_hex(std::string_view data); + std::string sha1_hex(std::string_view data); + std::string sha256_hex(std::string_view data); + + // ── HMAC ── + std::string hmac_md5_hex(std::string_view key, std::string_view data); + std::string hmac_sha1_hex(std::string_view key, std::string_view data); + std::string hmac_sha256_hex(std::string_view key, std::string_view data); + std::string hmac_sha512_hex(std::string_view key, std::string_view data); + bool hmac_equal(std::string_view mac1, std::string_view mac2); + + // ── Base64url ── + std::string base64url_encode_nopad(std::string_view data); + std::string base64url_decode(std::string_view data); + + // ── Signature Verification ── + enum class Algorithm + { + HS256, + HS384, + HS512, + RS256, + RS384, + RS512, + PS256, + PS384, + PS512, + ES256, + ES384, + ES512, + EdDSA + }; + + // Returns the algorithm enum for a JWT "alg" header value. + // Throws std::invalid_argument if unknown. + Algorithm parse_algorithm(std::string_view name); + + // Result of signature verification. + struct VerifyResult + { + bool valid; // true if signature verified successfully + std::string error; // non-empty if key parsing / crypto failed (not a + // mismatch, but an actual error) + }; + + // Verify a JWT signature. key_or_cert is a PEM certificate, PEM public key, + // or JWK JSON string. signing_input is "header.payload" (the first two + // base64url sections). signature is the raw decoded signature bytes. + VerifyResult verify_signature( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert); + + // Verify a JWT signature, trying all keys in a JWKS "keys" array. + // Returns valid=true on the first key that succeeds. + // If key_or_cert is not a JWKS, falls back to verify_signature. + VerifyResult verify_signature_any_key( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert); + + // ── Signing ── + + // Sign data with a JWK private key. + // Returns the raw signature bytes, or throws on error. + // key_jwk_json is the full JWK JSON string containing private key material. + std::string sign( + Algorithm algo, + std::string_view signing_input, + std::string_view key_jwk_json); + + // ── X.509 Certificate Parsing ── + + struct X509Name + { + std::string common_name; + // Extend with Organization, Country, etc. as needed + }; + + struct ParsedCertificate + { + X509Name subject; + std::vector dns_names; // empty if none + std::vector uri_strings; // empty if none + std::string der_b64; // base64-encoded DER (for keypair output) + }; + + struct ParseCertsResult + { + std::vector certs; + std::string error; // non-empty on failure + }; + + // Parse one or more X.509 certificates from input. + // Input can be: PEM string, base64-encoded PEM, base64-encoded DER, or + // concatenated DER bytes. + ParseCertsResult parse_certificates(std::string_view input); + + struct ParseCSRResult + { + X509Name subject; + std::string error; // non-empty on failure + }; + + // Parse an X.509 Certificate Signing Request. + ParseCSRResult parse_certificate_request(std::string_view input); + + struct VerifyCertsResult + { + bool valid; + std::vector certs; + std::string error; // non-empty on failure + }; + + // Parse and verify a certificate chain. + // Follows OPA convention: the last certificate in the PEM bundle is treated + // as the leaf (workload) certificate; all preceding certificates are treated + // as CA or intermediate certificates. Self-signed certificates in the + // preceding set are used as trust anchors. + // NOTE: CRL and OCSP revocation checking is not performed (matching OPA). + VerifyCertsResult parse_and_verify_certificates(std::string_view input); + + struct RSAPrivateKeyJWK + { + std::string kty; // "RSA" + std::string e, n, d, p, q, dp, dq, qi; // base64url-encoded + }; + + struct ParseRSAKeyResult + { + RSAPrivateKeyJWK key; + std::string error; // non-empty on failure + }; + + // Parse a PEM-encoded RSA private key and return as JWK. + ParseRSAKeyResult parse_rsa_private_key(std::string_view input); + + struct ParsePrivateKeysResult + { + std::vector keys; + bool is_empty_input; // true if input was empty string → null result + std::string error; // non-empty on failure + }; + + // Parse one or more PEM private keys. + ParsePrivateKeysResult parse_private_keys(std::string_view input); +} + +#endif // REGOCPP_HAS_CRYPTO diff --git a/src/builtins/crypto_mbedtls.cc b/src/builtins/crypto_mbedtls.cc new file mode 100644 index 00000000..171bbdb8 --- /dev/null +++ b/src/builtins/crypto_mbedtls.cc @@ -0,0 +1,1982 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifdef REGOCPP_CRYPTO_MBEDTLS + +#include "base64/base64.h" +#include "crypto_core.hh" +#include "crypto_utils.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + using rego::crypto_core::to_hex; + + // ── RAII wrappers ── + + struct MdCtx + { + mbedtls_md_context_t ctx; + MdCtx() + { + mbedtls_md_init(&ctx); + } + ~MdCtx() + { + mbedtls_md_free(&ctx); + } + MdCtx(const MdCtx&) = delete; + MdCtx& operator=(const MdCtx&) = delete; + }; + + struct PkCtx + { + mbedtls_pk_context ctx; + PkCtx() + { + mbedtls_pk_init(&ctx); + } + ~PkCtx() + { + mbedtls_pk_free(&ctx); + } + PkCtx(const PkCtx&) = delete; + PkCtx& operator=(const PkCtx&) = delete; + }; + + struct Mpi + { + mbedtls_mpi val; + Mpi() + { + mbedtls_mpi_init(&val); + } + ~Mpi() + { + mbedtls_mpi_free(&val); + } + Mpi(const Mpi&) = delete; + Mpi& operator=(const Mpi&) = delete; + }; + + struct EntropyCtx + { + mbedtls_entropy_context ctx; + EntropyCtx() + { + mbedtls_entropy_init(&ctx); + } + ~EntropyCtx() + { + mbedtls_entropy_free(&ctx); + } + EntropyCtx(const EntropyCtx&) = delete; + EntropyCtx& operator=(const EntropyCtx&) = delete; + }; + + struct CtrDrbgCtx + { + mbedtls_ctr_drbg_context ctx; + CtrDrbgCtx() + { + mbedtls_ctr_drbg_init(&ctx); + } + ~CtrDrbgCtx() + { + mbedtls_ctr_drbg_free(&ctx); + } + CtrDrbgCtx(const CtrDrbgCtx&) = delete; + CtrDrbgCtx& operator=(const CtrDrbgCtx&) = delete; + }; + + struct X509Crt + { + mbedtls_x509_crt crt; + X509Crt() + { + mbedtls_x509_crt_init(&crt); + } + ~X509Crt() + { + mbedtls_x509_crt_free(&crt); + } + X509Crt(const X509Crt&) = delete; + X509Crt& operator=(const X509Crt&) = delete; + }; + + struct X509Csr + { + mbedtls_x509_csr csr; + X509Csr() + { + mbedtls_x509_csr_init(&csr); + } + ~X509Csr() + { + mbedtls_x509_csr_free(&csr); + } + X509Csr(const X509Csr&) = delete; + X509Csr& operator=(const X509Csr&) = delete; + }; + + // ── RNG setup (shared) ── + + struct Rng + { + EntropyCtx entropy; + CtrDrbgCtx drbg; + + Rng() + { + int ret = mbedtls_ctr_drbg_seed( + &drbg.ctx, mbedtls_entropy_func, &entropy.ctx, nullptr, 0); + if (ret != 0) + { + throw std::runtime_error("mbedtls_ctr_drbg_seed failed"); + } + } + }; + + Rng& get_rng() + { + static Rng rng; + return rng; + } + + // ── Hash helpers ── + + const mbedtls_md_info_t* md_info_from_type(mbedtls_md_type_t type) + { + return mbedtls_md_info_from_type(type); + } + + std::string digest_hex(mbedtls_md_type_t type, std::string_view data) + { + const mbedtls_md_info_t* info = md_info_from_type(type); + if (!info) + { + throw std::runtime_error("unsupported digest type"); + } + + unsigned char buf[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(data.data()), + data.size(), + buf); + if (ret != 0) + { + throw std::runtime_error("mbedtls_md failed"); + } + + return to_hex(buf, mbedtls_md_get_size(info)); + } + + std::string hmac_hex( + mbedtls_md_type_t type, std::string_view key, std::string_view data) + { + const mbedtls_md_info_t* info = md_info_from_type(type); + if (!info) + { + throw std::runtime_error("unsupported digest type for HMAC"); + } + + unsigned char buf[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md_hmac( + info, + reinterpret_cast(key.data()), + key.size(), + reinterpret_cast(data.data()), + data.size(), + buf); + if (ret != 0) + { + throw std::runtime_error("mbedtls_md_hmac failed"); + } + + return to_hex(buf, mbedtls_md_get_size(info)); + } + + // ── Algorithm mapping ── + + using rego::crypto_core::Algorithm; + + mbedtls_md_type_t md_type_for_algo(Algorithm algo) + { + switch (algo) + { + case Algorithm::HS256: + case Algorithm::RS256: + case Algorithm::PS256: + case Algorithm::ES256: + return MBEDTLS_MD_SHA256; + case Algorithm::HS384: + case Algorithm::RS384: + case Algorithm::PS384: + case Algorithm::ES384: + return MBEDTLS_MD_SHA384; + case Algorithm::HS512: + case Algorithm::RS512: + case Algorithm::PS512: + case Algorithm::ES512: + return MBEDTLS_MD_SHA512; + case Algorithm::EdDSA: + return MBEDTLS_MD_NONE; + } + return MBEDTLS_MD_NONE; + } + + using rego::crypto_core::extract_jwks_keys; + using rego::crypto_core::json_select_string; + using rego::crypto_core::parse_json; + + // ── Key loading: PEM public keys and certificates ── + + bool pk_from_pem_pubkey(PkCtx& pk, std::string_view pem) + { + int ret = mbedtls_pk_parse_public_key( + &pk.ctx, + reinterpret_cast(pem.data()), + pem.size() + 1); // mbedtls requires null-terminated PEM + return ret == 0; + } + + bool pk_from_certificate(PkCtx& pk, std::string_view cert_pem) + { + X509Crt crt; + int ret = mbedtls_x509_crt_parse( + &crt.crt, + reinterpret_cast(cert_pem.data()), + cert_pem.size() + 1); + if (ret != 0) + { + return false; + } + + // Copy the public key from the certificate. + // We need to export and re-import the key. + // 8192 bytes is sufficient for RSA keys up to 16384 bits. + unsigned char buf[8192]; + int len = mbedtls_pk_write_pubkey_der(&crt.crt.pk, buf, sizeof(buf)); + if (len < 0) + { + return false; + } + + // mbedtls_pk_write_pubkey_der writes from end of buffer + ret = mbedtls_pk_parse_public_key( + &pk.ctx, buf + sizeof(buf) - len, static_cast(len)); + return ret == 0; + } + + // ── Key loading: JWK ── + + // Load MPI from base64url-encoded big integer + bool mpi_from_base64url(Mpi& mpi, std::string_view b64) + { + rego::crypto_core::SecureString raw(::base64_decode(b64)); + return mbedtls_mpi_read_binary( + &mpi.val, + reinterpret_cast(raw.data()), + raw.size()) == 0; + } + + // Parse a JWK RSA key into a PK context (public key only) + bool pk_from_jwk_rsa( + PkCtx& pk, std::string_view n_b64, std::string_view e_b64) + { + Mpi n, e; + if (!mpi_from_base64url(n, n_b64) || !mpi_from_base64url(e, e_b64)) + { + return false; + } + + int ret = + mbedtls_pk_setup(&pk.ctx, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (ret != 0) + { + return false; + } + + mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + ret = mbedtls_rsa_import(rsa, &n.val, nullptr, nullptr, nullptr, &e.val); + if (ret != 0) + { + return false; + } + return mbedtls_rsa_complete(rsa) == 0; + } + + // Map JWK curve name to mbedtls group ID + mbedtls_ecp_group_id ec_group_from_crv(std::string_view crv) + { + if (crv == "P-256") + return MBEDTLS_ECP_DP_SECP256R1; + if (crv == "P-384") + return MBEDTLS_ECP_DP_SECP384R1; + if (crv == "P-521") + return MBEDTLS_ECP_DP_SECP521R1; + return MBEDTLS_ECP_DP_NONE; + } + + // Parse a JWK EC key into a PK context (public key only) + bool pk_from_jwk_ec( + PkCtx& pk, + std::string_view crv, + std::string_view x_b64, + std::string_view y_b64) + { + mbedtls_ecp_group_id grp_id = ec_group_from_crv(crv); + if (grp_id == MBEDTLS_ECP_DP_NONE) + { + return false; + } + + std::string x_raw = ::base64_decode(x_b64); + std::string y_raw = ::base64_decode(y_b64); + + // Build uncompressed point: 0x04 || x || y + std::vector point; + point.reserve(1 + x_raw.size() + y_raw.size()); + point.push_back(0x04); + point.insert( + point.end(), + reinterpret_cast(x_raw.data()), + reinterpret_cast(x_raw.data()) + x_raw.size()); + point.insert( + point.end(), + reinterpret_cast(y_raw.data()), + reinterpret_cast(y_raw.data()) + y_raw.size()); + + int ret = + mbedtls_pk_setup(&pk.ctx, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + if (ret != 0) + { + return false; + } + + // Mbed TLS 3.6: use opaque API for EC keypair + mbedtls_ecp_keypair* ec = mbedtls_pk_ec(pk.ctx); + mbedtls_ecp_group grp; + mbedtls_ecp_group_init(&grp); + ret = mbedtls_ecp_group_load(&grp, grp_id); + if (ret != 0) + { + mbedtls_ecp_group_free(&grp); + return false; + } + + mbedtls_ecp_point Q; + mbedtls_ecp_point_init(&Q); + ret = mbedtls_ecp_point_read_binary(&grp, &Q, point.data(), point.size()); + mbedtls_ecp_group_free(&grp); + if (ret != 0) + { + mbedtls_ecp_point_free(&Q); + return false; + } + + ret = mbedtls_ecp_set_public_key(grp_id, ec, &Q); + mbedtls_ecp_point_free(&Q); + return ret == 0; + } + + // Dispatch JWK AST to appropriate key loader + bool pk_from_jwk_ast(PkCtx& pk, const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + if ( + n.empty() || e.empty() || n.size() > MaxRSAComponentB64Len || + e.size() > MaxRSAComponentB64Len) + { + return false; + } + return pk_from_jwk_rsa(pk, n, e); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + if ( + crv.empty() || x.empty() || y.empty() || + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len) + { + return false; + } + return pk_from_jwk_ec(pk, crv, x, y); + } + // OKP (Ed25519) not supported in mbedtls backend + return false; + } + + // Auto-detect key format and load PK context. + bool load_public_key( + PkCtx& pk, std::string_view key_or_cert, std::string_view kid = {}) + { + // PEM certificate + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != std::string_view::npos) + { + return pk_from_certificate(pk, key_or_cert); + } + + // PEM public key + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != std::string_view::npos) + { + return pk_from_pem_pubkey(pk, key_or_cert); + } + + // Try JSON (JWK or JWKS) + auto ast = parse_json(key_or_cert); + if (!ast) + { + return false; + } + + // Try JWKS first + auto keys = extract_jwks_keys(ast); + if (!keys.empty()) + { + for (auto& key_ast : keys) + { + if (kid.empty() || json_select_string(key_ast, "/kid") == kid) + { + return pk_from_jwk_ast(pk, key_ast); + } + } + return false; + } + + // Single JWK + std::string_view kty = json_select_string(ast, "/kty"); + if (!kty.empty()) + { + return pk_from_jwk_ast(pk, ast); + } + + return false; + } + + // ── Signature verification helpers ── + + bool verify_hmac( + Algorithm algo, + std::string_view signing_input, + std::string_view sig_bytes, + std::string_view secret) + { + mbedtls_md_type_t type = md_type_for_algo(algo); + const mbedtls_md_info_t* info = md_info_from_type(type); + if (!info) + { + return false; + } + + unsigned char buf[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md_hmac( + info, + reinterpret_cast(secret.data()), + secret.size(), + reinterpret_cast(signing_input.data()), + signing_input.size(), + buf); + if (ret != 0) + { + return false; + } + + size_t md_size = mbedtls_md_get_size(info); + if (md_size != sig_bytes.size()) + { + return false; + } + + // Constant-time comparison + volatile unsigned char result = 0; + for (size_t i = 0; i < md_size; ++i) + { + result |= static_cast(buf[i]) ^ + static_cast(sig_bytes[i]); + } + return result == 0; + } + + bool verify_rsa_pkcs1( + mbedtls_md_type_t md_type, + PkCtx& pk, + std::string_view signing_input, + std::string_view sig_bytes) + { + const mbedtls_md_info_t* info = md_info_from_type(md_type); + if (!info) + { + return false; + } + + // Hash the signing input + unsigned char hash[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(signing_input.data()), + signing_input.size(), + hash); + if (ret != 0) + { + return false; + } + + size_t hash_len = mbedtls_md_get_size(info); + + // Ensure PKCS#1 v1.5 padding is set. The cached PkCtx may have been + // previously configured for PSS (PKCS_V21) by verify_rsa_pss. + mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + if (rsa) + { + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15, md_type); + } + + // Verify signature + ret = mbedtls_pk_verify( + &pk.ctx, + md_type, + hash, + hash_len, + reinterpret_cast(sig_bytes.data()), + sig_bytes.size()); + return ret == 0; + } + + bool verify_rsa_pss( + mbedtls_md_type_t md_type, + PkCtx& pk, + std::string_view signing_input, + std::string_view sig_bytes) + { + const mbedtls_md_info_t* info = md_info_from_type(md_type); + if (!info) + { + return false; + } + + unsigned char hash[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(signing_input.data()), + signing_input.size(), + hash); + if (ret != 0) + { + return false; + } + + size_t hash_len = mbedtls_md_get_size(info); + + // Use RSA-PSS verification + mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + if (!rsa) + { + return false; + } + + ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_type); + if (ret != 0) + { + return false; + } + + // mbedtls_rsa_rsassa_pss_verify reads exactly rsa->len bytes from the + // signature buffer (the RSA modulus size). JWT signatures decoded from + // base64url may be shorter if leading zero bytes were stripped. Pad the + // signature to the expected length to avoid a heap-buffer-overflow. + size_t key_len = mbedtls_rsa_get_len(rsa); + std::vector sig_padded(key_len, 0); + if (sig_bytes.size() > key_len) + { + return false; + } + std::memcpy( + sig_padded.data() + (key_len - sig_bytes.size()), + sig_bytes.data(), + sig_bytes.size()); + + ret = mbedtls_rsa_rsassa_pss_verify( + rsa, + md_type, + static_cast(hash_len), + hash, + sig_padded.data()); + return ret == 0; + } + + bool verify_ecdsa( + mbedtls_md_type_t md_type, + PkCtx& pk, + std::string_view signing_input, + std::string_view sig_bytes) + { + // Validate signature length against the expected curve size. + // ES256 (SHA-256, P-256): 64 bytes, ES384 (SHA-384, P-384): 96 bytes, + // ES512 (SHA-512, P-521): 132 bytes. + size_t expected_sig_len = 0; + switch (md_type) + { + case MBEDTLS_MD_SHA256: + expected_sig_len = 64; + break; + case MBEDTLS_MD_SHA384: + expected_sig_len = 96; + break; + case MBEDTLS_MD_SHA512: + expected_sig_len = 132; + break; + default: + return false; + } + if (sig_bytes.empty() || sig_bytes.size() != expected_sig_len) + { + return false; + } + + const mbedtls_md_info_t* info = md_info_from_type(md_type); + if (!info) + { + return false; + } + + unsigned char hash[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(signing_input.data()), + signing_input.size(), + hash); + if (ret != 0) + { + return false; + } + + size_t hash_len = mbedtls_md_get_size(info); + + // JWT ECDSA signatures are raw R||S; mbedtls_pk_verify expects DER. + // Convert raw to DER ASN.1 SEQUENCE { INTEGER r, INTEGER s } + size_t half = sig_bytes.size() / 2; + Mpi r, s; + mbedtls_mpi_read_binary( + &r.val, reinterpret_cast(sig_bytes.data()), half); + mbedtls_mpi_read_binary( + &s.val, + reinterpret_cast(sig_bytes.data()) + half, + half); + + // Encode as DER: SEQUENCE { INTEGER r, INTEGER s } + unsigned char der_buf[256]; + unsigned char* p = der_buf + sizeof(der_buf); + size_t len = 0; + + // Write s + ret = mbedtls_asn1_write_mpi(&p, der_buf, &s.val); + if (ret < 0) + { + return false; + } + len += static_cast(ret); + + // Write r + ret = mbedtls_asn1_write_mpi(&p, der_buf, &r.val); + if (ret < 0) + { + return false; + } + len += static_cast(ret); + + // Write SEQUENCE tag + length + ret = mbedtls_asn1_write_len(&p, der_buf, len); + if (ret < 0) + { + return false; + } + len += static_cast(ret); + + ret = mbedtls_asn1_write_tag( + &p, der_buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE); + if (ret < 0) + { + return false; + } + len += static_cast(ret); + + ret = mbedtls_pk_verify(&pk.ctx, md_type, hash, hash_len, p, len); + return ret == 0; + } + + // ── PEM validation ── + + // ── Signing: private key loading ── + + bool pk_from_jwk_rsa_private( + PkCtx& pk, + std::string_view n_b64, + std::string_view e_b64, + std::string_view d_b64, + std::string_view p_b64, + std::string_view q_b64, + std::string_view dp_b64, + std::string_view dq_b64, + std::string_view qi_b64) + { + Mpi n, e, d, p, q; + + if ( + !mpi_from_base64url(n, n_b64) || !mpi_from_base64url(e, e_b64) || + !mpi_from_base64url(d, d_b64)) + { + return false; + } + + // p, q, dp, dq, qi are optional for mbedtls_rsa_complete + bool have_pq = !p_b64.empty() && !q_b64.empty() && + mpi_from_base64url(p, p_b64) && mpi_from_base64url(q, q_b64); + + int ret = + mbedtls_pk_setup(&pk.ctx, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (ret != 0) + { + return false; + } + + mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + ret = mbedtls_rsa_import( + rsa, + &n.val, + have_pq ? &p.val : nullptr, + have_pq ? &q.val : nullptr, + &d.val, + &e.val); + if (ret != 0) + { + return false; + } + + // mbedtls_rsa_complete will derive dp, dq, qi from d, p, q + return mbedtls_rsa_complete(rsa) == 0; + } + + bool pk_from_jwk_ec_private( + PkCtx& pk, + std::string_view crv, + std::string_view x_b64, + std::string_view y_b64, + std::string_view d_b64) + { + mbedtls_ecp_group_id grp_id = ec_group_from_crv(crv); + if (grp_id == MBEDTLS_ECP_DP_NONE) + { + return false; + } + + std::string x_raw = ::base64_decode(x_b64); + std::string y_raw = ::base64_decode(y_b64); + rego::crypto_core::SecureString d_raw(::base64_decode(d_b64)); + + // Build uncompressed point + std::vector point; + point.reserve(1 + x_raw.size() + y_raw.size()); + point.push_back(0x04); + point.insert( + point.end(), + reinterpret_cast(x_raw.data()), + reinterpret_cast(x_raw.data()) + x_raw.size()); + point.insert( + point.end(), + reinterpret_cast(y_raw.data()), + reinterpret_cast(y_raw.data()) + y_raw.size()); + + int ret = + mbedtls_pk_setup(&pk.ctx, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + if (ret != 0) + { + return false; + } + + // Set private key + mbedtls_ecp_keypair* ec = mbedtls_pk_ec(pk.ctx); + ret = mbedtls_ecp_read_key( + grp_id, + ec, + reinterpret_cast(d_raw.data()), + d_raw.size()); + if (ret != 0) + { + return false; + } + + // Set public key + mbedtls_ecp_group grp; + mbedtls_ecp_group_init(&grp); + ret = mbedtls_ecp_group_load(&grp, grp_id); + if (ret != 0) + { + mbedtls_ecp_group_free(&grp); + return false; + } + + mbedtls_ecp_point Q; + mbedtls_ecp_point_init(&Q); + ret = mbedtls_ecp_point_read_binary(&grp, &Q, point.data(), point.size()); + mbedtls_ecp_group_free(&grp); + if (ret != 0) + { + mbedtls_ecp_point_free(&Q); + return false; + } + + ret = mbedtls_ecp_set_public_key(grp_id, ec, &Q); + mbedtls_ecp_point_free(&Q); + return ret == 0; + } + + bool load_private_key_ast( + PkCtx& pk, std::string& ed25519_raw, const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxOKPComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + std::string_view d = json_select_string(ast, "/d"); + std::string_view p = json_select_string(ast, "/p"); + std::string_view q = json_select_string(ast, "/q"); + std::string_view dp = json_select_string(ast, "/dp"); + std::string_view dq = json_select_string(ast, "/dq"); + std::string_view qi = json_select_string(ast, "/qi"); + if (n.empty() || e.empty() || d.empty()) + { + return false; + } + for (auto sv : {n, e, d, p, q, dp, dq, qi}) + { + if (sv.size() > MaxRSAComponentB64Len) + { + return false; + } + } + return pk_from_jwk_rsa_private(pk, n, e, d, p, q, dp, dq, qi); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + std::string_view d = json_select_string(ast, "/d"); + if (crv.empty() || x.empty() || y.empty() || d.empty()) + { + return false; + } + if ( + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len || + d.size() > MaxECComponentB64Len) + { + return false; + } + return pk_from_jwk_ec_private(pk, crv, x, y, d); + } + if (kty == "OKP") + { + std::string_view d = json_select_string(ast, "/d"); + if (d.empty() || d.size() > MaxOKPComponentB64Len) + { + return false; + } + ed25519_raw = ::base64_decode(d); + return true; + } + return false; + } + + // ── Signing helpers ── + + std::string sign_hmac( + Algorithm algo, std::string_view signing_input, std::string_view secret) + { + mbedtls_md_type_t type = md_type_for_algo(algo); + const mbedtls_md_info_t* info = md_info_from_type(type); + if (!info) + { + throw std::runtime_error("unsupported digest type for HMAC signing"); + } + + unsigned char buf[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md_hmac( + info, + reinterpret_cast(secret.data()), + secret.size(), + reinterpret_cast(signing_input.data()), + signing_input.size(), + buf); + if (ret != 0) + { + throw std::runtime_error("HMAC signing failed"); + } + + size_t md_size = mbedtls_md_get_size(info); + return std::string(reinterpret_cast(buf), md_size); + } + + std::string sign_rsa( + mbedtls_md_type_t md_type, + PkCtx& pk, + std::string_view signing_input, + bool use_pss) + { + const mbedtls_md_info_t* info = md_info_from_type(md_type); + if (!info) + { + throw std::runtime_error("unsupported digest type for RSA signing"); + } + + unsigned char hash[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(signing_input.data()), + signing_input.size(), + hash); + if (ret != 0) + { + throw std::runtime_error("hash for RSA signing failed"); + } + + size_t hash_len = mbedtls_md_get_size(info); + mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + + if (use_pss) + { + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_type); + } + else + { + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15, md_type); + } + + size_t sig_len = mbedtls_rsa_get_len(rsa); + std::vector sig(sig_len); + + auto& rng = get_rng(); + if (use_pss) + { + ret = mbedtls_rsa_rsassa_pss_sign( + rsa, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx, + md_type, + static_cast(hash_len), + hash, + sig.data()); + } + else + { + ret = mbedtls_rsa_rsassa_pkcs1_v15_sign( + rsa, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx, + md_type, + static_cast(hash_len), + hash, + sig.data()); + } + + if (ret != 0) + { + throw std::runtime_error("RSA signing failed"); + } + + return std::string(reinterpret_cast(sig.data()), sig.size()); + } + + size_t ecdsa_component_size(Algorithm algo) + { + switch (algo) + { + case Algorithm::ES256: + return 32; + case Algorithm::ES384: + return 48; + case Algorithm::ES512: + return 66; + default: + return 0; + } + } + + std::string sign_ecdsa( + mbedtls_md_type_t md_type, + PkCtx& pk, + std::string_view signing_input, + Algorithm algo) + { + const mbedtls_md_info_t* info = md_info_from_type(md_type); + if (!info) + { + throw std::runtime_error("unsupported digest type for ECDSA signing"); + } + + unsigned char hash[MBEDTLS_MD_MAX_SIZE]; + int ret = mbedtls_md( + info, + reinterpret_cast(signing_input.data()), + signing_input.size(), + hash); + if (ret != 0) + { + throw std::runtime_error("hash for ECDSA signing failed"); + } + + size_t hash_len = mbedtls_md_get_size(info); + + // Sign — produces DER-encoded signature + unsigned char der_sig[256]; + size_t der_len = 0; + auto& rng = get_rng(); + + ret = mbedtls_pk_sign( + &pk.ctx, + md_type, + hash, + hash_len, + der_sig, + sizeof(der_sig), + &der_len, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx); + if (ret != 0) + { + throw std::runtime_error("ECDSA signing failed"); + } + + // Convert DER to raw R||S + // Parse the DER SEQUENCE { INTEGER r, INTEGER s } + unsigned char* p = der_sig; + unsigned char* end = der_sig + der_len; + size_t seq_len; + ret = mbedtls_asn1_get_tag( + &p, end, &seq_len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE); + if (ret != 0) + { + throw std::runtime_error("failed to parse ECDSA DER signature"); + } + + Mpi r, s; + ret = mbedtls_asn1_get_mpi(&p, end, &r.val); + if (ret != 0) + { + throw std::runtime_error("failed to parse ECDSA DER r"); + } + ret = mbedtls_asn1_get_mpi(&p, end, &s.val); + if (ret != 0) + { + throw std::runtime_error("failed to parse ECDSA DER s"); + } + + size_t comp_size = ecdsa_component_size(algo); + std::vector raw(comp_size * 2, 0); + mbedtls_mpi_write_binary(&r.val, raw.data(), comp_size); + mbedtls_mpi_write_binary(&s.val, raw.data() + comp_size, comp_size); + + return std::string(reinterpret_cast(raw.data()), raw.size()); + } + + // ── X.509 helpers ── + + std::string get_common_name(const mbedtls_x509_name* name) + { + if (!name) + { + return {}; + } + + // Walk the name list looking for OID 2.5.4.3 (CommonName) + const mbedtls_x509_name* cur = name; + while (cur) + { + // OID for CN: 0x55 0x04 0x03 + if ( + cur->oid.len == 3 && cur->oid.p[0] == 0x55 && cur->oid.p[1] == 0x04 && + cur->oid.p[2] == 0x03) + { + return std::string( + reinterpret_cast(cur->val.p), cur->val.len); + } + cur = cur->next; + } + return {}; + } + + void extract_sans( + const mbedtls_x509_crt& crt, + std::vector& dns_names, + std::vector& uri_strings) + { + const mbedtls_x509_sequence* cur = &crt.subject_alt_names; + while (cur) + { + if (cur->buf.len == 0) + { + cur = cur->next; + continue; + } + + // The tag encodes the SAN type: + // tag & 0x1F == 2 → dNSName + // tag & 0x1F == 6 → uniformResourceIdentifier + unsigned char tag = cur->buf.tag; + int san_type = tag & 0x1F; + + if (san_type == 2) // dNSName + { + // Need to parse the ASN.1 value from the raw buffer + // mbedtls stores the raw ASN.1 data; we need to extract the string + mbedtls_x509_subject_alternative_name san; + int ret = mbedtls_x509_parse_subject_alt_name(&cur->buf, &san); + if (ret == 0 && san.type == MBEDTLS_X509_SAN_DNS_NAME) + { + dns_names.emplace_back( + reinterpret_cast(san.san.unstructured_name.p), + san.san.unstructured_name.len); + } + mbedtls_x509_free_subject_alt_name(&san); + } + else if (san_type == 6) // URI + { + mbedtls_x509_subject_alternative_name san; + int ret = mbedtls_x509_parse_subject_alt_name(&cur->buf, &san); + if ( + ret == 0 && san.type == MBEDTLS_X509_SAN_UNIFORM_RESOURCE_IDENTIFIER) + { + uri_strings.emplace_back( + reinterpret_cast(san.san.unstructured_name.p), + san.san.unstructured_name.len); + } + mbedtls_x509_free_subject_alt_name(&san); + } + + cur = cur->next; + } + } + + std::string cert_to_der_b64(const mbedtls_x509_crt& crt) + { + std::string_view der(reinterpret_cast(crt.raw.p), crt.raw.len); + return ::base64_encode(der, false); + } + + using rego::crypto_core::ParsedCertificate; + + ParsedCertificate cert_to_parsed(const mbedtls_x509_crt& crt) + { + ParsedCertificate pc; + pc.subject.common_name = get_common_name(&crt.subject); + extract_sans(crt, pc.dns_names, pc.uri_strings); + pc.der_b64 = cert_to_der_b64(crt); + return pc; + } + + // MPI to base64url (for JWK export of RSA private keys) + std::string mpi_to_base64url(const mbedtls_mpi& mpi) + { + size_t len = mbedtls_mpi_size(&mpi); + std::vector buf(len); + mbedtls_mpi_write_binary(&mpi, buf.data(), len); + std::string_view sv(reinterpret_cast(buf.data()), buf.size()); + return rego::crypto_core::base64url_encode_nopad_impl(sv); + } +} + +// ── Public API implementation ── + +namespace rego::crypto_core +{ + std::string md5_hex(std::string_view data) + { + return digest_hex(MBEDTLS_MD_MD5, data); + } + + std::string sha1_hex(std::string_view data) + { + return digest_hex(MBEDTLS_MD_SHA1, data); + } + + std::string sha256_hex(std::string_view data) + { + return digest_hex(MBEDTLS_MD_SHA256, data); + } + + std::string hmac_md5_hex(std::string_view key, std::string_view data) + { + return hmac_hex(MBEDTLS_MD_MD5, key, data); + } + + std::string hmac_sha1_hex(std::string_view key, std::string_view data) + { + return hmac_hex(MBEDTLS_MD_SHA1, key, data); + } + + std::string hmac_sha256_hex(std::string_view key, std::string_view data) + { + return hmac_hex(MBEDTLS_MD_SHA256, key, data); + } + + std::string hmac_sha512_hex(std::string_view key, std::string_view data) + { + return hmac_hex(MBEDTLS_MD_SHA512, key, data); + } + + bool hmac_equal(std::string_view mac1, std::string_view mac2) + { + return hmac_equal_impl(mac1, mac2); + } + + std::string base64url_encode_nopad(std::string_view data) + { + return base64url_encode_nopad_impl(data); + } + + std::string base64url_decode(std::string_view data) + { + return base64url_decode_impl(data); + } + + Algorithm parse_algorithm(std::string_view name) + { + return parse_algorithm_impl(name); + } + + // Single-entry thread-local cache for parsed public keys. + // Avoids re-parsing the same PEM/JWK key on repeated JWT verifications + // with the same issuer key. + struct PkCache + { + std::string key_str; + PkCtx* pk = nullptr; + + ~PkCache() + { + delete pk; + } + + PkCache() = default; + PkCache(const PkCache&) = delete; + PkCache& operator=(const PkCache&) = delete; + + // Returns a cached PkCtx if the key matches, otherwise parses + // the new key, caches it, and returns it. Returns nullptr on failure. + PkCtx* get(std::string_view key) + { + if (pk != nullptr && key_str == key) + { + return pk; + } + + delete pk; + pk = new PkCtx(); + key_str.assign(key.data(), key.size()); + + if (!load_public_key(*pk, key)) + { + delete pk; + pk = nullptr; + key_str.clear(); + return nullptr; + } + + return pk; + } + }; + + VerifyResult verify_signature( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + // HMAC algorithms use the key directly as a secret + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + bool ok = verify_hmac(algo, signing_input, signature_bytes, key_or_cert); + return {ok, {}}; + } + + // Validate PEM structure before attempting to parse + std::string pem_err = validate_pem(key_or_cert); + if (!pem_err.empty()) + { + return {false, pem_err}; + } + + // EdDSA: not supported in mbedtls backend + if (algo == Algorithm::EdDSA) + { + return {false, "EdDSA algorithm is not supported"}; + } + + // Asymmetric (RSA/ECDSA): load public key (cached) + thread_local PkCache pk_cache; + PkCtx* pk = pk_cache.get(key_or_cert); + if (pk == nullptr) + { + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != + std::string_view::npos) + { + return {false, "failed to parse a PEM certificate"}; + } + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != + std::string_view::npos) + { + return {false, "failed to parse a PEM key"}; + } + if ( + key_or_cert.find("\"kty\"") != std::string_view::npos || + key_or_cert.find("\"keys\"") != std::string_view::npos) + { + return {false, "failed to parse a JWK key (set)"}; + } + return {false, {}}; + } + + mbedtls_md_type_t md_type = md_type_for_algo(algo); + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = verify_rsa_pkcs1(md_type, *pk, signing_input, signature_bytes); + break; + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = verify_rsa_pss(md_type, *pk, signing_input, signature_bytes); + break; + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(md_type, *pk, signing_input, signature_bytes); + break; + + default: + break; + } + + return {ok, {}}; + } + + VerifyResult verify_signature_any_key( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + auto ast = parse_json(key_or_cert); + if (!ast) + { + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + auto keys = extract_jwks_keys(ast); + if (keys.size() <= 1) + { + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + for (auto& key_ast : keys) + { + // EdDSA not supported in mbedtls backend + if (algo == Algorithm::EdDSA) + { + return {false, "EdDSA algorithm is not supported"}; + } + + PkCtx pk; + if (!pk_from_jwk_ast(pk, key_ast)) + { + continue; + } + + mbedtls_md_type_t md_type = md_type_for_algo(algo); + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = verify_rsa_pkcs1(md_type, pk, signing_input, signature_bytes); + break; + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = verify_rsa_pss(md_type, pk, signing_input, signature_bytes); + break; + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(md_type, pk, signing_input, signature_bytes); + break; + default: + break; + } + + if (ok) + { + return {true, {}}; + } + } + + return {false, {}}; + } + + std::string sign( + Algorithm algo, + std::string_view signing_input, + std::string_view key_jwk_json) + { + auto ast = parse_json(key_jwk_json); + if (!ast) + { + throw std::runtime_error("failed to parse JWK JSON"); + } + + // HMAC: extract the "k" field + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + std::string_view k = json_select_string(ast, "/k"); + if (k.empty()) + { + throw std::runtime_error("missing 'k' in oct JWK"); + } + crypto_core::SecureString secret(::base64_decode(k)); + return sign_hmac(algo, signing_input, secret.value); + } + + if (algo == Algorithm::EdDSA) + { + throw std::runtime_error("EdDSA algorithm is not supported"); + } + + PkCtx pk; + crypto_core::SecureString ed25519_raw; + if (!load_private_key_ast(pk, ed25519_raw.value, ast)) + { + throw std::runtime_error("failed to load private key from JWK"); + } + + mbedtls_md_type_t md_type = md_type_for_algo(algo); + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + return sign_rsa(md_type, pk, signing_input, false); + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + return sign_rsa(md_type, pk, signing_input, true); + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + return sign_ecdsa(md_type, pk, signing_input, algo); + + default: + throw std::runtime_error("unsupported algorithm for signing"); + } + } + + // ── X.509 Certificate Parsing ── + + // Parse concatenated DER certificates into an X509Crt chain. + // mbedtls_x509_crt_parse_der only parses one cert, so we must loop. + int parse_der_chain(X509Crt& chain, const unsigned char* data, size_t len) + { + const unsigned char* p = data; + const unsigned char* end = data + len; + int total_ret = 0; + bool any_ok = false; + + while (p < end) + { + // Each DER cert starts with SEQUENCE tag (0x30) followed by length + if (*p != 0x30) + { + break; + } + + // Peek at the length to determine cert size + const unsigned char* lp = p + 1; + size_t cert_len = 0; + if (lp >= end) + { + break; + } + + if (*lp < 0x80) + { + cert_len = *lp; + cert_len += 2; // tag + 1-byte length + } + else + { + size_t num_bytes = *lp & 0x7F; + if (num_bytes == 0 || num_bytes > 4 || lp + num_bytes >= end) + { + break; + } + for (size_t i = 0; i < num_bytes; ++i) + { + if (cert_len > (SIZE_MAX >> 8)) + { + return any_ok ? 0 : -1; // overflow guard + } + cert_len = (cert_len << 8) | lp[1 + i]; + } + cert_len += 2 + num_bytes; // tag + length-of-length + length bytes + } + + if (p + cert_len > end) + { + break; + } + + int ret = mbedtls_x509_crt_parse_der(&chain.crt, p, cert_len); + if (ret == 0) + { + any_ok = true; + } + else + { + total_ret = ret; + } + p += cert_len; + } + + return any_ok ? 0 : total_ret; + } + + ParseCertsResult parse_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + X509Crt chain; + int ret; + if (decoded.is_pem) + { + // mbedtls requires null-terminated PEM + std::string pem_str(decoded.data); + pem_str.push_back('\0'); + ret = mbedtls_x509_crt_parse( + &chain.crt, + reinterpret_cast(pem_str.data()), + pem_str.size()); + } + else + { + ret = parse_der_chain( + chain, + reinterpret_cast(decoded.data.data()), + decoded.data.size()); + } + + if (ret != 0 && chain.crt.raw.len == 0) + { + return {{}, "x509: malformed certificate"}; + } + + ParseCertsResult result; + const mbedtls_x509_crt* cur = &chain.crt; + while (cur && cur->raw.len > 0) + { + result.certs.push_back(cert_to_parsed(*cur)); + cur = cur->next; + } + + if (result.certs.empty()) + { + return {{}, "x509: malformed certificate"}; + } + return result; + } + + ParseCSRResult parse_certificate_request(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + X509Csr csr; + int ret; + if (decoded.is_pem) + { + std::string pem_str(decoded.data); + pem_str.push_back('\0'); + ret = mbedtls_x509_csr_parse( + &csr.csr, + reinterpret_cast(pem_str.data()), + pem_str.size()); + } + else + { + ret = mbedtls_x509_csr_parse_der( + &csr.csr, + reinterpret_cast(decoded.data.data()), + decoded.data.size()); + } + + if (ret != 0) + { + return {{}, "asn1: structure error"}; + } + + ParseCSRResult result; + result.subject.common_name = get_common_name(&csr.csr.subject); + return result; + } + + VerifyCertsResult parse_and_verify_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {false, {}, decoded.error}; + } + + X509Crt chain; + int ret; + if (decoded.is_pem) + { + std::string pem_str(decoded.data); + pem_str.push_back('\0'); + ret = mbedtls_x509_crt_parse( + &chain.crt, + reinterpret_cast(pem_str.data()), + pem_str.size()); + } + else + { + ret = parse_der_chain( + chain, + reinterpret_cast(decoded.data.data()), + decoded.data.size()); + } + + // Collect all parsed certs and their DER data + struct CertInfo + { + ParsedCertificate parsed; + const unsigned char* raw_p; + size_t raw_len; + bool self_signed; + }; + std::vector all_certs; + const mbedtls_x509_crt* cur = &chain.crt; + while (cur && cur->raw.len > 0) + { + bool ss = + (cur->issuer_raw.len == cur->subject_raw.len && + memcmp(cur->issuer_raw.p, cur->subject_raw.p, cur->issuer_raw.len) == + 0); + all_certs.push_back({cert_to_parsed(*cur), cur->raw.p, cur->raw.len, ss}); + cur = cur->next; + } + + if (all_certs.size() < 2) + { + VerifyCertsResult result; + result.valid = false; + for (auto& ci : all_certs) + { + result.certs.push_back(ci.parsed); + } + return result; + } + + // OPA convention: last cert is the leaf. All others are CA or + // intermediate. Build a leaf chain (leaf + intermediates) and a separate + // CA chain (self-signed roots) for mbedtls_x509_crt_verify. + X509Crt leaf_chain; + X509Crt ca_chain; + + // Parse leaf (last cert) first into the leaf chain + size_t leaf_idx = all_certs.size() - 1; + mbedtls_x509_crt_parse_der( + &leaf_chain.crt, all_certs[leaf_idx].raw_p, all_certs[leaf_idx].raw_len); + + // All certs except the leaf: self-signed → CA chain, otherwise → append + // to leaf chain as intermediates + for (size_t i = 0; i < leaf_idx; ++i) + { + if (all_certs[i].self_signed) + { + mbedtls_x509_crt_parse_der( + &ca_chain.crt, all_certs[i].raw_p, all_certs[i].raw_len); + } + else + { + mbedtls_x509_crt_parse_der( + &leaf_chain.crt, all_certs[i].raw_p, all_certs[i].raw_len); + } + } + + uint32_t flags = 0; + // NOTE: CRL and OCSP revocation checking is not performed (matching OPA). + // A revoked certificate will be accepted if otherwise valid. + ret = mbedtls_x509_crt_verify( + &leaf_chain.crt, + &ca_chain.crt, + nullptr, // no CRL + nullptr, // no CN to check + &flags, + nullptr, // no custom verification callback + nullptr); + + VerifyCertsResult result; + result.valid = (ret == 0); + + if (result.valid) + { + // Return leaf-first order (matching OPA): + // [leaf, intermediates..., root] + result.certs.push_back(all_certs[leaf_idx].parsed); + for (size_t i = 0; i < leaf_idx; ++i) + { + if (!all_certs[i].self_signed) + { + result.certs.push_back(all_certs[i].parsed); + } + } + for (size_t i = 0; i < leaf_idx; ++i) + { + if (all_certs[i].self_signed) + { + result.certs.push_back(all_certs[i].parsed); + } + } + } + else + { + for (auto& ci : all_certs) + { + result.certs.push_back(ci.parsed); + } + } + + return result; + } + + ParseRSAKeyResult parse_rsa_private_key(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + PkCtx pk; + int ret; + + auto& rng = get_rng(); + if (decoded.is_pem) + { + std::string pem_str(decoded.data); + pem_str.push_back('\0'); + ret = mbedtls_pk_parse_key( + &pk.ctx, + reinterpret_cast(pem_str.data()), + pem_str.size(), + nullptr, + 0, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx); + } + else + { + ret = mbedtls_pk_parse_key( + &pk.ctx, + reinterpret_cast(decoded.data.data()), + decoded.data.size(), + nullptr, + 0, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx); + } + + if (ret != 0 || !mbedtls_pk_can_do(&pk.ctx, MBEDTLS_PK_RSA)) + { + return {{}, "failed to parse RSA private key"}; + } + + const mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + + RSAPrivateKeyJWK jwk; + jwk.kty = "RSA"; + + // Export RSA components. mbedtls_rsa_export gives us N, P, Q, D, E + Mpi n, p, q, d, e; + mbedtls_rsa_export(rsa, &n.val, &p.val, &q.val, &d.val, &e.val); + + jwk.n = mpi_to_base64url(n.val); + jwk.e = mpi_to_base64url(e.val); + jwk.d = mpi_to_base64url(d.val); + jwk.p = mpi_to_base64url(p.val); + jwk.q = mpi_to_base64url(q.val); + + // Compute dp, dq, qi from d, p, q + Mpi dp, dq, qi, one, pm1, qm1; + mbedtls_mpi_lset(&one.val, 1); + + // dp = d mod (p-1) + mbedtls_mpi_sub_mpi(&pm1.val, &p.val, &one.val); + mbedtls_mpi_mod_mpi(&dp.val, &d.val, &pm1.val); + + // dq = d mod (q-1) + mbedtls_mpi_sub_mpi(&qm1.val, &q.val, &one.val); + mbedtls_mpi_mod_mpi(&dq.val, &d.val, &qm1.val); + + // qi = q^(-1) mod p + mbedtls_mpi_inv_mod(&qi.val, &q.val, &p.val); + + jwk.dp = mpi_to_base64url(dp.val); + jwk.dq = mpi_to_base64url(dq.val); + jwk.qi = mpi_to_base64url(qi.val); + + return {jwk, {}}; + } + + ParsePrivateKeysResult parse_private_keys(std::string_view input) + { + if (input.empty()) + { + return {{}, true, {}}; + } + + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, false, {}}; + } + + if (!decoded.is_pem) + { + return {{}, false, {}}; + } + + ParsePrivateKeysResult result; + result.is_empty_input = false; + + // Extract all PEM private key blocks + auto rsa_blocks = extract_pem_der_blocks(decoded.data, "RSA PRIVATE KEY"); + auto pkcs8_blocks = extract_pem_der_blocks(decoded.data, "PRIVATE KEY"); + + auto& rng = get_rng(); + + auto try_parse_rsa = [&](const std::string& der) { + PkCtx pk; + int ret = mbedtls_pk_parse_key( + &pk.ctx, + reinterpret_cast(der.data()), + der.size(), + nullptr, + 0, + mbedtls_ctr_drbg_random, + &rng.drbg.ctx); + + if (ret != 0 || !mbedtls_pk_can_do(&pk.ctx, MBEDTLS_PK_RSA)) + { + return; + } + + const mbedtls_rsa_context* rsa = mbedtls_pk_rsa(pk.ctx); + Mpi n, p, q, d, e; + mbedtls_rsa_export(rsa, &n.val, &p.val, &q.val, &d.val, &e.val); + + RSAPrivateKeyJWK jwk; + jwk.kty = "RSA"; + jwk.n = mpi_to_base64url(n.val); + jwk.e = mpi_to_base64url(e.val); + jwk.d = mpi_to_base64url(d.val); + jwk.p = mpi_to_base64url(p.val); + jwk.q = mpi_to_base64url(q.val); + + Mpi dp, dq, qi, one, pm1, qm1; + mbedtls_mpi_lset(&one.val, 1); + mbedtls_mpi_sub_mpi(&pm1.val, &p.val, &one.val); + mbedtls_mpi_mod_mpi(&dp.val, &d.val, &pm1.val); + mbedtls_mpi_sub_mpi(&qm1.val, &q.val, &one.val); + mbedtls_mpi_mod_mpi(&dq.val, &d.val, &qm1.val); + mbedtls_mpi_inv_mod(&qi.val, &q.val, &p.val); + + jwk.dp = mpi_to_base64url(dp.val); + jwk.dq = mpi_to_base64url(dq.val); + jwk.qi = mpi_to_base64url(qi.val); + + result.keys.push_back(jwk); + }; + + for (auto& der : rsa_blocks) + { + try_parse_rsa(der); + } + for (auto& der : pkcs8_blocks) + { + try_parse_rsa(der); + } + + return result; + } +} + +#endif // REGOCPP_CRYPTO_MBEDTLS diff --git a/src/builtins/crypto_openssl3.cc b/src/builtins/crypto_openssl3.cc new file mode 100644 index 00000000..c84222d4 --- /dev/null +++ b/src/builtins/crypto_openssl3.cc @@ -0,0 +1,1732 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#ifdef REGOCPP_CRYPTO_OPENSSL3 + +#include "base64/base64.h" +#include "crypto_core.hh" +#include "crypto_utils.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + using rego::crypto_core::to_hex; + + // Safe wrapper around BIO_new_mem_buf that rejects inputs exceeding + // INT_MAX, preventing undefined behaviour from the size_t → int cast. + BIO_ptr bio_from_mem(const void* data, size_t len) + { + if (len > static_cast(std::numeric_limits::max())) + { + return {nullptr, BIO_free}; + } + return {BIO_new_mem_buf(data, static_cast(len)), BIO_free}; + } + + std::string digest_hex(const EVP_MD* md, std::string_view data) + { + unsigned char buf[EVP_MAX_MD_SIZE]; + unsigned int len = 0; + + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + throw std::runtime_error("EVP_MD_CTX_new failed"); + } + + if ( + EVP_DigestInit_ex(ctx.get(), md, nullptr) != 1 || + EVP_DigestUpdate(ctx.get(), data.data(), data.size()) != 1 || + EVP_DigestFinal_ex(ctx.get(), buf, &len) != 1) + { + throw std::runtime_error("EVP_Digest failed"); + } + + return to_hex(buf, len); + } + + std::string hmac_hex( + const EVP_MD* md, std::string_view key, std::string_view data) + { + unsigned char buf[EVP_MAX_MD_SIZE]; + unsigned int len = 0; + + unsigned char* result = HMAC( + md, + key.data(), + static_cast(key.size()), + reinterpret_cast(data.data()), + data.size(), + buf, + &len); + if (!result) + { + throw std::runtime_error("HMAC failed"); + } + + return to_hex(buf, len); + } +} + +namespace rego::crypto_core +{ + std::string md5_hex(std::string_view data) + { + return digest_hex(EVP_md5(), data); + } + + std::string sha1_hex(std::string_view data) + { + return digest_hex(EVP_sha1(), data); + } + + std::string sha256_hex(std::string_view data) + { + return digest_hex(EVP_sha256(), data); + } + + std::string hmac_md5_hex(std::string_view key, std::string_view data) + { + return hmac_hex(EVP_md5(), key, data); + } + + std::string hmac_sha1_hex(std::string_view key, std::string_view data) + { + return hmac_hex(EVP_sha1(), key, data); + } + + std::string hmac_sha256_hex(std::string_view key, std::string_view data) + { + return hmac_hex(EVP_sha256(), key, data); + } + + std::string hmac_sha512_hex(std::string_view key, std::string_view data) + { + return hmac_hex(EVP_sha512(), key, data); + } + + bool hmac_equal(std::string_view mac1, std::string_view mac2) + { + return hmac_equal_impl(mac1, mac2); + } + + // ── Base64url ── + + std::string base64url_encode_nopad(std::string_view data) + { + return base64url_encode_nopad_impl(data); + } + + std::string base64url_decode(std::string_view data) + { + return base64url_decode_impl(data); + } + + // ── Algorithm parsing ── + + Algorithm parse_algorithm(std::string_view name) + { + return parse_algorithm_impl(name); + } + + // ── Key parsing helpers ── + + using EVP_PKEY_ptr = std::unique_ptr; + using BIO_ptr = std::unique_ptr; + using BIGNUM_ptr = std::unique_ptr; + using X509_ptr = std::unique_ptr; + + EVP_PKEY_ptr pkey_from_pem_pubkey(std::string_view pem) + { + BIO_ptr bio = bio_from_mem(pem.data(), pem.size()); + if (!bio) + { + return {nullptr, EVP_PKEY_free}; + } + EVP_PKEY* key = PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr); + return {key, EVP_PKEY_free}; + } + + EVP_PKEY_ptr pkey_from_certificate(std::string_view cert_pem) + { + BIO_ptr bio = bio_from_mem(cert_pem.data(), cert_pem.size()); + if (!bio) + { + return {nullptr, EVP_PKEY_free}; + } + X509_ptr cert( + PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr), X509_free); + if (!cert) + { + return {nullptr, EVP_PKEY_free}; + } + EVP_PKEY* key = X509_get_pubkey(cert.get()); + return {key, EVP_PKEY_free}; + } + + // Decode a base64url-encoded big integer into a BIGNUM + BIGNUM_ptr bn_from_base64url(std::string_view b64) + { + crypto_core::SecureString raw(::base64_decode(b64)); + BIGNUM* bn = BN_bin2bn( + reinterpret_cast(raw.data()), raw.size(), nullptr); + return {bn, BN_free}; + } + + // Parse a JWK RSA key (kty=RSA) into an EVP_PKEY (public key only) + // Uses OpenSSL 3.0 EVP_PKEY_fromdata API (no deprecated + // RSA_new/RSA_set0_key). + EVP_PKEY_ptr pkey_from_jwk_rsa(std::string_view n_b64, std::string_view e_b64) + { + BIGNUM_ptr n = bn_from_base64url(n_b64); + BIGNUM_ptr e = bn_from_base64url(e_b64); + if (!n || !e) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr bld( + OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); + if (!bld) + { + return {nullptr, EVP_PKEY_free}; + } + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_N, n.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_E, e.get()); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(bld.get()); + if (!params) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr ctx( + EVP_PKEY_CTX_new_from_name(nullptr, "RSA", nullptr), EVP_PKEY_CTX_free); + if (!ctx || EVP_PKEY_fromdata_init(ctx.get()) != 1) + { + OSSL_PARAM_free(params); + return {nullptr, EVP_PKEY_free}; + } + + EVP_PKEY* raw_pkey = nullptr; + int rc = + EVP_PKEY_fromdata(ctx.get(), &raw_pkey, EVP_PKEY_PUBLIC_KEY, params); + OSSL_PARAM_free(params); + if (rc != 1 || !raw_pkey) + { + return {nullptr, EVP_PKEY_free}; + } + return {raw_pkey, EVP_PKEY_free}; + } + + // Map JWK curve name to OpenSSL group name for EVP_PKEY_fromdata. + const char* ec_group_name(std::string_view crv) + { + if (crv == "P-256") + return SN_X9_62_prime256v1; + if (crv == "P-384") + return SN_secp384r1; + if (crv == "P-521") + return SN_secp521r1; + return nullptr; + } + + // Parse a JWK EC key (kty=EC) into an EVP_PKEY (public key only) + // Uses OpenSSL 3.0 EVP_PKEY_fromdata API (no deprecated EC_KEY*). + EVP_PKEY_ptr pkey_from_jwk_ec( + std::string_view crv, std::string_view x_b64, std::string_view y_b64) + { + const char* group_name = ec_group_name(crv); + if (!group_name) + { + return {nullptr, EVP_PKEY_free}; + } + + std::string x_raw = ::base64_decode(x_b64); + std::string y_raw = ::base64_decode(y_b64); + + // Build uncompressed point: 0x04 || x || y + std::vector point; + point.reserve(1 + x_raw.size() + y_raw.size()); + point.push_back(0x04); + point.insert( + point.end(), + reinterpret_cast(x_raw.data()), + reinterpret_cast(x_raw.data()) + x_raw.size()); + point.insert( + point.end(), + reinterpret_cast(y_raw.data()), + reinterpret_cast(y_raw.data()) + y_raw.size()); + + std::unique_ptr bld( + OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); + if (!bld) + { + return {nullptr, EVP_PKEY_free}; + } + OSSL_PARAM_BLD_push_utf8_string( + bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0); + OSSL_PARAM_BLD_push_octet_string( + bld.get(), OSSL_PKEY_PARAM_PUB_KEY, point.data(), point.size()); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(bld.get()); + if (!params) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr ctx( + EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free); + if (!ctx || EVP_PKEY_fromdata_init(ctx.get()) != 1) + { + OSSL_PARAM_free(params); + return {nullptr, EVP_PKEY_free}; + } + + EVP_PKEY* raw_pkey = nullptr; + int rc = + EVP_PKEY_fromdata(ctx.get(), &raw_pkey, EVP_PKEY_PUBLIC_KEY, params); + OSSL_PARAM_free(params); + if (rc != 1 || !raw_pkey) + { + return {nullptr, EVP_PKEY_free}; + } + return {raw_pkey, EVP_PKEY_free}; + } + + // Parse a JWK OKP key (kty=OKP, crv=Ed25519) into an EVP_PKEY + EVP_PKEY_ptr pkey_from_jwk_okp(std::string_view x_b64) + { + std::string x_raw = ::base64_decode(x_b64); + EVP_PKEY* pkey = EVP_PKEY_new_raw_public_key( + EVP_PKEY_ED25519, + nullptr, + reinterpret_cast(x_raw.data()), + x_raw.size()); + return {pkey, EVP_PKEY_free}; + } + + using crypto_core::extract_jwks_keys; + using crypto_core::json_select_string; + using crypto_core::parse_json; + + // Parse a JWK JSON AST into an EVP_PKEY (public key) + EVP_PKEY_ptr pkey_from_jwk_ast(const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxOKPComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + if ( + n.empty() || e.empty() || n.size() > MaxRSAComponentB64Len || + e.size() > MaxRSAComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + return pkey_from_jwk_rsa(n, e); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + if ( + crv.empty() || x.empty() || y.empty() || + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + return pkey_from_jwk_ec(crv, x, y); + } + if (kty == "OKP") + { + std::string_view x = json_select_string(ast, "/x"); + if (x.empty() || x.size() > MaxOKPComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + return pkey_from_jwk_okp(x); + } + return {nullptr, EVP_PKEY_free}; + } + + // Parse a JWK JSON string into an EVP_PKEY (public key) + EVP_PKEY_ptr pkey_from_jwk(std::string_view jwk_json) + { + auto ast = parse_json(jwk_json); + return pkey_from_jwk_ast(ast); + } + + // Auto-detect key format and load EVP_PKEY. + // Handles: PEM cert, PEM pubkey, JWK object, JWKS set. + EVP_PKEY_ptr load_public_key( + std::string_view key_or_cert, std::string_view kid = {}) + { + // PEM certificate + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != std::string_view::npos) + { + return pkey_from_certificate(key_or_cert); + } + + // PEM public key + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != std::string_view::npos) + { + return pkey_from_pem_pubkey(key_or_cert); + } + + // Try JSON (JWK or JWKS) + auto ast = parse_json(key_or_cert); + if (!ast) + { + return {nullptr, EVP_PKEY_free}; + } + + // Try JWKS first — look for "keys" array + auto keys = extract_jwks_keys(ast); + if (!keys.empty()) + { + // If kid specified, find matching key; otherwise use first + for (auto& key_ast : keys) + { + if (kid.empty()) + { + return pkey_from_jwk_ast(key_ast); + } + std::string_view key_kid = json_select_string(key_ast, "/kid"); + if (key_kid == kid) + { + return pkey_from_jwk_ast(key_ast); + } + } + return {nullptr, EVP_PKEY_free}; + } + + // Try single JWK object + std::string_view kty = json_select_string(ast, "/kty"); + if (!kty.empty()) + { + return pkey_from_jwk_ast(ast); + } + + return {nullptr, EVP_PKEY_free}; + } + + // Get the EVP_MD for an algorithm (nullptr for EdDSA which uses no digest) + const EVP_MD* md_for_algo(Algorithm algo) + { + switch (algo) + { + case Algorithm::HS256: + case Algorithm::RS256: + case Algorithm::PS256: + case Algorithm::ES256: + return EVP_sha256(); + case Algorithm::HS384: + case Algorithm::RS384: + case Algorithm::PS384: + case Algorithm::ES384: + return EVP_sha384(); + case Algorithm::HS512: + case Algorithm::RS512: + case Algorithm::PS512: + case Algorithm::ES512: + return EVP_sha512(); + case Algorithm::EdDSA: + return nullptr; + } + return nullptr; + } + + bool verify_hmac( + Algorithm algo, + std::string_view signing_input, + std::string_view sig_bytes, + std::string_view secret) + { + const EVP_MD* md = md_for_algo(algo); + unsigned char buf[EVP_MAX_MD_SIZE]; + unsigned int len = 0; + + unsigned char* result = HMAC( + md, + secret.data(), + static_cast(secret.size()), + reinterpret_cast(signing_input.data()), + signing_input.size(), + buf, + &len); + if (!result) + { + return false; + } + + // Constant-time comparison + if (len != sig_bytes.size()) + { + return false; + } + return CRYPTO_memcmp(buf, sig_bytes.data(), len) == 0; + } + + bool verify_rsa_pkcs1( + const EVP_MD* md, + EVP_PKEY* pkey, + std::string_view signing_input, + std::string_view sig_bytes) + { + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + return false; + } + if (EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, pkey) != 1) + { + return false; + } + return EVP_DigestVerify( + ctx.get(), + reinterpret_cast(sig_bytes.data()), + sig_bytes.size(), + reinterpret_cast(signing_input.data()), + signing_input.size()) == 1; + } + + bool verify_rsa_pss( + const EVP_MD* md, + EVP_PKEY* pkey, + std::string_view signing_input, + std::string_view sig_bytes) + { + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + return false; + } + EVP_PKEY_CTX* pctx = nullptr; + if (EVP_DigestVerifyInit(ctx.get(), &pctx, md, nullptr, pkey) != 1) + { + return false; + } + if (EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) != 1) + { + return false; + } + if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, RSA_PSS_SALTLEN_AUTO) != 1) + { + return false; + } + return EVP_DigestVerify( + ctx.get(), + reinterpret_cast(sig_bytes.data()), + sig_bytes.size(), + reinterpret_cast(signing_input.data()), + signing_input.size()) == 1; + } + + // Wrapper for OPENSSL_free (which is a macro and can't have its address + // taken) + void openssl_free(unsigned char* ptr) + { + OPENSSL_free(ptr); + } + + bool verify_ecdsa( + const EVP_MD* md, + EVP_PKEY* pkey, + std::string_view signing_input, + std::string_view sig_bytes) + { + // Validate signature length against the expected curve size. + // ES256 (SHA-256, P-256): 64 bytes, ES384 (SHA-384, P-384): 96 bytes, + // ES512 (SHA-512, P-521): 132 bytes. + size_t expected_sig_len = 0; + int md_nid = EVP_MD_nid(md); + switch (md_nid) + { + case NID_sha256: + expected_sig_len = 64; + break; + case NID_sha384: + expected_sig_len = 96; + break; + case NID_sha512: + expected_sig_len = 132; + break; + default: + return false; + } + if (sig_bytes.empty() || sig_bytes.size() != expected_sig_len) + { + return false; + } + + // ECDSA JWT signatures are in R||S raw format; OpenSSL expects DER. + // Convert raw (r || s) to DER-encoded signature. + size_t half = sig_bytes.size() / 2; + BIGNUM_ptr r( + BN_bin2bn( + reinterpret_cast(sig_bytes.data()), + half, + nullptr), + BN_free); + BIGNUM_ptr s( + BN_bin2bn( + reinterpret_cast(sig_bytes.data()) + half, + half, + nullptr), + BN_free); + if (!r || !s) + { + return false; + } + + ECDSA_SIG* ecdsa_sig = ECDSA_SIG_new(); + // ECDSA_SIG_set0 takes ownership + ECDSA_SIG_set0(ecdsa_sig, r.release(), s.release()); + + unsigned char* der = nullptr; + int der_len = i2d_ECDSA_SIG(ecdsa_sig, &der); + ECDSA_SIG_free(ecdsa_sig); + if (der_len <= 0) + { + return false; + } + + std::unique_ptr der_guard( + der, openssl_free); + + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + return false; + } + if (EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, pkey) != 1) + { + return false; + } + return EVP_DigestVerify( + ctx.get(), + der, + der_len, + reinterpret_cast(signing_input.data()), + signing_input.size()) == 1; + } + + bool verify_eddsa( + EVP_PKEY* pkey, std::string_view signing_input, std::string_view sig_bytes) + { + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + return false; + } + // EdDSA: md is NULL + if (EVP_DigestVerifyInit(ctx.get(), nullptr, nullptr, nullptr, pkey) != 1) + { + return false; + } + return EVP_DigestVerify( + ctx.get(), + reinterpret_cast(sig_bytes.data()), + sig_bytes.size(), + reinterpret_cast(signing_input.data()), + signing_input.size()) == 1; + } + + // ── Signature verification dispatch ── + + // Validate and classify PEM data. Returns an error string if the PEM + // is malformed, or empty string if OK (or not PEM at all). + VerifyResult verify_signature( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + // HMAC algorithms use the key directly as a secret + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + bool ok = verify_hmac(algo, signing_input, signature_bytes, key_or_cert); + return {ok, {}}; + } + + // Validate PEM structure before attempting to parse + std::string pem_err = validate_pem(key_or_cert); + if (!pem_err.empty()) + { + return {false, pem_err}; + } + + // Asymmetric: load the public key + EVP_PKEY_ptr pkey = load_public_key(key_or_cert); + if (!pkey) + { + // Determine what kind of key was attempted + if ( + key_or_cert.find("-----BEGIN CERTIFICATE-----") != + std::string_view::npos) + { + // Valid PEM structure (passed validate_pem) but OpenSSL couldn't parse + return {false, "failed to parse a PEM certificate"}; + } + if ( + key_or_cert.find("-----BEGIN PUBLIC KEY-----") != + std::string_view::npos) + { + return {false, "failed to parse a PEM key"}; + } + if ( + key_or_cert.find("\"kty\"") != std::string_view::npos || + key_or_cert.find("\"keys\"") != std::string_view::npos) + { + return {false, "failed to parse a JWK key (set)"}; + } + return {false, {}}; + } + + const EVP_MD* md = md_for_algo(algo); + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = verify_rsa_pkcs1(md, pkey.get(), signing_input, signature_bytes); + break; + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = verify_rsa_pss(md, pkey.get(), signing_input, signature_bytes); + break; + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(md, pkey.get(), signing_input, signature_bytes); + break; + + case Algorithm::EdDSA: + ok = verify_eddsa(pkey.get(), signing_input, signature_bytes); + break; + + default: + break; + } + + return {ok, {}}; + } + + VerifyResult verify_signature_any_key( + Algorithm algo, + std::string_view signing_input, + std::string_view signature_bytes, + std::string_view key_or_cert) + { + // Parse to check if this is a JWKS with multiple keys + auto ast = parse_json(key_or_cert); + if (!ast) + { + // Not JSON — use normal path (may be PEM) + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + auto keys = extract_jwks_keys(ast); + if (keys.size() <= 1) + { + // Not a JWKS or single key — use normal path + return verify_signature( + algo, signing_input, signature_bytes, key_or_cert); + } + + // Try each key; return success on first valid signature + for (auto& key_ast : keys) + { + auto pkey = pkey_from_jwk_ast(key_ast); + if (!pkey) + { + continue; + } + + const EVP_MD* md = md_for_algo(algo); + bool ok = false; + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + ok = verify_rsa_pkcs1(md, pkey.get(), signing_input, signature_bytes); + break; + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + ok = verify_rsa_pss(md, pkey.get(), signing_input, signature_bytes); + break; + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: + ok = verify_ecdsa(md, pkey.get(), signing_input, signature_bytes); + break; + case Algorithm::EdDSA: + ok = verify_eddsa(pkey.get(), signing_input, signature_bytes); + break; + default: + break; + } + + if (ok) + { + return {true, {}}; + } + } + + return {false, {}}; + } + + // ── Private key loading (for signing) ── + + EVP_PKEY_ptr pkey_from_jwk_rsa_private( + std::string_view n_b64, + std::string_view e_b64, + std::string_view d_b64, + std::string_view p_b64, + std::string_view q_b64, + std::string_view dp_b64, + std::string_view dq_b64, + std::string_view qi_b64) + { + BIGNUM_ptr n = bn_from_base64url(n_b64); + BIGNUM_ptr e = bn_from_base64url(e_b64); + BIGNUM_ptr d = bn_from_base64url(d_b64); + BIGNUM_ptr p = bn_from_base64url(p_b64); + BIGNUM_ptr q = bn_from_base64url(q_b64); + BIGNUM_ptr dp = bn_from_base64url(dp_b64); + BIGNUM_ptr dq = bn_from_base64url(dq_b64); + BIGNUM_ptr qi = bn_from_base64url(qi_b64); + if (!n || !e || !d || !p || !q || !dp || !dq || !qi) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr bld( + OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); + if (!bld) + { + return {nullptr, EVP_PKEY_free}; + } + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_N, n.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_E, e.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_D, d.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_FACTOR1, p.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_FACTOR2, q.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_EXPONENT1, dp.get()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_RSA_EXPONENT2, dq.get()); + OSSL_PARAM_BLD_push_BN( + bld.get(), OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qi.get()); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(bld.get()); + if (!params) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr ctx( + EVP_PKEY_CTX_new_from_name(nullptr, "RSA", nullptr), EVP_PKEY_CTX_free); + if (!ctx || EVP_PKEY_fromdata_init(ctx.get()) != 1) + { + OSSL_PARAM_free(params); + return {nullptr, EVP_PKEY_free}; + } + + EVP_PKEY* raw_pkey = nullptr; + int rc = EVP_PKEY_fromdata(ctx.get(), &raw_pkey, EVP_PKEY_KEYPAIR, params); + OSSL_PARAM_free(params); + if (rc != 1 || !raw_pkey) + { + return {nullptr, EVP_PKEY_free}; + } + return {raw_pkey, EVP_PKEY_free}; + } + + EVP_PKEY_ptr pkey_from_jwk_ec_private( + std::string_view crv, + std::string_view x_b64, + std::string_view y_b64, + std::string_view d_b64) + { + const char* group_name = ec_group_name(crv); + if (!group_name) + { + return {nullptr, EVP_PKEY_free}; + } + + crypto_core::SecureString x_raw(::base64_decode(x_b64)); + crypto_core::SecureString y_raw(::base64_decode(y_b64)); + crypto_core::SecureString d_raw(::base64_decode(d_b64)); + + // Build uncompressed point: 0x04 || x || y + std::vector point; + point.reserve(1 + x_raw.size() + y_raw.size()); + point.push_back(0x04); + point.insert( + point.end(), + reinterpret_cast(x_raw.data()), + reinterpret_cast(x_raw.data()) + x_raw.size()); + point.insert( + point.end(), + reinterpret_cast(y_raw.data()), + reinterpret_cast(y_raw.data()) + y_raw.size()); + + BIGNUM_ptr d_bn( + BN_bin2bn( + reinterpret_cast(d_raw.data()), + d_raw.size(), + nullptr), + BN_free); + if (!d_bn) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr bld( + OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); + if (!bld) + { + return {nullptr, EVP_PKEY_free}; + } + OSSL_PARAM_BLD_push_utf8_string( + bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0); + OSSL_PARAM_BLD_push_octet_string( + bld.get(), OSSL_PKEY_PARAM_PUB_KEY, point.data(), point.size()); + OSSL_PARAM_BLD_push_BN(bld.get(), OSSL_PKEY_PARAM_PRIV_KEY, d_bn.get()); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(bld.get()); + if (!params) + { + return {nullptr, EVP_PKEY_free}; + } + + std::unique_ptr ctx( + EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free); + if (!ctx || EVP_PKEY_fromdata_init(ctx.get()) != 1) + { + OSSL_PARAM_free(params); + return {nullptr, EVP_PKEY_free}; + } + + EVP_PKEY* raw_pkey = nullptr; + int rc = EVP_PKEY_fromdata(ctx.get(), &raw_pkey, EVP_PKEY_KEYPAIR, params); + OSSL_PARAM_free(params); + if (rc != 1 || !raw_pkey) + { + return {nullptr, EVP_PKEY_free}; + } + return {raw_pkey, EVP_PKEY_free}; + } + + EVP_PKEY_ptr pkey_from_jwk_okp_private(std::string_view d_b64) + { + crypto_core::SecureString d_raw(::base64_decode(d_b64)); + EVP_PKEY* pkey = EVP_PKEY_new_raw_private_key( + EVP_PKEY_ED25519, + nullptr, + reinterpret_cast(d_raw.data()), + d_raw.size()); + return {pkey, EVP_PKEY_free}; + } + + // Load a private key from a JWK JSON AST + EVP_PKEY_ptr load_private_key_ast(const trieste::Node& ast) + { + using rego::crypto_core::MaxECComponentB64Len; + using rego::crypto_core::MaxOKPComponentB64Len; + using rego::crypto_core::MaxRSAComponentB64Len; + std::string_view kty = json_select_string(ast, "/kty"); + if (kty == "RSA") + { + std::string_view n = json_select_string(ast, "/n"); + std::string_view e = json_select_string(ast, "/e"); + std::string_view d = json_select_string(ast, "/d"); + std::string_view p = json_select_string(ast, "/p"); + std::string_view q = json_select_string(ast, "/q"); + std::string_view dp = json_select_string(ast, "/dp"); + std::string_view dq = json_select_string(ast, "/dq"); + std::string_view qi = json_select_string(ast, "/qi"); + if (n.empty() || e.empty() || d.empty()) + { + return {nullptr, EVP_PKEY_free}; + } + for (auto sv : {n, e, d, p, q, dp, dq, qi}) + { + if (sv.size() > MaxRSAComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + } + return pkey_from_jwk_rsa_private(n, e, d, p, q, dp, dq, qi); + } + if (kty == "EC") + { + std::string_view crv = json_select_string(ast, "/crv"); + std::string_view x = json_select_string(ast, "/x"); + std::string_view y = json_select_string(ast, "/y"); + std::string_view d = json_select_string(ast, "/d"); + if (crv.empty() || x.empty() || y.empty() || d.empty()) + { + return {nullptr, EVP_PKEY_free}; + } + if ( + x.size() > MaxECComponentB64Len || y.size() > MaxECComponentB64Len || + d.size() > MaxECComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + return pkey_from_jwk_ec_private(crv, x, y, d); + } + if (kty == "OKP") + { + std::string_view d = json_select_string(ast, "/d"); + if (d.empty() || d.size() > MaxOKPComponentB64Len) + { + return {nullptr, EVP_PKEY_free}; + } + return pkey_from_jwk_okp_private(d); + } + return {nullptr, EVP_PKEY_free}; + } + + // Load a private key from a JWK JSON string + EVP_PKEY_ptr load_private_key(std::string_view jwk_json) + { + auto ast = parse_json(jwk_json); + return load_private_key_ast(ast); + } + + // ── Signing functions ── + + std::string sign_hmac( + Algorithm algo, std::string_view signing_input, std::string_view secret) + { + const EVP_MD* md = md_for_algo(algo); + unsigned char buf[EVP_MAX_MD_SIZE]; + unsigned int len = 0; + + unsigned char* result = HMAC( + md, + secret.data(), + static_cast(secret.size()), + reinterpret_cast(signing_input.data()), + signing_input.size(), + buf, + &len); + if (!result) + { + throw std::runtime_error("HMAC signing failed"); + } + return std::string(reinterpret_cast(buf), len); + } + + std::string sign_asymmetric( + const EVP_MD* md, + EVP_PKEY* pkey, + std::string_view signing_input, + int padding = 0) + { + std::unique_ptr ctx( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + if (!ctx) + { + throw std::runtime_error("EVP_MD_CTX_new failed"); + } + + EVP_PKEY_CTX* pctx = nullptr; + if (EVP_DigestSignInit(ctx.get(), &pctx, md, nullptr, pkey) != 1) + { + throw std::runtime_error("EVP_DigestSignInit failed"); + } + + if (padding == RSA_PKCS1_PSS_PADDING) + { + EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING); + EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, RSA_PSS_SALTLEN_DIGEST); + } + + // Determine signature length + size_t sig_len = 0; + if ( + EVP_DigestSign( + ctx.get(), + nullptr, + &sig_len, + reinterpret_cast(signing_input.data()), + signing_input.size()) != 1) + { + throw std::runtime_error("EVP_DigestSign (length) failed"); + } + + std::vector sig(sig_len); + if ( + EVP_DigestSign( + ctx.get(), + sig.data(), + &sig_len, + reinterpret_cast(signing_input.data()), + signing_input.size()) != 1) + { + throw std::runtime_error("EVP_DigestSign failed"); + } + sig.resize(sig_len); + return std::string(reinterpret_cast(sig.data()), sig.size()); + } + + // Convert DER-encoded ECDSA signature to raw R||S format + std::string ecdsa_der_to_raw( + const std::string& der_sig, size_t component_size) + { + const unsigned char* p = + reinterpret_cast(der_sig.data()); + ECDSA_SIG* ecdsa_sig = d2i_ECDSA_SIG(nullptr, &p, der_sig.size()); + if (!ecdsa_sig) + { + throw std::runtime_error("failed to decode ECDSA DER signature"); + } + + const BIGNUM* r = nullptr; + const BIGNUM* s = nullptr; + ECDSA_SIG_get0(ecdsa_sig, &r, &s); + + std::vector raw(component_size * 2, 0); + BN_bn2binpad(r, raw.data(), component_size); + BN_bn2binpad(s, raw.data() + component_size, component_size); + ECDSA_SIG_free(ecdsa_sig); + + return std::string(reinterpret_cast(raw.data()), raw.size()); + } + + size_t ecdsa_component_size(Algorithm algo) + { + switch (algo) + { + case Algorithm::ES256: + return 32; + case Algorithm::ES384: + return 48; + case Algorithm::ES512: + return 66; + default: + return 0; + } + } + + // ── sign() dispatch ── + + std::string sign( + Algorithm algo, + std::string_view signing_input, + std::string_view key_jwk_json) + { + auto ast = parse_json(key_jwk_json); + if (!ast) + { + throw std::runtime_error("failed to parse JWK JSON"); + } + + // HMAC: extract the "k" field (base64url-encoded secret) + if ( + algo == Algorithm::HS256 || algo == Algorithm::HS384 || + algo == Algorithm::HS512) + { + std::string_view k = json_select_string(ast, "/k"); + if (k.empty()) + { + throw std::runtime_error("missing 'k' in oct JWK"); + } + crypto_core::SecureString secret(::base64_decode(k)); + return sign_hmac(algo, signing_input, secret.value); + } + + // Asymmetric: load private key + EVP_PKEY_ptr pkey = load_private_key_ast(ast); + if (!pkey) + { + throw std::runtime_error("failed to load private key from JWK"); + } + + const EVP_MD* md = md_for_algo(algo); + + switch (algo) + { + case Algorithm::RS256: + case Algorithm::RS384: + case Algorithm::RS512: + return sign_asymmetric(md, pkey.get(), signing_input); + + case Algorithm::PS256: + case Algorithm::PS384: + case Algorithm::PS512: + return sign_asymmetric( + md, pkey.get(), signing_input, RSA_PKCS1_PSS_PADDING); + + case Algorithm::ES256: + case Algorithm::ES384: + case Algorithm::ES512: { + std::string der = sign_asymmetric(md, pkey.get(), signing_input); + return ecdsa_der_to_raw(der, ecdsa_component_size(algo)); + } + + case Algorithm::EdDSA: + return sign_asymmetric(nullptr, pkey.get(), signing_input); + + default: + throw std::runtime_error("unsupported algorithm for signing"); + } + } + + // ── X.509 Certificate Parsing ── + + using X509_STORE_ptr = + std::unique_ptr; + using X509_STORE_CTX_ptr = + std::unique_ptr; + using X509_REQ_ptr = std::unique_ptr; + + // Extract CommonName from an X509_NAME + std::string get_common_name(X509_NAME* name) + { + if (!name) + { + return {}; + } + int idx = X509_NAME_get_index_by_NID(name, NID_commonName, -1); + if (idx < 0) + { + return {}; + } + X509_NAME_ENTRY* entry = X509_NAME_get_entry(name, idx); + if (!entry) + { + return {}; + } + ASN1_STRING* data = X509_NAME_ENTRY_get_data(entry); + if (!data) + { + return {}; + } + unsigned char* utf8 = nullptr; + int len = ASN1_STRING_to_UTF8(&utf8, data); + if (len < 0) + { + return {}; + } + std::string result(reinterpret_cast(utf8), static_cast(len)); + OPENSSL_free(utf8); + return result; + } + + // Extract DNS names and URI strings from certificate SANs + void extract_sans( + X509* cert, + std::vector& dns_names, + std::vector& uri_strings) + { + GENERAL_NAMES* sans = static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + if (!sans) + { + return; + } + + for (int i = 0; i < sk_GENERAL_NAME_num(sans); ++i) + { + GENERAL_NAME* gen = sk_GENERAL_NAME_value(sans, i); + if (gen->type == GEN_DNS) + { + unsigned char* utf8 = nullptr; + int len = ASN1_STRING_to_UTF8(&utf8, gen->d.dNSName); + if (len >= 0) + { + dns_names.emplace_back( + reinterpret_cast(utf8), static_cast(len)); + OPENSSL_free(utf8); + } + } + else if (gen->type == GEN_URI) + { + unsigned char* utf8 = nullptr; + int len = ASN1_STRING_to_UTF8(&utf8, gen->d.uniformResourceIdentifier); + if (len >= 0) + { + uri_strings.emplace_back( + reinterpret_cast(utf8), static_cast(len)); + OPENSSL_free(utf8); + } + } + } + + GENERAL_NAMES_free(sans); + } + + // Convert an X509* to base64-encoded DER + std::string cert_to_der_b64(X509* cert) + { + unsigned char* buf = nullptr; + int len = i2d_X509(cert, &buf); + if (len <= 0) + { + return {}; + } + std::string_view der( + reinterpret_cast(buf), static_cast(len)); + std::string b64 = ::base64_encode(der, false); + OPENSSL_free(buf); + return b64; + } + + // Parse a single X509* into a ParsedCertificate + ParsedCertificate cert_to_parsed(X509* cert) + { + ParsedCertificate pc; + pc.subject.common_name = get_common_name(X509_get_subject_name(cert)); + extract_sans(cert, pc.dns_names, pc.uri_strings); + pc.der_b64 = cert_to_der_b64(cert); + return pc; + } + + // Parse PEM data into X509 objects by extracting DER from PEM blocks + // manually and feeding to d2i_X509, mirroring OPA's Go approach. + std::vector parse_pem_certs(std::string_view pem_data) + { + std::vector certs; + auto der_blocks = extract_pem_der_blocks(pem_data, "CERTIFICATE"); + + for (auto& der : der_blocks) + { + const unsigned char* p = + reinterpret_cast(der.data()); + X509* cert = d2i_X509(nullptr, &p, static_cast(der.size())); + if (cert) + { + certs.emplace_back(cert, X509_free); + } + else + { + ERR_clear_error(); + } + } + return certs; + } + + // Parse DER data into X509 objects (may be multiple concatenated) + std::vector parse_der_certs(std::string_view der_data) + { + std::vector certs; + const unsigned char* p = + reinterpret_cast(der_data.data()); + const unsigned char* end = p + der_data.size(); + + while (p < end) + { + const unsigned char* start = p; + X509* cert = d2i_X509(nullptr, &p, static_cast(end - start)); + if (!cert) + { + ERR_clear_error(); + break; + } + certs.emplace_back(cert, X509_free); + } + return certs; + } + + ParseCertsResult parse_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + std::vector certs; + if (decoded.is_pem) + { + certs = parse_pem_certs(decoded.data); + } + else + { + certs = parse_der_certs(decoded.data); + } + + if (certs.empty()) + { + return {{}, "x509: malformed certificate"}; + } + + ParseCertsResult result; + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + return result; + } + + ParseCSRResult parse_certificate_request(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + X509_REQ* req = nullptr; + if (decoded.is_pem) + { + BIO_ptr bio = bio_from_mem(decoded.data.data(), decoded.data.size()); + if (bio) + { + req = PEM_read_bio_X509_REQ(bio.get(), nullptr, nullptr, nullptr); + } + } + else + { + const unsigned char* p = + reinterpret_cast(decoded.data.data()); + req = d2i_X509_REQ(nullptr, &p, static_cast(decoded.data.size())); + } + + if (!req) + { + ERR_clear_error(); + return {{}, "asn1: structure error"}; + } + + X509_REQ_ptr req_ptr(req, X509_REQ_free); + ParseCSRResult result; + result.subject.common_name = + get_common_name(X509_REQ_get_subject_name(req)); + return result; + } + + VerifyCertsResult parse_and_verify_certificates(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {false, {}, decoded.error}; + } + + std::vector certs; + if (decoded.is_pem) + { + certs = parse_pem_certs(decoded.data); + } + else + { + certs = parse_der_certs(decoded.data); + } + + if (certs.size() < 2) + { + // Need at least a root CA and a leaf + VerifyCertsResult result; + result.valid = false; + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + return result; + } + + // Build the trust store and intermediate chain. + // OPA convention: last cert is leaf, first may be root, intermediates in + // between. We try to verify the chain by adding potential CA certs to the + // store and intermediates. + X509_STORE_ptr store(X509_STORE_new(), X509_STORE_free); + if (!store) + { + return {false, {}, "failed to create X509_STORE"}; + } + + // Add all certificates except the leaf as trusted or intermediates + STACK_OF(X509)* chain = sk_X509_new_null(); + for (size_t i = 0; i + 1 < certs.size(); ++i) + { + // Check if it's self-signed (potential root) + if (X509_check_issued(certs[i].get(), certs[i].get()) == X509_V_OK) + { + X509_STORE_add_cert(store.get(), certs[i].get()); + } + else + { + sk_X509_push(chain, certs[i].get()); + } + } + + // Leaf is the last certificate + X509* leaf = certs.back().get(); + + X509_STORE_CTX_ptr ctx(X509_STORE_CTX_new(), X509_STORE_CTX_free); + if (!ctx || !X509_STORE_CTX_init(ctx.get(), store.get(), leaf, chain)) + { + sk_X509_free(chain); + return {false, {}, "failed to init X509_STORE_CTX"}; + } + + // NOTE: CRL and OCSP revocation checking is not performed (matching OPA). + // A revoked certificate will be accepted if otherwise valid. + int verify_ok = X509_verify_cert(ctx.get()); + sk_X509_free(chain); + + VerifyCertsResult result; + result.valid = (verify_ok == 1); + + if (result.valid) + { + // Return the verified chain in leaf-first order (matching OPA behavior). + // X509_STORE_CTX_get0_chain returns the chain built by OpenSSL: + // leaf, intermediates..., root. + STACK_OF(X509)* verified_chain = X509_STORE_CTX_get0_chain(ctx.get()); + if (verified_chain) + { + int chain_len = sk_X509_num(verified_chain); + if (chain_len > static_cast(MaxCertChainLen)) + { + result.valid = false; + result.certs.clear(); + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + return result; + } + for (int i = 0; i < chain_len; ++i) + { + result.certs.push_back( + cert_to_parsed(sk_X509_value(verified_chain, i))); + } + } + } + else + { + // On failure, return certs in input order + for (auto& cert : certs) + { + result.certs.push_back(cert_to_parsed(cert.get())); + } + } + return result; + } + + // Extract an RSA BIGNUM component and return base64url-encoded (no padding) + std::string bn_to_base64url(const BIGNUM* bn) + { + if (!bn) + { + return {}; + } + int num_bytes = BN_num_bytes(bn); + std::vector buf(static_cast(num_bytes)); + BN_bn2bin(bn, buf.data()); + std::string_view sv(reinterpret_cast(buf.data()), buf.size()); + return base64url_encode_nopad(sv); + } + + ParseRSAKeyResult parse_rsa_private_key(std::string_view input) + { + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + return {{}, decoded.error}; + } + + EVP_PKEY* pkey = nullptr; + if (decoded.is_pem) + { + BIO_ptr bio = bio_from_mem(decoded.data.data(), decoded.data.size()); + if (bio) + { + pkey = PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr); + } + } + else + { + const unsigned char* p = + reinterpret_cast(decoded.data.data()); + pkey = d2i_PrivateKey( + EVP_PKEY_RSA, nullptr, &p, static_cast(decoded.data.size())); + } + + if (!pkey) + { + ERR_clear_error(); + return {{}, "failed to parse RSA private key"}; + } + + EVP_PKEY_ptr pkey_ptr(pkey, EVP_PKEY_free); + + // Extract RSA components using EVP_PKEY_get_bn_param (OpenSSL 3.0) + BIGNUM* n_bn = nullptr; + BIGNUM* e_bn = nullptr; + BIGNUM* d_bn = nullptr; + BIGNUM* p_bn = nullptr; + BIGNUM* q_bn = nullptr; + BIGNUM* dp_bn = nullptr; + BIGNUM* dq_bn = nullptr; + BIGNUM* qi_bn = nullptr; + + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &n_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &e_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_D, &d_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR1, &p_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR2, &q_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT1, &dp_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT2, &dq_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_COEFFICIENT1, &qi_bn); + + RSAPrivateKeyJWK jwk; + jwk.kty = "RSA"; + jwk.n = bn_to_base64url(n_bn); + jwk.e = bn_to_base64url(e_bn); + jwk.d = bn_to_base64url(d_bn); + jwk.p = bn_to_base64url(p_bn); + jwk.q = bn_to_base64url(q_bn); + jwk.dp = bn_to_base64url(dp_bn); + jwk.dq = bn_to_base64url(dq_bn); + jwk.qi = bn_to_base64url(qi_bn); + + BN_free(n_bn); + BN_free(e_bn); + BN_free(d_bn); + BN_free(p_bn); + BN_free(q_bn); + BN_free(dp_bn); + BN_free(dq_bn); + BN_free(qi_bn); + + return {jwk, {}}; + } + + ParsePrivateKeysResult parse_private_keys(std::string_view input) + { + if (input.empty()) + { + return {{}, true, {}}; + } + + DecodedInput decoded = decode_cert_input(input); + if (!decoded.error.empty()) + { + // For parse_private_keys, invalid input returns empty array, not error + return {{}, false, {}}; + } + + if (!decoded.is_pem) + { + // Must be PEM for private keys + return {{}, false, {}}; + } + + BIO_ptr bio = bio_from_mem(decoded.data.data(), decoded.data.size()); + if (!bio) + { + return {{}, false, {}}; + } + + ParsePrivateKeysResult result; + result.is_empty_input = false; + + while (true) + { + EVP_PKEY* pkey = + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr); + if (!pkey) + { + ERR_clear_error(); + break; + } + + EVP_PKEY_ptr pkey_ptr(pkey, EVP_PKEY_free); + + if (EVP_PKEY_is_a(pkey, "RSA")) + { + BIGNUM* n_bn = nullptr; + BIGNUM* e_bn = nullptr; + BIGNUM* d_bn = nullptr; + BIGNUM* p_bn = nullptr; + BIGNUM* q_bn = nullptr; + BIGNUM* dp_bn = nullptr; + BIGNUM* dq_bn = nullptr; + BIGNUM* qi_bn = nullptr; + + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &n_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &e_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_D, &d_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR1, &p_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR2, &q_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT1, &dp_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT2, &dq_bn); + EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_COEFFICIENT1, &qi_bn); + + RSAPrivateKeyJWK jwk; + jwk.kty = "RSA"; + jwk.n = bn_to_base64url(n_bn); + jwk.e = bn_to_base64url(e_bn); + jwk.d = bn_to_base64url(d_bn); + jwk.p = bn_to_base64url(p_bn); + jwk.q = bn_to_base64url(q_bn); + jwk.dp = bn_to_base64url(dp_bn); + jwk.dq = bn_to_base64url(dq_bn); + jwk.qi = bn_to_base64url(qi_bn); + + BN_free(n_bn); + BN_free(e_bn); + BN_free(d_bn); + BN_free(p_bn); + BN_free(q_bn); + BN_free(dp_bn); + BN_free(dq_bn); + BN_free(qi_bn); + + result.keys.push_back(jwk); + } + // Note: EC/OKP keys could be added here in the future + } + + return result; + } +} + +#endif // REGOCPP_CRYPTO_OPENSSL3 diff --git a/src/builtins/crypto_utils.hh b/src/builtins/crypto_utils.hh new file mode 100644 index 00000000..d006e159 --- /dev/null +++ b/src/builtins/crypto_utils.hh @@ -0,0 +1,418 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Platform-independent crypto utility functions shared between backends. + +#pragma once + +#ifdef REGOCPP_HAS_CRYPTO + +#include "base64/base64.h" + +#include +#include +#include +#include +#include + +namespace rego::crypto_core +{ + // Maximum base64url-encoded size (in bytes) accepted for a single JWK key + // component (n, e, d, p, q, dp, dq, qi, x, y). This bounds memory + // allocation during key parsing to prevent resource exhaustion from + // adversarial JWK inputs. 16 KB of base64url decodes to ~12 KB, which is + // far beyond any practical RSA modulus (8192-bit = 1024 bytes). + constexpr size_t MaxJWKComponentB64Len = 16384; + + // Per-algorithm JWK component size bounds (base64url-encoded bytes). + // These are tighter than MaxJWKComponentB64Len for algorithms where the + // key component sizes are well-known. + constexpr size_t MaxRSAComponentB64Len = 2048; // RSA-8192: ~1366 bytes + constexpr size_t MaxECComponentB64Len = 128; // P-521: ~88 bytes + constexpr size_t MaxOKPComponentB64Len = 128; // Ed25519: ~44 bytes + + // Maximum number of certificates allowed in a parsed chain. Bounds + // memory allocation and loop iterations during PEM/DER cert parsing + // and after chain verification in each backend. + constexpr size_t MaxCertChainLen = 256; + + // Maximum decoded size (in bytes) of a single PEM block. Prevents + // unbounded memory allocation from adversarial base64 payloads. + constexpr size_t MaxPEMBlockSize = 10 * 1024 * 1024; // 10 MB + + // ── Hex encoding ── + + inline std::string to_hex(const unsigned char* data, size_t len) + { + static const char hex_chars[] = "0123456789abcdef"; + std::string result; + result.reserve(len * 2); + for (size_t i = 0; i < len; ++i) + { + result += hex_chars[(data[i] >> 4) & 0x0F]; + result += hex_chars[data[i] & 0x0F]; + } + return result; + } + + // ── Constant-time HMAC comparison ── + + inline bool hmac_equal_impl(std::string_view mac1, std::string_view mac2) + { + if (mac1.size() != mac2.size()) + { + return false; + } + + volatile unsigned char result = 0; + for (size_t i = 0; i < mac1.size(); ++i) + { + result |= static_cast(mac1[i]) ^ + static_cast(mac2[i]); + } + return result == 0; + } + + // ── PEM validation ── + + // Validates that a PEM string contains a well-formed certificate or public + // key block. Returns an empty string on success, or an error message on + // failure. Shared across all crypto backends. + inline std::string validate_pem(std::string_view pem) + { + static const std::string_view cert_begin = "-----BEGIN CERTIFICATE-----"; + static const std::string_view cert_end = "-----END CERTIFICATE-----"; + static const std::string_view key_begin = "-----BEGIN PUBLIC KEY-----"; + static const std::string_view key_end = "-----END PUBLIC KEY-----"; + + auto cert_pos = pem.find(cert_begin); + if (cert_pos != std::string_view::npos) + { + auto end_pos = pem.find(cert_end, cert_pos); + if (end_pos == std::string_view::npos) + { + return "failed to parse a PEM certificate"; + } + auto after = end_pos + cert_end.size(); + auto remainder = pem.substr(after); + while (!remainder.empty() && + (remainder.front() == '\n' || remainder.front() == '\r' || + remainder.front() == ' ' || remainder.front() == '\t')) + { + remainder.remove_prefix(1); + } + if (!remainder.empty()) + { + return "extra data after a PEM certificate block"; + } + return {}; + } + + if ( + pem.find("-----BEGIN CERT") != std::string_view::npos && + pem.find(cert_begin) == std::string_view::npos) + { + return "failed to extract a Key from the PEM certificate"; + } + + auto key_pos = pem.find(key_begin); + if (key_pos != std::string_view::npos) + { + auto end_pos = pem.find(key_end, key_pos); + if (end_pos == std::string_view::npos) + { + return "failed to parse a PEM key"; + } + return {}; + } + + return {}; + } + + // ── Secure memory erasure ── + + // Zeroes the contents of a string before clearing it. Implemented in + // crypto.cc so that the platform-specific include (windows.h) does not + // leak into this header. + void secure_erase(std::string& s); + + // RAII guard that zeroes a std::string on destruction. + struct SecureString + { + std::string value; + + SecureString() = default; + explicit SecureString(std::string v) : value(std::move(v)) {} + ~SecureString() + { + secure_erase(value); + } + SecureString(const SecureString&) = delete; + SecureString& operator=(const SecureString&) = delete; + SecureString(SecureString&&) = default; + SecureString& operator=(SecureString&&) = default; + + const char* data() const + { + return value.data(); + } + size_t size() const + { + return value.size(); + } + bool empty() const + { + return value.empty(); + } + }; + + // ── Base64url ── + + inline std::string base64url_encode_nopad_impl(std::string_view data) + { + std::string encoded = ::base64_encode(data, true); + while (!encoded.empty() && encoded.back() == '=') + { + encoded.pop_back(); + } + return encoded; + } + + inline std::string base64url_decode_impl(std::string_view data) + { + return ::base64_decode(data); + } + + // ── Algorithm parsing ── + + inline Algorithm parse_algorithm_impl(std::string_view name) + { + if (name == "HS256") + return Algorithm::HS256; + if (name == "HS384") + return Algorithm::HS384; + if (name == "HS512") + return Algorithm::HS512; + if (name == "RS256") + return Algorithm::RS256; + if (name == "RS384") + return Algorithm::RS384; + if (name == "RS512") + return Algorithm::RS512; + if (name == "PS256") + return Algorithm::PS256; + if (name == "PS384") + return Algorithm::PS384; + if (name == "PS512") + return Algorithm::PS512; + if (name == "ES256") + return Algorithm::ES256; + if (name == "ES384") + return Algorithm::ES384; + if (name == "ES512") + return Algorithm::ES512; + if (name == "EdDSA") + return Algorithm::EdDSA; + throw std::invalid_argument( + std::string("unknown JWT algorithm: ") + std::string(name)); + } + + // ── PEM / DER input decoding ── + + // Decoded certificate input — result of decode_cert_input. + struct DecodedInput + { + std::string data; + bool is_pem; + std::string error; + }; + + // Decode input that may be: PEM string, base64-encoded PEM, or + // base64-encoded DER. Returns the decoded bytes and format. + inline DecodedInput decode_cert_input(std::string_view input) + { + // Direct PEM string + if (input.find("-----BEGIN") != std::string_view::npos) + { + return {std::string(input), true, {}}; + } + + // Count non-whitespace chars and reject if not a multiple of 4, + // since ::base64_decode is lenient about length but OPA is not. + size_t non_ws_count = 0; + for (char c : input) + { + if (c != '\n' && c != '\r' && c != ' ' && c != '\t') + { + non_ws_count++; + } + } + + if (non_ws_count == 0 || non_ws_count % 4 != 0) + { + return {{}, false, "illegal base64"}; + } + + // Try base64 decode — pos_of_char throws on invalid characters + try + { + std::string decoded = ::base64_decode(input); + if (decoded.empty()) + { + return {{}, false, "illegal base64"}; + } + // Check if the decoded data is PEM + if (decoded.find("-----BEGIN") != std::string::npos) + { + return {std::move(decoded), true, {}}; + } + // Otherwise it's raw DER + return {std::move(decoded), false, {}}; + } + catch (const std::exception&) + { + return {{}, false, "illegal base64"}; + } + } + + // Extract DER bytes from PEM blocks manually, like Go's pem.Decode. + inline std::vector extract_pem_der_blocks( + std::string_view pem_data, std::string_view block_type) + { + std::vector blocks; + std::string begin_marker = + std::string("-----BEGIN ") + std::string(block_type) + "-----"; + std::string end_marker = + std::string("-----END ") + std::string(block_type) + "-----"; + + size_t pos = 0; + while (pos < pem_data.size()) + { + if (blocks.size() >= MaxCertChainLen) + { + break; + } + + size_t begin = pem_data.find(begin_marker, pos); + if (begin == std::string_view::npos) + { + break; + } + size_t content_start = pem_data.find('\n', begin); + if (content_start == std::string_view::npos) + { + break; + } + content_start++; + + size_t end = pem_data.find(end_marker, content_start); + if (end == std::string_view::npos) + { + break; + } + + std::string b64; + for (size_t i = content_start; i < end; ++i) + { + char c = pem_data[i]; + if (c != '\n' && c != '\r' && c != ' ' && c != '\t') + { + b64 += c; + } + } + + try + { + // Check encoded size before decoding to avoid transient large + // allocations from adversarial PEM blocks. Base64 expands by ~4/3, + // so encoded size * 3/4 approximates decoded size. + if (b64.size() > MaxPEMBlockSize * 4 / 3) + { + pos = end + end_marker.size(); + continue; + } + auto decoded = ::base64_decode(b64); + if (decoded.size() > MaxPEMBlockSize) + { + // Skip oversized PEM blocks to prevent resource exhaustion + pos = end + end_marker.size(); + continue; + } + blocks.push_back(std::move(decoded)); + } + catch (const std::exception&) + { + // Skip malformed PEM blocks + } + + pos = end + end_marker.size(); + } + return blocks; + } + + // ── JSON helpers (shared across all crypto backends) ── + + inline trieste::Node parse_json(std::string_view json_str) + { + // Quick check: if the string doesn't start with '{' or '[' (after + // whitespace), it cannot be valid JSON. Skip the full parse to avoid + // generating hundreds of spurious error log messages when probing + // non-JSON inputs like PEM certificates. + auto pos = json_str.find_first_not_of(" \t\n\r"); + if ( + pos == std::string_view::npos || + (json_str[pos] != '{' && json_str[pos] != '[')) + { + return nullptr; + } + + std::string raw(json_str); + auto result = trieste::json::reader().synthetic(raw).read(); + if (!result.ok) + { + return nullptr; + } + return result.ast->front(); + } + + inline std::string_view json_select_string( + const trieste::Node& ast, const char* pointer) + { + if (!ast) + { + return {}; + } + auto val = trieste::json::select_string(ast, {pointer}); + if (!val.has_value()) + { + return {}; + } + return val->view(); + } + + inline std::vector extract_jwks_keys(const trieste::Node& ast) + { + std::vector result; + if (!ast) + { + return result; + } + auto keys_node = trieste::json::select(ast, {"/keys"}); + if ( + keys_node->type() == trieste::Error || + keys_node->type() != trieste::json::Array) + { + return result; + } + for (auto& elem : *keys_node) + { + if (elem->type() == trieste::json::Object) + { + result.push_back(elem); + } + } + return result; + } +} + +#endif // REGOCPP_HAS_CRYPTO diff --git a/src/builtins/internal.cc b/src/builtins/internal.cc index 6404fcc2..e44d3d83 100644 --- a/src/builtins/internal.cc +++ b/src/builtins/internal.cc @@ -1,4 +1,5 @@ #include "builtins.hh" +#include "trieste/json.h" namespace { @@ -52,7 +53,9 @@ namespace Node print(const Nodes& args) { - // TODO implement this properly + // Simplified print: does not wrap arguments in set comprehensions or + // support cross-product expansion. See GitHub issue #209 for full + // OPA-equivalent implementation. for (auto arg : args) { if (arg->type() == Undefined) @@ -85,6 +88,136 @@ namespace << bi::Void; return BuiltInDef::create({"internal.print"}, print_decl, ::print); } + + std::string stringify_value(const Node& node) + { + Node inner = node; + if (inner == Scalar) + { + inner = inner->front(); + } + + if (inner->type() == JSONString) + { + return get_string(inner); + } + + std::string key = to_key(inner, SetFormat::Rego, false, ", "); + // to_key renders strings with surrounding quotes (e.g. "foo") inside + // compound values. Those quotes must be escaped so the template result + // can be stored as a valid JSONString via Resolver::scalar. + std::ostringstream result; + for (char c : key) + { + if (c == '"') + { + result << '\\'; + } + result << c; + } + return result.str(); + } + + // Helper: append a string to an output stream, escaping control characters + // for valid JSON storage via Resolver::scalar. Backslashes and quotes from + // stringify_value are already correct and must not be re-escaped. + void escape_append(std::ostringstream& out, const std::string& s) + { + static const char hex_chars[] = "0123456789abcdef"; + for (char c : s) + { + switch (c) + { + case '\n': + out << '\\' << 'n'; + break; + case '\r': + out << '\\' << 'r'; + break; + case '\t': + out << '\\' << 't'; + break; + case '\b': + out << '\\' << 'b'; + break; + case '\f': + out << '\\' << 'f'; + break; + default: + if (static_cast(c) < 0x20) + { + // Escape remaining control characters as \uXXXX + out << "\\u00" + << hex_chars[(static_cast(c) >> 4) & 0x0F] + << hex_chars[static_cast(c) & 0x0F]; + } + else + { + out << c; + } + break; + } + } + } + + Node template_string(const Nodes& args) + { + Node arr = unwrap_arg( + args, UnwrapOpt(0).type(Array).func("internal.template_string")); + if (arr->type() == Error) + { + return arr; + } + + std::ostringstream buf; + for (const auto& elem : *arr) + { + // Array elements are Terms; access the inner node directly. + Node item = elem->front(); + + if (item == rego::Set) + { + if (item->size() == 0) + { + buf << ""; + } + else if (item->size() == 1) + { + escape_append(buf, stringify_value(item->front()->front())); + } + else + { + return err( + arr, + "eval_conflict_error: template-strings must not produce multiple " + "outputs", + EvalConflictError); + } + } + else + { + escape_append(buf, stringify_value(item)); + } + } + + return Resolver::scalar(buf.str()); + } + + BuiltIn template_string_factory() + { + const Node template_string_decl = + bi::Decl << (bi::ArgSeq + << (bi::Arg + << (bi::Name ^ "parts") + << (bi::Description ^ "array of template string parts") + << (bi::Type + << (bi::DynamicArray << (bi::Type << bi::Any))))) + << (bi::Result << (bi::Name ^ "output") + << (bi::Description ^ "composed template string") + << (bi::Type << bi::String)); + return BuiltInDef::create( + {"internal.template_string"}, template_string_decl, template_string); + } } namespace rego @@ -107,6 +240,10 @@ namespace rego { return print_factory(); } + else if (view == "template_string") + { + return template_string_factory(); + } return nullptr; } diff --git a/src/builtins/jwt.cc b/src/builtins/jwt.cc index d3b91bad..41ff568e 100644 --- a/src/builtins/jwt.cc +++ b/src/builtins/jwt.cc @@ -1,12 +1,898 @@ #include "builtins.hh" +#ifdef REGOCPP_HAS_CRYPTO +#include "crypto_core.hh" + +#include +#include +#include +#include +#endif + namespace { using namespace rego; - namespace bi = rego::builtins; + using namespace trieste; + namespace bi = builtins; const char* Message = "JSON Web Tokens are not supported"; +#ifdef REGOCPP_HAS_CRYPTO + + // ── JWT token structure helpers ── + + struct JWTParts + { + std::string_view header_b64; + std::string_view payload_b64; + std::string_view sig_b64; + std::string_view signing_input; // "header.payload" + }; + + // Split a JWT token into its three base64url sections. + // Returns false if the token doesn't have exactly 3 sections. + bool split_jwt(std::string_view token, JWTParts& parts) + { + auto dot1 = token.find('.'); + if (dot1 == std::string_view::npos) + { + return false; + } + auto dot2 = token.find('.', dot1 + 1); + if (dot2 == std::string_view::npos) + { + return false; + } + if (token.find('.', dot2 + 1) != std::string_view::npos) + { + return false; + } + + parts.header_b64 = token.substr(0, dot1); + parts.payload_b64 = token.substr(dot1 + 1, dot2 - dot1 - 1); + parts.sig_b64 = token.substr(dot2 + 1); + parts.signing_input = token.substr(0, dot2); + return true; + } + + std::size_t count_dots(std::string_view s) + { + std::size_t count = 0; + for (char c : s) + { + if (c == '.') + ++count; + } + return count; + } + + // Validate a base64url string. Returns the 0-based byte position of the + // first invalid character, or -1 if the input is valid. + int validate_base64url(std::string_view s) + { + for (int i = 0; i < static_cast(s.size()); ++i) + { + char c = s[i]; + bool valid = (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '-' || c == '_' || c == '='; + if (!valid) + { + return i; + } + } + return -1; + } + + // Parse a JSON string into a JSON AST (json:: namespace types). + // Returns nullptr on parse failure. + Node parse_json(const std::string& json_str) + { + auto result = ::json::reader().synthetic(json_str).read(); + if (!result.ok) + { + return nullptr; + } + return result.ast->front(); + } + + // Parse a JSON string into a Rego Term node (rego:: namespace types). + Node parse_json_to_term(const std::string& json_str) + { + auto result = ::json::reader().synthetic(json_str).wf_check_enabled(true) >> + json_to_rego(true); + if (!result.ok) + { + return nullptr; + } + return result.ast->front(); + } + + // Convert an existing json:: AST node to a Rego Term without re-parsing. + Node json_ast_to_term(const Node& json_node) + { + auto result = + json_to_rego(true).wf_check_enabled(true).rewrite(Top << json_node); + if (!result.ok) + { + return nullptr; + } + return result.ast->front(); + } + + // Convert raw bytes to hex string + std::string bytes_to_hex(std::string_view data) + { + static constexpr char hex[] = "0123456789abcdef"; + std::string result; + result.reserve(data.size() * 2); + for (unsigned char c : data) + { + result.push_back(hex[c >> 4]); + result.push_back(hex[c & 0x0f]); + } + return result; + } + + // Check if an "aud" array node contains a specific audience string + bool aud_array_contains(const Node& aud_node, std::string_view target) + { + for (auto& elem : *aud_node) + { + auto s = ::json::get_string(elem); + if (s.has_value() && s->view() == target) + { + return true; + } + } + return false; + } + + // Maximum nesting depth for nested JWT (cty: "JWT") to prevent stack + // overflow. Matches OPA's limit. + constexpr int MaxJWTNesting = 10; + + // ── io.jwt.decode implementation ── + + Node decode_impl(const Nodes& args, int depth = 0) + { + Node jwt_node = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func("io.jwt.decode")); + if (jwt_node->type() == Error) + { + return jwt_node; + } + + if (depth >= MaxJWTNesting) + { + return err(jwt_node, "nested JWT depth exceeded", EvalBuiltInError); + } + + std::string jwt_str = ::json::unescape(get_string(jwt_node)); + + // Check for no period separators + if (jwt_str.find('.') == std::string::npos) + { + return err( + jwt_node, "encoded JWT had no period separators", EvalBuiltInError); + } + + // Validate section count + std::size_t dots = count_dots(jwt_str); + if (dots != 2) + { + std::ostringstream msg; + msg << "encoded JWT must have 3 sections, found " << (dots + 1); + return err(jwt_node, msg.str(), EvalBuiltInError); + } + + JWTParts parts; + split_jwt(jwt_str, parts); + + // Validate base64url encoding of each section + int bad_byte = validate_base64url(parts.header_b64); + if (bad_byte >= 0) + { + std::ostringstream msg; + msg << "JWT header had invalid encoding: illegal base64 data at input " + "byte " + << bad_byte; + return err(jwt_node, msg.str(), EvalBuiltInError); + } + bad_byte = validate_base64url(parts.payload_b64); + if (bad_byte >= 0) + { + std::ostringstream msg; + msg << "JWT payload had invalid encoding: illegal base64 data at input " + "byte " + << bad_byte; + return err(jwt_node, msg.str(), EvalBuiltInError); + } + bad_byte = validate_base64url(parts.sig_b64); + if (bad_byte >= 0) + { + std::ostringstream msg; + msg << "JWT signature had invalid encoding: illegal base64 data at " + "input byte " + << bad_byte; + return err(jwt_node, msg.str(), EvalBuiltInError); + } + + // Decode and parse header + std::string header_json = crypto_core::base64url_decode(parts.header_b64); + Node header_ast = parse_json(header_json); + if (!header_ast) + { + return err(jwt_node, "failed to decode JWT header", EvalBuiltInError); + } + + // Check for JWE (encrypted JWT) + Node enc_node = ::json::select(header_ast, {"/enc"}); + if (enc_node->type() != Error) + { + return err( + jwt_node, + "JWT is a JWE object, which is not supported", + EvalBuiltInError); + } + + // Decode payload — handle nested JWT (cty: "JWT") + std::string payload_decoded = + crypto_core::base64url_decode(parts.payload_b64); + + auto cty = ::json::select_string(header_ast, {"/cty"}); + if (cty.has_value() && cty->view() == "JWT") + { + // The payload is an inner JWT token string. Recursively decode. + std::string inner = payload_decoded; + if (inner.size() >= 2 && inner.front() == '"' && inner.back() == '"') + { + inner = inner.substr(1, inner.size() - 2); + } + Node inner_token = JSONString ^ inner; + Nodes inner_args = {inner_token}; + return decode_impl(inner_args, depth + 1); + } + + // Convert already-parsed JSON ASTs to Rego terms for the return value + Node header = json_ast_to_term(header_ast); + if (!header) + { + return err(jwt_node, "failed to decode JWT header", EvalBuiltInError); + } + + Node payload_ast2 = parse_json(payload_decoded); + if (!payload_ast2) + { + return err(jwt_node, "failed to decode JWT payload", EvalBuiltInError); + } + Node payload = json_ast_to_term(payload_ast2); + if (!payload) + { + return err(jwt_node, "failed to decode JWT payload", EvalBuiltInError); + } + + // Decode signature to hex + std::string sig_raw = crypto_core::base64url_decode(parts.sig_b64); + std::string sig_hex = bytes_to_hex(sig_raw); + + return array({header, payload, rego::string(sig_hex)}); + } + + // ── io.jwt.verify_* implementation ── + + crypto_core::Algorithm algo_from_verify_name(std::string_view name) + { + if (name == "verify_hs256") + return crypto_core::Algorithm::HS256; + if (name == "verify_hs384") + return crypto_core::Algorithm::HS384; + if (name == "verify_hs512") + return crypto_core::Algorithm::HS512; + if (name == "verify_rs256") + return crypto_core::Algorithm::RS256; + if (name == "verify_rs384") + return crypto_core::Algorithm::RS384; + if (name == "verify_rs512") + return crypto_core::Algorithm::RS512; + if (name == "verify_ps256") + return crypto_core::Algorithm::PS256; + if (name == "verify_ps384") + return crypto_core::Algorithm::PS384; + if (name == "verify_ps512") + return crypto_core::Algorithm::PS512; + if (name == "verify_es256") + return crypto_core::Algorithm::ES256; + if (name == "verify_es384") + return crypto_core::Algorithm::ES384; + if (name == "verify_es512") + return crypto_core::Algorithm::ES512; + if (name == "verify_eddsa") + return crypto_core::Algorithm::EdDSA; + throw std::invalid_argument("unknown verify function"); + } + + Node verify_impl(std::string_view func_suffix, const Nodes& args) + { + std::string func_name = "io.jwt." + std::string(func_suffix); + + Node jwt_node = + unwrap_arg(args, UnwrapOpt(0).type(JSONString).func(func_name)); + if (jwt_node->type() == Error) + { + return jwt_node; + } + Node key_node = + unwrap_arg(args, UnwrapOpt(1).type(JSONString).func(func_name)); + if (key_node->type() == Error) + { + return key_node; + } + + std::string jwt_str = ::json::unescape(get_string(jwt_node)); + std::string key_str = ::json::unescape(get_string(key_node)); + if (jwt_str.find('.') == std::string::npos) + { + return err( + jwt_node, "encoded JWT had no period separators", EvalBuiltInError); + } + + std::size_t dots = count_dots(jwt_str); + if (dots != 2) + { + std::ostringstream msg; + msg << "encoded JWT must have 3 sections, found " << (dots + 1); + return err(jwt_node, msg.str(), EvalBuiltInError); + } + + JWTParts parts; + split_jwt(jwt_str, parts); + + crypto_core::Algorithm expected_algo = algo_from_verify_name(func_suffix); + + // Parse header to check algorithm + std::string header_json = crypto_core::base64url_decode(parts.header_b64); + Node header_ast = parse_json(header_json); + if (!header_ast) + { + return boolean(false); + } + + crypto_core::Algorithm token_algo; + try + { + auto alg = ::json::select_string(header_ast, {"/alg"}); + if (!alg.has_value()) + { + return boolean(false); + } + token_algo = crypto_core::parse_algorithm(alg->view()); + } + catch (...) + { + return boolean(false); + } + + std::string sig_raw = crypto_core::base64url_decode(parts.sig_b64); + + // Always perform signature verification even on algorithm mismatch + // to avoid leaking the expected algorithm through response timing. + auto result = crypto_core::verify_signature( + expected_algo, parts.signing_input, sig_raw, key_str); + + if (!result.error.empty()) + { + return err(jwt_node, result.error, EvalBuiltInError); + } + + return boolean(result.valid && token_algo == expected_algo); + } + + // ── Claim validation helpers ── + + // Check a simple string claim (iss, sub). Returns true if passes. + bool check_string_claim( + const Node& payload_ast, + const std::string& json_path, + const std::optional& constraint) + { + if (!constraint.has_value()) + { + return true; + } + auto value = ::json::select_string(payload_ast, {json_path}); + return value.has_value() && value->view() == constraint.value(); + } + + // Check the "aud" claim. Returns true if passes. + bool check_aud_claim( + const Node& payload_ast, const std::optional& constraint) + { + Node payload_aud = ::json::select(payload_ast, {"/aud"}); + bool payload_has_aud = payload_aud->type() != Error; + + if (constraint.has_value()) + { + if (!payload_has_aud) + { + return false; + } + if (payload_aud->type() == ::json::Array) + { + return aud_array_contains(payload_aud, constraint.value()); + } + auto aud_str = ::json::get_string(payload_aud); + return aud_str.has_value() && aud_str->view() == constraint.value(); + } + + // Token has aud but constraints don't specify aud — failure + return !payload_has_aud; + } + + // ── io.jwt.decode_verify implementation ── + + // Find a string value for a key in a Rego Object node (constraints) + std::optional get_object_string( + const Node& obj, const std::string& key) + { + for (auto& item : *obj) + { + if (item->type() != ObjectItem) + continue; + Node k = item->front(); + auto key_str = try_get_string(k); + if (key_str.has_value() && key_str.value() == key) + { + Node v = item->back(); + auto val_str = try_get_string(v); + if (val_str.has_value()) + { + return val_str; + } + } + } + return std::nullopt; + } + + // Find a numeric value for a key in a Rego Object node (constraints) + std::optional get_object_number( + const Node& obj, const std::string& key) + { + for (auto& item : *obj) + { + if (item->type() != ObjectItem) + continue; + Node k = item->front(); + auto key_str = try_get_string(k); + if (key_str.has_value() && key_str.value() == key) + { + Node v = item->back(); + auto val = try_get_double(v); + if (val.has_value()) + { + return val; + } + } + } + return std::nullopt; + } + + Node decode_verify_impl(const Nodes& args, int depth = 0) + { + Node jwt_node = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("io.jwt.decode_verify")); + if (jwt_node->type() == Error) + { + return jwt_node; + } + + if (depth >= MaxJWTNesting) + { + return err(jwt_node, "nested JWT depth exceeded", EvalBuiltInError); + } + + Node constraints = + unwrap_arg(args, UnwrapOpt(1).type(Object).func("io.jwt.decode_verify")); + if (constraints->type() == Error) + { + return constraints; + } + + auto make_failure = []() { + return array({boolean(false), object({}), object({})}); + }; + + std::string jwt_str = ::json::unescape(get_string(jwt_node)); + + if (jwt_str.find('.') == std::string::npos) + { + return err( + jwt_node, "encoded JWT had no period separators", EvalBuiltInError); + } + + std::size_t dots = count_dots(jwt_str); + if (dots != 2) + { + std::ostringstream msg; + msg << "encoded JWT must have 3 sections, found " << (dots + 1); + return err(jwt_node, msg.str(), EvalBuiltInError); + } + + JWTParts parts; + split_jwt(jwt_str, parts); + + // Decode and parse header into JSON AST + std::string header_json = crypto_core::base64url_decode(parts.header_b64); + Node header_ast = parse_json(header_json); + if (!header_ast) + { + return err(jwt_node, "failed to decode JWT header", EvalBuiltInError); + } + + // Check for JWE + Node enc_node = ::json::select(header_ast, {"/enc"}); + if (enc_node->type() != Error) + { + return err( + jwt_node, + "JWT is a JWE object, which is not supported", + EvalBuiltInError); + } + + // Check for critical extensions — we don't support any + Node crit_node = ::json::select(header_ast, {"/crit"}); + if (crit_node->type() != Error) + { + return make_failure(); + } + + // Extract the algorithm from the header + crypto_core::Algorithm token_algo; + std::string alg_str; + try + { + auto alg = ::json::select_string(header_ast, {"/alg"}); + if (!alg.has_value()) + { + return make_failure(); + } + alg_str = std::string(alg->view()); + token_algo = crypto_core::parse_algorithm(alg_str); + } + catch (...) + { + return make_failure(); + } + + // Check "alg" constraint — must match header algorithm + auto alg_constraint = get_object_string(constraints, "alg"); + if (alg_constraint.has_value() && alg_constraint.value() != alg_str) + { + return make_failure(); + } + + // Get the verification key from constraints + auto cert_str = get_object_string(constraints, "cert"); + auto secret_str = get_object_string(constraints, "secret"); + + // Signature verification + std::string sig_raw = crypto_core::base64url_decode(parts.sig_b64); + crypto_core::VerifyResult vresult; + + if (cert_str.has_value()) + { + std::string key = ::json::unescape(cert_str.value()); + vresult = crypto_core::verify_signature_any_key( + token_algo, parts.signing_input, sig_raw, key); + } + else if (secret_str.has_value()) + { + std::string key = ::json::unescape(secret_str.value()); + vresult = crypto_core::verify_signature( + token_algo, parts.signing_input, sig_raw, key); + } + else + { + return make_failure(); + } + + if (!vresult.error.empty()) + { + return err(jwt_node, vresult.error, EvalBuiltInError); + } + + if (!vresult.valid) + { + return make_failure(); + } + + // Handle nested JWT (cty: "JWT") + auto cty = ::json::select_string(header_ast, {"/cty"}); + if (cty.has_value() && cty->view() == "JWT") + { + std::string payload_decoded = + crypto_core::base64url_decode(parts.payload_b64); + // Strip surrounding quotes if present + if ( + payload_decoded.size() >= 2 && payload_decoded.front() == '"' && + payload_decoded.back() == '"') + { + payload_decoded = payload_decoded.substr(1, payload_decoded.size() - 2); + } + Node inner_token = JSONString ^ payload_decoded; + Nodes inner_args = {inner_token, args[1]}; + return decode_verify_impl(inner_args, depth + 1); + } + + // Decode and parse payload into JSON AST for claim validation + std::string payload_json = crypto_core::base64url_decode(parts.payload_b64); + Node payload_ast = parse_json(payload_json); + if (!payload_ast) + { + return err(jwt_node, "failed to decode JWT payload", EvalBuiltInError); + } + + // ── Claim validation ── + + // Determine the "current time" in seconds. + // Constraints "time" is in nanoseconds; otherwise use wall clock. + double now_seconds; + auto time_constraint = get_object_number(constraints, "time"); + if (time_constraint.has_value()) + { + now_seconds = time_constraint.value() / 1e9; + } + else + { + auto now = std::chrono::system_clock::now(); + now_seconds = + std::chrono::duration(now.time_since_epoch()).count(); + } + + // Check "iss" and "sub" string claims + if ( + !check_string_claim( + payload_ast, "/iss", get_object_string(constraints, "iss")) || + !check_string_claim( + payload_ast, "/sub", get_object_string(constraints, "sub"))) + { + return make_failure(); + } + + // Check "aud" claim + if (!check_aud_claim(payload_ast, get_object_string(constraints, "aud"))) + { + return make_failure(); + } + + // Check "exp" — payload exp must be in the future + Node exp_node = ::json::select(payload_ast, {"/exp"}); + if (exp_node->type() != Error) + { + auto exp_val = ::json::get_number(exp_node); + if (!exp_val.has_value()) + { + return err(jwt_node, "exp value must be a number", EvalBuiltInError); + } + if (exp_val.value() <= now_seconds) + { + return make_failure(); + } + } + + // Check "nbf" — payload nbf must be in the past + Node nbf_node = ::json::select(payload_ast, {"/nbf"}); + if (nbf_node->type() != Error) + { + auto nbf_val = ::json::get_number(nbf_node); + if (!nbf_val.has_value()) + { + return err(jwt_node, "nbf value must be a number", EvalBuiltInError); + } + if (nbf_val.value() > now_seconds) + { + return make_failure(); + } + } + + // All checks passed — convert already-parsed JSON ASTs to Rego terms + Node header = json_ast_to_term(header_ast); + if (!header) + { + return err(jwt_node, "failed to decode JWT header", EvalBuiltInError); + } + + Node payload = json_ast_to_term(payload_ast); + if (!payload) + { + return err(jwt_node, "failed to decode JWT payload", EvalBuiltInError); + } + + return array({boolean(true), header, payload}); + } + + // ── Shared encode/sign helpers ── + + // Serialize a Rego Object node to a compact JSON string with sorted keys. + std::string rego_node_to_json(const Node& node) + { + Node term = Resolver::to_term(node); + if (term->type() == Error) + { + return {}; + } + auto result = rego_to_json().wf_check_enabled(true).rewrite(Top << term); + if (!result.ok) + { + return {}; + } + return json::to_string(result.ast); + } + + // Determine if the JWT "typ" implies the payload should be valid JSON. + // OPA treats "JWT" or empty/missing typ as requiring JSON payload. + bool is_jwt_typ(std::string_view typ) + { + return typ.empty() || typ == "JWT"; + } + + // ── io.jwt.encode_sign_raw implementation ── + + Node encode_sign_raw_impl(const Nodes& args) + { + Node header_node = unwrap_arg( + args, UnwrapOpt(0).type(JSONString).func("io.jwt.encode_sign_raw")); + if (header_node->type() == Error) + return header_node; + + Node payload_node = unwrap_arg( + args, UnwrapOpt(1).type(JSONString).func("io.jwt.encode_sign_raw")); + if (payload_node->type() == Error) + return payload_node; + + Node key_node = unwrap_arg( + args, UnwrapOpt(2).type(JSONString).func("io.jwt.encode_sign_raw")); + if (key_node->type() == Error) + return key_node; + + std::string header_str = ::json::unescape(get_string(header_node)); + std::string payload_str = ::json::unescape(get_string(payload_node)); + std::string key_str = ::json::unescape(get_string(key_node)); + + // Parse header to extract and validate "alg" + Node header_ast = parse_json(header_str); + if (!header_ast) + { + if (header_str.empty()) + { + return err( + header_node, + "missing or invalid 'alg' header: cannot parse JSON: " + "cannot parse empty string", + EvalBuiltInError); + } + // If the input looks like it intended to be a JSON object, report parse + // error. Otherwise treat it as a value with no "alg" field. + if (header_str.find('{') != std::string::npos) + { + return err( + header_node, + "missing or invalid 'alg' header: cannot parse JSON", + EvalBuiltInError); + } + return err( + header_node, + "missing or invalid 'alg' header: jwsbb: header \"alg\" not found", + EvalBuiltInError); + } + + auto alg_opt = ::json::select_string(header_ast, {"/alg"}); + if (!alg_opt.has_value()) + { + return err( + header_node, + "missing or invalid 'alg' header: jwsbb: header \"alg\" not found", + EvalBuiltInError); + } + + std::string_view alg_str = alg_opt->view(); + crypto_core::Algorithm algo; + try + { + algo = crypto_core::parse_algorithm(alg_str); + } + catch (const std::invalid_argument&) + { + return err( + header_node, + "unknown JWS algorithm: " + std::string(alg_str), + EvalBuiltInError); + } + + // Check if typ implies JWT; if so, payload must be valid JSON + auto typ_opt = ::json::select_string(header_ast, {"/typ"}); + std::string_view typ = typ_opt.has_value() ? typ_opt->view() : ""; + if (is_jwt_typ(typ)) + { + if (payload_str.empty()) + { + return err( + payload_node, + "type is JWT but payload is not JSON", + EvalBuiltInError); + } + Node payload_check = parse_json(payload_str); + if (!payload_check) + { + return err( + payload_node, + "type is JWT but payload is not JSON", + EvalBuiltInError); + } + } + + // Base64url-encode header and payload as raw strings + std::string header_b64 = crypto_core::base64url_encode_nopad(header_str); + std::string payload_b64 = crypto_core::base64url_encode_nopad(payload_str); + + // Build the signing input + std::string signing_input = header_b64 + "." + payload_b64; + + // Sign + std::string sig_bytes; + try + { + sig_bytes = crypto_core::sign(algo, signing_input, key_str); + } + catch (const std::exception& e) + { + return err(header_node, e.what(), EvalBuiltInError); + } + + std::string sig_b64 = crypto_core::base64url_encode_nopad(sig_bytes); + std::string token = signing_input + "." + sig_b64; + return JSONString ^ token; + } + + // ── io.jwt.encode_sign implementation ── + + Node encode_sign_impl(const Nodes& args) + { + Node header_node = + unwrap_arg(args, UnwrapOpt(0).type(Object).func("io.jwt.encode_sign")); + if (header_node->type() == Error) + return header_node; + + Node payload_node = unwrap_arg( + args, UnwrapOpt(1).types({Object, Set}).func("io.jwt.encode_sign")); + if (payload_node->type() == Error) + return payload_node; + + Node key_node = + unwrap_arg(args, UnwrapOpt(2).type(Object).func("io.jwt.encode_sign")); + if (key_node->type() == Error) + return key_node; + + // Serialize the Rego objects to compact JSON with sorted keys + std::string header_json = rego_node_to_json(header_node); + std::string payload_json = rego_node_to_json(payload_node); + std::string key_json = rego_node_to_json(key_node); + + if (header_json.empty() || key_json.empty()) + { + return err( + header_node, "failed to serialize arguments", EvalBuiltInError); + } + + // Delegate to the raw implementation using the serialized JSON strings + Node h = JSONString ^ header_json; + Node p = JSONString ^ payload_json; + Node k = JSONString ^ key_json; + Nodes raw_args = {h, p, k}; + return encode_sign_raw_impl(raw_args); + } + +#endif // REGOCPP_HAS_CRYPTO + BuiltIn decode_factory() { const Node decode_decl = bi::Decl @@ -29,7 +915,14 @@ namespace << (bi::DynamicObject << (bi::Type << bi::Any) << (bi::Type << bi::Any))) << (bi::Type << bi::String)))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder({"io.jwt.decode"}, decode_decl, Message); +#else + return BuiltInDef::create( + {"io.jwt.decode"}, decode_decl, [](const Nodes& args) { + return decode_impl(args); + }); +#endif } BuiltIn decode_verify_factory() @@ -65,8 +958,15 @@ namespace << (bi::Type << (bi::DynamicObject << (bi::Type << bi::Any) << (bi::Type << bi::Any)))))); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"io.jwt.decode_verify"}, decode_verify_decl, Message); +#else + return BuiltInDef::create( + {"io.jwt.decode_verify"}, decode_verify_decl, [](const Nodes& args) { + return decode_verify_impl(args); + }); +#endif } const Node verify_decl = bi::Decl @@ -108,8 +1008,13 @@ namespace << (bi::Result << (bi::Name ^ "output") << (bi::Description ^ "signed JWT") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"io.jwt.encode_sign"}, encode_sign_decl, Message); +#else + return BuiltInDef::create( + {"io.jwt.encode_sign"}, encode_sign_decl, encode_sign_impl); +#endif } BuiltIn encode_sign_raw_factory() @@ -127,8 +1032,13 @@ namespace << (bi::Result << (bi::Name ^ "output") << (bi::Description ^ "signed JWT") << (bi::Type << bi::String)); +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder( {"io.jwt.encode_sign_raw"}, encode_sign_raw_decl, Message); +#else + return BuiltInDef::create( + {"io.jwt.encode_sign_raw"}, encode_sign_raw_decl, encode_sign_raw_impl); +#endif } } @@ -158,7 +1068,14 @@ namespace rego view == "verify_rs256" || view == "verify_rs384" || view == "verify_rs512") { +#ifndef REGOCPP_HAS_CRYPTO return BuiltInDef::placeholder(name, verify_decl, Message); +#else + return BuiltInDef::create( + name, verify_decl, [view = std::string(view)](const Nodes& args) { + return verify_impl(view, args); + }); +#endif } if (view == "encode_sign") { diff --git a/src/builtins/numbers.cc b/src/builtins/numbers.cc index a97e94e1..ba5ac39d 100644 --- a/src/builtins/numbers.cc +++ b/src/builtins/numbers.cc @@ -99,7 +99,7 @@ namespace { return err( step_number, - "numbers.range_step: step must be a positive number above zero", + "numbers.range_step: step must be a positive integer", EvalBuiltInError); } diff --git a/src/builtins/regex.cc b/src/builtins/regex.cc index 3a7dd05b..8904811c 100644 --- a/src/builtins/regex.cc +++ b/src/builtins/regex.cc @@ -138,7 +138,7 @@ namespace try { std::regex re(pattern); - bool match = std::regex_match(value, re); + bool match = std::regex_search(value, re); return Resolver::scalar(match); } catch (std::regex_error& e) @@ -304,17 +304,11 @@ namespace } Node array = NodeDef::create(Array); - std::smatch match; - for (std::size_t i = 0; i < number && !value.empty(); ++i) + auto it = std::sregex_iterator(value.begin(), value.end(), re); + auto end = std::sregex_iterator(); + for (std::size_t i = 0; i < number && it != end; ++i, ++it) { - std::regex_search(value, match, re); - if (match.empty()) - { - break; - } - - array->push_back(Resolver::scalar(match.str())); - value = match.suffix(); + array->push_back(Resolver::scalar(it->str())); } return array; diff --git a/src/builtins/time.cc b/src/builtins/time.cc index 1156e47e..b051d3ba 100644 --- a/src/builtins/time.cc +++ b/src/builtins/time.cc @@ -1,5 +1,4 @@ #include "builtins.hh" -#include "re2/stringpiece.h" #include "rego.hh" #include @@ -64,7 +63,7 @@ namespace { const char* duration_re = R"((-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][-+]?[0-9]+)?)((?:ns|us|µs|ms|s|m|h)))"; - const RE2 re(duration_re); + const TRegex re(duration_re); assert(re.ok()); std::string number; @@ -73,8 +72,8 @@ namespace std::size_t start = 0; while (start < duration.size()) { - re2::StringPiece input(duration.c_str() + start, duration.size() - start); - if (RE2::PartialMatch(input, re, &number, &unit)) + std::string_view input(duration.c_str() + start, duration.size() - start); + if (TRegex::PartialMatch(input, re, &number, &unit)) { double number_d = std::stod(number); double unit_ns = duration_units.at(unit); diff --git a/src/bundle_binary.cc b/src/bundle_binary.cc index 0da8b2dc..98b73acc 100644 --- a/src/bundle_binary.cc +++ b/src/bundle_binary.cc @@ -2,6 +2,7 @@ #include "rego.hh" #include "trieste/wf.h" +#include #include #include diff --git a/src/dependency_graph.cc b/src/dependency_graph.cc index 28d104d3..b0f084b3 100644 --- a/src/dependency_graph.cc +++ b/src/dependency_graph.cc @@ -860,6 +860,12 @@ namespace rego return false; } + if (lhs == TemplateString) + { + add_equals(lhs_term, rhs_var); + return false; + } + if (lhs == Array) { return add_array_var(lhs, rhs_var); diff --git a/src/encoding.cc b/src/encoding.cc index 1e33a940..4a25df0d 100644 --- a/src/encoding.cc +++ b/src/encoding.cc @@ -104,7 +104,7 @@ namespace rego { std::string to_key( const Node& node, - bool set_as_array, + SetFormat set_format, bool sort_arrays, const char* list_delim) { @@ -164,14 +164,14 @@ namespace rego std::sort(keys.begin(), keys.end()); std::transform( keys.begin(), keys.end(), std::back_inserter(items), [&](auto& key) { - return to_key(key.node, set_as_array, sort_arrays, list_delim); + return to_key(key.node, set_format, sort_arrays, list_delim); }); } else { for (const auto& child : *node) { - items.push_back(to_key(child, set_as_array, sort_arrays, list_delim)); + items.push_back(to_key(child, set_format, sort_arrays, list_delim)); } } @@ -197,13 +197,22 @@ namespace rego std::sort(node_keys.begin(), node_keys.end()); - if (set_as_array) + switch (set_format) { - buf << "["; - } - else - { - buf << "<"; + case SetFormat::Square: + buf << "["; + break; + case SetFormat::Rego: + if (node_keys.empty()) + { + buf << "set()"; + return buf.str(); + } + buf << "{"; + break; + case SetFormat::Angle: + buf << "<"; + break; } join( @@ -211,20 +220,23 @@ namespace rego node_keys.begin(), node_keys.end(), list_delim, - [set_as_array, sort_arrays, list_delim]( + [set_format, sort_arrays, list_delim]( std::ostream& stream, const NodeKey& node_key) { - stream << to_key( - node_key.node, set_as_array, sort_arrays, list_delim); + stream << to_key(node_key.node, set_format, sort_arrays, list_delim); return true; }); - if (set_as_array) - { - buf << "]"; - } - else + switch (set_format) { - buf << ">"; + case SetFormat::Square: + buf << "]"; + break; + case SetFormat::Rego: + buf << "}"; + break; + case SetFormat::Angle: + buf << ">"; + break; } } else if (node->in({Object, DataObject, Bindings})) @@ -234,14 +246,13 @@ namespace rego { auto key = child / Key; auto value = child / Val; - std::string key_str = - to_key(key, set_as_array, sort_arrays, list_delim); + std::string key_str = to_key(key, set_format, sort_arrays, list_delim); if (!is_quoted(key_str)) { key_str = add_quotes(json::escape(key_str)); } items.insert( - {key_str, to_key(value, set_as_array, sort_arrays, list_delim)}); + {key_str, to_key(value, set_format, sort_arrays, list_delim)}); } buf << "{"; @@ -250,15 +261,28 @@ namespace rego items.begin(), items.end(), ", ", - [](std::ostream& stream, const auto& item) { - stream << item.first << ":" << item.second; + [list_delim](std::ostream& stream, const auto& item) { + stream << item.first << ":"; + // Use spaced separator when callers request spaced list delimiters + // (i.e. display format like sprintf %v and template strings). + if (list_delim[0] == ',' && list_delim[1] == ' ') + { + stream << " "; + } + stream << item.second; return true; }); buf << "}"; } else if (node->in({Scalar, Term, DataTerm})) { - return to_key(node->front(), set_as_array, sort_arrays, list_delim); + return to_key(node->front(), set_format, sort_arrays, list_delim); + } + else if (node == TemplateString) + { + // TemplateString should be lowered to an opblock call before reaching + // encoding. If we get here, something went wrong in the pipeline. + buf << ""; } else if (node == Result) { @@ -268,7 +292,7 @@ namespace rego if (!terms->empty()) { buf << '"' << "expressions" << '"' << ":" - << to_key(terms, set_as_array, sort_arrays, list_delim); + << to_key(terms, set_format, sort_arrays, list_delim); if (!bindings->empty()) { buf << ", "; @@ -278,7 +302,7 @@ namespace rego if (!bindings->empty()) { buf << '"' << "bindings" << '"' << ":" - << to_key(bindings, set_as_array, sort_arrays, list_delim); + << to_key(bindings, set_format, sort_arrays, list_delim); } buf << "}"; @@ -291,9 +315,9 @@ namespace rego node->begin(), node->end(), ", ", - [set_as_array, sort_arrays, list_delim]( + [set_format, sort_arrays, list_delim]( std::ostream& stream, const Node& n) { - stream << to_key(n, set_as_array, sort_arrays, list_delim); + stream << to_key(n, set_format, sort_arrays, list_delim); return true; }); buf << ']'; @@ -310,7 +334,7 @@ namespace rego { if (node->size() == 1) { - return to_key(node->front(), set_as_array, sort_arrays, list_delim); + return to_key(node->front(), set_format, sort_arrays, list_delim); } buf << '['; @@ -319,9 +343,9 @@ namespace rego node->begin(), node->end(), ", ", - [set_as_array, sort_arrays, list_delim]( + [set_format, sort_arrays, list_delim]( std::ostream& stream, const Node& n) { - stream << to_key(n, set_as_array, sort_arrays, list_delim); + stream << to_key(n, set_format, sort_arrays, list_delim); return true; }); buf << ']'; diff --git a/src/file_to_rego.cc b/src/file_to_rego.cc index b375d41a..84249d7a 100644 --- a/src/file_to_rego.cc +++ b/src/file_to_rego.cc @@ -130,7 +130,8 @@ namespace With); const auto wf_prep_tokens = (wf_parse_tokens | Scalar | Placeholder) - - (Int | Float | JSONString | RawString | True | False | Null); + (Int | Float | JSONString | RawString | TemplateString | True | False | + Null | TemplateLiteral); // clang-format off const auto wf_prep = @@ -143,8 +144,9 @@ namespace | (Package <<= RefGroup) | (Import <<= RefGroup * Var) | (Scalar <<= Int | Float | String | True | False | Null) - | (String <<= JSONString | RawString) + | (String <<= JSONString | RawString | TemplateString) | (RefGroup <<= (Var | Dot | Square)++) + | (TemplateString <<= Group++) | (Group <<= wf_prep_tokens++) ; // clang-format on @@ -214,9 +216,30 @@ namespace In(Group) * T(Int, Float, True, False, Null)[Scalar] >> [](Match& _) { return Scalar << _(Scalar); }, - In(Group) * T(JSONString, RawString)[String] >> + In(Group) * T(JSONString, RawString, TemplateString)[String] >> [](Match& _) { return Scalar << (String << _(String)); }, + In(Group) * T(TemplateLiteral)[String] >> + [](Match& _) { + // Unescape \{ to { (template-specific escape) + std::string text(_(String)->location().view()); + std::string unescaped; + unescaped.reserve(text.size()); + for (size_t i = 0; i < text.size(); ++i) + { + if (text[i] == '\\' && i + 1 < text.size() && text[i + 1] == '{') + { + unescaped.push_back('{'); + ++i; + } + else + { + unescaped.push_back(text[i]); + } + } + return Scalar << (String << (JSONString ^ add_quotes(unescaped))); + }, + // errors In(Top) * T(File)[File] >> [](Match& _) { @@ -451,6 +474,10 @@ namespace In(Import) * (T(RefGroup) << (T(Ref)[Ref] * End)) >> [](Match& _) { return _(Ref); }, + In(TemplateString) * + (T(Group) << ((T(Brace)[Brace] << T(Group)[Group]) * End)) >> + [](Match& _) { return _(Group); }, + T(Brace)[Brace] << T(Group)[Group] >> [comma_groups, colon_groups, or_groups](Match& _) { NodeDef* group = _(Group).get(); @@ -1475,7 +1502,7 @@ namespace return Seq << _(Package) << version << importseq << policy; }, - In(Query) * (T(Group) << (T(Literal)++[Query] * End)) >> + In(Query, TemplateString) * (T(Group) << (T(Literal)++[Query] * End)) >> [](Match& _) { return Seq << _[Query]; }, In(RuleRef) * T(Var)[Var] >> diff --git a/src/internal.cc b/src/internal.cc index 43f346a9..97667ff4 100644 --- a/src/internal.cc +++ b/src/internal.cc @@ -2,6 +2,7 @@ #include "rego.hh" +#include #include #include @@ -261,7 +262,11 @@ namespace rego double floor = std::floor(value); if (value == floor) { - return BigInt(static_cast(floor)); + // Format the integer value as a decimal string to avoid undefined + // behavior when the double exceeds the range of size_t or int64_t. + char buf[32]; + std::snprintf(buf, sizeof(buf), "%.0f", floor); + return BigInt(Location(buf)); } return std::nullopt; @@ -912,6 +917,12 @@ namespace rego return Term << set; } + if (value == TemplateString) + { + return rego::err( + value, "Template string cannot be a constant", rego::RegoTypeError); + } + throw std::runtime_error("Invalid term"); } diff --git a/src/internal.hh b/src/internal.hh index f096848f..e6674dc4 100644 --- a/src/internal.hh +++ b/src/internal.hh @@ -62,7 +62,8 @@ namespace rego inline const auto wf_parse_tokens = Query | Module | wf_json | wf_arith_op | wf_bool_op | wf_bin_op | Package | Var | Brace | Square | Dot | Paren | Assign | Unify | EmptySet | Colon | RawString | Default | Some | Import | - Else | As | With | NewLine | Comma | If | IsIn | Contains | Every; + Else | As | With | NewLine | Comma | If | IsIn | Contains | Every | + TemplateString | TemplateLiteral; // clang-format off inline const auto wf_parser = @@ -74,6 +75,7 @@ namespace rego | (Square <<= Group++) | (Paren <<= Group++) | (Group <<= wf_parse_tokens++) + | (TemplateString <<= Group++) ; // clang-format on @@ -211,7 +213,7 @@ namespace rego Node m_scope; Node m_error; Node m_orderedseq; - bool m_needs_sort; + bool m_needs_sort = false; bool is_assigned(const std::string& name); bool any_unassigned(const Nodes& nodes); diff --git a/src/interpreter.cc b/src/interpreter.cc index fa4e02cf..2bbd27da 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -496,7 +496,7 @@ namespace rego else { WFContext context(rego::wf_result); - output_buf << rego::to_key(ast, true); + output_buf << rego::to_key(ast, rego::SetFormat::Square); } return output_buf.str(); diff --git a/src/json.cc b/src/json.cc index ef21e8f6..9f0e3a5b 100644 --- a/src/json.cc +++ b/src/json.cc @@ -3,6 +3,8 @@ #include "internal.hh" #include "rego.hh" +#include + namespace { using namespace trieste; @@ -38,6 +40,92 @@ namespace R"(\-?[[:digit:]]+\.[[:digit:]]+(?:e[+-]?[[:digit:]]+)?)"; const char* IntRE = R"(\-?[[:digit:]]+)"; + // Extract the key string view from an ObjectItem or DataObjectItem node. + // The key child is either Term >> Scalar >> JSONString or + // DataTerm >> Scalar >> String >> JSONString. + std::string_view item_key_view(const Node& item) + { + Node key = item->front(); + // Walk down through wrapper nodes to the leaf token. + while (key->size() > 0) + { + key = key->front(); + } + return key->location().view(); + } + + // Deduplicate object children, keeping the last occurrence of each key. + // This matches Go json.Unmarshal (and OPA) semantics. + // Returns a new Nodes vector with duplicates removed, or empty if no + // duplicates were found. + Nodes dedup_items(const Node& obj) + { + // Single pass: build the last-index map. + std::unordered_map last_index; + last_index.reserve(obj->size()); + for (std::size_t i = 0; i < obj->size(); ++i) + { + last_index[item_key_view(obj->at(i))] = i; + } + + if (last_index.size() == obj->size()) + { + // No duplicates + return {}; + } + + // Build result preserving original order. + Nodes result; + result.reserve(last_index.size()); + for (std::size_t i = 0; i < obj->size(); ++i) + { + if (last_index[item_key_view(obj->at(i))] == i) + { + result.push_back(obj->at(i)); + } + } + return result; + } + + // Pass that deduplicates object keys (last-wins) in Term-wrapped AST. + PassDef dedup_object_keys_term() + { + return { + "dedup_object_keys", + wf_from_json_term, + dir::bottomup | dir::once, + { + In(Term) * T(Object)[Object] >> [](Match& _) -> Node { + Node obj = _(Object); + Nodes items = dedup_items(obj); + if (items.empty()) + { + return NoChange; + } + return Object << items; + }, + }}; + } + + // Pass that deduplicates object keys (last-wins) in DataTerm-wrapped AST. + PassDef dedup_object_keys_dataterm() + { + return { + "dedup_object_keys", + wf_from_json_dataterm, + dir::bottomup | dir::once, + { + In(DataTerm) * T(DataObject)[DataObject] >> [](Match& _) -> Node { + Node obj = _(DataObject); + Nodes items = dedup_items(obj); + if (items.empty()) + { + return NoChange; + } + return DataObject << items; + }, + }}; + } PassDef from_json_to_dataterm() { return { @@ -162,13 +250,37 @@ namespace }, (T(Term) << (T(Scalar) << T(True)[json::True])) >> - [](Match& _) { return json::True ^ _(json::True); }, + [](Match& _) { + Location loc = _(json::True)->location(); + if (loc.len == 0) + { + return json::True ^ "true"; + } + + return json::True ^ loc; + }, (T(Term) << (T(Scalar) << T(False)[json::False])) >> - [](Match& _) { return json::False ^ _(json::False); }, + [](Match& _) { + Location loc = _(json::False)->location(); + if (loc.len == 0) + { + return json::False ^ "false"; + } + + return json::False ^ loc; + }, (T(Term) << (T(Scalar) << T(Null)[json::Null])) >> - [](Match& _) { return json::Null ^ _(json::Null); }, + [](Match& _) { + Location loc = _(json::Null)->location(); + if (loc.len == 0) + { + return json::Null ^ "null"; + } + + return json::Null ^ loc; + }, (T(Term) << (T(Array, Set)[json::Array])) >> [](Match& _) { @@ -242,9 +354,11 @@ namespace rego Rewriter json_to_rego(bool as_term) { auto pass = as_term ? from_json_to_term() : from_json_to_dataterm(); + auto dedup = + as_term ? dedup_object_keys_term() : dedup_object_keys_dataterm(); return { "json_to_rego", - {pass}, + {pass, dedup}, json::wf, }; } diff --git a/src/opblock.cc b/src/opblock.cc index ae96cbc4..65771518 100644 --- a/src/opblock.cc +++ b/src/opblock.cc @@ -1,4 +1,5 @@ #include "internal.hh" +#include "rego.hh" namespace { @@ -348,6 +349,49 @@ namespace rego return OpBlock << (Operand << (IRString ^ json_string)) << Block; } + Node templatestring_to_opblock(Node templatestring) + { + Location array_name = templatestring->fresh({"tpl_array"}); + Location result_name = templatestring->fresh({"tpl_result"}); + size_t num_elements = templatestring->size(); + + Node block = Block + << (MakeArrayStmt << (Int32 ^ std::to_string(num_elements)) + << (LocalRef ^ array_name)); + + for (Node literal : *templatestring) + { + // Evaluate the expression inside a try-block and add to a set. + // If the expression is undefined, the set remains empty. + Location set_name = templatestring->fresh({"tpl_set"}); + block << (MakeSetStmt << (LocalRef ^ set_name)); + + Node tryblock = NodeDef::create(Block); + assert(literal == Literal); + Node expr = literal / Expr; + if (expr != OpBlock) + { + return err(templatestring, "Invalid template string expression"); + } + + Node expr_operand = to_operand(tryblock, expr); + tryblock << (SetAddStmt << expr_operand << (LocalRef ^ set_name)); + + block << (BlockStmt << (BlockSeq << tryblock)); + block + << (ArrayAppendStmt << (LocalRef ^ array_name) + << (Operand << (LocalRef ^ set_name))); + } + + block + << (BuiltInCallStmt << (IRString ^ "internal.template_string") + << (OperandSeq + << (Operand << (LocalRef ^ array_name))) + << (LocalRef ^ result_name)); + + return OpBlock << (Operand << (LocalRef ^ result_name)) << block; + } + Node boolean_to_opblock(Node term) { return OpBlock << (Operand << (Boolean ^ term)) << Block; @@ -609,6 +653,11 @@ namespace rego return scalar_to_opblock(value); } + if (value == TemplateString) + { + return templatestring_to_opblock(value); + } + if (value == Array) { return array_to_opblock(value); diff --git a/src/output.cc b/src/output.cc index ec48b418..9ea4274f 100644 --- a/src/output.cc +++ b/src/output.cc @@ -16,7 +16,7 @@ namespace rego assert(node == Results); std::ostringstream buf; WFContext context(rego::wf_result); - buf << rego::to_key(m_node, true); + buf << rego::to_key(m_node, rego::SetFormat::Square); m_json = buf.str(); } } diff --git a/src/parse.cc b/src/parse.cc index effb4a1a..6fd41778 100644 --- a/src/parse.cc +++ b/src/parse.cc @@ -73,6 +73,13 @@ namespace rego std::shared_ptr newline_mode = std::make_shared(NewlineMode::Ignore); + // Tracks nesting of template strings. Each entry is the quote character + // ('"' or '`') for the enclosing template string, so that '}' knows + // whether to return to template scanning or normal parsing. + constexpr size_t MaxTemplateNesting = 64; + std::shared_ptr> template_stack = + std::make_shared>(); + // Our starting path tries to determine if the input is a module or a query p("start", { @@ -182,9 +189,17 @@ namespace rego R"({(?:\r?\n)?)" >> [](auto& m) { m.push(Brace); }, "}" >> - [](auto& m) { + [template_stack](auto& m) { m.term(); - m.pop(Brace); + if (!template_stack->empty() && m.in(TemplateString)) + { + char quote = template_stack->back(); + m.mode(quote == '"' ? "template_dq" : "template_raw"); + } + else + { + m.pop(Brace); + } }, R"(\()" >> [](auto& m) { m.push(Paren); }, @@ -219,6 +234,66 @@ namespace rego "-" >> [](auto& m) { m.add(Subtract); }, + // Template string $"{ — double-quoted, has expression(s) + R"((\$")((?:[^"\\\{]|\\["\\\/bfnrt\{]|\\u[[:xdigit:]]{4})*)(\{(?:\r?\n)?))" >> + [template_stack, MaxTemplateNesting](auto& m) { + if (template_stack->size() >= MaxTemplateNesting) + { + m.error("template string nesting depth exceeded"); + return; + } + m.push(TemplateString, 1); + if (m.match(2).len > 0) + { + m.add(TemplateLiteral, 2); + m.term(); + } + template_stack->push_back('"'); + }, + + // Template string $"" — double-quoted, no expressions + R"((\$")((?:[^"\\\{]|\\["\\\/bfnrt\{]|\\u[[:xdigit:]]{4})*)("))" >> + [](auto& m) { + m.push(TemplateString, 1); + if (m.match(2).len > 0) + { + m.add(TemplateLiteral, 2); + m.term(); + } + m.term(); + m.pop(TemplateString); + }, + + // Template string $`{ — raw, has expression(s) + R"((\$\`)((?:[^\`\{\\]|\\{)*)(\{(?:\r?\n)?))" >> + [template_stack, MaxTemplateNesting](auto& m) { + if (template_stack->size() >= MaxTemplateNesting) + { + m.error("template string nesting depth exceeded"); + return; + } + m.push(TemplateString, 1); + if (m.match(2).len > 0) + { + m.add(TemplateLiteral, 2); + m.term(); + } + template_stack->push_back('`'); + }, + + // Template string $`` — raw, no expressions + R"((\$\`)((?:[^\`\{\\]|\\{)*)(\`))" >> + [](auto& m) { + m.push(TemplateString, 1); + if (m.match(2).len > 0) + { + m.add(TemplateLiteral, 2); + m.term(); + } + m.term(); + m.pop(TemplateString); + }, + // RE for a JSON string: // " : a double quote followed by either: // 1. [^"\\\x00-\x1F]+ : one or more characters that are not a double @@ -258,7 +333,82 @@ namespace rego R"(\s+)" >> [](auto&) {}, }); - p.done([](auto& m) { m.term({Module, Query}); }); + // After closing a template expression }, scan the next literal chunk + // in a double-quoted template string. + p("template_dq", + { + // Literal chunk followed by another expression + R"(((?:[^"\\\{]|\\["\\\/bfnrt\{]|\\u[[:xdigit:]]{4})+)(\{(?:\r?\n)?))" >> + [](auto& m) { + m.add(TemplateLiteral, 1); + m.term(); + m.mode("main"); + }, + + // Another expression immediately (no literal between) + R"(\{(?:\r?\n)?)" >> [](auto& m) { m.mode("main"); }, + + // Literal chunk followed by end of template string + R"(((?:[^"\\\{]|\\["\\\/bfnrt\{]|\\u[[:xdigit:]]{4})+)("))" >> + [template_stack](auto& m) { + m.add(TemplateLiteral, 1); + m.term(); + template_stack->pop_back(); + m.term(); + m.pop(TemplateString); + m.mode("main"); + }, + + // End of template string immediately (no trailing literal) + R"(")" >> + [template_stack](auto& m) { + template_stack->pop_back(); + m.term(); + m.pop(TemplateString); + m.mode("main"); + }, + }); + + // Same as template_dq but for raw (backtick) template strings. + p("template_raw", + { + // Literal chunk followed by another expression. + // \\{ matches an escaped brace (literal {), not an interpolation start. + R"(((?:[^\`\{\\]|\\{)+)(\{(?:\r?\n)?))" >> + [](auto& m) { + m.add(TemplateLiteral, 1); + m.term(); + m.mode("main"); + }, + + // Another expression immediately + R"(\{(?:\r?\n)?)" >> [](auto& m) { m.mode("main"); }, + + // Literal chunk followed by end of template string. + R"(((?:[^\`\{\\]|\\{)+)(\`))" >> + [template_stack](auto& m) { + m.add(TemplateLiteral, 1); + m.term(); + template_stack->pop_back(); + m.term(); + m.pop(TemplateString); + m.mode("main"); + }, + + // End of template string immediately + R"(\`)" >> + [template_stack](auto& m) { + template_stack->pop_back(); + m.term(); + m.pop(TemplateString); + m.mode("main"); + }, + }); + + p.done([template_stack](auto& m) { + template_stack->clear(); + m.term({Module, Query}); + }); p.gen({ Int >> [](auto& rnd) { return rand_int(rnd); }, @@ -274,6 +424,7 @@ namespace rego builtins::Name >> [](auto& rnd) { return rand_string(rnd); }, builtins::Description >> [](auto& rnd) { return rand_string(rnd); }, IRString >> [](auto& rnd) { return rand_string(rnd); }, + TemplateLiteral >> [](auto& rnd) { return rand_string(rnd); }, JSONString >> [](auto& rnd) { return rand_quoted(rnd, '"'); }, RawString >> [](auto& rnd) { return rand_quoted(rnd, '`'); }, Version >> diff --git a/src/rego_c.cc b/src/rego_c.cc index b6e72f42..d39eb636 100644 --- a/src/rego_c.cc +++ b/src/rego_c.cc @@ -3,6 +3,8 @@ #include "internal.hh" #include "rego.hh" +#include + namespace logging = trieste::logging; namespace @@ -24,9 +26,25 @@ namespace rego { void setError(regoInterpreter* rego, const std::string& error) { + if (rego == nullptr) + { + return; + } reinterpret_cast(rego)->c_error(error); } + regoEnum check_c_str( + regoInterpreter* rego, const char* ptr, const char* param_name) + { + if (ptr == nullptr) + { + std::string msg = std::string(param_name) + " must not be null"; + setError(rego, msg); + return REGO_ERROR; + } + return REGO_OK; + } + struct regoOutput { Output output; @@ -99,12 +117,20 @@ extern "C" { regoSize regoErrorSize(regoInterpreter* rego) { + if (rego == nullptr) + { + return 0; + } logging::Debug() << "regoErrorSize: " << rego; return reinterpret_cast(rego)->c_error().size() + 1; } regoEnum regoError(regoInterpreter* rego, char* buffer, regoSize size) { + if (rego == nullptr) + { + return REGO_ERROR; + } logging::Debug() << "regoGetError: " << (void*)buffer << "[" << size << "]"; const std::string& error_str = @@ -164,6 +190,11 @@ extern "C" regoEnum regoLogLevelFromString(const char* value) { + if (value == nullptr) + { + return REGO_ERROR; + } + rego::LogLevel loglevel; try { @@ -177,6 +208,10 @@ extern "C" regoEnum regoSetLogLevel(regoInterpreter* rego, regoEnum level) { + if (rego == nullptr) + { + return REGO_ERROR; + } rego::Interpreter* r = reinterpret_cast(rego); switch (level) { @@ -256,6 +291,10 @@ extern "C" regoEnum regoGetLogLevel(regoInterpreter* rego) { + if (rego == nullptr) + { + return REGO_ERROR; + } return (regoEnum) reinterpret_cast(rego)->log_level(); } @@ -274,6 +313,12 @@ extern "C" regoEnum regoAddModuleFile(regoInterpreter* rego, const char* path) { + regoEnum err = rego::check_c_str(rego, path, "path"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoAddModuleFile: " << path; try { @@ -290,6 +335,18 @@ extern "C" regoEnum regoAddModule( regoInterpreter* rego, const char* name, const char* contents) { + regoEnum err = rego::check_c_str(rego, name, "name"); + if (err != REGO_OK) + { + return err; + } + + err = rego::check_c_str(rego, contents, "contents"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoAddModule: " << name; try { @@ -305,6 +362,12 @@ extern "C" regoEnum regoAddDataJSONFile(regoInterpreter* rego, const char* path) { + regoEnum err = rego::check_c_str(rego, path, "path"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoAddDataJSONFile: " << path; try { @@ -320,6 +383,12 @@ extern "C" regoEnum regoAddDataJSON(regoInterpreter* rego, const char* contents) { + regoEnum err = rego::check_c_str(rego, contents, "contents"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoAddDataJSON: " << contents; try { @@ -335,6 +404,12 @@ extern "C" regoEnum regoSetInputJSONFile(regoInterpreter* rego, const char* path) { + regoEnum err = rego::check_c_str(rego, path, "path"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoSetInputJSONFile: " << path; try { @@ -358,6 +433,12 @@ extern "C" regoEnum regoSetInputTerm(regoInterpreter* rego, const char* contents) { + regoEnum err = rego::check_c_str(rego, contents, "contents"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoSetInputTerm: " << contents; try { @@ -377,6 +458,12 @@ extern "C" logging::Debug() << "regoSetInput: interp=" << rego << " input=" << input; try { + if (input == nullptr) + { + rego::setError(rego, "input must not be null"); + return REGO_ERROR; + } + rego::regoInput* ri = reinterpret_cast(input); if (ri->stack.empty()) { @@ -404,18 +491,32 @@ extern "C" void regoSetDebugEnabled(regoInterpreter* rego, regoBoolean enabled) { + if (rego == nullptr) + { + return; + } logging::Debug() << "regoSetDebugEnabled: " << enabled; reinterpret_cast(rego)->debug_enabled(enabled); } regoBoolean regoGetDebugEnabled(regoInterpreter* rego) { + if (rego == nullptr) + { + return false; + } logging::Debug() << "regoGetDebugEnabled"; return reinterpret_cast(rego)->debug_enabled(); } regoEnum regoSetDebugPath(regoInterpreter* rego, const char* path) { + regoEnum err = rego::check_c_str(rego, path, "path"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoSetDebugPath: " << path; try { @@ -432,18 +533,36 @@ extern "C" void regoSetWellFormedChecksEnabled( regoInterpreter* rego, regoBoolean enabled) { + if (rego == nullptr) + { + return; + } logging::Debug() << "regoSetWellFormedChecksEnabled: " << enabled; reinterpret_cast(rego)->wf_check_enabled(enabled); } regoBoolean regoGetWellFormedChecksEnabled(regoInterpreter* rego) { + if (rego == nullptr) + { + return false; + } logging::Debug() << "regoGetWellFormedChecksEnabled"; return reinterpret_cast(rego)->wf_check_enabled(); } regoOutput* regoQuery(regoInterpreter* rego, const char* query_expr) { + if (rego == nullptr) + { + return nullptr; + } + if (query_expr == nullptr) + { + rego::setError(rego, "query_expr must not be null"); + return nullptr; + } + logging::Debug() << "regoQuery: " << query_expr; try { @@ -463,6 +582,10 @@ extern "C" void regoSetStrictBuiltInErrors(regoInterpreter* rego, regoBoolean enabled) { + if (rego == nullptr) + { + return; + } logging::Debug() << "regoSetStrictBuiltInErrors: " << enabled; reinterpret_cast(rego)->builtins()->strict_errors( enabled); @@ -470,6 +593,10 @@ extern "C" regoBoolean regoGetStrictBuiltInErrors(regoInterpreter* rego) { + if (rego == nullptr) + { + return false; + } logging::Debug() << "regoGetStrictBuiltInErrors"; return reinterpret_cast(rego) ->builtins() @@ -478,6 +605,11 @@ extern "C" regoBoolean regoIsAvailableBuiltIn(regoInterpreter* rego, const char* name) { + if (rego == nullptr || name == nullptr) + { + return false; + } + logging::Debug() << "regoIsBuiltIn: " << name; rego::Location loc(name); @@ -495,6 +627,10 @@ extern "C" regoBundle* regoBuild(regoInterpreter* rego) { + if (rego == nullptr) + { + return nullptr; + } logging::Debug() << "regoBuild"; rego::regoBundle* bundle = new rego::regoBundle(); bundle->bundle = nullptr; @@ -504,6 +640,16 @@ extern "C" regoBundle* regoBundleLoad(regoInterpreter* rego, const char* dir) { + if (rego == nullptr) + { + return nullptr; + } + if (dir == nullptr) + { + rego::setError(rego, "dir must not be null"); + return nullptr; + } + logging::Debug() << "regoBundleLoad"; rego::regoBundle* bundle = new rego::regoBundle(); bundle->bundle = nullptr; @@ -514,6 +660,16 @@ extern "C" regoBundle* regoBundleLoadBinary(regoInterpreter* rego, const char* path) { + if (rego == nullptr) + { + return nullptr; + } + if (path == nullptr) + { + rego::setError(rego, "path must not be null"); + return nullptr; + } + logging::Debug() << "regoBundleLoadBinary"; rego::regoBundle* bundle = new rego::regoBundle(); bundle->bundle = nullptr; @@ -567,6 +723,12 @@ extern "C" regoEnum regoBundleSave( regoInterpreter* rego, const char* dir, regoBundle* bundle) { + regoEnum err = rego::check_c_str(rego, dir, "dir"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoBundleSave: " << dir; try { @@ -590,6 +752,12 @@ extern "C" regoEnum regoBundleSaveBinary( regoInterpreter* rego, const char* path, regoBundle* bundle) { + regoEnum err = rego::check_c_str(rego, path, "path"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoBundleSaveBinary"; try { @@ -613,6 +781,12 @@ extern "C" regoEnum regoSetQuery(regoInterpreter* rego, const char* query_expr) { + regoEnum err = rego::check_c_str(rego, query_expr, "query_expr"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoSetQuery: " << query_expr; try { @@ -628,6 +802,12 @@ extern "C" regoEnum regoAddEntrypoint(regoInterpreter* rego, const char* entrypoint) { + regoEnum err = rego::check_c_str(rego, entrypoint, "entrypoint"); + if (err != REGO_OK) + { + return err; + } + logging::Debug() << "regoAddEntrypoint: " << entrypoint; try { @@ -644,6 +824,10 @@ extern "C" regoOutput* regoBundleQuery(regoInterpreter* rego, regoBundle* bundle) { + if (rego == nullptr) + { + return nullptr; + } logging::Debug() << "regoBundleQuery rego(" << rego << ") bundle(" << bundle << ")"; try @@ -671,6 +855,16 @@ extern "C" regoOutput* regoBundleQueryEntrypoint( regoInterpreter* rego, regoBundle* bundle, const char* entrypoint) { + if (rego == nullptr) + { + return nullptr; + } + if (entrypoint == nullptr) + { + rego::setError(rego, "entrypoint must not be null"); + return nullptr; + } + logging::Debug() << "regoBundleQueryEntrypoint: rego(" << rego << ") bundle(" << bundle << ") " << entrypoint; try @@ -743,6 +937,11 @@ extern "C" regoNode* regoOutputBindingAtIndex( regoOutput* output, regoSize index, const char* name) { + if (name == nullptr) + { + return nullptr; + } + logging::Debug() << "regoOutputBindingAtIndex: " << name; auto val = reinterpret_cast(output)->output.binding_at( index, name); @@ -996,7 +1195,8 @@ extern "C" logging::Debug() << "regoNodeJSONSize"; auto node_ptr = reinterpret_cast(node); trieste::WFContext context(rego::wf_result); - std::string json = rego::to_key(node_ptr->intrusive_ptr_from_this(), true); + std::string json = rego::to_key( + node_ptr->intrusive_ptr_from_this(), rego::SetFormat::Square); return static_cast(json.size() + 1); } @@ -1006,7 +1206,8 @@ extern "C" auto node_ptr = reinterpret_cast(node); trieste::WFContext context(rego::wf_result); - std::string json = rego::to_key(node_ptr->intrusive_ptr_from_this(), true); + std::string json = rego::to_key( + node_ptr->intrusive_ptr_from_this(), rego::SetFormat::Square); if (size < json.size() + 1) { return REGO_ERROR_BUFFER_TOO_SMALL; diff --git a/src/rego_to_bundle.cc b/src/rego_to_bundle.cc index 9fee1044..826ef121 100644 --- a/src/rego_to_bundle.cc +++ b/src/rego_to_bundle.cc @@ -27,6 +27,7 @@ namespace | (RuleHeadObjDynamic <<= ExprSeq * Expr) | (RuleHeadSetDynamic <<= ExprSeq * Expr) | (Scalar <<= JSONString | Int | Float | True | False | Null) + | (Term <<= TemplateString | Ref | Var | Scalar | Array | Object | Set | Membership | ArrayCompr | ObjectCompr | SetCompr) | (Membership <<= (Key >>= Expr | Undefined) * (Val >>= Expr) * Expr) ; // clang-format on @@ -106,8 +107,18 @@ namespace wf_bundle_refheads, dir::bottomup | dir::once, { + In(Term) * + (T(Scalar) << (T(String) << T(TemplateString)[TemplateString])) >> + [](Match& _) { return _(TemplateString); }, + + In(DataTerm) * + (T(Scalar) << (T(String) << T(TemplateString)[TemplateString])) >> + [](Match& _) { + return err(_(TemplateString), "Invalid template string in data"); + }, + In(Scalar) * (T(String) << T(JSONString)[JSONString]) >> - [](Match& _) { return JSONString ^ _(JSONString); }, + [](Match& _) { return _(JSONString); }, In(Scalar) * (T(String) << T(RawString)[RawString]) >> [](Match& _) { @@ -687,6 +698,13 @@ namespace In(Query) * (T(Literal) << T(SomeDecl)[SomeDecl]) >> [](Match& _) { return err(_(SomeDecl), "Invalid some statement"); }, + In(TemplateString) * + (T(Literal) << (T(SomeDecl)[SomeDecl] * T(WithSeq))) >> + [](Match& _) { + return err( + _(SomeDecl), "Invalid some statement in template string"); + }, + In(Query) * (T(Literal)[Literal] << ((T(Expr) @@ -1543,7 +1561,7 @@ namespace | (ExprAssignFromArray <<= AssignVar * Var * Int) | (ExprAssignFromObject <<= AssignVar * Var * Expr) | (Literal <<= (Expr >>= ExprAssignFromArray | ExprAssignFromObject | ExprIsArray | ExprIsObject | ExprAssign | ExprUnify | ExprScan | ExprEvery | Local | Expr | NotExpr) * WithSeq) - | (Term <<= (UnifyVar | Ref | Var | Scalar | Array | Object | Set | Membership | ArrayCompr | ObjectCompr | SetCompr)) + | (Term <<= (UnifyVar | TemplateString | Ref | Var | Scalar | Array | Object | Set | Membership | ArrayCompr | ObjectCompr | SetCompr)) ; // clang-format on @@ -2793,6 +2811,12 @@ namespace // errors + In(TemplateString) * + (T(Literal) << (T(ExprUnify)[ExprUnify] * T(WithSeq))) >> + [](Match& _) { + return err(_(ExprUnify), "Invalid unification in template string"); + }, + In(ExprScan) * (T(Local) << T(Ident)[Ident]) >> [](Match& _) -> Node { std::string_view name = _(Ident)->location().view(); if (name.starts_with("scan")) diff --git a/src/virtual_machine.cc b/src/virtual_machine.cc index d249fdd4..4cb1ef2b 100644 --- a/src/virtual_machine.cc +++ b/src/virtual_machine.cc @@ -659,7 +659,7 @@ namespace rego case b::StatementType::MakeNumberRef: { const Location& num_value = m_bundle->strings[stmt.op0.index]; - if (RE2::FullMatch(num_value.view(), m_int_regex)) + if (TRegex::FullMatch(num_value.view(), m_int_regex)) { state.write_local(stmt.target, Int ^ num_value); } diff --git a/tests/bigint.yaml b/tests/bigint.yaml index 59178809..8ed291c3 100644 --- a/tests/bigint.yaml +++ b/tests/bigint.yaml @@ -1,3 +1,5 @@ +# Speculative tests for full big-integer arithmetic support. +# These are NOT part of the compliance suite and may fail on the current build. cases: - note: bigint/4-bit modules: diff --git a/tests/builtins.cc b/tests/builtins.cc index 4442bb73..1efd77b1 100644 --- a/tests/builtins.cc +++ b/tests/builtins.cc @@ -1,5 +1,6 @@ #include "test_case.h" +#include #include #ifndef _WIN32 @@ -29,7 +30,7 @@ namespace { const char* duration_re = R"((-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][-+]?[0-9]+)?)((?:ns|us|µs|ms|s|m|h)))"; - const RE2 re(duration_re); + const TRegex re(duration_re); assert(re.ok()); std::string number; @@ -38,8 +39,8 @@ namespace std::size_t start = 0; while (start < duration.size()) { - re2::StringPiece input(duration.c_str() + start, duration.size() - start); - if (RE2::PartialMatch(input, re, &number, &unit)) + std::string_view input(duration.c_str() + start, duration.size() - start); + if (TRegex::PartialMatch(input, re, &number, &unit)) { double number_d = std::stod(number); double unit_ns = duration_units.at(unit); diff --git a/tests/c_api.cc b/tests/c_api.cc index 17a14591..321a73d6 100644 --- a/tests/c_api.cc +++ b/tests/c_api.cc @@ -68,7 +68,7 @@ int main(void) regoSetDebugEnabled(rego, true); regoSetDebugPath(rego, "test"); - regoSetLogLevel(rego, regoLogLevelFromString("Debug")); + regoSetLogLevel(rego, regoLogLevelFromString("Warning")); err = regoAddModuleFile(rego, "examples/objects.rego"); if (err != REGO_OK) diff --git a/tests/regocpp.yaml b/tests/regocpp.yaml index 42759ed1..28992eb9 100644 --- a/tests/regocpp.yaml +++ b/tests/regocpp.yaml @@ -1311,3 +1311,50 @@ cases: - x: baz: gimel foo: aleph +- modules: + - | + package test + + greeting := "world" + msg := $"Hello {greeting}!" + note: regocpp/stringinterp + query: data.test.msg = x + want_result: + - x: Hello world! +- modules: + - | + package test + + msg := $"Hello {input.name with input as {"name": "world"}}!" + note: regocpp/stringinterp-with + query: data.test.msg = x + want_result: + - x: Hello world! +- modules: + - | + package test + + token := io.jwt.encode_sign({"alg": "HS256"}, {"sub": "alice"}, {"kty": "oct", "k": base64url.encode_no_pad("secret")}) + + match := io.jwt.decode_verify(token, {"secret": "secret", "alg": "HS256", "sub": "alice"}) + note: regocpp/jwt-sub-match + query: data.test.match = x + want_result: + - x: + - true + - {"alg": "HS256"} + - {"sub": "alice"} +- modules: + - | + package test + + token := io.jwt.encode_sign({"alg": "HS256"}, {"sub": "alice"}, {"kty": "oct", "k": base64url.encode_no_pad("secret")}) + + mismatch := io.jwt.decode_verify(token, {"secret": "secret", "alg": "HS256", "sub": "bob"}) + note: regocpp/jwt-sub-mismatch + query: data.test.mismatch = x + want_result: + - x: + - false + - {} + - {} diff --git a/tests/test_case.cc b/tests/test_case.cc index b8ce3da7..21a3472c 100644 --- a/tests/test_case.cc +++ b/tests/test_case.cc @@ -143,7 +143,8 @@ namespace rego_test m_want_defined(false), m_sort_bindings(false), m_strict_error(false), - m_broken(false) + m_broken(false), + m_unsupported(false) {} std::optional TestCase::maybe_get_object( @@ -319,11 +320,13 @@ namespace rego_test assert(string == JSONString); std::string key = std::string(string->location().view()); key = key.substr(1, key.size() - 2); // remove quotes - std::string value = rego::to_key(item / Val, true, m_sort_bindings); + std::string value = + rego::to_key(item / Val, rego::SetFormat::Square, m_sort_bindings); binding_map[key] = value; } - std::string binding_key = rego::to_key(object, true, m_sort_bindings); + std::string binding_key = + rego::to_key(object, rego::SetFormat::Square, m_sort_bindings); binding_maps[binding_key] = binding_map; } } @@ -343,12 +346,13 @@ namespace rego_test { std::string key = std::string((binding / rego::Key)->location().view()); - std::string value = - rego::to_key((binding / rego::Val), true, m_sort_bindings); + std::string value = rego::to_key( + (binding / rego::Val), rego::SetFormat::Square, m_sort_bindings); binding_map[key] = value; } - std::string binding_key = rego::to_key(bindings, true, m_sort_bindings); + std::string binding_key = + rego::to_key(bindings, rego::SetFormat::Square, m_sort_bindings); binding_maps[binding_key] = binding_map; } } @@ -611,6 +615,18 @@ namespace rego_test test_case.broken(true); } +#ifndef REGOCPP_CRYPTO_OPENSSL3 + if ( + test_case.note() == "jwtdecodeverify/EdDSA" || + test_case.note() == "jwtencodesign/EdDSA" || + test_case.note() == "jwtencodesignraw/EdDSA" || + test_case.note().find("jwtverifyeddsa/") == 0) + { + // Only the OpenSSL backend supports Ed25519/EdDSA. + test_case.unsupported(true); + } +#endif + return test_case; } catch (const std::exception& e) @@ -746,6 +762,11 @@ namespace rego_test return {Outcome::Skip, "Test Broken"}; } + if (m_unsupported) + { + return {Outcome::Skip, "Not Supported"}; + } + if (actual->type() == ErrorSeq) { if (actual->size() > 1) @@ -1072,4 +1093,15 @@ namespace rego_test m_broken = broken; return *this; } + + bool TestCase::unsupported() const + { + return m_unsupported; + } + + TestCase& TestCase::unsupported(bool unsupported) + { + m_unsupported = unsupported; + return *this; + } } \ No newline at end of file diff --git a/tests/test_case.h b/tests/test_case.h index a87dfdce..65bf8d84 100644 --- a/tests/test_case.h +++ b/tests/test_case.h @@ -113,6 +113,10 @@ namespace rego_test bool broken() const; TestCase& broken(bool broken); + /// indicates the test uses functionality not supported by this backend + bool unsupported() const; + TestCase& unsupported(bool unsupported); + /// whether to perform a serialisation round-trip before running the test RoundTrip roundtrip() const; TestCase& roundtrip(RoundTrip setting); @@ -162,6 +166,7 @@ namespace rego_test bool m_sort_bindings; bool m_strict_error; bool m_broken; + bool m_unsupported; }; } diff --git a/tools/main.cc b/tools/main.cc index eb2601ae..9b26a8ee 100644 --- a/tools/main.cc +++ b/tools/main.cc @@ -326,7 +326,8 @@ int main(int argc, char** argv) Timer timer("Query", timing); result = interpreter->query_bundle(bundle)->front(); } - trieste::logging::Output() << rego::to_key(result, true) << std::endl; + trieste::logging::Output() + << rego::to_key(result, rego::SetFormat::Square) << std::endl; } else { @@ -338,7 +339,8 @@ int main(int argc, char** argv) Timer timer("Endpoint", timing); result = interpreter->query_bundle(bundle, entrypoint)->front(); } - trieste::logging::Output() << rego::to_key(result, true) << std::endl; + trieste::logging::Output() + << rego::to_key(result, rego::SetFormat::Square) << std::endl; } } } diff --git a/wrappers/dotnet/Rego/Rego.csproj b/wrappers/dotnet/Rego/Rego.csproj index 4ae6f731..1c17ad65 100644 --- a/wrappers/dotnet/Rego/Rego.csproj +++ b/wrappers/dotnet/Rego/Rego.csproj @@ -10,7 +10,7 @@ Matthew Johnson Microsoft - 1.2.0 + 1.3.0 Copyright (c) Microsoft. All rights reserved. This client library provides in-process Rego query support for .NET applications. @@ -46,21 +46,24 @@ librego_shared.so $(RegoCPPBuildDir)/src/$(RegoCPPLibraryName) + mbedtls rego_shared.dll $(RegoCPPBuildDir)/src/Release/$(RegoCPPLibraryName) + bcrypt librego_shared.dylib $(RegoCPPBuildDir)/src/$(RegoCPPLibraryName) + mbedtls - + diff --git a/wrappers/python/setup.py b/wrappers/python/setup.py index a9ce92a1..c5ca7ead 100644 --- a/wrappers/python/setup.py +++ b/wrappers/python/setup.py @@ -28,7 +28,7 @@ with open("README.md", "r") as file: LONG_DESCRIPTION = file.read() -VERSION = "1.2.0" +VERSION = "1.3.0" class CMakeExtension(Extension): @@ -76,12 +76,18 @@ def build_extension(self, ext: CMakeExtension): subprocess.check_call(["git", "clone", repo, src_path]) subprocess.check_call(["git", "checkout", tag], cwd=src_path) + if platform.uname()[0] == "Windows": + crypto_backend = "bcrypt" + else: + crypto_backend = "mbedtls" + cmake_args = [f"-S {src_path}", f"-B {self.build_temp}", f"-DCMAKE_INSTALL_PREFIX={extdir}", "-DREGOCPP_BUILD_SHARED=ON", f"-DCMAKE_BUILD_TYPE={cfg}", - "-DSNMALLOC_ENABLE_DYNAMIC_LOADING=ON"] + "-DSNMALLOC_ENABLE_DYNAMIC_LOADING=ON", + f"-DREGOCPP_CRYPTO_BACKEND={crypto_backend}"] build_args = ["--build", self.build_temp, "--config", cfg, diff --git a/wrappers/rust/regorust/Cargo.toml b/wrappers/rust/regorust/Cargo.toml index e013368a..5c42a234 100644 --- a/wrappers/rust/regorust/Cargo.toml +++ b/wrappers/rust/regorust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "regorust" -version = "1.2.0" +version = "1.3.0" edition = "2021" description = "Rust bindings for the rego-cpp Rego compiler and interpreter" license = "MIT" diff --git a/wrappers/rust/regorust/build.rs b/wrappers/rust/regorust/build.rs index a569a958..ec36e67c 100644 --- a/wrappers/rust/regorust/build.rs +++ b/wrappers/rust/regorust/build.rs @@ -62,6 +62,7 @@ fn main() { .as_str(), "-DCMAKE_INSTALL_PREFIX=rust", "-DREGOCPP_COPY_EXAMPLES=ON", + "-DREGOCPP_CRYPTO_BACKEND=mbedtls", ]) .current_dir(®ocpp_path) .status() @@ -90,7 +91,6 @@ fn main() { let header_path_str = header_path.to_str().unwrap(); println!("cargo:rustc-link-search={}", libdir_path.to_str().unwrap()); - println!("cargo:rustc-link-lib=static:+whole-archive=re2"); println!("cargo:rustc-link-lib=static:+whole-archive=json"); println!("cargo:rustc-link-lib=static:+whole-archive=yaml"); println!("cargo:rustc-link-lib=static=rego"); @@ -104,6 +104,12 @@ fn main() { println!("cargo:rustc-link-lib=static:+whole-archive=snmalloc-new-override"); println!("cargo:rustc-link-lib=stdc++"); } + // mbedtls libraries for crypto/JWT builtins + println!("cargo:rustc-link-lib=static=mbedtls"); + println!("cargo:rustc-link-lib=static=mbedcrypto"); + println!("cargo:rustc-link-lib=static=mbedx509"); + println!("cargo:rustc-link-lib=static=everest"); + println!("cargo:rustc-link-lib=static=p256m"); println!("cargo:rerun-if-changed={}", header_path_str); // The bindgen::Builder is the main entry point