diff --git a/Cargo.toml b/Cargo.toml index 2e1cea09e..5c2491c47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ chrono = "0.4.41" divan = "0.1.21" hex = "0.4.3" itertools = "0.14.0" +num-bigint = "0.4" paste = "1.0.15" postcard = { version = "1.1.1", features = ["use-std"] } primitive-types = "0.13.1" diff --git a/noir-examples/embedded_curve_msm/Nargo.toml b/noir-examples/embedded_curve_msm/Nargo.toml new file mode 100644 index 000000000..ec9891616 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml new file mode 100644 index 000000000..edf585681 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -0,0 +1,9 @@ +# ============================================================ +# MSM test vectors: result = s1 * G + s2 * G +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# ============================================================ +# n - 2, n - 3 (previously failing with [2]G offset) +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_identity.toml b/noir-examples/embedded_curve_msm/Prover_near_identity.toml new file mode 100644 index 000000000..a156f96ee --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_identity.toml @@ -0,0 +1,6 @@ +# MSM edge case: n-2 and n-3 (previously failing with [2]G offset) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_order.toml b/noir-examples/embedded_curve_msm/Prover_near_order.toml new file mode 100644 index 000000000..d8ae04eb5 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_order.toml @@ -0,0 +1,6 @@ +# MSM edge case: near-max scalars (n-10 and n-20) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168189" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168179" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml new file mode 100644 index 000000000..8455db356 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml @@ -0,0 +1,5 @@ +# MSM edge case: one zero scalar, one non-zero (0*G + 5*G = 5*G) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "5" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml new file mode 100644 index 000000000..0bd8866c7 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml @@ -0,0 +1,5 @@ +# MSM edge case: all-zero scalars (0*G + 0*G = point at infinity) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "0" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr new file mode 100644 index 000000000..19a193181 --- /dev/null +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -0,0 +1,52 @@ +use std::embedded_curve_ops::{ + EmbeddedCurvePoint, + EmbeddedCurveScalar, + multi_scalar_mul, +}; + +/// Exercises the MultiScalarMul ACIR blackbox with 2 Grumpkin points. +/// Computes s1 * G + s2 * G where G is the Grumpkin generator. +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + // Grumpkin generator + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + + let s1 = EmbeddedCurveScalar { lo: scalar1_lo, hi: scalar1_hi }; + let s2 = EmbeddedCurveScalar { lo: scalar2_lo, hi: scalar2_hi }; + + // MSM: result = s1 * G + s2 * G + let result = multi_scalar_mul([g, g], [s1, s2]); + + // Prevent dead-code elimination - forces the blackbox to be retained + // Using is_infinite as return value ensures the MSM is computed + assert(result.is_infinite == (scalar1_lo + scalar1_hi + scalar2_lo + scalar2_hi == 0)); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify by computing independently: 3*G should match + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + let s3 = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let check = multi_scalar_mul([g], [s3]); + + assert(check.x == expected_x); + assert(check.y == expected_y); +} diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml new file mode 100644 index 000000000..6b16fd3ae --- /dev/null +++ b/noir-examples/native_msm/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "native_msm" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir-examples/native_msm/Prover.toml b/noir-examples/native_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/native_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr new file mode 100644 index 000000000..901722a4e --- /dev/null +++ b/noir-examples/native_msm/src/main.nr @@ -0,0 +1,357 @@ +global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; + +// Hardcoded offset generators: offset = 5*G, offset_final = 2^252 * 5*G +// These are compile-time constants -- no unconstrained computation, no runtime verification needed. +global OFFSET_X: Field = 12229279139087521908560794489267966517139449915173592433539394009359081620359; +global OFFSET_Y: Field = 12096995292699515952722386974733884667125946823386040531322131902193094989869; +global OFFSET_FINAL_X: Field = 17097678145015848904467691187715743297134903912023447344174597163323183228319; +global OFFSET_FINAL_Y: Field = 14560299638432262069836824301755786319891239433999125203465365199349384123743; + +// BN254 scalar field modulus as wNAF slices (MSB first). +// Used for lexicographic range check: ensures wNAF integer < p, preventing mod-p ambiguity. +global MODULUS_SLICES: [u8; 64] = [ + 9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, + 13, 12, 2, 8, 2, 2, 13, 11, 4, 0, 12, 0, 10, 12, 2, 14, + 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, + 10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0, +]; + +struct GPoint { x: Field, y: Field } +struct GPointResult { x: Field, y: Field, is_infinity: bool } +struct Hint { lambda: Field, x3: Field, y3: Field } + +// ====== Constrained EC verification ====== + +// ~4 constraints: verify point doubling +fn c_double(p: GPoint, h: Hint) -> GPoint { + let xx = p.x * p.x; + assert(h.lambda * (p.y + p.y) == 3 * xx); // a=0 for Grumpkin + assert(h.lambda * h.lambda == h.x3 + p.x + p.x); + assert(h.lambda * (p.x - h.x3) == h.y3 + p.y); + GPoint { x: h.x3, y: h.y3 } +} + +// ~4 constraints: verify point addition (incomplete -- offset generator prevents edge cases) +fn c_add(p1: GPoint, p2: GPoint, h: Hint) -> GPoint { + assert(p1.x != p2.x); + assert(h.lambda * (p2.x - p1.x) == p2.y - p1.y); + assert(h.lambda * h.lambda == h.x3 + p1.x + p2.x); + assert(h.lambda * (p1.x - h.x3) == h.y3 + p1.y); + GPoint { x: h.x3, y: h.y3 } +} + +// ====== Unconstrained witness generation ====== + +unconstrained fn u_double(p: GPoint) -> (GPoint, Hint) { + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +unconstrained fn u_add(p1: GPoint, p2: GPoint) -> (GPoint, Hint) { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +// Unconstrained complete addition hint (handles add, double, and inverse cases) +unconstrained fn u_complete_add_hint(p1: GPoint, p2: GPoint) -> Hint { + if p1.x == p2.x { + if p1.y == p2.y { + let (_, hint) = u_double(p1); + hint + } else { + Hint { lambda: 0, x3: 0, y3: 0 } + } + } else { + let (_, hint) = u_add(p1, p2); + hint + } +} + +// Unconstrained full addition hint (handles infinity inputs) +unconstrained fn u_full_add_hint( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, +) -> Hint { + if p1_inf | p2_inf { + Hint { lambda: 0, x3: 0, y3: 0 } + } else { + u_complete_add_hint(p1, p2) + } +} + +// Constrained complete addition: handles add, double, and inverse-point cases. +// Both inputs must be valid on-curve points (not identity). +// Uses `active * constraint == 0` pattern so constraints are trivially satisfied +// when the result is the identity (inverse-point case). +fn c_complete_add(p1: GPoint, p2: GPoint, h: Hint) -> GPointResult { + let x_eq: bool = p1.x == p2.x; + let y_eq: bool = p1.y == p2.y; + let is_infinity: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let active: Field = (!is_infinity) as Field; + + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(active * (h.lambda * lambda_lhs - lambda_rhs) == 0); + + // x3 verification: lambda^2 = x3 + x1 + x2 (same for add and double since x2=x1 when doubling) + assert(active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + + // y3 verification: lambda * (x1 - x3) = y3 + y1 + assert(active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + GPointResult { x: h.x3 * active, y: h.y3 * active, is_infinity } +} + +// Constrained full addition: handles all cases including identity inputs. +// This is used for the final MSM sum where either operand may be the identity. +fn c_full_add( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, + h: Hint, +) -> GPointResult { + let neither_inf = !p1_inf & !p2_inf; + let both_inf = p1_inf & p2_inf; + let only_p1_inf = p1_inf & !p2_inf; + + // EC constraints are only active when neither input is identity. + let ec_active: Field = neither_inf as Field; + + // Determine add/double/inverse case (only meaningful when neither is identity). + // Guard with neither_inf so garbage coordinates from identity points don't affect predicates. + let x_eq: bool = (p1.x == p2.x) & neither_inf; + let y_eq: bool = (p1.y == p2.y) & neither_inf; + let is_inf_from_add: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let arith_active: Field = ec_active * (!is_inf_from_add as Field); + + // Lambda, x3, y3 constraints (zeroed when inactive) + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(arith_active * (h.lambda * lambda_lhs - lambda_rhs) == 0); + assert(arith_active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + assert(arith_active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + // Output selection + let result_is_inf: bool = both_inf | is_inf_from_add; + + let out_x = if result_is_inf { 0 } + else if only_p1_inf { p2.x } + else if p2_inf { p1.x } + else { h.x3 }; + let out_y = if result_is_inf { 0 } + else if only_p1_inf { p2.y } + else if p2_inf { p1.y } + else { h.y3 }; + + GPointResult { x: out_x, y: out_y, is_infinity: result_is_inf } +} + +unconstrained fn decompose_wnaf(scalar_lo: Field, scalar_hi: Field) -> ([u8; 64], bool) { + let lo_bytes = scalar_lo.to_le_bytes::<16>(); + let hi_bytes = scalar_hi.to_le_bytes::<16>(); + let mut nibbles: [u8; 64] = [0; 64]; + for i in 0..16 { + nibbles[2 * i] = lo_bytes[i] & 0x0F; + nibbles[2 * i + 1] = lo_bytes[i] >> 4; + } + for i in 0..16 { + nibbles[32 + 2 * i] = hi_bytes[i] & 0x0F; + nibbles[32 + 2 * i + 1] = hi_bytes[i] >> 4; + } + let skew: bool = (nibbles[0] & 1) == 0; + nibbles[0] = nibbles[0] + (skew as u8); + let mut slices: [u8; 64] = [0; 64]; + slices[63] = (nibbles[0] + 15) / 2; + for i in 1..64 { + let nibble = nibbles[i]; + slices[63 - i] = (nibble + 15) / 2; + if (nibble & 1) == 0 { + slices[63 - i] += 1; + slices[64 - i] -= 8; + } + } + (slices, skew) +} + +// 326 hints per scalar mul: 8 table + 1 init + 63*5 loop + 1 skew + 1 final +unconstrained fn compute_transcript( + P: GPoint, slices: [u8; 64], skew: bool, + offset: GPoint, offset_final: GPoint, +) -> [Hint; 326] { + let mut h: [Hint; 326] = [Hint { lambda: 0, x3: 0, y3: 0 }; 326]; + let mut p: u32 = 0; + + // Table: 2P, then P+2P, 3P+2P, ... + let (d2, d2h) = u_double(P); + h[p] = d2h; p += 1; + let mut table: [GPoint; 16] = [GPoint { x: 0, y: 0 }; 16]; + table[8] = P; + table[7] = GPoint { x: P.x, y: 0 - P.y }; + let mut A = P; + for i in 1..8 { + let (s, sh) = u_add(A, d2); h[p] = sh; p += 1; + A = s; + table[8 + i] = A; + table[7 - i] = GPoint { x: A.x, y: 0 - A.y }; + } + + // Init: offset + T[slices[0]] + let (ir, ih) = u_add(offset, table[slices[0] as u32]); + h[p] = ih; p += 1; + let mut acc = ir; + + // 63 windows: 4 doubles + 1 add each + for _w in 1..64 { + for _ in 0..4 { let (d, dh) = u_double(acc); h[p] = dh; p += 1; acc = d; } + let tp = table[slices[_w] as u32]; + let (s, sh) = u_add(acc, tp); h[p] = sh; p += 1; acc = s; + } + + // Skew correction (always compute valid hint even if unused) + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let (sr, sh) = u_add(acc, neg_P); + h[p] = sh; p += 1; + if skew { acc = sr; } + + // Final offset subtraction (complete -- handles identity result when scalar = 0) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + h[p] = u_complete_add_hint(acc, neg_off); + + h +} + +// ====== Scalar range check ====== +// Lexicographic comparison: ensures wNAF slices represent an integer < field modulus. +// Without this, a prover could encode scalar + k*p (for k != 0) using valid 4-bit slices, +// since the Horner reconstruction only checks equality mod p. +// Mirrors noir_bigcurve's `compare_scalar_field_to_bignum`. +fn assert_slices_less_than_modulus(slices: [u8; 64]) { + let mut found_strictly_less: bool = false; + for i in 0..64 { + if !found_strictly_less { + let s = slices[i]; + let m = MODULUS_SLICES[i]; + // If we find a digit strictly less than modulus digit, scalar < modulus -- done. + if s as u8 < m { + found_strictly_less = true; + } else { + // If strictly greater at any position (without prior strictly-less), scalar >= modulus. + assert(s == m, "wNAF scalar exceeds field modulus"); + } + } + } + // If all digits equal, scalar == modulus, which is also invalid (must be strictly less). + assert(found_strictly_less, "wNAF scalar equals field modulus"); +} + +// ====== Main scalar multiplication ====== + +fn scalar_mul_wnaf(P: GPoint, scalar_lo: Field, scalar_hi: Field) -> GPointResult { + // 1. Decompose scalar into wNAF slices + // Safety: slices and skew are fully constrained below (range, reconstruction, and modulus bound) + let (slices, skew) = unsafe { decompose_wnaf(scalar_lo, scalar_hi) }; + + // Range check: each slice fits in 4 bits + for i in 0..64 { (slices[i] as Field).assert_max_bit_size::<4>(); } + + // Soundness fix #1: scalar range check -- slices represent integer < field modulus + assert_slices_less_than_modulus(slices); + + // Reconstruction check: wNAF Horner evaluation == scalar_lo + scalar_hi * 2^128 + let mut r: Field = 0; + for i in 0..64 { r = r * 16; r += (slices[i] as Field) * 2 - 15; } + r -= skew as Field; + + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + let mut expected: Field = 0; + let mut pow: Field = 1; + for i in 0..128 { expected += (lo_bits[i] as Field) * pow; pow *= 2; } + let two_128: Field = pow; + let mut hi_val: Field = 0; + pow = 1; + for i in 0..128 { hi_val += (hi_bits[i] as Field) * pow; pow *= 2; } + expected += hi_val * two_128; + assert(r == expected); + + // 2. Offset generators -- hardcoded compile-time constants (soundness fix #2) + let offset = GPoint { x: OFFSET_X, y: OFFSET_Y }; + let offset_final = GPoint { x: OFFSET_FINAL_X, y: OFFSET_FINAL_Y }; + + // 3. Transcript of EC operation hints + // Safety: every hint is verified by a constrained c_double or c_add call below + let hints = unsafe { compute_transcript(P, slices, skew, offset, offset_final) }; + let mut hp: u32 = 0; + + // 4. Build 16-entry lookup table: T[8]=P, T[9]=3P, ..., T[15]=15P, T[7]=-P, ..., T[0]=-15P + let d2 = c_double(P, hints[hp]); hp += 1; + let mut tx: [Field; 16] = [0; 16]; + let mut ty: [Field; 16] = [0; 16]; + tx[8] = P.x; ty[8] = P.y; + tx[7] = P.x; ty[7] = 0 - P.y; + let mut A = P; + for i in 1..8 { + A = c_add(A, d2, hints[hp]); hp += 1; + tx[8 + i] = A.x; ty[8 + i] = A.y; + tx[7 - i] = A.x; ty[7 - i] = 0 - A.y; + } + + // 5. Init accumulator: offset + T[slices[0]] + let first = GPoint { x: tx[slices[0] as u32], y: ty[slices[0] as u32] }; + let mut acc = c_add(offset, first, hints[hp]); hp += 1; + + // 6. Main wNAF loop: 63 windows * (4 doublings + 1 table add) + for _w in 1..64 { + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + let tp = GPoint { x: tx[slices[_w] as u32], y: ty[slices[_w] as u32] }; + acc = c_add(acc, tp, hints[hp]); hp += 1; + } + + // 7. Skew correction: if scalar was even, subtract P + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let skew_r = c_add(acc, neg_P, hints[hp]); hp += 1; + acc = if skew { skew_r } else { acc }; + + // 8. Subtract accumulated offset: result = acc - 2^252 * offset + // Uses complete addition to handle the identity result (scalar = 0 mod group_order) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + c_complete_add(acc, neg_off, hints[hp]) +} + +/// 2-point MSM on Grumpkin: s1*G + s2*G +fn main( + scalar1_lo: pub Field, scalar1_hi: pub Field, + scalar2_lo: pub Field, scalar2_hi: pub Field, +) -> pub (Field, Field, bool) { + let g = GPoint { x: 1, y: GRUMPKIN_GEN_Y }; + + let r1 = scalar_mul_wnaf(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul_wnaf(g, scalar2_lo, scalar2_hi); + + // Full addition: handles r1 == r2 (doubling), r1 == -r2 (identity), and identity inputs + let add_hint = unsafe { + u_full_add_hint( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + ) + }; + let result = c_full_add( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + add_hint, + ); + + // Verify result is on Grumpkin (skip for identity) + let on_curve = result.y * result.y - (result.x * result.x * result.x - 17); + assert((!result.is_infinity as Field) * on_curve == 0); + + (result.x, result.y, result.is_infinity) +} diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index f7cf80db2..4c9950064 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -17,10 +17,12 @@ pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, digits::{decompose_into_digits, DigitalDecompositionWitnesses}, ram::{SpiceMemoryOperation, SpiceWitnesses}, - scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, + scheduling::{ + Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + }, witness_builder::{ - CombinedTableEntryInverseData, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, witness_generator::NoirWitnessGenerator, }; diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index a5cbaefd6..0794744ae 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -1,7 +1,7 @@ use { crate::witness::{ - ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, std::collections::HashMap, }; @@ -78,7 +78,10 @@ impl DependencyInfo { WitnessBuilder::Sum(_, ops) => ops.iter().map(|SumTerm(_, idx)| *idx).collect(), WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), - WitnessBuilder::Inverse(_, x) => vec![*x], + WitnessBuilder::Inverse(_, x) + | WitnessBuilder::SafeInverse(_, x) + | WitnessBuilder::ModularInverse(_, x, _) + | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( _, sz, @@ -152,6 +155,28 @@ impl DependencyInfo { } v } + WitnessBuilder::MultiLimbMulModHint { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), + WitnessBuilder::MultiLimbAddQuotient { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbSubBorrow { + a_limbs, b_limbs, .. + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -198,6 +223,32 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::EcDoubleHint { px, py, .. } => vec![*px, *py], + WitnessBuilder::EcAddHint { x1, y1, x2, y2, .. } => vec![*x1, *y1, *x2, *y2], + WitnessBuilder::NonNativeEcHint { inputs, .. } => { + inputs.iter().flatten().copied().collect() + } + WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], + WitnessBuilder::EcScalarMulHint { + px_limbs, + py_limbs, + s_lo, + s_hi, + .. + } => px_limbs + .iter() + .chain(py_limbs.iter()) + .copied() + .chain([*s_lo, *s_hi]) + .collect(), + WitnessBuilder::SelectWitness { + flag, + on_false, + on_true, + .. + } => vec![*flag, *on_false, *on_true], + WitnessBuilder::BooleanOr { a, b, .. } => vec![*a, *b], + WitnessBuilder::SignedBitHint { scalar, .. } => vec![*scalar], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -240,6 +291,9 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::SafeInverse(idx, _) + | WitnessBuilder::ModularInverse(idx, ..) + | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) | WitnessBuilder::LogUpDenominator(idx, ..) | WitnessBuilder::LogUpInverse(idx, ..) @@ -254,6 +308,13 @@ impl DependencyInfo { | WitnessBuilder::SpreadWitness(idx, ..) | WitnessBuilder::SpreadLookupDenominator(idx, ..) | WitnessBuilder::SpreadTableQuotient { idx, .. } => vec![*idx], + WitnessBuilder::SelectWitness { output, .. } + | WitnessBuilder::BooleanOr { output, .. } => vec![*output], + WitnessBuilder::SignedBitHint { + output_start, + num_bits, + .. + } => (*output_start..*output_start + *num_bits + 1).collect(), WitnessBuilder::MultiplicitiesForRange(start, range, _) => { (*start..*start + *range).collect() @@ -282,6 +343,47 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } + WitnessBuilder::MultiLimbMulModHint { + output_start, + num_limbs, + .. + } => { + let count = (4 * *num_limbs - 2) as usize; + (*output_start..*output_start + count).collect() + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::EcDoubleHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } + WitnessBuilder::EcAddHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } + WitnessBuilder::NonNativeEcHint { + output_start, + num_limbs, + op, + .. + } => { + let count = match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (12 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (7 * *num_limbs - 4) as usize, + }; + (*output_start..*output_start + count).collect() + } + WitnessBuilder::FakeGLVHint { output_start, .. } => { + (*output_start..*output_start + 4).collect() + } + WitnessBuilder::EcScalarMulHint { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + 2 * *num_limbs as usize).collect(), + WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], + WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 9503847a3..dd63190c2 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,15 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::SafeInverse(idx, operand) => { + WitnessBuilder::SafeInverse(self.remap(*idx), self.remap(*operand)) + } + WitnessBuilder::ModularInverse(idx, operand, modulus) => { + WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) + } + WitnessBuilder::IntegerQuotient(idx, dividend, divisor) => { + WitnessBuilder::IntegerQuotient(self.remap(*idx), self.remap(*dividend), *divisor) + } WitnessBuilder::ProductLinearOperation( idx, ProductLinearTerm(x, a, b), @@ -215,6 +224,64 @@ impl WitnessIndexRemapper { .collect(), ) } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbMulModHint { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbModularInverse { + output_start: self.remap(*output_start), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbAddQuotient { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbSubBorrow { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), hi: self.remap(*hi), @@ -299,6 +366,113 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcDoubleHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => WitnessBuilder::EcAddHint { + output_start: self.remap(*output_start), + x1: self.remap(*x1), + y1: self.remap(*y1), + x2: self.remap(*x2), + y2: self.remap(*y2), + field_modulus_p: *field_modulus_p, + }, + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => WitnessBuilder::NonNativeEcHint { + output_start: self.remap(*output_start), + op: op.clone(), + inputs: inputs + .iter() + .map(|v| v.iter().map(|&w| self.remap(w)).collect()) + .collect(), + curve_a: *curve_a, + curve_b: *curve_b, + field_modulus_p: *field_modulus_p, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => WitnessBuilder::FakeGLVHint { + output_start: self.remap(*output_start), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_order: *curve_order, + }, + WitnessBuilder::EcScalarMulHint { + output_start, + px_limbs, + py_limbs, + s_lo, + s_hi, + curve_a, + field_modulus_p, + num_limbs, + limb_bits, + } => WitnessBuilder::EcScalarMulHint { + output_start: self.remap(*output_start), + px_limbs: px_limbs.iter().map(|&w| self.remap(w)).collect(), + py_limbs: py_limbs.iter().map(|&w| self.remap(w)).collect(), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + num_limbs: *num_limbs, + limb_bits: *limb_bits, + }, + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => WitnessBuilder::SelectWitness { + output: self.remap(*output), + flag: self.remap(*flag), + on_false: self.remap(*on_false), + on_true: self.remap(*on_true), + }, + WitnessBuilder::BooleanOr { output, a, b } => WitnessBuilder::BooleanOr { + output: self.remap(*output), + a: self.remap(*a), + b: self.remap(*b), + }, + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => WitnessBuilder::SignedBitHint { + output_start: self.remap(*output_start), + scalar: self.remap(*scalar), + num_bits: *num_bits, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0628fc2e3..8bcbc3d09 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -54,9 +54,21 @@ pub struct CombinedTableEntryInverseData { pub xor_out: FieldElement, } +/// Operation type for the unified non-native EC hint. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum NonNativeEcOp { + /// Point doubling: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 15N-6 + Double, + /// Point addition: inputs = \[\[x1_limbs\], \[y1_limbs\], \[x2_limbs\], + /// \[y2_limbs\]\], outputs 15N-6 + Add, + /// On-curve check: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 9N-4 + OnCurve, +} + /// Indicates how to solve for a collection of R1CS witnesses in terms of /// earlier (i.e. already solved for) R1CS witnesses and/or ACIR witness values. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum WitnessBuilder { /// Constant value, used for the constant one witness & e.g. static lookups /// (witness index, constant value) @@ -88,6 +100,23 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// Safe inverse: like Inverse but handles zero by outputting 0. + /// Used by compute_is_zero where the input may be zero. Solved in the + /// Other layer (not batch-inverted), so zero inputs don't poison the batch. + /// (witness index, operand witness index) + SafeInverse(usize, usize), + /// The modular inverse of the value at a specified witness index, modulo + /// a given prime modulus. Computes a^{-1} mod m using Fermat's little + /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this + /// operates as integer modular arithmetic. + /// (witness index, operand witness index, modulus) + ModularInverse(usize, usize, #[serde(with = "serde_ark")] FieldElement), + /// The integer quotient floor(dividend / divisor). Used by reduce_mod to + /// compute k = floor(v / m) so that v = k*m + result with 0 <= result < m. + /// Unlike field multiplication by the inverse, this performs true integer + /// division on the BigInteger representation. + /// (witness index, dividend witness index, divisor constant) + IntegerQuotient(usize, usize, #[serde(with = "serde_ark")] FieldElement), /// Products with linear operations on the witness indices. /// Fields are ProductLinearOperation(witness_idx, (index, a, b), (index, c, /// d)) such that we wish to compute (ax + b) * (cx + d). @@ -189,6 +218,61 @@ pub enum WitnessBuilder { /// Inverse of combined lookup table entry denominator (constant operands). /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), + /// Prover hint for multi-limb modular multiplication: (a * b) mod p. + /// Given inputs a and b as N-limb vectors (each limb `limb_bits` wide), + /// and a constant 256-bit modulus p, computes quotient q, remainder r, + /// and carry witnesses for schoolbook column verification. + /// + /// Outputs (4*num_limbs - 2) witnesses starting at output_start: + /// [0..N) q limbs (quotient) + /// [N..2N) r limbs (remainder) — OUTPUT + /// [2N..4N-2) carry witnesses (unsigned-offset) + MultiLimbMulModHint { + output_start: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb modular inverse: a^{-1} mod p. + /// Given input a as N-limb vector and constant modulus p, + /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). + /// + /// Outputs num_limbs witnesses at output_start: inv limbs. + MultiLimbModularInverse { + output_start: usize, + a_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb addition quotient: q = floor((a + b) / p). + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1}. + /// + /// Outputs 1 witness at output: q. + MultiLimbAddQuotient { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Prover hint for multi-limb subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. + /// + /// Outputs 1 witness at output: q. + MultiLimbSubBorrow { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... @@ -198,6 +282,118 @@ pub enum WitnessBuilder { packed: usize, chunk_bits: Vec, }, + /// Prover hint for FakeGLV scalar decomposition. + /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, + /// computes half_gcd(s, n) → (|s1|, |s2|, neg1, neg2) such that: + /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) + /// + /// Outputs 4 witnesses starting at output_start: + /// \[0\] |s1| (128-bit field element) + /// \[1\] |s2| (128-bit field element) + /// \[2\] neg1 (boolean: 0 or 1) + /// \[3\] neg2 (boolean: 0 or 1) + FakeGLVHint { + output_start: usize, + s_lo: usize, + s_hi: usize, + curve_order: [u64; 4], + }, + /// Prover hint for EC scalar multiplication: computes R = \[s\]P. + /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, + /// computes R = \[s\]P on the curve with parameter `curve_a` and + /// field modulus `field_modulus_p`. + /// + /// When `num_limbs == 1`: inputs are single witnesses, outputs 2 + /// witnesses (R_x, R_y) as native field elements. + /// When `num_limbs >= 2`: inputs are limb witnesses, outputs + /// `2 * num_limbs` witnesses (R_x limbs then R_y limbs). + EcScalarMulHint { + output_start: usize, + px_limbs: Vec, + py_limbs: Vec, + s_lo: usize, + s_hi: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + num_limbs: u32, + limb_bits: u32, + }, + /// Prover hint for EC point doubling on native field. + /// Given P = (px, py) and curve parameter `a`, computes: + /// lambda = (3*px^2 + a) / (2*py) mod p + /// x3 = lambda^2 - 2*px mod p + /// y3 = lambda * (px - x3) - py mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcDoubleHint { + output_start: usize, + px: usize, + py: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, + /// Prover hint for EC point addition on native field. + /// Given P1 = (x1, y1) and P2 = (x2, y2), computes: + /// lambda = (y2 - y1) / (x2 - x1) mod p + /// x3 = lambda^2 - x1 - x2 mod p + /// y3 = lambda * (x1 - x3) - y1 mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcAddHint { + output_start: usize, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + field_modulus_p: [u64; 4], + }, + /// Conditional select: output = on_false + flag * (on_true - on_false). + /// When flag=0, output=on_false; when flag=1, output=on_true. + /// (output, flag, on_false, on_true) + SelectWitness { + output: usize, + flag: usize, + on_false: usize, + on_true: usize, + }, + /// Boolean OR: output = a + b - a*b = 1 - (1-a)*(1-b). + /// (output, a, b) + BooleanOr { + output: usize, + a: usize, + b: usize, + }, + /// Unified prover hint for non-native EC operations (multi-limb). + /// + /// `op` selects the operation: + /// - `Double`: inputs = \[\[px\], \[py\]\], outputs 15N-6 witnesses + /// - `Add`: inputs = \[\[x1\], \[y1\], \[x2\], \[y2\]\], outputs 15N-6 + /// witnesses + /// - `OnCurve`: inputs = \[\[px\], \[py\]\], outputs 9N-4 witnesses + NonNativeEcHint { + output_start: usize, + op: NonNativeEcOp, + inputs: Vec>, + curve_a: [u64; 4], + curve_b: [u64; 4], + field_modulus_p: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, + /// Signed-bit decomposition hint for wNAF scalar multiplication. + /// Given scalar s with num_bits bits, computes sign-bits b_0..b_{n-1} + /// and skew ∈ {0,1} such that: + /// s + skew + (2^n - 1) = Σ b_i * 2^{i+1} + /// where d_i = 2*b_i - 1 ∈ {-1, +1}. + /// + /// Outputs (num_bits + 1) witnesses at output_start: + /// [0..num_bits) b_i sign bits + /// \[num_bits\] skew (0 if s is odd, 1 if s is even) + SignedBitHint { + output_start: usize, + scalar: usize, + num_bits: usize, + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -260,6 +456,17 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, + WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::SignedBitHint { num_bits, .. } => *num_bits + 1, + WitnessBuilder::EcDoubleHint { .. } => 3, + WitnessBuilder::EcAddHint { .. } => 3, + WitnessBuilder::NonNativeEcHint { op, num_limbs, .. } => match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (15 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (9 * *num_limbs - 4) as usize, + }, + WitnessBuilder::FakeGLVHint { .. } => 4, + WitnessBuilder::EcScalarMulHint { num_limbs, .. } => 2 * *num_limbs as usize, _ => 1, } diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index a4390623d..0b80faef9 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -33,6 +33,7 @@ whir.workspace = true # 3rd party anyhow.workspace = true +num-bigint.workspace = true postcard.workspace = true rand.workspace = true rayon.workspace = true diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs new file mode 100644 index 000000000..85c0d471f --- /dev/null +++ b/provekit/prover/src/bigint_mod.rs @@ -0,0 +1,1343 @@ +/// BigInteger modular arithmetic on [u64; 4] limbs (256-bit). +/// +/// These helpers compute modular inverse via Fermat's little theorem: +/// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and +/// square-and-multiply exponentiation. +use { + ark_ff::PrimeField, + num_bigint::{BigInt, Sign}, + provekit_common::FieldElement, +}; + +/// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → +/// 512-bit). +pub fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let product = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = product as u64; + carry = product >> 64; + } + result[i + 4] = carry as u64; + } + result +} + +/// Compare 8-limb value with 4-limb value (zero-extended to 8 limbs). +/// Returns Ordering::Greater if wide > narrow, etc. +#[cfg(test)] +fn cmp_wide_narrow(wide: &[u64; 8], narrow: &[u64; 4]) -> std::cmp::Ordering { + // Check high limbs of wide (must all be zero for equality/less) + for i in (4..8).rev() { + if wide[i] != 0 { + return std::cmp::Ordering::Greater; + } + } + // Compare the low 4 limbs + for i in (0..4).rev() { + match wide[i].cmp(&narrow[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Modular reduction of a 512-bit value by a 256-bit modulus. +/// Uses bit-by-bit long division. +fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { + // Find the highest set bit in wide + let mut highest_bit = 0; + for i in (0..8).rev() { + if wide[i] != 0 { + highest_bit = i * 64 + (64 - wide[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [0u64; 4]; + } + + // Bit-by-bit long division + // remainder starts at 0, we shift in bits from the dividend + let mut remainder = [0u64; 4]; + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let shift_carry = shift_left_one(&mut remainder); + + // Bring in the next bit from wide + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + let bit = (wide[limb_idx] >> bit_idx) & 1; + remainder[0] |= bit; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit modulus, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(modulus[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + } + } + + remainder +} + +/// Left-shift a 4-limb number by 1 bit. Returns the carry-out bit. +fn shift_left_one(a: &mut [u64; 4]) -> u64 { + let mut carry = 0u64; + for limb in a.iter_mut() { + let new_carry = *limb >> 63; + *limb = (*limb << 1) | carry; + carry = new_carry; + } + carry +} + +/// Compare two 4-limb numbers. +pub fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { + for i in (0..4).rev() { + match a[i].cmp(&b[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Subtract b from a in-place (a -= b). Assumes a >= b. +fn sub_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) { + let mut borrow = 0u64; + for i in 0..4 { + let (diff, borrow1) = a[i].overflowing_sub(b[i]); + let (diff2, borrow2) = diff.overflowing_sub(borrow); + a[i] = diff2; + borrow = (borrow1 as u64) + (borrow2 as u64); + } + debug_assert_eq!(borrow, 0, "subtraction underflow: a < b"); +} + +/// Modular multiplication: (a * b) mod m. +pub fn mul_mod(a: &[u64; 4], b: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + reduce_wide(&wide, m) +} + +/// Modular exponentiation: base^exp mod m using square-and-multiply. +pub fn mod_pow(base: &[u64; 4], exp: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + // Find highest set bit in exp + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // exp == 0 → result = 1 (for m > 1) + return [1, 0, 0, 0]; + } + + let mut result = [1u64, 0, 0, 0]; // 1 + for bit_pos in (0..highest_bit).rev() { + // Square + result = mul_mod(&result, &result, m); + // Multiply if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (exp[limb_idx] >> bit_idx) & 1 == 1 { + result = mul_mod(&result, base, m); + } + } + + result +} + +/// Integer division with remainder: dividend = quotient * divisor + remainder, +/// where 0 <= remainder < divisor. Uses bit-by-bit long division. +pub fn divmod(dividend: &[u64; 4], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + // Find the highest set bit in dividend + let mut highest_bit = 0; + for i in (0..4).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from dividend + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If remainder >= divisor, subtract and set quotient bit + if cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, divisor); + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + (quotient, remainder) +} + +/// Subtract a small u64 value from a 4-limb number. Assumes a >= small. +pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { + let mut result = *a; + let (diff, borrow) = result[0].overflowing_sub(small); + result[0] = diff; + if borrow { + for limb in result[1..].iter_mut() { + let (d, b) = limb.overflowing_sub(1); + *limb = d; + if !b { + break; + } + } + } + result +} + +/// Add two 4-limb (256-bit) numbers, returning a 5-limb result with carry. +pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + result[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + result[4] = carry; + result +} + +/// Add two 4-limb numbers in-place: a += b. Returns the carry-out. +pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + a[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + carry +} + +/// Subtract b from a in-place, returning true if a >= b (no underflow). +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns +/// false. +pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + a[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + borrow == 0 +} + +/// Returns true if val == 0. +pub fn is_zero(val: &[u64; 4]) -> bool { + val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 +} + +/// Compute the bit mask for a limb of the given width. +pub fn limb_mask(limb_bits: u32) -> u128 { + if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + } +} + +/// Right-shift a 4-limb (256-bit) value by `bits` positions. +pub fn shr_256(val: &[u64; 4], bits: u32) -> [u64; 4] { + if bits >= 256 { + return [0; 4]; + } + let mut shifted = [0u64; 4]; + let word_shift = (bits / 64) as usize; + let bit_shift = bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = val[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= val[i + word_shift + 1] << (64 - bit_shift); + } + } + } + shifted +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width. +/// Returns u128 limb values (each < 2^limb_bits). +pub fn decompose_to_u128_limbs(val: &[u64; 4], num_limbs: usize, limb_bits: u32) -> Vec { + let mask = limb_mask(limb_bits); + let mut limbs = Vec::with_capacity(num_limbs); + let mut remaining = *val; + for _ in 0..num_limbs { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & mask); + remaining = shr_256(&remaining, limb_bits); + } + limbs +} + +/// Convert u128 limbs to i128 limbs (for carry computation linear terms). +pub fn to_i128_limbs(limbs: &[u128]) -> Vec { + limbs.iter().map(|&v| v as i128).collect() +} + +/// Convert a `[u64; 8]` wide value to a `BigInt`. +fn wide_to_bigint(v: &[u64; 8]) -> BigInt { + let mut bytes = [0u8; 64]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a `[u64; 4]` to a `BigInt`. +fn u256_to_bigint(v: &[u64; 4]) -> BigInt { + let mut bytes = [0u8; 32]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a non-negative `BigInt` to `u128`. Panics if negative or too large. +fn bigint_to_u128(v: &BigInt) -> u128 { + assert!(v.sign() != Sign::Minus, "bigint_to_u128: negative value"); + let (_, bytes) = v.to_bytes_le(); + assert!(bytes.len() <= 16, "bigint_to_u128: value exceeds 128 bits"); + let mut buf = [0u8; 16]; + buf[..bytes.len()].copy_from_slice(&bytes); + u128::from_le_bytes(buf) +} + +/// Convert a non-negative `BigInt` to `[u64; 4]`. Panics if it doesn't fit. +fn bigint_to_u256(v: &BigInt) -> [u64; 4] { + assert!(v.sign() != Sign::Minus, "bigint_to_u256: negative value"); + let (_, bytes) = v.to_bytes_le(); + let mut result = [0u64; 4]; + for (i, chunk) in bytes.chunks(8).enumerate() { + if i >= 4 { + assert!( + chunk.iter().all(|&b| b == 0), + "bigint_to_u256: value exceeds 256 bits" + ); + break; + } + let mut buf = [0u8; 8]; + buf[..chunk.len()].copy_from_slice(chunk); + result[i] = u64::from_le_bytes(buf); + } + result +} + +/// Compute signed quotient q such that: +/// Σ lhs_products\[i\] * coeff_i + Σ lhs_linear\[j\] * coeff_j +/// - Σ rhs_products\[i\] * coeff_i - Σ rhs_linear\[j\] * coeff_j ≡ 0 (mod p) +/// +/// Returns (|q| limbs, is_negative) where q = (LHS - RHS) / p. +pub fn signed_quotient_wide( + lhs_products: &[(&[u64; 4], &[u64; 4], u64)], + rhs_products: &[(&[u64; 4], &[u64; 4], u64)], + lhs_linear: &[(&[u64; 4], u64)], + rhs_linear: &[(&[u64; 4], u64)], + p: &[u64; 4], + n: usize, + w: u32, +) -> (Vec, bool) { + fn accumulate_wide_products(terms: &[(&[u64; 4], &[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); + for &(a, b, coeff) in terms { + let prod = widening_mul(a, b); + acc += wide_to_bigint(&prod) * BigInt::from(coeff); + } + acc + } + + fn accumulate_wide_linear(terms: &[(&[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); + for &(val, coeff) in terms { + acc += u256_to_bigint(val) * BigInt::from(coeff); + } + acc + } + + let lhs = accumulate_wide_products(lhs_products) + accumulate_wide_linear(lhs_linear); + let rhs = accumulate_wide_products(rhs_products) + accumulate_wide_linear(rhs_linear); + + let diff = lhs - rhs; + let p_big = u256_to_bigint(p); + + let q_big = &diff / &p_big; + let rem = &diff - &q_big * &p_big; + debug_assert_eq!( + rem, + BigInt::from(0), + "signed_quotient_wide: non-zero remainder" + ); + + let is_neg = q_big.sign() == Sign::Minus; + let q_abs_big = if is_neg { -&q_big } else { q_big }; + + // Decompose directly from BigInt into u128 limbs at `w` bits each, + // since the quotient may exceed 256 bits. + let limb_mask = (BigInt::from(1u64) << w) - 1; + let mut limbs = Vec::with_capacity(n); + let mut remaining = q_abs_big; + for _ in 0..n { + let limb_val = &remaining & &limb_mask; + limbs.push(bigint_to_u128(&limb_val)); + remaining >>= w; + } + debug_assert_eq!( + remaining, + BigInt::from(0), + "quotient doesn't fit in {n} limbs at {w} bits" + ); + + (limbs, is_neg) +} + +/// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` +/// boundaries. +pub fn reconstruct_from_u128_limbs(limb_values: &[u128], limb_bits: u32) -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_u128 in limb_values.iter() { + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += limb_bits; + } + val +} + +/// Compute schoolbook carries for a*b = p*q + r verification in base +/// 2^limb_bits. Returns unsigned-offset carries ready to be written as +/// witnesses. +pub fn compute_mul_mod_carries( + a_limbs: &[u128], + b_limbs: &[u128], + p_limbs: &[u128], + q_limbs: &[u128], + r_limbs: &[u128], + limb_bits: u32, +) -> Vec { + let n = a_limbs.len(); + let w = limb_bits; + let num_carries = 2 * n - 2; + let carry_offset = BigInt::from(1u64) << (w + ((n as f64).log2().ceil() as u32) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut carry = BigInt::from(0); + + for k in 0..(2 * n - 1) { + let mut col_value = BigInt::from(0); + + // a*b products + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += BigInt::from(a_limbs[i]) * BigInt::from(b_limbs[j as usize]); + } + } + + // Subtract p*q + r + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_limbs[j as usize]); + } + } + if k < n { + col_value -= BigInt::from(r_limbs[k]); + } + + col_value += &carry; + + if k < 2 * n - 2 { + let mask = (BigInt::from(1u64) << w) - 1; + debug_assert_eq!( + &col_value & &mask, + BigInt::from(0), + "non-zero remainder at column {k}" + ); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); + } + } + + carries +} + +/// Compute the number of bits needed for the half-GCD sub-scalars. +/// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. +pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { + let mut order_bits = 0u32; + for i in (0..4).rev() { + if n[i] != 0 { + order_bits = (i as u32) * 64 + (64 - n[i].leading_zeros()); + break; + } + } + (order_bits + 1) / 2 +} + +/// Build the threshold value `2^half_bits` as a `[u64; 4]`. +fn build_threshold(half_bits: u32) -> [u64; 4] { + assert!(half_bits <= 255, "half_bits must be <= 255"); + let mut threshold = [0u64; 4]; + let word = (half_bits / 64) as usize; + let bit = half_bits % 64; + threshold[word] = 1u64 << bit; + threshold +} + +/// Half-GCD scalar decomposition for FakeGLV. +/// +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such +/// that: `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// +/// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below +/// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. +/// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. +pub fn half_gcd(s: &[u64; 4], n: &[u64; 4]) -> ([u64; 4], [u64; 4], bool, bool) { + // Extended GCD on (n, s): + // We track: r_{i} = r_{i-2} - q_i * r_{i-1} + // t_{i} = t_{i-2} - q_i * t_{i-1} + // Starting: r_0 = n, r_1 = s, t_0 = 0, t_1 = 1 + // + // We want: t_i * s ≡ r_i (mod n) [up to sign] + // More precisely: t_i * s ≡ (-1)^{i+1} * r_i (mod n) + // + // The relation we verify is: sign_r * |r_i| + sign_t * |t_i| * s ≡ 0 (mod n) + + // Threshold: 2^half_bits where half_bits = ceil(order_bits / 2) + let half_bits = half_gcd_bits(n); + let threshold = build_threshold(half_bits); + + // r_prev = n, r_curr = s + let mut r_prev = *n; + let mut r_curr = *s; + + // t_prev = 0, t_curr = 1 + let mut t_prev = [0u64; 4]; + let mut t_curr = [1u64, 0, 0, 0]; + + // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, + // positive) + let mut t_prev_neg = false; + let mut t_curr_neg = false; + + let mut iteration = 0u32; + + loop { + // Check if r_curr < threshold + if cmp_4limb(&r_curr, &threshold) == std::cmp::Ordering::Less { + break; + } + + if is_zero(&r_curr) { + break; + } + + // q = r_prev / r_curr, new_r = r_prev % r_curr + let (q, new_r) = divmod(&r_prev, &r_curr); + + // new_t = t_prev + q * t_curr (in terms of absolute values and signs) + // Since the GCD recurrence is: t_{i} = t_{i-2} - q_i * t_{i-1} + // In terms of absolute values with sign tracking: + // If t_prev and q*t_curr have the same sign → subtract magnitudes + // If they have different signs → add magnitudes + // new_t = |t_prev| +/- q * |t_curr|, with sign flips each + // iteration. + // + // The standard extended GCD recurrence gives: + // t_i = t_{i-2} - q_i * t_{i-1} + // We track magnitudes and sign bits separately. + + // Compute q * t_curr + let qt = mul_mod_no_reduce(&q, &t_curr); + + // new_t magnitude and sign: + // In the standard recurrence: new_t_val = t_prev_val - q * t_curr_val + // where t_prev_val = (-1)^t_prev_neg * |t_prev|, etc. + // + // But it's simpler to just track: alternating signs. + // In the half-GCD: t values alternate in sign. So: + // new_t = t_prev + q * t_curr (absolute addition since signs alternate) + let mut new_t = qt; + add_4limb_inplace(&mut new_t, &t_prev); + let new_t_neg = !t_curr_neg; + + r_prev = r_curr; + r_curr = new_r; + t_prev = t_curr; + t_prev_neg = t_curr_neg; + t_curr = new_t; + t_curr_neg = new_t_neg; + iteration += 1; + } + + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD + // property). + // + // From the extended GCD identity: t_i * s ≡ r_i (mod n) + // Rearranging: -r_i + t_i * s ≡ 0 (mod n) + // + // The circuit checks: (-1)^neg1 * |r_i| + (-1)^neg2 * |t_i| * s ≡ 0 (mod n) + // Since r_i is always non-negative, neg1 must always be true (negate r_i). + // neg2 must match the actual sign of t_i so that (-1)^neg2 * |t_i| = t_i. + + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| + + let neg1 = true; // always negate r_i: -r_i + t_i * s ≡ 0 (mod n) + let neg2 = t_curr_neg; + + (val1, val2, neg1, neg2) +} + +/// Multiply two 4-limb values without modular reduction. +/// Returns the lower 4 limbs (ignoring overflow beyond 256 bits). +/// Used internally by half_gcd for q * t_curr where the result is known to fit. +fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + [wide[0], wide[1], wide[2], wide[3]] +} + +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +/// Convert a `[u64; 4]` bigint to a `FieldElement`. +pub fn bigint_to_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*val)).unwrap() +} + +/// Read a `FieldElement` witness as a `[u64; 4]` bigint. +pub fn fe_to_bigint(fe: FieldElement) -> [u64; 4] { + fe.into_bigint().0 +} + +/// Reconstruct a 256-bit scalar from two 128-bit halves: `scalar = lo + hi * +/// 2^128`. +pub fn reconstruct_from_halves(lo: &[u64; 4], hi: &[u64; 4]) -> [u64; 4] { + [lo[0], lo[1], hi[0], hi[1]] +} + +// --------------------------------------------------------------------------- +// Modular arithmetic helpers for EC operations (prover-side) +// --------------------------------------------------------------------------- + +/// Modular addition: (a + b) mod p. +pub fn mod_add(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let sum = add_4limb(a, b); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if sum[4] > 0 || cmp_4limb(&sum4, p) != std::cmp::Ordering::Less { + // sum >= p, subtract p (borrow absorbs carry bit if sum[4] > 0) + let mut result = sum4; + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = result[i].overflowing_sub(p[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + result[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + result + } else { + sum4 + } +} + +/// Modular subtraction: (a - b) mod p. +pub fn mod_sub(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let mut result = *a; + let no_borrow = sub_4limb_checked(&mut result, b); + if no_borrow { + result + } else { + // a < b, add p to get (a - b + p) + add_4limb_inplace(&mut result, p); + result + } +} + +/// Modular inverse: a^{p-2} mod p (Fermat's little theorem). +pub fn mod_inverse(a: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let exp = sub_u64(p, 2); + mod_pow(a, &exp, p) +} + +/// EC point doubling with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcDoubleHint` prover which needs lambda as a witness. +pub fn ec_point_double_with_lambda( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let x_sq = mul_mod(px, px, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let numerator = mod_add(&three_x_sq, a, p); + let two_y = mod_add(py, py, p); + let denom_inv = mod_inverse(&two_y, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + let lambda_sq = mul_mod(&lambda, &lambda, p); + let two_x = mod_add(px, px, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + let x_minus_x3 = mod_sub(px, &x3, p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, py, p); + + (lambda, x3, y3) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_double_with_lambda(px, py, a, p); + (x3, y3) +} + +/// EC point addition with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcAddHint` prover which needs lambda as a witness. +pub fn ec_point_add_with_lambda( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4], [u64; 4]) { + let numerator = mod_sub(p2y, p1y, p); + let denominator = mod_sub(p2x, p1x, p); + let denom_inv = mod_inverse(&denominator, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + let lambda_sq = mul_mod(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(p1x, p2x, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + let x1_minus_x3 = mod_sub(p1x, &x3, p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, p1y, p); + + (lambda, x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_add_with_lambda(p1x, p1y, p2x, p2y, p); + (x3, y3) +} + +/// EC scalar multiplication via double-and-add: returns \[scalar\]*P. +pub fn ec_scalar_mul( + px: &[u64; 4], + py: &[u64; 4], + scalar: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // Find highest set bit in scalar + let mut highest_bit = 0; + for i in (0..4).rev() { + if scalar[i] != 0 { + highest_bit = i * 64 + (64 - scalar[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // scalar == 0 → point at infinity (not representable in affine) + panic!("ec_scalar_mul: scalar is zero"); + } + + // Start from the MSB-1 and double-and-add + let mut rx = *px; + let mut ry = *py; + + for bit_pos in (0..highest_bit - 1).rev() { + // Double + let (dx, dy) = ec_point_double(&rx, &ry, a, p); + rx = dx; + ry = dy; + + // Add if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (scalar[limb_idx] >> bit_idx) & 1 == 1 { + let (ax, ay) = ec_point_add(&rx, &ry, px, py, p); + rx = ax; + ry = ay; + } + } + + (rx, ry) +} + +/// Integer division of a 512-bit dividend by a 256-bit divisor. +/// Returns (quotient, remainder) where both fit in 256 bits. +/// Panics if the quotient would exceed 256 bits. +pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let mut highest_bit = 0; + for i in (0..8).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + let shift_carry = shift_left_one(&mut remainder); + + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit divisor, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + // Subtract divisor with inline borrow tracking (handles the case + // where remainder < divisor but shift_carry provides the extra bit). + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(divisor[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + // When shift_carry was set, the borrow absorbs it (they cancel out). + debug_assert_eq!( + borrow, shift_carry, + "unexpected borrow in divmod_wide at bit_pos {}", + bit_pos + ); + + assert!(bit_pos < 256, "quotient exceeds 256 bits"); + quotient[bit_pos / 64] |= 1u64 << (bit_pos % 64); + } + } + + (quotient, remainder) +} + +/// Compute unsigned-offset carries for a general merged column equation. +/// +/// Each `product_set` entry is (a_limbs, b_limbs, coefficient): +/// LHS_terms = Σ coeff * Σ_{i+j=k} a\[i\]*b\[j\] +/// +/// Each `linear_set` entry is (limb_values, coefficient) for non-product terms: +/// LHS_terms += Σ coeff * val\[k\] (for k < val.len()) +/// +/// The equation verified is: +/// LHS + Σ p\[i\]*q_neg\[j\] = RHS + Σ p\[i\]*q_pos\[j\] + carry_chain +/// +/// `q_pos_limbs` and `q_neg_limbs` are both non-negative; at most one is +/// non-zero. +pub fn compute_ec_verification_carries( + product_sets: &[(&[u128], &[u128], i64)], + linear_terms: &[(Vec, i64)], // (limb_values extended to 2N-1, coefficient) + p_limbs: &[u128], + q_pos_limbs: &[u128], + q_neg_limbs: &[u128], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) -> Vec { + let w = limb_bits; + let num_columns = 2 * n - 1; + let num_carries = num_columns - 1; + + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = w + extra_bits; + let carry_offset = BigInt::from(1u64) << carry_offset_bits; + + let mut carries = Vec::with_capacity(num_carries); + let mut carry = BigInt::from(0); + + for k in 0..num_columns { + let mut col_value = BigInt::from(0); + + // Product terms + for &(a, b, coeff) in product_sets { + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += + BigInt::from(coeff) * BigInt::from(a[i]) * BigInt::from(b[j as usize]); + } + } + } + + // Linear terms + for (vals, coeff) in linear_terms { + if k < vals.len() { + col_value += BigInt::from(*coeff) * BigInt::from(vals[k]); + } + } + + // p*q_neg on positive side, p*q_pos on negative side + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += BigInt::from(p_limbs[i]) * BigInt::from(q_neg_limbs[j as usize]); + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_pos_limbs[j as usize]); + } + } + + col_value += &carry; + + if k < num_carries { + let mask = (BigInt::from(1u64) << w) - 1; + debug_assert_eq!( + &col_value & &mask, + BigInt::from(0), + "non-zero remainder at column {k}: col_value={col_value}" + ); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); + } else { + debug_assert_eq!( + col_value, + BigInt::from(0), + "non-zero final column value: {col_value}" + ); + } + } + + carries +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_widening_mul_small() { + // 3 * 7 = 21 + let a = [3, 0, 0, 0]; + let b = [7, 0, 0, 0]; + let result = widening_mul(&a, &b); + assert_eq!(result[0], 21); + assert_eq!(result[1..], [0; 7]); + } + + #[test] + fn test_widening_mul_overflow() { + // u64::MAX * u64::MAX = (2^64-1)^2 = 2^128 - 2^65 + 1 + let a = [u64::MAX, 0, 0, 0]; + let b = [u64::MAX, 0, 0, 0]; + let result = widening_mul(&a, &b); + // (2^64-1)^2 = 0xFFFFFFFFFFFFFFFE_0000000000000001 + assert_eq!(result[0], 1); + assert_eq!(result[1], u64::MAX - 1); + assert_eq!(result[2..], [0; 6]); + } + + #[test] + fn test_reduce_wide_no_reduction() { + // 5 mod 7 = 5 + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [5, 0, 0, 0]); + } + + #[test] + fn test_reduce_wide_basic() { + // 10 mod 7 = 3 + let wide = [10, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [3, 0, 0, 0]); + } + + #[test] + fn test_mul_mod_small() { + // (5 * 3) mod 7 = 15 mod 7 = 1 + let a = [5, 0, 0, 0]; + let b = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mul_mod(&a, &b, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_small() { + // 3^4 mod 7 = 81 mod 7 = 4 + let base = [3, 0, 0, 0]; + let exp = [4, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [4, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_small() { + // Inverse of 3 mod 7: 3^{7-2} = 3^5 mod 7 = 243 mod 7 = 5 + // Check: 3 * 5 = 15 = 2*7 + 1 ≡ 1 (mod 7) ✓ + let a = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + let exp = sub_u64(&m, 2); // m - 2 = 5 + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [5, 0, 0, 0]); + // Verify: a * inv mod m = 1 + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_prime_23() { + // Inverse of 5 mod 23: 5^{21} mod 23 + // 5^{-1} mod 23 = 14 (because 5*14 = 70 = 3*23 + 1) + let a = [5, 0, 0, 0]; + let m = [23, 0, 0, 0]; + let exp = sub_u64(&m, 2); + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [14, 0, 0, 0]); + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_basic() { + assert_eq!(sub_u64(&[10, 0, 0, 0], 3), [7, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_borrow() { + // [0, 1, 0, 0] = 2^64; subtract 1 → [u64::MAX, 0, 0, 0] + assert_eq!(sub_u64(&[0, 1, 0, 0], 1), [u64::MAX, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_large_prime() { + // Use a 128-bit prime: p = 2^127 - 1 = 170141183460469231731687303715884105727 + // In limbs: [u64::MAX, 2^63 - 1, 0, 0] + let p = [u64::MAX, (1u64 << 63) - 1, 0, 0]; + + // a = 42 + let a = [42, 0, 0, 0]; + let exp = sub_u64(&p, 2); + let inv = mod_pow(&a, &exp, &p); + + // Verify: a * inv mod p = 1 + assert_eq!(mul_mod(&a, &inv, &p), [1, 0, 0, 0]); + } + + #[test] + fn test_cmp_wide_narrow() { + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let narrow = [5, 0, 0, 0]; + assert_eq!(cmp_wide_narrow(&wide, &narrow), std::cmp::Ordering::Equal); + + let wide_greater = [0, 0, 0, 0, 1, 0, 0, 0]; + assert_eq!( + cmp_wide_narrow(&wide_greater, &narrow), + std::cmp::Ordering::Greater + ); + } + + #[test] + fn test_mod_pow_zero_exp() { + // a^0 mod m = 1 + let base = [42, 0, 0, 0]; + let exp = [0, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_one_exp() { + // a^1 mod m = a mod m + let base = [10, 0, 0, 0]; + let exp = [1, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_exact() { + // 21 / 7 = 3 remainder 0 + let (q, r) = divmod(&[21, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_with_remainder() { + // 17 / 7 = 2 remainder 3 + let (q, r) = divmod(&[17, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [2, 0, 0, 0]); + assert_eq!(r, [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_smaller_dividend() { + // 5 / 7 = 0 remainder 5 + let (q, r) = divmod(&[5, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_zero_dividend() { + let (q, r) = divmod(&[0, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_large() { + // 2^64 / 3 = 6148914691236517205 remainder 1 + // 2^64 in limbs: [0, 1, 0, 0] + let (q, r) = divmod(&[0, 1, 0, 0], &[3, 0, 0, 0]); + assert_eq!(q, [6148914691236517205, 0, 0, 0]); + assert_eq!(r, [1, 0, 0, 0]); + // Verify: q * 3 + 1 = 2^64 + assert_eq!(6148914691236517205u64.wrapping_mul(3).wrapping_add(1), 0u64); + // wraps to 0 in u64 = 2^64 + } + + #[test] + fn test_divmod_consistency() { + // Verify dividend = quotient * divisor + remainder for various inputs + let cases: Vec<([u64; 4], [u64; 4])> = vec![ + ([100, 0, 0, 0], [7, 0, 0, 0]), + ([u64::MAX, 0, 0, 0], [1000, 0, 0, 0]), + ([0, 1, 0, 0], [u64::MAX, 0, 0, 0]), // 2^64 / (2^64 - 1) + ]; + for (dividend, divisor) in cases { + let (q, r) = divmod(÷nd, &divisor); + // Verify: q * divisor + r = dividend + let product = widening_mul(&q, &divisor); + // Add remainder to product + let mut sum = product; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // sum should equal dividend (zero-extended to 8 limbs) + let mut expected = [0u64; 8]; + expected[..4].copy_from_slice(÷nd); + assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); + } + } + + #[test] + fn test_divmod_wide_small() { + // 21 / 7 = 3 remainder 0 (512-bit dividend) + let dividend = [21, 0, 0, 0, 0, 0, 0, 0]; + let divisor = [7, 0, 0, 0]; + let (q, r) = divmod_wide(÷nd, &divisor); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_large() { + // Compute a * b where a, b are 256-bit, then divide by a + // Should give quotient = b, remainder = 0 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; // secp256r1 p + let b = [42, 0, 0, 0]; + let product = widening_mul(&a, &b); + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_with_remainder() { + // (a * b + 5) / a = b remainder 5 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let b = [100, 0, 0, 0]; + let mut product = widening_mul(&a, &b); + // Add 5 + let (sum, overflow) = product[0].overflowing_add(5); + product[0] = sum; + if overflow { + for i in 1..8 { + let (s, o) = product[i].overflowing_add(1); + product[i] = s; + if !o { + break; + } + } + } + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_consistency() { + // Verify: q * divisor + r = dividend + let a = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let b = [0xaabbccdd, 0x11223344, 0x55667788, 0x99001122]; + let product = widening_mul(&a, &b); + let divisor = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (q, r) = divmod_wide(&product, &divisor); + + // Verify: q * divisor + r = product + let qd = widening_mul(&q, &divisor); + let mut sum = qd; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + } + + #[test] + fn test_half_gcd_small() { + // s = 42, n = 101 + let s = [42, 0, 0, 0]; + let n = [101, 0, 0, 0]; + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + let sign1: i128 = if neg1 { -1 } else { 1 }; + let sign2: i128 = if neg2 { -1 } else { 1 }; + let v1 = val1[0] as i128; + let v2 = val2[0] as i128; + let s_val = s[0] as i128; + let n_val = n[0] as i128; + let lhs = ((sign1 * v1 + sign2 * v2 * s_val) % n_val + n_val) % n_val; + assert_eq!(lhs, 0, "half_gcd relation failed for small values"); + } + + #[test] + fn test_half_gcd_grumpkin_order() { + // Grumpkin curve order (BN254 base field order) + let n = [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ]; + // Some scalar + let s = [ + 0x123456789abcdef0_u64, + 0xfedcba9876543210_u64, + 0x1111111111111111_u64, + 0x2222222222222222_u64, + ]; + + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // val1 and val2 should be < 2^128 + assert_eq!(val1[2], 0, "val1 should be < 2^128"); + assert_eq!(val1[3], 0, "val1 should be < 2^128"); + assert_eq!(val2[2], 0, "val2 should be < 2^128"); + assert_eq!(val2[3], 0, "val2 should be < 2^128"); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + // Use big integer arithmetic + let term2_full = widening_mul(&val2, &s); + let (_, term2_mod_n) = divmod_wide(&term2_full, &n); + + // Compute: sign1 * val1 + sign2 * term2_mod_n (mod n) + let effective1 = if neg1 { + // n - val1 + let mut result = n; + sub_4limb_checked(&mut result, &val1); + result + } else { + val1 + }; + let effective2 = if neg2 { + let mut result = n; + sub_4limb_checked(&mut result, &term2_mod_n); + result + } else { + term2_mod_n + }; + + let sum = add_4limb(&effective1, &effective2); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + // sum might be >= n, so reduce + let (_, remainder) = if sum[4] > 0 { + // Sum overflows 256 bits, need wide divmod + let wide = [sum[0], sum[1], sum[2], sum[3], sum[4], 0, 0, 0]; + divmod_wide(&wide, &n) + } else { + divmod(&sum4, &n) + }; + assert_eq!( + remainder, + [0, 0, 0, 0], + "half_gcd relation failed for Grumpkin order" + ); + } +} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 85586ac1f..de84a360c 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -21,8 +21,9 @@ use { whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, }; +pub mod bigint_mod; pub mod input_utils; -mod r1cs; +pub mod r1cs; mod whir_r1cs; mod witness; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index db91e5e0a..a895b1c80 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,13 +1,22 @@ use { - crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + crate::{ + bigint_mod::{ + add_4limb, bigint_to_fe, cmp_4limb, compute_ec_verification_carries, + compute_mul_mod_carries, decompose_to_u128_limbs, divmod, divmod_wide, + ec_point_add_with_lambda, ec_point_double_with_lambda, ec_scalar_mul, fe_to_bigint, + half_gcd, mod_pow, mul_mod, reconstruct_from_halves, reconstruct_from_u128_limbs, + signed_quotient_wide, sub_u64, to_i128_limbs, widening_mul, + }, + witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + }, acir::native_types::WitnessMap, - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, provekit_common::{ utils::noir_to_native, witness::{ - compute_spread, ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, - WitnessBuilder, WitnessCoefficient, + compute_spread, ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, + SumTerm, WitnessBuilder, WitnessCoefficient, }, FieldElement, NoirElement, TranscriptSponge, }, @@ -23,6 +32,42 @@ pub trait WitnessBuilderSolver { ); } +/// Resolve a ConstantOrR1CSWitness to its FieldElement value. +fn resolve(witness: &[Option], v: &ConstantOrR1CSWitness) -> FieldElement { + match v { + ConstantOrR1CSWitness::Constant(c) => *c, + ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), + } +} + +/// Convert a u128 value to a FieldElement. +fn u128_to_fe(val: u128) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([val as u64, (val >> 64) as u64, 0, 0])).unwrap() +} + +/// Read witness limbs and reconstruct as [u64; 4]. +fn read_witness_limbs( + witness: &[Option], + indices: &[usize], + limb_bits: u32, +) -> [u64; 4] { + let limb_values: Vec = indices + .iter() + .map(|&idx| { + let bigint = witness[idx].unwrap().into_bigint().0; + bigint[0] as u128 | ((bigint[1] as u128) << 64) + }) + .collect(); + reconstruct_from_u128_limbs(&limb_values, limb_bits) +} + +/// Write u128 limb values as FieldElement witnesses starting at `start`. +fn write_limbs(witness: &mut [Option], start: usize, vals: &[u128]) { + for (i, &val) in vals.iter().enumerate() { + witness[start + i] = Some(u128_to_fe(val)); + } +} + impl WitnessBuilderSolver for WitnessBuilder { fn solve( &self, @@ -65,6 +110,26 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::SafeInverse(witness_idx, operand_idx) => { + let val = witness[*operand_idx].unwrap(); + witness[*witness_idx] = Some(if val == FieldElement::zero() { + FieldElement::zero() + } else { + val.inverse().unwrap() + }); + } + WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { + let a_limbs = fe_to_bigint(witness[*operand_idx].unwrap()); + let m_limbs = modulus.into_bigint().0; + let exp = sub_u64(&m_limbs, 2); + witness[*witness_idx] = Some(bigint_to_fe(&mod_pow(&a_limbs, &exp, &m_limbs))); + } + WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { + let d_limbs = fe_to_bigint(witness[*dividend_idx].unwrap()); + let m_limbs = divisor.into_bigint().0; + let (quotient, _) = divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = Some(bigint_to_fe("ient)); + } WitnessBuilder::IndexedLogUpDenominator( witness_idx, sz_challenge, @@ -145,18 +210,9 @@ impl WitnessBuilderSolver for WitnessBuilder { rhs, output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let output = match output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let output = resolve(witness, output); witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() - (lhs @@ -175,22 +231,10 @@ impl WitnessBuilderSolver for WitnessBuilder { and_output, xor_output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let and_out = match and_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let xor_out = match xor_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let and_out = resolve(witness, and_output); + let xor_out = resolve(witness, xor_output); // Encoding: sz - (lhs + rs*rhs + rs²*and_out + rs³*xor_out) witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() @@ -203,18 +247,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::MultiplicitiesForBinOp(witness_idx, atomic_bits, operands) => { let mut multiplicities = vec![0u32; 2usize.pow(2 * *atomic_bits)]; for (lhs, rhs) in operands { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); let index = (lhs.into_bigint().0[0] << *atomic_bits) + rhs.into_bigint().0[0]; multiplicities[index as usize] += 1; } @@ -223,14 +257,8 @@ impl WitnessBuilderSolver for WitnessBuilder { } } WitnessBuilder::U32Addition(result_witness_idx, carry_witness_idx, a, b) => { - let a_val = match a { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; - let b_val = match b { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; + let a_val = resolve(witness, a); + let b_val = resolve(witness, b); assert!( a_val.into_bigint().num_bits() <= 32, "a_val must be less than or equal to 32 bits, got {}", @@ -258,12 +286,7 @@ impl WitnessBuilderSolver for WitnessBuilder { // Sum all inputs as u64 to handle overflow. let mut sum: u64 = 0; for input in inputs { - let val = match input { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(idx) => { - witness[*idx].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, input).into_bigint().0[0]; assert!(val < (1u64 << 32), "input must be 32-bit"); sum += val; } @@ -274,14 +297,8 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*carry_witness_idx] = Some(FieldElement::from(quotient)); } WitnessBuilder::And(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -297,14 +314,8 @@ impl WitnessBuilderSolver for WitnessBuilder { )); } WitnessBuilder::Xor(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -319,6 +330,101 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } + WitnessBuilder::MultiLimbMulModHint { + output_start, + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, modulus); + + let q_limbs_vals = decompose_to_u128_limbs(&q_val, n, w); + let r_limbs_vals = decompose_to_u128_limbs(&r_val, n, w); + + let carries = compute_mul_mod_carries( + &decompose_to_u128_limbs(&a_val, n, w), + &decompose_to_u128_limbs(&b_val, n, w), + &decompose_to_u128_limbs(modulus, n, w), + &q_limbs_vals, + &r_limbs_vals, + w, + ); + + write_limbs(witness, *output_start, &q_limbs_vals); + write_limbs(witness, *output_start + n, &r_limbs_vals); + write_limbs(witness, *output_start + 2 * n, &carries); + } + WitnessBuilder::MultiLimbModularInverse { + output_start, + a_limbs, + modulus, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let exp = sub_u64(modulus, 2); + let inv = mod_pow(&a_val, &exp, modulus); + write_limbs(witness, *output_start, &decompose_to_u128_limbs(&inv, n, w)); + } + WitnessBuilder::MultiLimbAddQuotient { + output, + a_limbs, + b_limbs, + modulus, + limb_bits, + .. + } => { + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let sum = add_4limb(&a_val, &b_val); + let q = if sum[4] > 0 { + 1u64 + } else { + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if cmp_4limb(&sum4, modulus) != std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + } + }; + + witness[*output] = Some(FieldElement::from(q)); + } + WitnessBuilder::MultiLimbSubBorrow { + output, + a_limbs, + b_limbs, + limb_bits, + .. + } => { + let w = *limb_bits; + + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); + + let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + }; + + witness[*output] = Some(FieldElement::from(q)); + } WitnessBuilder::BytePartition { lo, hi, x, k } => { let x_val = witness[*x].unwrap().into_bigint().0[0]; debug_assert!(x_val < 256, "BytePartition input must be 8-bit"); @@ -330,6 +436,472 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*lo] = Some(FieldElement::from(lo_val)); witness[*hi] = Some(FieldElement::from(hi_val)); } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => { + let s_val = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), + ); + + let (val1, val2, neg1, neg2) = half_gcd(&s_val, curve_order); + + witness[*output_start] = Some(bigint_to_fe(&val1)); + witness[*output_start + 1] = Some(bigint_to_fe(&val2)); + witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); + } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => { + let px_val = fe_to_bigint(witness[*px].unwrap()); + let py_val = fe_to_bigint(witness[*py].unwrap()); + + let (lambda, x3, y3) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); + + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); + } + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => { + let x1_val = fe_to_bigint(witness[*x1].unwrap()); + let y1_val = fe_to_bigint(witness[*y1].unwrap()); + let x2_val = fe_to_bigint(witness[*x2].unwrap()); + let y2_val = fe_to_bigint(witness[*y2].unwrap()); + + let (lambda, x3, y3) = + ec_point_add_with_lambda(&x1_val, &y1_val, &x2_val, &y2_val, field_modulus_p); + + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); + } + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + let os = *output_start; + + let p_l = decompose_to_u128_limbs(field_modulus_p, n, w); + + // Helper to split signed quotient into (q_pos, q_neg). + fn split_quotient( + q_abs: Vec, + is_neg: bool, + n: usize, + ) -> (Vec, Vec) { + if is_neg { + (vec![0u128; n], q_abs) + } else { + (q_abs, vec![0u128; n]) + } + } + + match op { + NonNativeEcOp::Double => { + let px_val = read_witness_limbs(witness, &inputs[0], w); + let py_val = read_witness_limbs(witness, &inputs[1], w); + let (lam, x3v, y3v) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + let a_l = decompose_to_u128_limbs(curve_a, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Per-equation max_coeff_sum must match compiler + let mcs_eq1 = 6 + 2 * n as u64; // 2+3+1+2n + let mcs_eq2 = 4 + 2 * n as u64; // 1+1+2+2n + let mcs_eq3 = 4 + 2 * n as u64; // 1+1+1+1+2n + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + + // Eq1: 2*λ*py - 3*px² - a = q1*p + let (q1_abs, q1_neg) = signed_quotient_wide( + &[(&lam, &py_val, 2)], + &[(&px_val, &px_val, 3)], + &[], + &[(curve_a, 1)], + field_modulus_p, + n, + w, + ); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + 3 * n, &q1_pos); + write_limbs(witness, os + 4 * n, &q1_neg); + let c1 = compute_ec_verification_carries( + &[(&ll, &pyl, 2), (&pl, &pl, -3)], + &[(to_i128_limbs(&a_l), -1)], + &p_l, + &q1_pos, + &q1_neg, + n, + w, + mcs_eq1, + ); + write_limbs(witness, os + 5 * n, &c1); + + // Eq2: λ² - x3 - 2*px = q2*p + let (q2_abs, q2_neg) = signed_quotient_wide( + &[(&lam, &lam, 1)], + &[], + &[], + &[(&x3v, 1), (&px_val, 2)], + field_modulus_p, + n, + w, + ); + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 7 * n - 2, &q2_pos); + write_limbs(witness, os + 8 * n - 2, &q2_neg); + let c2 = compute_ec_verification_carries( + &[(&ll, &ll, 1)], + &[(to_i128_limbs(&xl), -1), (to_i128_limbs(&pl), -2)], + &p_l, + &q2_pos, + &q2_neg, + n, + w, + mcs_eq2, + ); + write_limbs(witness, os + 9 * n - 2, &c2); + + // Eq3: λ*px - λ*x3 - y3 - py = q3*p + let (q3_abs, q3_neg) = signed_quotient_wide( + &[(&lam, &px_val, 1)], + &[(&lam, &x3v, 1)], + &[], + &[(&y3v, 1), (&py_val, 1)], + field_modulus_p, + n, + w, + ); + let (q3_pos, q3_neg) = split_quotient(q3_abs, q3_neg, n); + write_limbs(witness, os + 11 * n - 4, &q3_pos); + write_limbs(witness, os + 12 * n - 4, &q3_neg); + let c3 = compute_ec_verification_carries( + &[(&ll, &pl, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&pyl), -1)], + &p_l, + &q3_pos, + &q3_neg, + n, + w, + mcs_eq3, + ); + write_limbs(witness, os + 13 * n - 4, &c3); + } + NonNativeEcOp::Add => { + let x1v = read_witness_limbs(witness, &inputs[0], w); + let y1v = read_witness_limbs(witness, &inputs[1], w); + let x2v = read_witness_limbs(witness, &inputs[2], w); + let y2v = read_witness_limbs(witness, &inputs[3], w); + let (lam, x3v, y3v) = + ec_point_add_with_lambda(&x1v, &y1v, &x2v, &y2v, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let x1l = decompose_to_u128_limbs(&x1v, n, w); + let y1l = decompose_to_u128_limbs(&y1v, n, w); + let x2l = decompose_to_u128_limbs(&x2v, n, w); + let y2l = decompose_to_u128_limbs(&y2v, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Must match compiler's max_coeff_sum: 1+1+1+1 + 2*n + let mcs = 4 + 2 * n as u64; + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + + // Eq1: λ*x2 - λ*x1 + y1 - y2 = q1*p + let (q1_abs, q1_neg) = signed_quotient_wide( + &[(&lam, &x2v, 1)], + &[(&lam, &x1v, 1)], + &[(&y1v, 1)], + &[(&y2v, 1)], + field_modulus_p, + n, + w, + ); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + 3 * n, &q1_pos); + write_limbs(witness, os + 4 * n, &q1_neg); + let c1 = compute_ec_verification_carries( + &[(&ll, &x2l, 1), (&ll, &x1l, -1)], + &[(to_i128_limbs(&y2l), -1), (to_i128_limbs(&y1l), 1)], + &p_l, + &q1_pos, + &q1_neg, + n, + w, + mcs, + ); + write_limbs(witness, os + 5 * n, &c1); + + // Eq2: λ² - x3 - x1 - x2 = q2*p + let (q2_abs, q2_neg) = signed_quotient_wide( + &[(&lam, &lam, 1)], + &[], + &[], + &[(&x3v, 1), (&x1v, 1), (&x2v, 1)], + field_modulus_p, + n, + w, + ); + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 7 * n - 2, &q2_pos); + write_limbs(witness, os + 8 * n - 2, &q2_neg); + let c2 = compute_ec_verification_carries( + &[(&ll, &ll, 1)], + &[ + (to_i128_limbs(&xl), -1), + (to_i128_limbs(&x1l), -1), + (to_i128_limbs(&x2l), -1), + ], + &p_l, + &q2_pos, + &q2_neg, + n, + w, + mcs, + ); + write_limbs(witness, os + 9 * n - 2, &c2); + + // Eq3: λ*x1 - λ*x3 - y3 - y1 = q3*p + let (q3_abs, q3_neg) = signed_quotient_wide( + &[(&lam, &x1v, 1)], + &[(&lam, &x3v, 1)], + &[], + &[(&y3v, 1), (&y1v, 1)], + field_modulus_p, + n, + w, + ); + let (q3_pos, q3_neg) = split_quotient(q3_abs, q3_neg, n); + write_limbs(witness, os + 11 * n - 4, &q3_pos); + write_limbs(witness, os + 12 * n - 4, &q3_neg); + let c3 = compute_ec_verification_carries( + &[(&ll, &x1l, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&y1l), -1)], + &p_l, + &q3_pos, + &q3_neg, + n, + w, + mcs, + ); + write_limbs(witness, os + 13 * n - 4, &c3); + } + NonNativeEcOp::OnCurve => { + let px_val = read_witness_limbs(witness, &inputs[0], w); + let py_val = read_witness_limbs(witness, &inputs[1], w); + let x_sq_val = mul_mod(&px_val, &px_val, field_modulus_p); + let xsl = decompose_to_u128_limbs(&x_sq_val, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + write_limbs(witness, os, &xsl); + + let a_is_zero = curve_a.iter().all(|&v| v == 0); + // Per-equation max_coeff_sum must match compiler + let mcs_eq1: u64 = 2 + 2 * n as u64; // 1+1+2n + let mcs_eq2: u64 = if a_is_zero { + 3 + 2 * n as u64 // 1+1+1+2n + } else { + 4 + 2 * n as u64 // 1+1+1+1+2n + }; + + // Layout: [x_sq(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 + + // Eq1: px·px - x_sq = q1·p + let (q1_abs, q1_neg) = signed_quotient_wide( + &[(&px_val, &px_val, 1)], + &[], + &[], + &[(&x_sq_val, 1)], + field_modulus_p, + n, + w, + ); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + n, &q1_pos); + write_limbs(witness, os + 2 * n, &q1_neg); + let c1 = compute_ec_verification_carries( + &[(&pl, &pl, 1)], + &[(to_i128_limbs(&xsl), -1)], + &p_l, + &q1_pos, + &q1_neg, + n, + w, + mcs_eq1, + ); + write_limbs(witness, os + 3 * n, &c1); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let a_l = decompose_to_u128_limbs(curve_a, n, w); + let b_l = decompose_to_u128_limbs(curve_b, n, w); + + let mut rhs_prods: Vec<(&[u64; 4], &[u64; 4], u64)> = + vec![(&x_sq_val, &px_val, 1)]; + if !a_is_zero { + rhs_prods.push((curve_a, &px_val, 1)); + } + let (q2_abs, q2_neg) = signed_quotient_wide( + &[(&py_val, &py_val, 1)], + &rhs_prods, + &[], + &[(curve_b, 1)], + field_modulus_p, + n, + w, + ); + + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 5 * n - 2, &q2_pos); + write_limbs(witness, os + 6 * n - 2, &q2_neg); + + let mut prod_sets: Vec<(&[u128], &[u128], i64)> = + vec![(&pyl, &pyl, 1), (&xsl, &pl, -1)]; + if !a_is_zero { + prod_sets.push((&a_l, &pl, -1)); + } + let c2 = compute_ec_verification_carries( + &prod_sets, + &[(to_i128_limbs(&b_l), -1)], + &p_l, + &q2_pos, + &q2_neg, + n, + w, + mcs_eq2, + ); + write_limbs(witness, os + 7 * n - 2, &c2); + } + } + } + WitnessBuilder::EcScalarMulHint { + output_start, + px_limbs, + py_limbs, + s_lo, + s_hi, + curve_a, + field_modulus_p, + num_limbs, + limb_bits, + } => { + let n = *num_limbs as usize; + let scalar = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), + ); + + let px_val = if n == 1 { + fe_to_bigint(witness[px_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, px_limbs, *limb_bits) + }; + let py_val = if n == 1 { + fe_to_bigint(witness[py_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, py_limbs, *limb_bits) + }; + + let (rx, ry) = ec_scalar_mul(&px_val, &py_val, &scalar, curve_a, field_modulus_p); + + if n == 1 { + witness[*output_start] = Some(bigint_to_fe(&rx)); + witness[*output_start + 1] = Some(bigint_to_fe(&ry)); + } else { + let rx_limbs = decompose_to_u128_limbs(&rx, n, *limb_bits); + let ry_limbs = decompose_to_u128_limbs(&ry, n, *limb_bits); + write_limbs(witness, *output_start, &rx_limbs); + write_limbs(witness, *output_start + n, &ry_limbs); + } + } + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => { + let f = witness[*flag].unwrap(); + let a = witness[*on_false].unwrap(); + let b = witness[*on_true].unwrap(); + witness[*output] = Some(a + f * (b - a)); + } + WitnessBuilder::BooleanOr { output, a, b } => { + let a_val = witness[*a].unwrap(); + let b_val = witness[*b].unwrap(); + witness[*output] = Some(a_val + b_val - a_val * b_val); + } + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => { + let s_fe = witness[*scalar].unwrap(); + let s_big = s_fe.into_bigint().0; + // NOTE: Only reads lower 128 bits. Safe for FakeGLV half-scalars + // (≤128 bits for 256-bit curves) but would silently truncate + // larger values. The R1CS reconstruction constraint catches this. + let s_val: u128 = s_big[0] as u128 | ((s_big[1] as u128) << 64); + let n = *num_bits; + let skew: u128 = if s_val & 1 == 0 { 1 } else { 0 }; + let s_adj = s_val + skew; + // t = (s_adj + 2^n - 1) / 2 + // Both s_adj and 2^n-1 are odd, so sum is even. + // To avoid u128 overflow when n >= 128, rewrite as: + // t = (s_adj - 1) / 2 + (2^n - 1 + 1) / 2 = (s_adj - 1) / 2 + 2^(n-1) + let t = if n == 0 { + s_adj / 2 + } else { + (s_adj - 1) / 2 + (1u128 << (n - 1)) + }; + for i in 0..n { + witness[*output_start + i] = Some(FieldElement::from(((t >> i) & 1) as u64)); + } + witness[*output_start + n] = Some(FieldElement::from(skew as u64)); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" @@ -393,12 +965,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let table_size = 1usize << *num_bits; let mut multiplicities = vec![0u32; table_size]; for query in queries { - let val = match query { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(w) => { - witness[*w].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, query).into_bigint().0[0]; multiplicities[val as usize] += 1; } for (i, count) in multiplicities.iter().enumerate() { @@ -408,14 +975,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { let sz_val = witness[*sz].unwrap(); let rs_val = witness[*rs].unwrap(); - let input_val = match input { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; - let spread_val = match spread_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; + let input_val = resolve(witness, input); + let spread_val = resolve(witness, spread_output); // sz - (input + rs * spread_output) witness[*idx] = Some(sz_val - (input_val + rs_val * spread_val)); } diff --git a/provekit/r1cs-compiler/src/constraint_helpers.rs b/provekit/r1cs-compiler/src/constraint_helpers.rs new file mode 100644 index 000000000..9561fe7f3 --- /dev/null +++ b/provekit/r1cs-compiler/src/constraint_helpers.rs @@ -0,0 +1,133 @@ +//! General-purpose R1CS constraint helpers. + +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, +}; + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +/// +/// Uses a single witness + single R1CS constraint: +/// flag * (on_true - on_false) = result - on_false +pub(crate) fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + // When both branches are the same witness, result is trivially that witness. + if on_false == on_true { + return on_false; + } + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SelectWitness { + output: result, + flag, + on_false, + on_true, + }); + // flag * (on_true - on_false) = result - on_false + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, on_true), (-FieldElement::ONE, on_false)], + &[(FieldElement::ONE, result), (-FieldElement::ONE, on_false)], + ); + result +} + +/// Packs bit witnesses into a digit: `d = Σ bits\[i\] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. +/// Does NOT constrain a or b to be boolean — caller must ensure that. +/// +/// Uses a single witness + single R1CS constraint: +/// (1 - a) * (1 - b) = 1 - result +pub(crate) fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { + let one = compiler.witness_one(); + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::BooleanOr { + output: result, + a, + b, + }); + // (1 - a) * (1 - b) = 1 - result + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, one), (-FieldElement::ONE, a)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, result)], + ); + result +} + +/// Creates a constant witness with the given value, pinned by an R1CS +/// constraint so that a malicious prover cannot set it to an arbitrary value. +pub(crate) fn add_constant_witness( + compiler: &mut NoirToR1CSCompiler, + value: FieldElement, +) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: 1 * w = value * 1 (embeds the constant into the constraint matrix) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, compiler.witness_one())], + ); + w +} + +/// Constrains a witness to equal a known constant value. +/// Uses the constant as an R1CS coefficient — no witness needed for the +/// expected value. Use this for identity checks where the witness must equal +/// a compile-time-known value. +pub(crate) fn constrain_to_constant( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + value: FieldElement, +) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, witness)], + &[(value, compiler.witness_one())], + ); +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +pub(crate) fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +pub(crate) fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 91c4e4128..657f7bd78 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -1,5 +1,6 @@ use { crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, ark_std::One, provekit_common::{ witness::{DigitalDecompositionWitnesses, WitnessBuilder}, @@ -66,7 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() * FieldElement::from(1u64 << *log_base); + let multiplier = + *digit_multipliers.last().unwrap() * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 7de8f899b..0b9890a8a 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,10 +1,12 @@ mod binops; +mod constraint_helpers; mod digits; mod memory; +pub mod msm; mod noir_proof_scheme; -mod noir_to_r1cs; +pub mod noir_to_r1cs; mod poseidon2; -mod range_check; +pub mod range_check; mod sha256_compression; mod spread; mod uints; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs new file mode 100644 index 000000000..de5edda41 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -0,0 +1,712 @@ +//! Analytical cost model for MSM parameter optimization. +//! +//! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): +//! `calculate_msm_witness_cost` estimates total cost, `get_optimal_msm_params` +//! searches the parameter space for the minimum. + +use std::collections::BTreeMap; + +/// The 256-bit scalar is split into two halves (s_lo, s_hi) because it doesn't +/// fit in the native field. +const SCALAR_HALF_BITS: usize = 128; + +/// Per-point overhead witnesses shared across all non-native paths. +/// - detect_skip: 2×is_zero(3W) + product(1W) + or(1W) = 8W +/// - sanitize: 4 select_witness = 4W +/// - ec_hint: EcScalarMulHint(2W) + 2W selects = 4W +/// - glv_hint: s1, s2, neg1, neg2 = 4W +const DETECT_SKIP_WIT: usize = 8; +const SANITIZE_WIT: usize = 4; +const EC_HINT_WIT: usize = 4; +const GLV_HINT_WIT: usize = 4; + +fn ceil_div(a: usize, b: usize) -> usize { + (a + b - 1) / b +} + +/// Table building ops: (doubles, adds) for constructing a signed-digit table. +/// Each table has `half_table_size` entries of odd multiples. +fn table_build_ops(half_table_size: usize) -> (usize, usize) { + if half_table_size >= 2 { + (1, half_table_size - 1) + } else { + (0, 0) + } +} + +/// Per-point overhead witnesses common to all non-native paths. +fn per_point_overhead(half_bits: usize, num_limbs: usize, sr_witnesses: usize) -> usize { + let scalar_bit_decomp = 2 * (half_bits + 1); + let point_decomp = if num_limbs > 1 { 4 * num_limbs } else { 0 }; + scalar_bit_decomp + + DETECT_SKIP_WIT + + SANITIZE_WIT + + EC_HINT_WIT + + GLV_HINT_WIT + + point_decomp + + sr_witnesses +} + +// --------------------------------------------------------------------------- +// Field op cost helpers (used by generic single-limb path + scalar relation) +// --------------------------------------------------------------------------- + +/// Total witnesses produced by N-limb field operations. +/// +/// Per-op witness counts by configuration: +/// - Native (N=1, same field): 1 per op (direct R1CS) +/// - Single-limb non-native (N=1): 5 per add/sub/mul, 6 per inv (reduce_mod_p +/// pattern) +/// - Multi-limb (N>1): add/sub = 1+6N, mul = N²+7N-2, inv = N²+8N-2 (schoolbook +/// multiplication + quotient/remainder) +fn field_op_witnesses( + n_add: usize, + n_sub: usize, + n_mul: usize, + n_inv: usize, + num_limbs: usize, + is_native: bool, +) -> usize { + if is_native { + n_add + n_sub + n_mul + n_inv + } else if num_limbs == 1 { + (n_add + n_sub + n_mul) * 5 + n_inv * 6 + } else { + let n = num_limbs; + let w_as = 1 + 6 * n; + let w_m = n * n + 7 * n - 2; + let w_i = n * n + 8 * n - 2; + (n_add + n_sub) * w_as + n_mul * w_m + n_inv * w_i + } +} + +/// Aggregate range checks from field ops into a map. +/// +/// - Native: no range checks +/// - Single-limb non-native: 1 check at `modulus_bits` per add/sub/mul, 2 per +/// inv +/// - Multi-limb: `limb_bits`-wide checks from less_than_p, plus +/// `carry_bits`-wide checks from schoolbook column carries +fn add_field_op_range_checks( + n_add: usize, + n_sub: usize, + n_mul: usize, + n_inv: usize, + num_limbs: usize, + limb_bits: u32, + modulus_bits: u32, + is_native: bool, + rc_map: &mut BTreeMap, +) { + if is_native { + return; + } + if num_limbs == 1 { + *rc_map.entry(modulus_bits).or_default() += n_add + n_sub + n_mul + 2 * n_inv; + } else { + let n = num_limbs; + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let carry_bits = limb_bits + ceil_log2_n + 2; + *rc_map.entry(limb_bits).or_default() += + (n_add + n_sub) * 2 * n + n_mul * 3 * n + n_inv * 4 * n; + *rc_map.entry(carry_bits).or_default() += (n_mul + n_inv) * (2 * n - 2); + } +} + +// --------------------------------------------------------------------------- +// Hint-verified EC op cost model (non-native, num_limbs >= 2) +// --------------------------------------------------------------------------- + +/// Witness and range check costs for a single hint-verified EC operation. +struct HintVerifiedEcCost { + witnesses: usize, + rc_limb: usize, + rc_carry: usize, + carry_bits: u32, +} + +impl HintVerifiedEcCost { + /// point_double: (15N-6)W hint + 5N² products + N constants + 3×3N ltp + fn point_double(n: usize, limb_bits: u32) -> Self { + let wit = (15 * n - 6) + 5 * n * n + n + 9 * n; + Self { + witnesses: wit, + rc_limb: 9 * n + 6 * n, // 9N hint limbs (3+6 q_pos/q_neg) + 3×2N ltp limbs + rc_carry: 3 * (2 * n - 2), // 3 equations × (2N-2) carries + carry_bits: hint_carry_bits(limb_bits, 6 + 2 * n as u64, n), + } + } + + /// point_add: (15N-6)W hint + 4N² products + 3×3N ltp + fn point_add(n: usize, limb_bits: u32) -> Self { + let wit = (15 * n - 6) + 4 * n * n + 9 * n; + Self { + witnesses: wit, + rc_limb: 9 * n + 6 * n, + rc_carry: 3 * (2 * n - 2), + carry_bits: hint_carry_bits(limb_bits, 4 + 2 * n as u64, n), + } + } + + /// on_curve (worst case, a != 0): (9N-4)W hint + 4N² products + 2N + /// constants + 3N ltp + fn on_curve(n: usize, limb_bits: u32) -> Self { + let wit = (9 * n - 4) + 4 * n * n + 2 * n + 3 * n; + Self { + witnesses: wit, + rc_limb: 5 * n + 2 * n, // 5N hint limbs (1+4 q_pos/q_neg) + 2N ltp limbs + rc_carry: 2 * (2 * n - 2), // 2 equations × (2N-2) carries + carry_bits: hint_carry_bits(limb_bits, 5 + 2 * n as u64, n), + } + } + + /// Accumulate `count` of this op's range checks into `rc_map`. + fn add_range_checks(&self, count: usize, limb_bits: u32, rc_map: &mut BTreeMap) { + *rc_map.entry(limb_bits).or_default() += count * self.rc_limb; + *rc_map.entry(self.carry_bits).or_default() += count * self.rc_carry; + } +} + +/// Carry range check bits for hint-verified EC column equations. +fn hint_carry_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + limb_bits + extra_bits + 1 +} + +// --------------------------------------------------------------------------- +// Scalar relation cost +// --------------------------------------------------------------------------- + +/// Witnesses and range checks for scalar relation verification. +/// +/// Verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod n)` using multi-limb +/// arithmetic with the curve order as modulus. Components: +/// - Scalar decomposition (DD digits for s_lo, s_hi) +/// - Half-scalar decomposition (DD digits for s1, s2) +/// - One mul + one add + one sub for sign handling +/// - XOR witnesses (2) + select (num_limbs) +/// - s2 non-zero check: compute_is_zero(3W) + constrain_zero +fn scalar_relation_cost( + native_field_bits: u32, + scalar_bits: usize, +) -> (usize, BTreeMap) { + let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); + let n = ceil_div(scalar_bits, limb_bits as usize); + let half_bits = (scalar_bits + 1) / 2; + let half_limbs = ceil_div(half_bits, limb_bits as usize); + let scalar_half_limbs = ceil_div(SCALAR_HALF_BITS, limb_bits as usize); + + let has_cross = n > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; + let witnesses = 2 * scalar_half_limbs + + has_cross as usize + + 2 * n + + field_op_witnesses(1, 1, 1, 0, n, false) + + 2 + + n + + 3; // compute_is_zero(s2): inv + product + is_zero + + // Only n limbs worth of scalar DD digits get range checks; unused digits + // are zero-constrained instead (soundness fix for small curves). + let scalar_dd_rcs = n.min(2 * scalar_half_limbs); + let mut rc_map = BTreeMap::new(); + *rc_map.entry(limb_bits).or_default() += scalar_dd_rcs + 2 * half_limbs; + add_field_op_range_checks( + 1, + 1, + 1, + 0, + n, + limb_bits, + scalar_bits as u32, + false, + &mut rc_map, + ); + + (witnesses, rc_map) +} + +// --------------------------------------------------------------------------- +// MSM cost entry point +// --------------------------------------------------------------------------- + +/// Total estimated witness cost for an MSM. +/// +/// Accounts for three categories of witnesses: +/// 1. **Inline witnesses** — field ops, selects, is_zero, hints, DDs +/// 2. **Range check resolution** — LogUp/naive cost for all range checks +/// 3. **Per-point overhead** — detect_skip, sanitization, point decomposition +pub fn calculate_msm_witness_cost( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + window_size: usize, + limb_bits: u32, + is_native: bool, +) -> usize { + if is_native { + return calculate_msm_witness_cost_native(native_field_bits, n_points, scalar_bits); + } + + let n = ceil_div(curve_modulus_bits as usize, limb_bits as usize); + let half_bits = (scalar_bits + 1) / 2; + let w = window_size; + let half_table_size = 1usize << (w - 1); + let num_windows = ceil_div(half_bits, w); + + if n >= 2 { + calculate_msm_witness_cost_hint_verified( + native_field_bits, + n_points, + scalar_bits, + w, + limb_bits, + n, + half_bits, + half_table_size, + num_windows, + ) + } else { + calculate_msm_witness_cost_generic( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + w, + limb_bits, + n, + half_bits, + half_table_size, + num_windows, + ) + } +} + +// --------------------------------------------------------------------------- +// Hint-verified (multi-limb) non-native cost +// --------------------------------------------------------------------------- + +/// Hint-verified non-native MSM cost (num_limbs >= 2). +#[allow(clippy::too_many_arguments)] +fn calculate_msm_witness_cost_hint_verified( + native_field_bits: u32, + n_points: usize, + scalar_bits: usize, + w: usize, + limb_bits: u32, + n: usize, + half_bits: usize, + half_table_size: usize, + num_windows: usize, +) -> usize { + let ec_double = HintVerifiedEcCost::point_double(n, limb_bits); + let ec_add = HintVerifiedEcCost::point_add(n, limb_bits); + let ec_oncurve = HintVerifiedEcCost::on_curve(n, limb_bits); + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + let (tbl_d, tbl_a) = table_build_ops(half_table_size); + + // negate_mod_p_multi: 3N witnesses, N range checks (no less_than_p) + let negate_wit = 3 * n; + + // --- Per-point EC witnesses (each point has its own doubling chain) --- + let pp_doubles = num_windows * w; // per-point doublings (no longer shared) + let pp_doubles_ec = pp_doubles * ec_double.witnesses; + let pp_offset_constants = 2 * n; // offset limbs per point + let pp_table_ec = 2 * (tbl_d * ec_double.witnesses + tbl_a * ec_add.witnesses); + let pp_loop_ec = num_windows * 2 * ec_add.witnesses; + let pp_skew_ec = 2 * ec_add.witnesses; + let pp_oncurve = 2 * ec_oncurve.witnesses; + let pp_y_negate = 2 * (negate_wit + n); // 2 × (negate + select) + let pp_signed_lookup_negate = num_windows * 2 * (negate_wit + n); + let pp_skew_negate = 2 * negate_wit; + let pp_skew_selects = 2 * 2 * n; // 2 branches × 2N + let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; + let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); + + let per_point = pp_doubles_ec + + pp_offset_constants + + pp_table_ec + + pp_loop_ec + + pp_skew_ec + + pp_oncurve + + pp_y_negate + + pp_signed_lookup_negate + + pp_skew_negate + + pp_skew_selects + + pp_table_selects + + pp_xor + + per_point_overhead(half_bits, n, sr_witnesses); + + // --- Shared constants --- + let shared_constants = 3; // gen_x, gen_y, zero + + // --- Point accumulation --- + let accum = n_points * (ec_add.witnesses + 2 * n) // per-point add + skip select + + n_points.saturating_sub(1) // all_skipped products + + ec_add.witnesses + 4 * n + 2 * n // offset sub + constants + selects + + 2 + 2; // mask + recompose + + // --- Range checks --- + let mut rc_map: BTreeMap = BTreeMap::new(); + + // Per-point: loop doublings + table doubles + table/loop/skew adds + on-curve + let pp_doubles_count = pp_doubles + 2 * tbl_d; + let pp_adds_count = 2 * tbl_a + num_windows * 2 + 2; + ec_double.add_range_checks(n_points * pp_doubles_count, limb_bits, &mut rc_map); + ec_add.add_range_checks(n_points * pp_adds_count, limb_bits, &mut rc_map); + ec_oncurve.add_range_checks(n_points * 2, limb_bits, &mut rc_map); + + // Accumulation adds + ec_add.add_range_checks(n_points + 1, limb_bits, &mut rc_map); + + // Negate range checks: N limb checks per negate + let negate_count_pp = 2 + num_windows * 2 + 2; // y-negate + signed_lookup + skew + *rc_map.entry(limb_bits).or_default() += n_points * negate_count_pp * n; + + // Point decomposition + *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; + + // Scalar relation + for (&bits, &count) in &sr_range_checks { + *rc_map.entry(bits).or_default() += n_points * count; + } + + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + shared_constants + n_points * per_point + accum + range_check_cost +} + +// --------------------------------------------------------------------------- +// Generic (single-limb) non-native cost +// --------------------------------------------------------------------------- + +/// Generic (single-limb) non-native MSM cost using MultiLimbOps field op +/// chains. Uses per-point accumulators (no shared doublings). +#[allow(clippy::too_many_arguments)] +fn calculate_msm_witness_cost_generic( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + w: usize, + limb_bits: u32, + n: usize, + half_bits: usize, + half_table_size: usize, + num_windows: usize, +) -> usize { + // point_double: (5 add, 3 sub, 4 mul, 1 inv) + // point_add: (1 add, 5 sub, 3 mul, 1 inv) + let (tbl_d, tbl_a) = table_build_ops(half_table_size); + let pp_loop_doubles = num_windows * w; + + // --- Per-point field ops: loop doubles + tables + loop adds + skew + on-curve + // + y-negate --- + let pp_add = pp_loop_doubles * 5 + 2 * (tbl_d * 5 + tbl_a) + num_windows * 2 + 2 + 4; + let pp_sub = + pp_loop_doubles * 3 + 2 * (tbl_d * 3 + tbl_a * 5) + num_windows * (2 * 5 + 2) + 2 * 6 + 2; + let pp_mul = + pp_loop_doubles * 4 + 2 * (tbl_d * 4 + tbl_a * 3) + num_windows * 2 * 3 + 2 * 3 + 8; + let pp_inv = pp_loop_doubles + 2 * (tbl_d + tbl_a) + num_windows * 2 + 2; + + let pp_field_ops = field_op_witnesses(pp_add, pp_sub, pp_mul, pp_inv, n, false); + + let pp_doubles = pp_loop_doubles + 2 * tbl_d; + let pp_negate_zeros = (4 + 2 * num_windows) * n; + let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros + 2 * n; // +2N for offset limbs + let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; + let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); + let pp_signed_y_selects = num_windows * 2 * n; + let pp_y_negate_selects = 2 * n; + let pp_skew_selects = 2 * 2 * n; + let pp_selects = + pp_table_selects + pp_xor + pp_signed_y_selects + pp_y_negate_selects + pp_skew_selects; + + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + + let per_point = + pp_field_ops + pp_constants + pp_selects + per_point_overhead(half_bits, n, sr_witnesses); + + let shared_constants = 3; // gen_x, gen_y, zero + + // --- Point accumulation --- + let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); + let accum = n_points * (pa_cost + 2 * n) + + n_points.saturating_sub(1) + + pa_cost + + 4 * n + + 2 * n + + 2 + + if n > 1 { 2 } else { 0 }; + + // --- Range checks --- + let mut rc_map: BTreeMap = BTreeMap::new(); + add_field_op_range_checks( + n_points * pp_add, + n_points * pp_sub, + n_points * pp_mul, + n_points * pp_inv, + n, + limb_bits, + curve_modulus_bits, + false, + &mut rc_map, + ); + if n > 1 { + *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; + } + for (&bits, &count) in &sr_range_checks { + *rc_map.entry(bits).or_default() += n_points * count; + } + add_field_op_range_checks( + n_points + 1, + (n_points + 1) * 5, + (n_points + 1) * 3, + n_points + 1, + n, + limb_bits, + curve_modulus_bits, + false, + &mut rc_map, + ); + + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + n_points * per_point + shared_constants + accum + range_check_cost +} + +// --------------------------------------------------------------------------- +// Native-field cost +// --------------------------------------------------------------------------- + +/// Native-field MSM cost: hint-verified EC ops with signed-bit wNAF (w=1). +/// +/// The native path uses prover hints verified via raw R1CS constraints: +/// - `point_double_verified_native`: 4W (3 hint + 1 product) +/// - `point_add_verified_native`: 3W (3 hint) +/// - `verify_on_curve_native`: 2W (2 products) +/// - No multi-limb arithmetic → zero EC-related range checks +/// +/// Uses per-point accumulators: each point has its own doubling chain and +/// identity check for soundness. +fn calculate_msm_witness_cost_native( + native_field_bits: u32, + n_points: usize, + scalar_bits: usize, +) -> usize { + let half_bits = (scalar_bits + 1) / 2; + + let on_curve = 4; // 2 × verify_on_curve_native (2W each) + let y_negate = 6; // 2 × 3W (neg_y, y_eff, neg_y_eff) + let (sr_wit, sr_rc) = scalar_relation_cost(native_field_bits, scalar_bits); + + // Per point per bit: 4W (double) + 2×(1W select + 3W add) = 12W + let ec_loop_pp = half_bits * 12; + // Skew correction: 2 branches × (3W add + 2W select) = 10W per point + let skew_pp = 10; + // Offset constants per point + let offset_pp = 2; + + let per_point = on_curve + y_negate + + 2 * (half_bits + 1) // scalar bit decomposition + + DETECT_SKIP_WIT + SANITIZE_WIT + EC_HINT_WIT + GLV_HINT_WIT + + sr_wit + + ec_loop_pp + skew_pp + offset_pp; + + let shared_constants = 3; // gen_x, gen_y, zero + + let accum = 2 // initial acc constants + + n_points * 5 // add(3W) + skip_select(2W) + + n_points.saturating_sub(1) // all_skipped products + + 10; // offset sub: 3 const + 2 sel + 3 add + 2 mask + + // Range checks (only from scalar relation for native) + let mut rc_map: BTreeMap = BTreeMap::new(); + for (&bits, &count) in &sr_rc { + *rc_map.entry(bits).or_default() += n_points * count; + } + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + n_points * per_point + shared_constants + accum + range_check_cost +} + +// --------------------------------------------------------------------------- +// Parameter search +// --------------------------------------------------------------------------- + +/// Picks the widest limb size for scalar-relation multi-limb arithmetic that +/// fits inside the native field without overflow. +/// +/// For BN254 (254-bit native field, ~254-bit order): N=3 @ 85-bit limbs. +/// For small curves where half_scalar × full_scalar fits natively: N=1. +pub(super) fn scalar_relation_limb_bits(native_field_bits: u32, order_bits: usize) -> u32 { + let half_bits = (order_bits + 1) / 2; + + // N=1 is valid only if mul product fits in the native field. + if half_bits + order_bits < native_field_bits as usize { + return order_bits as u32; + } + + for n in 2..=super::MAX_LIMBS { + let lb = ((order_bits + n - 1) / n) as u32; + if column_equation_fits_native_field(native_field_bits, lb, n) { + return lb; + } + } + + panic!("native field too small for scalar relation verification"); +} + +/// Check whether schoolbook column equation values fit in the native field. +/// +/// The maximum integer value in any column equation is bounded by +/// `2^(2W + ceil(log2(N)) + 3)` where W = limb_bits, N = num_limbs. +/// This must be less than the native field modulus (~`2^native_field_bits`). +pub fn column_equation_fits_native_field( + native_field_bits: u32, + limb_bits: u32, + num_limbs: usize, +) -> bool { + if num_limbs <= 1 { + return true; + } + let ceil_log2_n = (num_limbs as f64).log2().ceil() as u32; + 2 * limb_bits + ceil_log2_n + 3 < native_field_bits +} + +/// Search for optimal (limb_bits, window_size) minimizing witness cost. +/// +/// Searches limb_bits ∈ \[8..max\] and window_size ∈ \[2..8\]. +/// Each candidate is checked for column equation soundness. +pub fn get_optimal_msm_params( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + is_native: bool, +) -> (u32, usize) { + if is_native { + // Native path uses signed-bit wNAF (w=1), no limb decomposition. + // Window size is unused; return a default. + let cost = calculate_msm_witness_cost_native(native_field_bits, n_points, scalar_bits); + let _ = cost; // cost is the same regardless of window_size + return (native_field_bits, 4); + } + + let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; + let mut best_cost = usize::MAX; + let mut best_limb_bits = max_limb_bits.min(86); + let mut best_window = 4; + + for lb in 8..=max_limb_bits { + let num_limbs = ceil_div(curve_modulus_bits as usize, lb as usize); + if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { + continue; + } + for ws in 2..=8usize { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + lb, + false, + ); + if cost < best_cost { + best_cost = cost; + best_limb_bits = lb; + best_window = ws; + } + } + } + + (best_limb_bits, best_window) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_optimal_params_bn254_native() { + let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256, true); + assert_eq!(limb_bits, 254); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_secp256r1() { + let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256, false); + let num_limbs = ceil_div(256, limb_bits as usize); + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_goldilocks() { + let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64, false); + let num_limbs = ceil_div(64, limb_bits as usize); + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_column_equation_soundness_boundary() { + assert!(column_equation_fits_native_field(254, 124, 3)); + assert!(!column_equation_fits_native_field(254, 125, 3)); + assert!(!column_equation_fits_native_field(254, 126, 3)); + } + + #[test] + fn test_secp256r1_limb_bits_not_126() { + let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256, false); + assert!( + limb_bits <= 124, + "secp256r1 limb_bits={limb_bits} exceeds safe maximum 124" + ); + } + + #[test] + fn test_scalar_relation_cost_grumpkin() { + let (sr, rc) = scalar_relation_cost(254, 256); + assert!(sr > 50 && sr < 200, "unexpected scalar_relation={sr}"); + let total_rc: usize = rc.values().sum(); + assert!(total_rc > 30, "too few range checks: {total_rc}"); + assert!(total_rc < 200, "too many range checks: {total_rc}"); + } + + #[test] + fn test_scalar_relation_cost_small_curve() { + let (sr, _) = scalar_relation_cost(254, 64); + assert!( + sr < 100, + "64-bit curve scalar_relation={sr} should be < 100" + ); + } + + #[test] + fn test_field_op_witnesses_single_limb() { + // inv_mod_p_single: a_inv(1) + mul_mod_p_single(5) = 6 + assert_eq!(field_op_witnesses(0, 0, 0, 1, 1, false), 6); + // add_mod_p_single: 5 + assert_eq!(field_op_witnesses(1, 0, 0, 0, 1, false), 5); + } + + #[test] + fn test_estimate_range_check_cost_basic() { + use crate::range_check::estimate_range_check_cost; + + assert_eq!(estimate_range_check_cost(&BTreeMap::new()), 0); + + let mut checks = BTreeMap::new(); + checks.insert(8u32, 100usize); + let cost = estimate_range_check_cost(&checks); + assert!(cost > 0, "expected nonzero cost for 100 8-bit checks"); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve/mod.rs b/provekit/r1cs-compiler/src/msm/curve/mod.rs new file mode 100644 index 000000000..3f2c75d20 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/mod.rs @@ -0,0 +1,458 @@ +use { + ark_ff::{Field, PrimeField}, + provekit_common::FieldElement, +}; + +mod u256_arith; + +pub struct CurveParams { + pub field_modulus_p: [u64; 4], + pub curve_order_n: [u64; 4], + pub curve_a: [u64; 4], + pub curve_b: [u64; 4], + pub generator: ([u64; 4], [u64; 4]), + pub offset_point: ([u64; 4], [u64; 4]), +} + +impl CurveParams { + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) + } + + /// Decompose (p - 1) into `num_limbs` limbs of `limb_bits` width each. + pub fn p_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + let p_minus_1 = sub_one_u64_4(&self.field_modulus_p); + decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) + } + + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` + /// width. + pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) + } + + pub fn curve_b_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_b, limb_bits, num_limbs) + } + + /// Number of bits in the field modulus. + pub fn modulus_bits(&self) -> u32 { + if self.is_native_field() { + FieldElement::MODULUS_BIT_SIZE + } else { + let p = &self.field_modulus_p; + for i in (0..4).rev() { + if p[i] != 0 { + return (i as u32) * 64 + (64 - p[i].leading_zeros()); + } + } + 0 + } + } + + /// Returns true if the curve's base field modulus equals the native BN254 + /// scalar field modulus. + pub fn is_native_field(&self) -> bool { + let native_mod = FieldElement::MODULUS; + self.field_modulus_p == native_mod.0 + } + + /// Convert modulus to a native field element (only valid when p < native + /// modulus). + pub fn p_native_fe(&self) -> FieldElement { + curve_native_point_fe(&self.field_modulus_p) + } + + /// Decompose the curve order n into `num_limbs` limbs of `limb_bits` width + /// each. + pub fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_order_n, limb_bits, num_limbs) + } + + /// Decompose (curve_order_n - 1) into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn curve_order_n_minus_1_limbs( + &self, + limb_bits: u32, + num_limbs: usize, + ) -> Vec { + let n_minus_1 = sub_one_u64_4(&self.curve_order_n); + decompose_to_limbs(&n_minus_1, limb_bits, num_limbs) + } + + /// Number of bits in the curve order n. + pub fn curve_order_bits(&self) -> u32 { + // Compute bit length directly from raw limbs to avoid reduction + // mod the native field (curve_order_n may exceed the native modulus). + let n = &self.curve_order_n; + for i in (0..4).rev() { + if n[i] != 0 { + return (i as u32) * 64 + (64 - n[i].leading_zeros()); + } + } + 0 + } + + /// Number of bits for the GLV half-scalar: `ceil(order_bits / 2)`. + /// This determines the bit width of the sub-scalars s1, s2 from half-GCD. + pub fn glv_half_bits(&self) -> u32 { + (self.curve_order_bits() + 1) / 2 + } + + /// Decompose the generator x-coordinate into limbs. + pub fn generator_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.generator.0, limb_bits, num_limbs) + } + + /// Decompose the offset point x-coordinate into limbs. + pub fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.0, limb_bits, num_limbs) + } + + /// Decompose the offset point y-coordinate into limbs. + pub fn offset_y_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.1, limb_bits, num_limbs) + } + + /// Compute `[2^n_doublings] * offset_point` on the curve (compile-time + /// only). + /// + /// Used to compute the accumulated offset after the scalar_mul_glv loop: + /// since the accumulator starts at R and gets doubled n times total, the + /// offset to subtract is `[2^n]*R`, not just `R`. + pub fn accumulated_offset(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + if self.is_native_field() { + self.accumulated_offset_native(n_doublings) + } else { + self.accumulated_offset_generic(n_doublings) + } + } + + /// Compute accumulated offset using FieldElement arithmetic (native field). + fn accumulated_offset_native(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let mut x = curve_native_point_fe(&self.offset_point.0); + let mut y = curve_native_point_fe(&self.offset_point.1); + let a = curve_native_point_fe(&self.curve_a); + + for _ in 0..n_doublings { + let x_sq = x * x; + let num = x_sq + x_sq + x_sq + a; + let denom_inv = (y + y).inverse().unwrap(); + let lambda = num * denom_inv; + let x3 = lambda * lambda - x - x; + let y3 = lambda * (x - x3) - y; + x = x3; + y = y3; + } + + (x.into_bigint().0, y.into_bigint().0) + } + + /// Compute accumulated offset using generic 256-bit arithmetic (non-native + /// field). + fn accumulated_offset_generic(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let p = &self.field_modulus_p; + let mut x = self.offset_point.0; + let mut y = self.offset_point.1; + let a = &self.curve_a; + + for _ in 0..n_doublings { + let (x3, y3) = u256_arith::ec_point_double(&x, &y, a, p); + x = x3; + y = y3; + } + + (x, y) + } +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, +/// returned as FieldElements. +pub fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + // Special case: when a single limb needs > 128 bits, FieldElement::from(u128) + // would truncate. Use from_sign_and_limbs to preserve the full value. + if num_limbs == 1 && limb_bits > 128 { + return vec![curve_native_point_fe(val)]; + } + + let mask: u128 = if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + }; + let mut result = vec![FieldElement::from(0u64); num_limbs]; + let mut remaining = *val; + for item in result.iter_mut() { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + *item = FieldElement::from(lo & mask); + // Shift remaining right by limb_bits + if limb_bits >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (limb_bits / 64) as usize; + let bit_shift = limb_bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + result +} + +/// Subtract 1 from a [u64; 4] value. +fn sub_one_u64_4(val: &[u64; 4]) -> [u64; 4] { + let mut result = *val; + for limb in result.iter_mut() { + if *limb > 0 { + *limb -= 1; + return result; + } + *limb = u64::MAX; // borrow + } + result +} + +/// Converts a 256-bit value ([u64; 4]) into a single native field element. +pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_sign_and_limbs(true, val) +} + +/// Negate a field element: compute `-val mod p` (i.e., `p - val`). +/// Returns `[0; 4]` when `val` is zero. +pub fn negate_field_element(val: &[u64; 4], modulus: &[u64; 4]) -> [u64; 4] { + if *val == [0u64; 4] { + return [0u64; 4]; + } + // val is in [1, p-1], so p - val is in [1, p-1] — no borrow. + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = modulus[i].overflowing_sub(val[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + debug_assert!(!borrow, "negate_field_element: val >= modulus"); + result +} + +/// Grumpkin curve parameters. +/// +/// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 +/// scalar field, and its order is the BN254 base field order. +/// +/// Equation: y² = x³ − 17 (a = 0, b = −17 mod p) +pub fn grumpkin_params() -> CurveParams { + CurveParams { + // BN254 scalar field modulus + field_modulus_p: [ + 0x43e1f593f0000001_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // BN254 base field modulus + curve_order_n: [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + curve_a: [0; 4], + // b = −17 mod p + curve_b: [ + 0x43e1f593effffff0_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // Generator G = (1, sqrt(−16) mod p) + generator: ([1, 0, 0, 0], [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ]), + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) + offset_point: ( + [ + 0x626578b496650e95_u64, + 0x8678dcf264df6c01_u64, + 0xf0b3eb7e6d02aba8_u64, + 0x223748a4c4edde75_u64, + ], + [ + 0xb75fb4c26bcd4f35_u64, + 0x4d4ba4d97d5f99d9_u64, + 0xccab35fdbf52368a_u64, + 0x25b41c5f56f8472b_u64, + ], + ), + } +} + +pub fn secp256r1_params() -> CurveParams { + CurveParams { + field_modulus_p: [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ], + curve_order_n: [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ], + curve_a: [ + 0xfffffffffffffffc_u64, + 0x00000000ffffffff_u64, + 0x0000000000000000_u64, + 0xffffffff00000001_u64, + ], + curve_b: [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ], + generator: ( + [ + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ], + [ + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ], + ), + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) + offset_point: ( + [ + 0x57c84fc9d789bd85_u64, + 0xfc35ff7dc297eac3_u64, + 0xfb982fd588c6766e_u64, + 0x447d739beedb5e67_u64, + ], + [ + 0x0c7e33c972e25b32_u64, + 0x3d349b95a7fae500_u64, + 0xe12e9d953a4aaff7_u64, + 0x2d4825ab834131ee_u64, + ], + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_offset_point_on_curve_grumpkin() { + let c = grumpkin_params(); + let x = curve_native_point_fe(&c.offset_point.0); + let y = curve_native_point_fe(&c.offset_point.1); + let b = curve_native_point_fe(&c.curve_b); + // Grumpkin: y^2 = x^3 + b (a=0) + assert_eq!(y * y, x * x * x + b, "offset point not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_single_double_grumpkin() { + let c = grumpkin_params(); + let (x4, y4) = c.accumulated_offset(1); + let x = curve_native_point_fe(&x4); + let y = curve_native_point_fe(&y4); + let b = curve_native_point_fe(&c.curve_b); + // Should still be on curve + assert_eq!(y * y, x * x * x + b, "[2]*offset not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_native_vs_generic() { + let c = grumpkin_params(); + // Both paths should give the same result + let native = c.accumulated_offset_native(10); + let generic = c.accumulated_offset_generic(10); + assert_eq!(native, generic, "native vs generic mismatch for n=10"); + } + + #[test] + fn test_accumulated_offset_256_on_curve() { + let c = grumpkin_params(); + let (x, y) = c.accumulated_offset(256); + let xfe = curve_native_point_fe(&x); + let yfe = curve_native_point_fe(&y); + let b = curve_native_point_fe(&c.curve_b); + assert_eq!(yfe * yfe, xfe * xfe * xfe + b, "[2^257]G not on Grumpkin"); + } + + #[test] + fn test_offset_point_on_curve_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let x = &c.offset_point.0; + let y = &c.offset_point.1; + let a = &c.curve_a; + let b = &c.curve_b; + // y^2 = x^3 + a*x + b (mod p) + let y_sq = u256_arith::mod_mul(y, y, p); + let x_sq = u256_arith::mod_mul(x, x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, x, p); + let ax = u256_arith::mod_mul(a, x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "offset point not on secp256r1"); + } + + #[test] + fn test_accumulated_offset_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let a = &c.curve_a; + let b = &c.curve_b; + let (x, y) = c.accumulated_offset(256); + // Verify the accumulated offset is on the curve + let y_sq = u256_arith::mod_mul(&y, &y, p); + let x_sq = u256_arith::mod_mul(&x, &x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, &x, p); + let ax = u256_arith::mod_mul(a, &x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "accumulated offset not on secp256r1"); + } + + #[test] + fn test_fe_roundtrip() { + // Verify from_sign_and_limbs / into_bigint roundtrip + let val: [u64; 4] = [42, 0, 0, 0]; + let fe = curve_native_point_fe(&val); + let back = fe.into_bigint().0; + assert_eq!(val, back, "roundtrip failed for small value"); + + let val2: [u64; 4] = [ + 0x6d8bc688cdbffffe, + 0x19a74caa311e13d4, + 0xddeb49cdaa36306d, + 0x06ce1b0827aafa85, + ]; + let fe2 = curve_native_point_fe(&val2); + let back2 = fe2.into_bigint().0; + assert_eq!(val2, back2, "roundtrip failed for offset x"); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs b/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs new file mode 100644 index 000000000..e2bcea5e3 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs @@ -0,0 +1,176 @@ +//! 256-bit modular arithmetic for compile-time EC point computations. +//! Only used to precompute accumulated offset points; not performance-critical. + +pub(super) type U256 = [u64; 4]; + +/// Returns true if a >= b. +fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal +} + +/// a + b, returns (result, carry). +fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) +} + +/// a - b, returns (result, borrow). +fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) +} + +/// (a + b) mod p. +pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } +} + +/// (a - b) mod p. +fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } +} + +/// Schoolbook multiplication producing 512-bit result. +fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result +} + +/// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long +/// division. +fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if a[i] != 0 { + total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit of a + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (a[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r +} + +/// (a * b) mod p. +pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = mul_wide(a, b); + mod_reduce_wide(&wide, p) +} + +/// a^exp mod p using square-and-multiply. +fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result +} + +/// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. +fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) +} + +/// EC point doubling on y^2 = x^3 + ax + b. +pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/generic.rs b/provekit/r1cs-compiler/src/msm/ec_points/generic.rs new file mode 100644 index 000000000..c0236fe8f --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/generic.rs @@ -0,0 +1,104 @@ +//! Generic point operations using `MultiLimbOps` field arithmetic. +//! +//! These work for any field (native or non-native) by going through the +//! `MultiLimbOps` abstraction layer. + +use crate::msm::{multi_limb_ops::MultiLimbOps, Limbs}; + +/// Generic point doubling on y^2 = x^3 + ax + b. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) +/// x3 = lambda^2 - 2 * x1 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (the inverse verification will fail to +/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { + let a = ops.curve_a(); + + // Computing numerator = 3 * x1^2 + a + let x1_sq = ops.mul(x1, x1); + let two_x1_sq = ops.add(x1_sq, x1_sq); + let three_x1_sq = ops.add(two_x1_sq, x1_sq); + let numerator = ops.add(three_x1_sq, a); + + // Computing denominator = 2 * y1 + let denominator = ops.add(y1, y1); + + // Computing lambda = numerator * denominator^(-1) + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - 2 * x1 + let lambda_sq = ops.mul(lambda, lambda); + let two_x1 = ops.add(x1, x1); + let x3 = ops.sub(lambda_sq, two_x1); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Generic point addition on y^2 = x^3 + ax + b. +/// +/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): +/// lambda = (y2 - y1) / (x2 - x1) +/// x3 = lambda^2 - x1 - x2 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge cases — x1 = x2: +/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does +/// not exist. This covers two cases: +/// - P1 = P2 (same point): use `point_double` instead. +/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. +/// This function does NOT handle either case — the constraint system +/// will be unsatisfiable if x1 = x2. The caller must detect this +/// and branch accordingly. +pub fn point_add( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { + // Computing lambda = (y2 - y1) / (x2 - x1) + let numerator = ops.sub(y2, y1); + let denominator = ops.sub(x2, x1); + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - x1 - x2 + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Conditional point select without boolean constraint on `flag`. +/// Caller must ensure `flag` is already constrained boolean. +pub fn point_select_unchecked( + ops: &mut MultiLimbOps, + flag: usize, + on_false: (Limbs, Limbs), + on_true: (Limbs, Limbs), +) -> (Limbs, Limbs) { + let x = ops.select_unchecked(flag, on_false.0, on_true.0); + let y = ops.select_unchecked(flag, on_false.1, on_true.1); + (x, y) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs new file mode 100644 index 000000000..b7ddf20a6 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs @@ -0,0 +1,161 @@ +//! Native-field hint-verified EC operations. +//! +//! These operate on single native field element witnesses (no multi-limb). +//! Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw +//! R1CS constraints, eliminating expensive field inversions from the circuit. + +use { + crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::{Field, PrimeField}, + provekit_common::{witness::WitnessBuilder, FieldElement}, +}; + +/// Hint-verified point doubling for native field. +/// +/// Allocates EcDoubleHint → (lambda, x3, y3) = 3W hint. +/// Verification via 4 R1CS constraints: +/// 1. x_sq = px * px (1W+1C via add_product) +/// 2. lambda · 2·py = 3·x_sq + a (1C) +/// 3. lambda² = x3 + 2·px (1C) +/// 4. lambda · (px - x3) = y3 + py (1C) +/// +/// Total: 4W + 4C. +pub fn point_double_verified_native( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { + output_start: hint_start, + px, + py, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // x_sq = px * px (1W + 1C) + let x_sq = compiler.add_product(px, px); + + // Constraint: lambda * (2 * py) = 3 * x_sq + a + // A = [lambda], B = [2*py], C = [3*x_sq + a_const] + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let three = FieldElement::from(3u64); + let two = FieldElement::from(2u64); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, lambda)], &[(two, py)], &[ + (three, x_sq), + (a_fe, compiler.witness_one()), + ]); + + // Constraint: lambda^2 = x3 + 2*px + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x3), (two, px)], + ); + + // Constraint: lambda * (px - x3) = y3 + py + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, px), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, py)], + ); + + (x3, y3) +} + +/// Hint-verified point addition for native field. +/// +/// Allocates EcAddHint → (lambda, x3, y3) = 3W. +/// Verification constraints (3C): +/// 1. lambda * (x2 - x1) = y2 - y1 (1C raw) +/// 2. lambda^2 = x3 + x1 + x2 (1C raw) +/// 3. lambda * (x1 - x3) = y3 + y1 (1C raw) +/// +/// Total: 3W + 3C. +pub fn point_add_verified_native( + compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcAddHint { + output_start: hint_start, + x1, + y1, + x2, + y2, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // Constraint: lambda * (x2 - x1) = y2 - y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x2), (-FieldElement::ONE, x1)], + &[(FieldElement::ONE, y2), (-FieldElement::ONE, y1)], + ); + + // Constraint: lambda^2 = x3 + x1 + x2 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[ + (FieldElement::ONE, x3), + (FieldElement::ONE, x1), + (FieldElement::ONE, x2), + ], + ); + + // Constraint: lambda * (x1 - x3) = y3 + y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x1), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, y1)], + ); + + (x3, y3) +} + +/// On-curve check for native field: y² = x³ + a·x + b. +/// +/// Verification via 3 R1CS constraints: +/// 1. x_sq = x · x (1W+1C via add_product) +/// 2. x_cu = x_sq · x (1W+1C via add_product) +/// 3. y · y = x_cu + a·x + b (1C) +/// +/// Total: 2W + 3C. +pub fn verify_on_curve_native( + compiler: &mut NoirToR1CSCompiler, + x: usize, + y: usize, + curve: &CurveParams, +) { + let x_sq = compiler.add_product(x, x); + let x_cu = compiler.add_product(x_sq, x); + + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let b_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_b)).unwrap(); + + // y * y = x_cu + a*x + b + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, y)], &[(FieldElement::ONE, y)], &[ + (FieldElement::ONE, x_cu), + (a_fe, x), + (b_fe, compiler.witness_one()), + ]); +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs new file mode 100644 index 000000000..836e98339 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs @@ -0,0 +1,651 @@ +//! Non-native hint-verified EC operations (multi-limb schoolbook). +//! +//! These replace the step-by-step MultiLimbOps chain with prover hints verified +//! via schoolbook column equations. Each bilinear mod-p equation is checked by: +//! 1. Pre-computing product witnesses a\[i\]*b\[j\] +//! 2. Column equations: Σ coeff·prod\[k\] + linear\[k\] + carry_in + offset = Σ +//! p\[i\]*q\[j\] + carry_out * W +//! Since p is constant, p\[i\]*q\[j\] terms are linear in q (no product +//! witness). + +use { + crate::{ + msm::{multi_limb_arith::less_than_p_check_multi, multi_limb_ops::MultiLimbParams, Limbs}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, + provekit_common::{ + witness::{NonNativeEcOp, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Collect witness indices from `start..start+len`. +fn witness_range(start: usize, len: usize) -> Vec { + (start..start + len).collect() +} + +/// Allocate N×N product witnesses for `a\[i\]*b\[j\]`. +fn make_products(compiler: &mut NoirToR1CSCompiler, a: &[usize], b: &[usize]) -> Vec> { + let n = a.len(); + debug_assert_eq!(n, b.len()); + let mut prods = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + prods[i][j] = compiler.add_product(a[i], b[j]); + } + } + prods +} + +/// Allocate pinned constant witnesses from pre-decomposed `FieldElement` limbs. +fn allocate_pinned_constant_limbs( + compiler: &mut NoirToR1CSCompiler, + limb_values: &[FieldElement], +) -> Vec { + limb_values + .iter() + .map(|&val| { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(w, val), + )); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(val, compiler.witness_one())], + ); + w + }) + .collect() +} + +/// Range-check limb witnesses at `limb_bits` and carry witnesses at +/// `carry_range_bits`. +fn range_check_limbs_and_carries( + range_checks: &mut BTreeMap>, + limb_vecs: &[&[usize]], + carry_vecs: &[&[usize]], + limb_bits: u32, + carry_range_bits: u32, +) { + for limbs in limb_vecs { + for &w in *limbs { + range_checks.entry(limb_bits).or_default().push(w); + } + } + for carries in carry_vecs { + for &c in *carries { + range_checks.entry(carry_range_bits).or_default().push(c); + } + } +} + +/// Convert `Vec` to `Limbs` and do a less-than-p check. +fn less_than_p_check_vec( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + v: &[usize], + params: &MultiLimbParams, +) { + let n = v.len(); + let mut limbs = Limbs::new(n); + for i in 0..n { + limbs[i] = v[i]; + } + less_than_p_check_multi( + compiler, + range_checks, + limbs, + ¶ms.p_minus_1_limbs, + params.two_pow_w, + params.limb_bits, + ); +} + +/// Compute carry range bits for hint-verified column equations. +fn carry_range_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + limb_bits + extra_bits + 1 +} + +/// Soundness check: verify that merged column equations fit the native field. +fn check_column_equation_fits(limb_bits: u32, max_coeff_sum: u64, n: usize, op_name: &str) { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let max_bits = 2 * limb_bits + extra_bits + 1; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "{op_name} column equation overflow: limb_bits={limb_bits}, n={n}, needs {max_bits} bits", + ); +} + +/// Merge terms with the same witness index by summing their coefficients. +fn merge_terms(terms: &[(FieldElement, usize)]) -> Vec<(FieldElement, usize)> { + use {ark_ff::Zero, std::collections::HashMap}; + let mut map: HashMap = HashMap::new(); + for &(coeff, idx) in terms { + *map.entry(idx).or_insert_with(FieldElement::zero) += coeff; + } + let mut result: Vec<(FieldElement, usize)> = map.into_iter().map(|(idx, c)| (c, idx)).collect(); + result.sort_by_key(|&(_, idx)| idx); + result +} + +/// Emit schoolbook column equations for a merged verification equation. +/// +/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k + Σ p\[i\]*q_neg\[j\] +/// = Σ p\[i\]*q_pos\[j\] + carry_chain (as integers) +/// +/// `product_sets`: each (products_2d, coefficient) where products_2d\[i\]\[j\] +/// is the witness index for a\[i\]*b\[j\]. +/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms. +/// `q_pos_witnesses`, `q_neg_witnesses`: split quotient limbs (N entries each). +/// `carry_witnesses`: unsigned-offset carry witnesses (2N-2 entries). +fn emit_schoolbook_column_equations( + compiler: &mut NoirToR1CSCompiler, + product_sets: &[(&[Vec], FieldElement)], // (products[i][j], coeff) + linear_limbs: &[(&[usize], FieldElement)], // (limb_witnesses, coeff) + q_pos_witnesses: &[usize], + q_neg_witnesses: &[usize], + carry_witnesses: &[usize], + p_limbs: &[FieldElement], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) { + let w1 = compiler.witness_one(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + + // Carry offset scaled for the merged equation's larger coefficients + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + let offset_w_minus_carry = offset_w - carry_offset_fe; + + let num_columns = 2 * n - 1; + + for k in 0..num_columns { + // LHS: Σ coeff * products[i][j] for i+j=k + Σ p[i]*q_neg[j] + carry_in + offset + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + + for &(products, coeff) in product_sets { + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((coeff, products[i][j_val as usize])); + } + } + } + + // Add linear terms (for k < N only, since linear_limbs are N-length) + for &(limbs, coeff) in linear_limbs { + if k < limbs.len() { + lhs_terms.push((coeff, limbs[k])); + } + } + + // Add p*q_neg on the LHS (positive side) + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((p_limbs[i], q_neg_witnesses[j_val as usize])); + } + } + + // Add carry_in and offset + if k > 0 { + lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ p[i]*q_pos[j] for i+j=k + carry_out * W (or offset at last column) + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q_pos_witnesses[j_val as usize])); + } + } + + if k < num_columns - 1 { + rhs_terms.push((two_pow_w, carry_witnesses[k])); + } else { + // Last column: balance with offset_w (no outgoing carry) + rhs_terms.push((offset_w, w1)); + } + + // Merge terms with the same witness index (products may share cached witnesses) + let lhs_merged = merge_terms(&lhs_terms); + let rhs_merged = merge_terms(&rhs_terms); + compiler + .r1cs + .add_constraint(&lhs_merged, &[(FieldElement::ONE, w1)], &rhs_merged); + } +} + +/// Helper to convert witness Vec to Limbs. +fn vec_to_limbs(v: &[usize]) -> Limbs { + let n = v.len(); + let mut limbs = Limbs::new(n); + for i in 0..n { + limbs[i] = v[i]; + } + limbs +} + +/// Hint-verified on-curve check for non-native field (multi-limb). +/// +/// Verifies y² = x³ + ax + b (mod p) via two schoolbook column equations: +/// Eq1: x·x - x_sq = q1·p (x_sq correctness) +/// Eq2: y·y - x_sq·x - a·x - b = q2·p (on-curve) +/// +/// Total: (7N-4)W hint + 3N² products (or 2N² when a=0) + 2×(2N-1) column +/// constraints + 1 less-than-p check. +pub fn verify_on_curve_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified on-curve check requires n >= 2"); + + let a_is_zero = params.curve_a_raw.iter().all(|&v| v == 0); + + let max_coeff_sum: u64 = if a_is_zero { + 4 + 2 * n as u64 + } else { + 5 + 2 * n as u64 + }; + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "On-curve"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::OnCurve, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: params.curve_b_raw, + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [x_sq(N), q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 + let x_sq = witness_range(os, n); + let q1_pos = witness_range(os + n, n); + let q1_neg = witness_range(os + 2 * n, n); + let c1 = witness_range(os + 3 * n, 2 * n - 2); + let q2_pos = witness_range(os + 5 * n - 2, n); + let q2_neg = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + + // Eq1: px·px - x_sq = q1·p + let prod_px_px = make_products(compiler, &px.as_slice()[..n], &px.as_slice()[..n]); + + let max_coeff_eq1: u64 = 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_px_px, FieldElement::ONE)], + &[(&x_sq, -FieldElement::ONE)], + &q1_pos, + &q1_neg, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let prod_py_py = make_products(compiler, &py.as_slice()[..n], &py.as_slice()[..n]); + let prod_xsq_px = make_products(compiler, &x_sq, &px.as_slice()[..n]); + let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_b_limbs[..n]); + + if a_is_zero { + let max_coeff_eq2: u64 = 1 + 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2_pos, + &q2_neg, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } else { + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + let prod_a_px = make_products(compiler, &a_limbs, &px.as_slice()[..n]); + + let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + (&prod_a_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2_pos, + &q2_neg, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } + + // Range checks on hint outputs + let crb = carry_range_bits(params.limb_bits, max_coeff_sum, n); + range_check_limbs_and_carries( + range_checks, + &[&x_sq, &q1_pos, &q1_neg, &q2_pos, &q2_neg], + &[&c1, &c2], + params.limb_bits, + crb, + ); + + // Less-than-p check for x_sq + less_than_p_check_vec(compiler, range_checks, &x_sq, params); +} + +/// Hint-verified point doubling for non-native field (multi-limb). +/// +/// Allocates NonNativeEcDoubleHint → (lambda, x3, y3, q1, c1, q2, c2, q3, c3). +/// Verifies via schoolbook column equations on 3 EC equations: +/// Eq1: 2·lambda·py - 3·px² - a = q1·p (2N² products: lam·py, px·px) +/// Eq2: lambda² - x3 - 2·px = q2·p (1N² products: lam·lam) +/// Eq3: lambda·(px - x3) - y3 - py = q3·p (2N² products: lam·px, lam·x3) +/// +/// Total: (12N-6)W hint + 5N²+N products + 3×(2N-1) column constraints +/// + 3 less-than-p checks. +pub fn point_double_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff_sum: u64 = 2 + 3 + 1 + 2 * n as u64; // λy(2) + xx(3) + a(1) + pq_pos(N) + pq_neg(N) + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "Merged EC double"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Double, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: [0; 4], // unused for double + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1_pos = witness_range(os + 3 * n, n); + let q1_neg = witness_range(os + 4 * n, n); + let c1 = witness_range(os + 5 * n, 2 * n - 2); + let q2_pos = witness_range(os + 7 * n - 2, n); + let q2_neg = witness_range(os + 8 * n - 2, n); + let c2 = witness_range(os + 9 * n - 2, 2 * n - 2); + let q3_pos = witness_range(os + 11 * n - 4, n); + let q3_neg = witness_range(os + 12 * n - 4, n); + let c3 = witness_range(os + 13 * n - 4, 2 * n - 2); + + let px_s = &px.as_slice()[..n]; + let py_s = &py.as_slice()[..n]; + + // Eq1: 2*lambda*py - 3*px*px - a = q1*p + let prod_lam_py = make_products(compiler, &lambda, py_s); + let prod_px_px = make_products(compiler, px_s, px_s); + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + + let max_coeff_eq1: u64 = 2 + 3 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_py, FieldElement::from(2u64)), + (&prod_px_px, -FieldElement::from(3u64)), + ], + &[(&a_limbs, -FieldElement::ONE)], + &q1_pos, + &q1_neg, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: lambda² - x3 - 2*px = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + let max_coeff_eq2: u64 = 1 + 1 + 2 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[(&x3, -FieldElement::ONE), (px_s, -FieldElement::from(2u64))], + &q2_pos, + &q2_neg, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + + // Eq3: lambda*px - lambda*x3 - y3 - py = q3*p + let prod_lam_px = make_products(compiler, &lambda, px_s); + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_px, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], + &q3_pos, + &q3_neg, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq3, + ); + + // Range checks on hint outputs + // max_coeff across eqs: Eq1 = 6+2N, Eq2 = 4+2N, Eq3 = 4+2N → worst = 6+2N + let max_coeff_carry = 6u64 + 2 * n as u64; + let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); + range_check_limbs_and_carries( + range_checks, + &[ + &lambda, &x3, &y3, &q1_pos, &q1_neg, &q2_pos, &q2_neg, &q3_pos, &q3_neg, + ], + &[&c1, &c2, &c3], + params.limb_bits, + crb, + ); + + // Less-than-p checks for lambda, x3, y3 + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + (vec_to_limbs(&x3), vec_to_limbs(&y3)) +} + +/// Hint-verified point addition for non-native field (multi-limb). +/// +/// Same approach as `point_double_verified_non_native` but verifies: +/// Eq1: lambda·(x2-x1) - (y2-y1) = q1·p (2N² products: lam·x2, lam·x1) +/// Eq2: lambda² - x3 - x1 - x2 = q2·p (1N² products: lam·lam) +/// Eq3: lambda·(x1-x3) - y3 - y1 = q3·p (1N² products: lam·x3; lam·x1 +/// reused) +/// +/// Total: (12N-6)W hint + 4N² products + 3×(2N-1) column constraints +/// + 3 less-than-p checks. +pub fn point_add_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; // all 3 eqs: 1+1+1+1+2N + check_column_equation_fits(params.limb_bits, max_coeff, n, "EC add"); + + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Add, + inputs: vec![ + x1.as_slice()[..n].to_vec(), + y1.as_slice()[..n].to_vec(), + x2.as_slice()[..n].to_vec(), + y2.as_slice()[..n].to_vec(), + ], + curve_a: [0; 4], // unused for add + curve_b: [0; 4], // unused for add + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1_pos = witness_range(os + 3 * n, n); + let q1_neg = witness_range(os + 4 * n, n); + let c1 = witness_range(os + 5 * n, 2 * n - 2); + let q2_pos = witness_range(os + 7 * n - 2, n); + let q2_neg = witness_range(os + 8 * n - 2, n); + let c2 = witness_range(os + 9 * n - 2, 2 * n - 2); + let q3_pos = witness_range(os + 11 * n - 4, n); + let q3_neg = witness_range(os + 12 * n - 4, n); + let c3 = witness_range(os + 13 * n - 4, 2 * n - 2); + + let x1_s = &x1.as_slice()[..n]; + let y1_s = &y1.as_slice()[..n]; + let x2_s = &x2.as_slice()[..n]; + let y2_s = &y2.as_slice()[..n]; + + // Eq1: lambda*x2 - lambda*x1 - y2 + y1 = q1*p + let prod_lam_x2 = make_products(compiler, &lambda, x2_s); + let prod_lam_x1 = make_products(compiler, &lambda, x1_s); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x2, FieldElement::ONE), + (&prod_lam_x1, -FieldElement::ONE), + ], + &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], + &q1_pos, + &q1_neg, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq2: lambda² - x3 - x1 - x2 = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[ + (&x3, -FieldElement::ONE), + (x1_s, -FieldElement::ONE), + (x2_s, -FieldElement::ONE), + ], + &q2_pos, + &q2_neg, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq3: lambda*x1 - lambda*x3 - y3 - y1 = q3*p + // Reuse prod_lam_x1 from Eq1 + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x1, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], + &q3_pos, + &q3_neg, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Range checks + // max_coeff across all 3 eqs = 4+2N + let max_coeff_carry = 4u64 + 2 * n as u64; + let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); + range_check_limbs_and_carries( + range_checks, + &[ + &lambda, &x3, &y3, &q1_pos, &q1_neg, &q2_pos, &q2_neg, &q3_pos, &q3_neg, + ], + &[&c1, &c2, &c3], + params.limb_bits, + crb, + ); + + // Less-than-p checks + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + (vec_to_limbs(&x3), vec_to_limbs(&y3)) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/mod.rs b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs new file mode 100644 index 000000000..debfb3c14 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs @@ -0,0 +1,51 @@ +//! Elliptic curve point operations for MSM. +//! +//! Submodules: +//! - `generic` — generic EC ops via `MultiLimbOps` abstraction +//! - `tables` — point table construction, lookup, and merged GLV loop +//! - `hints_native` — native-field hint-verified EC ops +//! - `hints_non_native` — non-native hint-verified EC ops (schoolbook) + +mod generic; +mod hints_native; +mod hints_non_native; +mod tables; + +// Re-exports +use super::{multi_limb_ops::MultiLimbOps, Limbs}; +pub use { + generic::{point_add, point_double, point_select_unchecked}, + hints_native::{ + point_add_verified_native, point_double_verified_native, verify_on_curve_native, + }, + hints_non_native::{ + point_add_verified_non_native, point_double_verified_non_native, verify_on_curve_non_native, + }, + tables::{scalar_mul_merged_glv, MergedGlvPoint}, +}; + +/// Dispatching point doubling: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_double_dispatch(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_double_verified_non_native(ops.compiler, ops.range_checks, x1, y1, ops.params) + } else { + point_double(ops, x1, y1) + } +} + +/// Dispatching point addition: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_add_dispatch( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_add_verified_non_native(ops.compiler, ops.range_checks, x1, y1, x2, y2, ops.params) + } else { + point_add(ops, x1, y1, x2, y2) + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/tables.rs b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs new file mode 100644 index 000000000..27f2f7e5a --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs @@ -0,0 +1,259 @@ +//! Point table construction and lookup for windowed scalar multiplication. +//! +//! Builds tables of point multiples and performs lookups using bit witnesses +//! for both unsigned and signed-digit windowed approaches. + +use { + super::{generic::point_select_unchecked, point_add_dispatch, point_double_dispatch}, + crate::msm::{multi_limb_ops::MultiLimbOps, Limbs}, + ark_ff::Field, + provekit_common::{witness::SumTerm, FieldElement}, +}; + +/// Builds a signed point table of odd multiples for signed-digit windowed +/// scalar multiplication. +/// +/// T\[0\] = P, T\[1\] = 3P, T\[2\] = 5P, ..., T\[k-1\] = (2k-1)P +/// where k = `half_table_size` = 2^(w-1). +/// +/// Build cost: 1 point_double (for 2P) + (k-1) point_adds when k >= 2. +fn build_signed_point_table( + ops: &mut MultiLimbOps, + px: Limbs, + py: Limbs, + half_table_size: usize, +) -> Vec<(Limbs, Limbs)> { + assert!(half_table_size >= 1); + let mut table = Vec::with_capacity(half_table_size); + table.push((px, py)); // T[0] = 1*P + if half_table_size >= 2 { + let two_p = point_double_dispatch(ops, px, py); // 2P + for i in 1..half_table_size { + let prev = table[i - 1]; + table.push(point_add_dispatch(ops, prev.0, prev.1, two_p.0, two_p.1)); + } + } + table +} + +/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ +/// bits\[i\] * 2^i`. +/// +/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, +/// halving the candidate set at each level. Total: `(2^w - 1)` point selects +/// for a table of `2^w` entries. +/// +/// When `constrain_bits` is true, each bit is constrained boolean exactly +/// once. When false, bits are assumed already constrained (e.g. XOR'd bits +/// derived from boolean-constrained inputs). +fn table_lookup( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], + bits: &[usize], + constrain_bits: bool, +) -> (Limbs, Limbs) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + if constrain_bits { + ops.constrain_flag(bit); + } + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select_unchecked( + ops, + bit, + current[i], + current[i + half], + )); + } + current = next; + } + current[0] +} + +/// Signed-digit table lookup: selects from a half-size table using XOR'd +/// index bits, then conditionally negates y based on the sign bit. +/// +/// For a w-bit window with bits \[b_0, ..., b_{w-1}\] (LSB first): +/// - sign_bit = b_{w-1} (MSB): 1 = positive digit, 0 = negative digit +/// - index_bits = \[b_0, ..., b_{w-2}\] (lower w-1 bits) +/// - When positive: table index = lower bits as-is +/// - When negative: table index = bitwise complement of lower bits, and y is +/// negated +/// +/// The XOR'd bits are computed as: `idx_i = 1 - b_i - MSB + 2*b_i*MSB`, +/// which equals `b_i` when MSB=1, and `1-b_i` when MSB=0. +/// +/// # Precondition +/// `sign_bit` must be boolean-constrained by the caller. This function uses +/// it in `select_unchecked` without re-constraining. Currently satisfied: +/// `decompose_signed_bits` boolean-constrains all bits including the MSB +/// used as `sign_bit`. +fn signed_table_lookup( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], + index_bits: &[usize], + sign_bit: usize, +) -> (Limbs, Limbs) { + let (x, y) = if index_bits.is_empty() { + // w=1: single entry, no lookup needed + assert_eq!(table.len(), 1); + table[0] + } else { + // Compute XOR'd index bits: idx_i = 1 - b_i - MSB + 2*b_i*MSB + let one_w = ops.compiler.witness_one(); + let two = FieldElement::from(2u64); + let xor_bits: Vec = index_bits + .iter() + .map(|&bit| { + let prod = ops.compiler.add_product(bit, sign_bit); + ops.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::ONE), one_w), + SumTerm(Some(-FieldElement::ONE), bit), + SumTerm(Some(-FieldElement::ONE), sign_bit), + SumTerm(Some(two), prod), + ]) + }) + .collect(); + + // XOR'd bits are boolean by construction (product of two booleans + // combined linearly), so skip redundant boolean constraints. + table_lookup(ops, table, &xor_bits, false) + }; + + // Conditionally negate y: sign_bit=0 (negative) → -y, sign_bit=1 (positive) → y + let neg_y = ops.negate(y); + let eff_y = ops.select_unchecked(sign_bit, neg_y, y); + // select_unchecked(flag, on_false, on_true): + // sign_bit=0 → on_false=neg_y (negative digit, negate y) ✓ + // sign_bit=1 → on_true=y (positive digit, keep y) ✓ + + (x, eff_y) +} + +/// Per-point data for merged multi-point GLV scalar multiplication. +pub struct MergedGlvPoint { + /// Point P x-coordinate (limbs) + pub px: Limbs, + /// Point P y-coordinate (effective, post-negation) + pub py: Limbs, + /// Signed-bit decomposition of |s1| (half-scalar for P), LSB first + pub s1_bits: Vec, + /// Skew correction witness for s1 branch (boolean) + pub s1_skew: usize, + /// Point R x-coordinate (limbs) + pub rx: Limbs, + /// Point R y-coordinate (effective, post-negation) + pub ry: Limbs, + /// Signed-bit decomposition of |s2| (half-scalar for R), LSB first + pub s2_bits: Vec, + /// Skew correction witness for s2 branch (boolean) + pub s2_skew: usize, +} + +/// Merged multi-point GLV scalar multiplication with shared doublings +/// and signed-digit windows. +/// +/// Uses signed-digit encoding: each w-bit window produces a signed odd digit +/// d ∈ {±1, ±3, ..., ±(2^w - 1)}, eliminating zero-digit handling. +/// Tables store odd multiples \[P, 3P, 5P, ..., (2^w-1)P\] with only +/// 2^(w-1) entries (half the unsigned table size). +/// +/// After the main loop, applies skew corrections: if skew=1, subtracts P +/// (or R) to account for the signed decomposition bias. +/// +/// Returns the final accumulator `(x, y)`. +pub fn scalar_mul_merged_glv( + ops: &mut MultiLimbOps, + points: &[MergedGlvPoint], + window_size: usize, + offset_x: Limbs, + offset_y: Limbs, +) -> (Limbs, Limbs) { + assert!(!points.is_empty()); + let n = points[0].s1_bits.len(); + let w = window_size; + let half_table_size = 1usize << (w - 1); + + // Build signed point tables (odd multiples) for all points upfront + let tables: Vec<(Vec<(Limbs, Limbs)>, Vec<(Limbs, Limbs)>)> = points + .iter() + .map(|pt| { + let tp = build_signed_point_table(ops, pt.px, pt.py, half_table_size); + let tr = build_signed_point_table(ops, pt.rx, pt.ry, half_table_size); + (tp, tr) + }) + .collect(); + + let num_windows = (n + w - 1) / w; + let mut acc = (offset_x, offset_y); + + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; + + // w shared doublings on the accumulator (shared across ALL points) + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = point_double_dispatch(ops, doubled_acc.0, doubled_acc.1); + } + + let mut cur = doubled_acc; + + // For each point: P branch + R branch (signed-digit lookup) + for (pt, (table_p, table_r)) in points.iter().zip(tables.iter()) { + // --- P branch (s1 window) --- + let s1_window_bits = &pt.s1_bits[bit_start..bit_end]; + let sign_bit_p = s1_window_bits[actual_w - 1]; // MSB + let index_bits_p = &s1_window_bits[..actual_w - 1]; // lower bits + let actual_table_p = if actual_w < w { + &table_p[..1 << (actual_w - 1)] + } else { + &table_p[..] + }; + let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); + // All signed digits are non-zero — no is_zero check needed + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); + + // --- R branch (s2 window) --- + let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; + let sign_bit_r = s2_window_bits[actual_w - 1]; // MSB + let index_bits_r = &s2_window_bits[..actual_w - 1]; // lower bits + let actual_table_r = if actual_w < w { + &table_r[..1 << (actual_w - 1)] + } else { + &table_r[..] + }; + let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); + } + + acc = cur; + } + + // Skew corrections: subtract P (or R) if skew=1 for each point. + // The signed decomposition gives: scalar = Σ d_i * 2^i - skew, + // so the main loop computed (scalar + skew) * P. If skew=1, subtract P. + for pt in points { + // P branch skew + let neg_py = ops.negate(pt.py); + let (sub_px, sub_py) = point_add_dispatch(ops, acc.0, acc.1, pt.px, neg_py); + let new_x = ops.select_unchecked(pt.s1_skew, acc.0, sub_px); + let new_y = ops.select_unchecked(pt.s1_skew, acc.1, sub_py); + acc = (new_x, new_y); + + // R branch skew + let neg_ry = ops.negate(pt.ry); + let (sub_rx, sub_ry) = point_add_dispatch(ops, acc.0, acc.1, pt.rx, neg_ry); + let new_x = ops.select_unchecked(pt.s2_skew, acc.0, sub_rx); + let new_y = ops.select_unchecked(pt.s2_skew, acc.1, sub_ry); + acc = (new_x, new_y); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/limbs.rs b/provekit/r1cs-compiler/src/msm/limbs.rs new file mode 100644 index 000000000..46d3bfdb5 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/limbs.rs @@ -0,0 +1,96 @@ +//! `Limbs`: fixed-capacity, `Copy` array of witness indices. + +// --------------------------------------------------------------------------- +// Limbs: fixed-capacity, Copy array of witness indices +// --------------------------------------------------------------------------- + +/// Maximum number of limbs supported. Covers all practical field sizes +/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). +pub const MAX_LIMBS: usize = 32; + +/// A fixed-capacity array of witness indices, indexed by limb position. +/// +/// This type is `Copy`, so it can be passed by value without requiring +/// const generics or dispatch macros. The runtime `len` field tracks how +/// many limbs are actually in use. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, +} + +impl Limbs { + /// Sentinel value for uninitialized limb slots. Using `usize::MAX` + /// ensures accidental use of an unfilled slot indexes an absurdly + /// large witness, causing an immediate out-of-bounds panic. + const UNINIT: usize = usize::MAX; + + /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. + pub fn new(len: usize) -> Self { + assert!( + len > 0 && len <= MAX_LIMBS, + "limb count must be 1..={MAX_LIMBS}, got {len}" + ); + Self { + data: [Self::UNINIT; MAX_LIMBS], + len, + } + } + + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self { + data: [Self::UNINIT; MAX_LIMBS], + len: 1, + }; + l.data[0] = value; + l + } + + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] + } + + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len + } +} + +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() + } +} + +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } +} +impl Eq for Limbs {} + +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] + } +} + +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] + } +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs new file mode 100644 index 000000000..19c161a83 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -0,0 +1,303 @@ +pub mod cost_model; +pub mod curve; +pub(crate) mod ec_points; +mod limbs; +pub(crate) mod multi_limb_arith; +pub(crate) mod multi_limb_ops; +mod native; +mod non_native; +mod sanitize; +mod scalar_relation; +#[cfg(test)] +mod tests; + +pub use limbs::{Limbs, MAX_LIMBS}; +use { + crate::{ + constraint_helpers::{add_constant_witness, constrain_boolean}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field, PrimeField}, + curve::CurveParams, + provekit_common::{ + witness::{ConstantOrR1CSWitness, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// MSM entry point +// --------------------------------------------------------------------------- + +/// Processes all deferred MSM operations. +/// +/// Internally selects the optimal (limb_bits, window_size) via cost model +/// and uses Grumpkin curve parameters. +/// +/// Each entry is `(points, scalars, (out_x, out_y, out_inf))` where: +/// - `points` has layout `[x1, y1, inf1, x2, y2, inf2, ...]` (3 per point) +/// - `scalars` has layout `[s1_lo, s1_hi, s2_lo, s2_hi, ...]` (2 per scalar) +/// - outputs are the R1CS witness indices for the result point +/// Grumpkin-specific MSM entry point (used by the Noir `MultiScalarMul` black +/// box). +pub fn add_msm( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + range_checks: &mut BTreeMap>, +) { + let curve = curve::grumpkin_params(); + add_msm_with_curve(compiler, msm_ops, range_checks, &curve); +} + +/// Curve-agnostic MSM: compiles MSM operations for any curve described by +/// `curve`. +pub fn add_msm_with_curve( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + if msm_ops.is_empty() { + return; + } + + let native_bits = provekit_common::FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let is_native = curve.is_native_field(); + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); + let (limb_bits, window_size) = + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256, is_native); + + for (points, scalars, outputs) in msm_ops { + add_single_msm( + compiler, + &points, + &scalars, + outputs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +/// Processes a single MSM operation. +fn add_single_msm( + compiler: &mut NoirToR1CSCompiler, + points: &[ConstantOrR1CSWitness], + scalars: &[ConstantOrR1CSWitness], + outputs: (usize, usize, usize), + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + assert!( + points.len() % 3 == 0, + "points length must be a multiple of 3" + ); + let n = points.len() / 3; + assert_eq!( + scalars.len(), + 2 * n, + "scalars length must be 2x the number of points" + ); + + // Resolve all inputs to witness indices + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + let is_native = curve.is_native_field(); + let num_limbs = if is_native { + 1 + } else { + (curve.modulus_bits() as usize + limb_bits as usize - 1) / limb_bits as usize + }; + + let n_points = point_wits.len() / 3; + if curve.is_native_field() { + native::process_multi_point_native( + compiler, + &point_wits, + &scalar_wits, + outputs, + n_points, + range_checks, + curve, + ); + } else { + non_native::process_multi_point_non_native( + compiler, + &point_wits, + &scalar_wits, + outputs, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +/// Resolves a `ConstantOrR1CSWitness` to a witness index. +fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { + match input { + ConstantOrR1CSWitness::Witness(idx) => *idx, + ConstantOrR1CSWitness::Constant(value) => add_constant_witness(compiler, *value), + } +} + +// --------------------------------------------------------------------------- +// Multi-limb MSM interface (for non-native curves with coords > BN254_Fr) +// --------------------------------------------------------------------------- + +/// MSM outputs when coordinates are in multi-limb form. +pub struct MsmLimbedOutputs { + pub out_x_limbs: Vec, + pub out_y_limbs: Vec, + pub out_inf: usize, +} + +/// Multi-limb MSM entry point for non-native curves. +/// +/// Point coordinates are provided as limbs, avoiding truncation when +/// values exceed BN254_Fr. Each point uses stride `2*num_limbs + 1`: +/// `[px_l0..px_lN-1, py_l0..py_lN-1, inf]`. +/// +/// Scalars remain as `[s_lo, s_hi]` pairs (128-bit halves fit in BN254_Fr). +/// Outputs are per-limb: `MsmLimbedOutputs` with N limbs for each coordinate. +pub fn add_msm_with_curve_limbed( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + MsmLimbedOutputs, + )>, + range_checks: &mut BTreeMap>, + curve: &CurveParams, + num_limbs: usize, +) { + assert!( + !curve.is_native_field(), + "limbed MSM is only for non-native curves" + ); + if msm_ops.is_empty() { + return; + } + + let native_bits = provekit_common::FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let stride = 2 * num_limbs + 1; + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / stride).sum(); + let (limb_bits, window_size) = + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256, false); + + // Verify num_limbs matches what cost model produces + let expected_num_limbs = (curve_bits as usize + limb_bits as usize - 1) / limb_bits as usize; + assert_eq!( + num_limbs, expected_num_limbs, + "num_limbs mismatch: caller passed {num_limbs}, cost model expects {expected_num_limbs}" + ); + + for (points, scalars, outputs) in msm_ops { + assert!( + points.len() % stride == 0, + "points length must be a multiple of {stride} (2*{num_limbs}+1)" + ); + let n = points.len() / stride; + assert_eq!(scalars.len(), 2 * n, "scalars length must be 2x n_points"); + assert_eq!(outputs.out_x_limbs.len(), num_limbs); + assert_eq!(outputs.out_y_limbs.len(), num_limbs); + + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + non_native::process_multi_point_non_native_limbed( + compiler, + &point_wits, + &scalar_wits, + &outputs, + n, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +// --------------------------------------------------------------------------- +// Signed-bit decomposition (shared by native and non-native paths) +// --------------------------------------------------------------------------- + +/// Signed-bit decomposition for wNAF scalar multiplication. +/// +/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} +/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: +/// scalar = Σ d_i * 2^i - skew +/// +/// Reconstruction constraint (1 linear R1CS): +/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// +/// All bits and skew are boolean-constrained. +/// +/// # Limitation +/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower +/// 128 bits of the field element). This is correct for FakeGLV half-scalars +/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` +/// exceeds 128. The R1CS reconstruction constraint would then fail. +pub(crate) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let two = FieldElement::from(2u64); + let constant = two.pow([num_bits as u64]) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs new file mode 100644 index 000000000..f0163b88e --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -0,0 +1,686 @@ +//! N-limb modular arithmetic for EC field operations. +//! +//! Replaces both `ec_ops.rs` (N=1 path) and `wide_ops.rs` (N>1 path) with +//! unified multi-limb operations using `Limbs` (runtime-sized, Copy). + +use { + super::Limbs, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// N=1 single-limb path (moved from ec_ops.rs) +// --------------------------------------------------------------------------- + +/// Reduce the value to given modulus (N=1 path). +/// Computes v = k*m + result, where 0 <= result < m. +pub fn reduce_mod_p( + compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + let k = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + let k_mul_m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks.entry(modulus_bits).or_default().push(result); + + result +} + +/// a + b mod p (N=1 path) +pub fn add_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p (N=1 path) +pub fn mul_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod_p(compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod p (N=1 path) +pub fn sub_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod_p(compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod p (N=1 path) +pub fn inv_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + let reduced = mul_mod_p_single(compiler, a, a_inv, modulus, range_checks); + + // Constrain reduced = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, compiler.witness_one()), + ], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + let mod_bits = modulus.into_bigint().num_bits(); + range_checks.entry(mod_bits).or_default().push(a_inv); + + a_inv +} + +/// Checks if value is zero or not (used by all N values). +/// Returns a boolean witness: 1 if zero, 0 if non-zero. +/// +/// Uses SafeInverse (not Inverse) because the input value may be zero. +/// SafeInverse outputs 0 when the input is 0, and is solved in the Other +/// layer (not batch-inverted), so zero inputs don't poison the batch. +pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SafeInverse(value_inv, value)); + + let value_mul_value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + + let is_zero = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(is_zero, vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ])); + + // v × v^(-1) = 1 - is_zero + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // v × is_zero = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + is_zero +} + +// --------------------------------------------------------------------------- +// N≥2 multi-limb path (generalization of wide_ops.rs) +// --------------------------------------------------------------------------- + +/// Shared core for `add_mod_p_multi` and `sub_mod_p_multi`. +/// +/// Both operations follow the same carry-chain structure; only the per-limb +/// formula and quotient witness builder differ: +/// add: v\[i\] = a\[i\] + b\[i\] + 2^W - q·p\[i\] + carry (q via +/// MultiLimbAddQuotient) sub: v\[i\] = a\[i\] - b\[i\] + q·p\[i\] + 2^W + +/// carry (q via MultiLimbSubBorrow) +/// +/// When `carry_prev` is present, the w1 coefficient absorbs a `-1` term +/// (i.e. `w1_coeff = ... + 2^W - 1`) so that `carry_prev` contributes +/// `+1 · carry_prev` without creating duplicate (row, col) entries. +fn add_sub_mod_p_core( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + is_add: bool, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "add/sub_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q ∈ {0, 1} + let q = compiler.num_witnesses(); + if is_add { + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + } else { + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + } + // q is boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; + // add: a[i] + b[i] + w1_coeff*1 - p[i]*q + carry + // sub: a[i] - b[i] + p[i]*q + w1_coeff*1 + carry + let mut terms = vec![SumTerm(None, a[i])]; + if is_add { + terms.push(SumTerm(None, b[i])); + terms.push(SumTerm(Some(w1_coeff), w1)); + terms.push(SumTerm(Some(-p_limbs[i]), q)); + } else { + terms.push(SumTerm(Some(-FieldElement::ONE), b[i])); + terms.push(SumTerm(Some(p_limbs[i]), q)); + terms.push(SumTerm(Some(w1_coeff), w1)); + } + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + } + let v_offset = compiler.add_sum(terms); + + // carry = floor(v_offset / 2^W) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); + // r[i] = v_offset - carry * 2^W + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); + + r +} + +/// (a + b) mod p for multi-limb values. +/// +/// Per limb i: v_i = a\[i\] + b\[i\] + 2^W - q*p\[i\] + carry_{i-1} +/// carry_i = floor(v_i / 2^W) +/// r\[i\] = v_i - carry_i * 2^W +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + add_sub_mod_p_core( + compiler, + range_checks, + true, + a, + b, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ) +} + +/// Negate a multi-limb value: computes `p - y` directly via borrow chain. +/// +/// Since inputs are already verified `y ∈ [0, p)` (from less_than_p on +/// prior operations), the result `p - y` is in `(0, p]`. For `y > 0`, +/// the result is in `(0, p)` — canonical. For `y = 0`, the result is `p ≡ 0`, +/// which has valid limbs (each < 2^limb_bits) and is correct modulo p. +/// +/// This avoids the generic `sub_mod_p_multi` pathway which allocates +/// N zero-constant witnesses, a borrow quotient, and a less_than_p check. +/// +/// Witnesses: 3N (N v-sums + N borrows + N result limbs). +/// Constraints: N (one `add_sum` per limb, each producing 1W+1C). +/// Range checks: N at limb_bits. +pub fn negate_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + y: Limbs, + p_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) -> Limbs { + let n = y.len(); + assert!(n >= 2, "negate_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + let mut r = Limbs::new(n); + let mut borrow_prev: Option = None; + + for i in 0..n { + // v[i] = p[i] + 2^W - y[i] + borrow_{i-1} + // The 2^W offset ensures v[i] >= 0 (since p[i] >= 0, y[i] < 2^W). + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix. + let w1_coeff = if borrow_prev.is_some() { + p_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_limbs[i] + two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), y[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + let v = compiler.add_sum(terms); + + // borrow[i] = floor(v[i] / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v, two_pow_w)); + + // r[i] = v[i] - borrow[i] * 2^W + r[i] = compiler.add_sum(vec![SumTerm(None, v), SumTerm(Some(-two_pow_w), borrow)]); + + // Range check r[i] — ensures borrow is uniquely determined + range_checks.entry(limb_bits).or_default().push(r[i]); + + borrow_prev = Some(borrow); + } + + r +} + +/// (a - b) mod p for multi-limb values. +pub fn sub_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + add_sub_mod_p_core( + compiler, + range_checks, + false, + a, + b, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ) +} + +/// (a * b) mod p for multi-limb values using schoolbook multiplication. +/// +/// Verifies: a·b = p·q + r in base W = 2^limb_bits. +/// Column k: Σ_{i+j=k} a\[i\]*b\[j\] + carry_{k-1} + OFFSET +/// = Σ_{i+j=k} p\[i\]*q\[j\] + r\[k\] + carry_k * W +pub fn mul_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "mul_mod_p_multi requires n >= 2, got n={n}"); + + // Soundness check: column equation values must not overflow the native field. + // The maximum value across either side of any column equation is bounded by + // 2^(2*limb_bits + ceil(log2(n)) + 3). This must be strictly less than the + // native field modulus p >= 2^(MODULUS_BIT_SIZE - 1). + { + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let max_bits = 2 * limb_bits + ceil_log2_n + 3; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs requires \ + {max_bits} bits, but native field is only {} bits. Use smaller limb_bits.", + FieldElement::MODULUS_BIT_SIZE, + ); + } + + let w1 = compiler.witness_one(); + let num_carries = 2 * n - 2; + // Carry offset: 2^(limb_bits + ceil(log2(n)) + 1) + let extra_bits = ((n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + // offset_w = carry_offset * 2^limb_bits + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits + // - 1) + let offset_w_minus_carry = offset_w - carry_offset_fe; + + // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { + output_start: os, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + + // q[0..n), r[n..2n), carries[2n..4n-2) + let q: Vec = (0..n).map(|i| os + i).collect(); + let r_indices: Vec = (0..n).map(|i| os + n + i).collect(); + let cu: Vec = (0..num_carries).map(|i| os + 2 * n + i).collect(); + + // Step 2: Product witnesses for a[i]*b[j] (n² R1CS constraints) + let mut ab_products = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + ab_products[i][j] = compiler.add_product(a[i], b[j]); + } + } + + // Step 3: Column equations (2n-1 R1CS constraints) + for k in 0..(2 * n - 1) { + // LHS: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((FieldElement::ONE, ab_products[i][j_val as usize])); + } + } + // Add carry_{k-1} + if k > 0 { + lhs_terms.push((FieldElement::ONE, cu[k - 1])); + // Add offset_w - carry_offset for subsequent columns + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + // First column: add offset_w + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q[j_val as usize])); + } + } + if k < n { + rhs_terms.push((FieldElement::ONE, r_indices[k])); + } + if k < 2 * n - 2 { + rhs_terms.push((two_pow_w, cu[k])); + } else { + // Last column: RHS includes offset_w to balance the LHS offset + // LHS has: carry[k-1] + offset_w_minus_carry = true_carry + offset_w + // RHS needs: sum_pq[k] + offset_w (no outgoing carry at last column) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } + + // Step 4: less-than-p check and range checks on r + let mut r_limbs = Limbs::new(n); + for (i, &ri) in r_indices.iter().enumerate() { + r_limbs[i] = ri; + } + less_than_p_check_multi( + compiler, + range_checks, + r_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); + + // Step 5: Range checks for q limbs and carries + for i in 0..n { + range_checks.entry(limb_bits).or_default().push(q[i]); + } + // Carry range: limb_bits + extra_bits + 1 (carry_offset_bits + 1) + let carry_range_bits = carry_offset_bits + 1; + for &c in &cu { + range_checks.entry(carry_range_bits).or_default().push(c); + } + + r_limbs +} + +/// a^(-1) mod p for multi-limb values. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, +/// ..., 0]. +pub fn inv_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "inv_mod_p_multi requires n >= 2, got n={n}"); + + // Hint: compute inverse + let inv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { + output_start: inv_start, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + let mut inv = Limbs::new(n); + for i in 0..n { + inv[i] = inv_start + i; + range_checks.entry(limb_bits).or_default().push(inv[i]); + } + + // Verify: a * inv mod p = [1, 0, ..., 0] + let product = mul_mod_p_multi( + compiler, + range_checks, + a, + inv, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ); + + // Constrain product[0] = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[0])], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // Constrain product[1..n] = 0 + for i in 1..n { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[i])], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + } + + inv +} + +/// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. +/// Uses borrow propagation: d\[i\] = (p-1)\[i\] + 2^W - r\[i\] + borrow_in - +/// borrow_out · 2^W. The 2^W offset keeps intermediate values non-negative. +pub fn less_than_p_check_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r: Limbs, + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) { + let n = r.len(); + let w1 = compiler.witness_one(); + let mut borrow_prev: Option = None; + for i in 0..n { + // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if borrow_prev.is_some() { + p_minus_1_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_minus_1_limbs[i] + two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), r[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + let v_diff = compiler.add_sum(terms); + + // borrow = floor(v_diff / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v_diff, two_pow_w)); + // d[i] = v_diff - borrow * 2^W + let d_i = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_pow_w), borrow), + ]); + + // Range check r[i] and d[i] + range_checks.entry(limb_bits).or_default().push(r[i]); + range_checks.entry(limb_bits).or_default().push(d_i); + + borrow_prev = Some(borrow); + } + + // Constrain final carry = 1: the 2^W offset at each limb propagates + // a carry of 1 through the chain. For valid r < p, the final carry + // must be exactly 1. If r >= p, the carry chain underflows and the + // final carry is 0. + if let Some(final_borrow) = borrow_prev { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, final_borrow)], + &[(FieldElement::ONE, compiler.witness_one())], + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs new file mode 100644 index 000000000..3fccdbd5c --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -0,0 +1,365 @@ +//! `MultiLimbOps` — field arithmetic parameterized by runtime limb count. +//! +//! Uses `Limbs` (a fixed-capacity Copy type) as the element representation, +//! enabling arbitrary limb counts without const generics or dispatch macros. + +use { + super::{multi_limb_arith, Limbs}, + crate::{ + constraint_helpers::{constrain_boolean, pack_bits_helper, select_witness}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Parameters for multi-limb field arithmetic. +pub struct MultiLimbParams { + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + /// Raw curve_a value as [u64; 4] for hint-verified EC ops + pub curve_a_raw: [u64; 4], + pub curve_b_limbs: Vec, + /// Raw curve_b value as [u64; 4] for hint-verified on-curve checks + pub curve_b_raw: [u64; 4], + /// p = native field → skip mod reduction + pub is_native: bool, + /// For N=1 non-native: the modulus as a single FieldElement + pub modulus_fe: Option, +} + +impl MultiLimbParams { + /// Build params for EC field operations (mod field_modulus_p). + pub fn for_field_modulus( + num_limbs: usize, + limb_bits: u32, + curve: &super::curve::CurveParams, + ) -> Self { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + Self { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + curve_a_raw: curve.curve_a, + curve_b_limbs: curve.curve_b_limbs(limb_bits, num_limbs), + curve_b_raw: curve.curve_b, + is_native, + modulus_fe, + } + } + + /// Build params for scalar relation verification (mod curve_order_n). + pub fn for_curve_order( + num_limbs: usize, + limb_bits: u32, + curve: &super::curve::CurveParams, + ) -> Self { + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if num_limbs == 1 { + Some(super::curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; + Self { + num_limbs, + limb_bits, + p_limbs: curve.curve_order_n_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + curve_a_raw: [0u64; 4], // unused for scalar relation + curve_b_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + curve_b_raw: [0u64; 4], // unused for scalar relation + is_native: false, /* always non-native for + * scalar relation */ + modulus_fe, + } + } +} + +/// Unified field operations struct parameterized by runtime limb count. +pub struct MultiLimbOps<'a, 'p> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: &'p MultiLimbParams, +} + +impl MultiLimbOps<'_, '_> { + fn is_native_single(&self) -> bool { + self.params.num_limbs == 1 && self.params.is_native + } + + fn is_non_native_single(&self) -> bool { + self.params.num_limbs == 1 && !self.params.is_native + } + + fn n(&self) -> usize { + self.params.num_limbs + } + + /// Negate a multi-limb value: computes `p - value (mod p)`. + /// + /// For multi-limb non-native (N≥2), uses a dedicated borrow chain that + /// skips zero-constant allocation, the borrow-quotient hint, and the + /// less_than_p check — saving (4N+1) witnesses per call. + pub fn negate(&mut self, value: Limbs) -> Limbs { + if self.params.num_limbs >= 2 && !self.params.is_native { + multi_limb_arith::negate_mod_p_multi( + self.compiler, + self.range_checks, + value, + &self.params.p_limbs, + self.params.two_pow_w, + self.params.limb_bits, + ) + } else { + let zero_vals = vec![FieldElement::from(0u64); self.params.num_limbs]; + let zero = self.constant_limbs(&zero_vals); + self.sub(zero, value) + } + } + + pub fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, merge into a single + // term with coefficient 2 to avoid duplicate column indices in + // the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let r = if a[0] == b[0] { + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::from(2u64)), a[0])]) + } else { + self.compiler + .add_sum(vec![SumTerm(None, a[0]), SumTerm(None, b[0])]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::add_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::add_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + pub fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, a - a = 0. Use a + // single zero-coefficient term to avoid duplicate column indices. + let r = if a[0] == b[0] { + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::ZERO), a[0])]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(Some(-FieldElement::ONE), b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::sub_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::sub_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + pub fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + let r = self.compiler.add_product(a[0], b[0]); + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::mul_mod_p_single( + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::mul_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + pub fn inv(&mut self, a: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + if self.is_native_single() { + let a_inv = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Inverse(a_inv, a[0])); + // a * a_inv = 1 + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a[0])], + &[(FieldElement::ONE, a_inv)], + &[(FieldElement::ONE, self.compiler.witness_one())], + ); + Limbs::single(a_inv) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = + multi_limb_arith::inv_mod_p_single(self.compiler, a[0], modulus, self.range_checks); + Limbs::single(r) + } else { + multi_limb_arith::inv_mod_p_multi( + self.compiler, + self.range_checks, + a, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + pub fn curve_a(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + let value = self.params.curve_a_limbs[i]; + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: prevent malicious prover from choosing a different curve_a + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, self.compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, self.compiler.witness_one())], + ); + out[i] = w; + } + out + } + + /// Constrains `flag` to be boolean (`flag * flag = flag`). + pub fn constrain_flag(&mut self, flag: usize) { + constrain_boolean(self.compiler, flag); + } + + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + pub fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + out[i] = select_witness(self.compiler, flag, on_false[i], on_true[i]); + } + out + } + + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean. + pub fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } + + /// Checks if a native witness value is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + pub fn is_zero(&mut self, value: usize) -> usize { + multi_limb_arith::compute_is_zero(self.compiler, value) + } + + /// Packs bit witnesses into a single digit witness: `d = Σ bits\[i\] * + /// 2^i`. Does NOT constrain bits to be boolean — caller must ensure + /// that. + pub fn pack_bits(&mut self, bits: &[usize]) -> usize { + pack_bits_helper(self.compiler, bits) + } + + /// Returns a constant field element from its limb decomposition. + pub fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { + let n = self.n(); + assert_eq!( + limbs.len(), + n, + "constant_limbs: expected {n} limbs, got {}", + limbs.len() + ); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, limbs[i]))); + // Pin: prevent malicious prover from altering constant values + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, self.compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(limbs[i], self.compiler.witness_one())], + ); + out[i] = w; + } + out + } +} diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs new file mode 100644 index 000000000..62f9374e2 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -0,0 +1,311 @@ +//! Native-field MSM path: hint-verified EC ops with signed-bit wNAF. +//! +//! Used when `curve.is_native_field()` — the curve's base field matches the +//! R1CS native field (e.g. Grumpkin over BN254). Field operations are free +//! (single native witness each), so EC operations use prover hints verified +//! via raw R1CS constraints instead of expensive field inversions. +//! +//! ## Key techniques +//! +//! - **Hint-verified EC ops**: point_double (4W+4C), point_add (3W+3C), and +//! on-curve checks (2W+3C) use prover-supplied (lambda, x3, y3) hints +//! verified by the EC equations directly as R1CS constraints. +//! - **FakeGLV**: each 256-bit scalar is decomposed into two ~128-bit +//! half-scalars (s1, s2) via half-GCD, verified by a scalar relation mod the +//! curve order. +//! - **Signed-bit wNAF (w=1)**: each half-scalar is decomposed into signed bits +//! d_i ∈ {-1, +1}, eliminating zero-digit handling. A skew bit corrects the +//! bias post-loop. +//! - **Merged doubling**: all points share a single doubling per bit, saving 4W +//! per extra point per bit. +//! +//! ## Phases +//! +//! 1. **Preprocessing**: per-point sanitization (degenerate-case replacement), +//! on-curve checks, FakeGLV decomposition, signed-bit decomposition, +//! y-negation based on sign flags. +//! 2. **Merged scalar mul**: single loop from MSB to LSB with one shared +//! doubling + per-point P/R branch adds. Skew corrections after loop. +//! Identity check: final accumulator must equal the known offset. +//! 3. **Scalar relations**: per-point verification that (-1)^neg1·|s1| + +//! (-1)^neg2·|s2|·s ≡ 0 (mod curve_order). +//! 4. **Accumulation**: adds each point's scalar-mul result to an accumulator +//! (skipping degenerate points), subtracts the offset, constrains outputs. + +use { + super::{ + curve, ec_points, + sanitize::{ + emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, negate_y_signed_native, + sanitize_point_scalar, + }, + scalar_relation, + }, + crate::{ + constraint_helpers::{ + add_constant_witness, constrain_equal, constrain_to_constant, select_witness, + }, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::AdditiveGroup, + curve::CurveParams, + provekit_common::FieldElement, + std::collections::BTreeMap, +}; + +/// Per-point preprocessed data for the merged native scalar mul loop. +/// +/// Holds the inputs needed by `scalar_mul_merged_native_wnaf` to process +/// one point's P and R branches inside the shared-doubling loop. +struct NativePointData { + px: usize, + py_eff: usize, + neg_py_eff: usize, + s1_bits: Vec, + s1_skew: usize, + rx: usize, + ry_eff: usize, + neg_ry_eff: usize, + s2_bits: Vec, + s2_skew: usize, +} + +/// Native-field MSM with merged-loop optimization. +/// +/// All points share a single doubling per bit, saving 4*(n-1) constraints +/// per bit of the half-scalar. +pub(super) fn process_multi_point_native( + compiler: &mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + n_points: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + let (out_x, out_y, out_inf) = outputs; + let one = compiler.witness_one(); + + // Generator constants for sanitization + let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); + let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); + let gen_x_witness = add_constant_witness(compiler, gen_x_fe); + let gen_y_witness = add_constant_witness(compiler, gen_y_fe); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + let mut all_skipped: Option = None; + let mut native_points: Vec = Vec::new(); + let mut scalar_rel_inputs: Vec<(usize, usize, usize, usize, usize, usize)> = Vec::new(); + let mut accum_inputs: Vec<(usize, usize, usize)> = Vec::new(); + + // Phase 1: Per-point preprocessing + for i in 0..n_points { + let san = sanitize_point_scalar( + compiler, + point_wits[3 * i], + point_wits[3 * i + 1], + scalar_wits[2 * i], + scalar_wits[2 * i + 1], + point_wits[3 * i + 2], + gen_x_witness, + gen_y_witness, + zero_witness, + one, + ); + + all_skipped = Some(match all_skipped { + None => san.is_skip, + Some(prev) => compiler.add_product(prev, san.is_skip), + }); + + let (sanitized_rx, sanitized_ry) = emit_ec_scalar_mul_hint_and_sanitize( + compiler, + &san, + gen_x_witness, + gen_y_witness, + curve, + ); + + // On-curve checks + ec_points::verify_on_curve_native(compiler, san.px, san.py, curve); + ec_points::verify_on_curve_native(compiler, sanitized_rx, sanitized_ry, curve); + + // FakeGLV decomposition + signed-bit decomposition + let (s1, s2, neg1, neg2) = emit_fakeglv_hint(compiler, san.s_lo, san.s_hi, curve); + let half_bits = curve.glv_half_bits() as usize; + let (s1_bits, s1_skew) = super::decompose_signed_bits(compiler, s1, half_bits); + let (s2_bits, s2_skew) = super::decompose_signed_bits(compiler, s2, half_bits); + + // Y-negation + let (py_eff, neg_py_eff) = negate_y_signed_native(compiler, neg1, san.py); + let (ry_eff, neg_ry_eff) = negate_y_signed_native(compiler, neg2, sanitized_ry); + + native_points.push(NativePointData { + px: san.px, + py_eff, + neg_py_eff, + s1_bits, + s1_skew, + rx: sanitized_rx, + ry_eff, + neg_ry_eff, + s2_bits, + s2_skew, + }); + + scalar_rel_inputs.push((san.s_lo, san.s_hi, s1, s2, neg1, neg2)); + accum_inputs.push((sanitized_rx, sanitized_ry, san.is_skip)); + } + + // Phase 2: Per-point scalar mul verification + // + // Each point gets its own accumulator and identity check. This ensures + // per-point soundness: b_i * (R_i - scalar_i * P_i) = O with b_i ≠ 0 + // implies R_i = scalar_i * P_i for each point independently. + let half_bits = curve.glv_half_bits() as usize; + let offset_x_fe = curve::curve_native_point_fe(&curve.offset_point.0); + let offset_y_fe = curve::curve_native_point_fe(&curve.offset_point.1); + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); + let acc_off_x_fe = curve::curve_native_point_fe(&acc_off_x_raw); + let acc_off_y_fe = curve::curve_native_point_fe(&acc_off_y_raw); + + for pt in &native_points { + let offset_x = add_constant_witness(compiler, offset_x_fe); + let offset_y = add_constant_witness(compiler, offset_y_fe); + + let (ver_acc_x, ver_acc_y) = scalar_mul_merged_native_wnaf( + compiler, + std::slice::from_ref(pt), + offset_x, + offset_y, + curve, + ); + + constrain_to_constant(compiler, ver_acc_x, acc_off_x_fe); + constrain_to_constant(compiler, ver_acc_y, acc_off_y_fe); + } + + // Phase 3: Per-point scalar relations + for &(s_lo, s_hi, s1, s2, neg1, neg2) in &scalar_rel_inputs { + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1, + s2, + neg1, + neg2, + curve, + ); + } + + // Phase 4: Accumulation (same offset-based logic) + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + let mut acc_x = add_constant_witness(compiler, offset_x_fe); + let mut acc_y = add_constant_witness(compiler, offset_y_fe); + + for &(sanitized_rx, sanitized_ry, is_skip) in &accum_inputs { + let (cand_x, cand_y) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + sanitized_rx, + sanitized_ry, + curve, + ); + acc_x = select_witness(compiler, is_skip, cand_x, acc_x); + acc_y = select_witness(compiler, is_skip, cand_y, acc_y); + } + + // Offset subtraction and output constraining + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_fe = curve::curve_native_point_fe(&neg_offset_y_raw); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_fe = curve::curve_native_point_fe(&neg_gen_y_raw); + + let sub_x_off = add_constant_witness(compiler, offset_x_fe); + let sub_x = select_witness(compiler, all_skipped, sub_x_off, gen_x_witness); + + let neg_off_y_w = add_constant_witness(compiler, neg_offset_y_fe); + let neg_g_y_w = add_constant_witness(compiler, neg_gen_y_fe); + let sub_y = select_witness(compiler, all_skipped, neg_off_y_w, neg_g_y_w); + + let (result_x, result_y) = + ec_points::point_add_verified_native(compiler, acc_x, acc_y, sub_x, sub_y, curve); + + let masked_result_x = select_witness(compiler, all_skipped, result_x, zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, result_y, zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + constrain_equal(compiler, out_inf, all_skipped); +} + +/// Multi-point scalar multiplication for native field using signed-bit wNAF +/// (w=1). +/// +/// Called once per point with a single-element slice for per-point +/// soundness (each point gets its own accumulator and identity check). +/// Each bit costs: 4C (double) + 2 × (1C select + 3C add) = 12C per point. +fn scalar_mul_merged_native_wnaf( + compiler: &mut NoirToR1CSCompiler, + points: &[NativePointData], + offset_x: usize, + offset_y: usize, + curve: &CurveParams, +) -> (usize, usize) { + let n = points[0].s1_bits.len(); + let mut acc_x = offset_x; + let mut acc_y = offset_y; + + // wNAF loop: MSB to LSB, shared doubling + for i in (0..n).rev() { + // Single shared double + let (dx, dy) = ec_points::point_double_verified_native(compiler, acc_x, acc_y, curve); + let mut cur_x = dx; + let mut cur_y = dy; + + // For each point: P branch + R branch + for pt in points { + let sel_py = select_witness(compiler, pt.s1_bits[i], pt.neg_py_eff, pt.py_eff); + let (ax, ay) = + ec_points::point_add_verified_native(compiler, cur_x, cur_y, pt.px, sel_py, curve); + + let sel_ry = select_witness(compiler, pt.s2_bits[i], pt.neg_ry_eff, pt.ry_eff); + (cur_x, cur_y) = + ec_points::point_add_verified_native(compiler, ax, ay, pt.rx, sel_ry, curve); + } + + acc_x = cur_x; + acc_y = cur_y; + } + + // Skew corrections for all points + for pt in points { + let (sub_px, sub_py) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + pt.px, + pt.neg_py_eff, + curve, + ); + acc_x = select_witness(compiler, pt.s1_skew, acc_x, sub_px); + acc_y = select_witness(compiler, pt.s1_skew, acc_y, sub_py); + + let (sub_rx, sub_ry) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + pt.rx, + pt.neg_ry_eff, + curve, + ); + acc_x = select_witness(compiler, pt.s2_skew, acc_x, sub_rx); + acc_y = select_witness(compiler, pt.s2_skew, acc_y, sub_ry); + } + + (acc_x, acc_y) +} diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs new file mode 100644 index 000000000..e72e232cc --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -0,0 +1,552 @@ +//! Non-native (generic multi-limb) MSM path. +//! +//! Used when `!curve.is_native_field()` — the curve's base field differs from +//! the R1CS native field (e.g. SECP256R1 over BN254). Field elements are +//! represented as multi-limb values (N limbs of `limb_bits` each), and +//! arithmetic is verified via schoolbook column equations. +//! +//! ## Key techniques +//! +//! - **Multi-limb arithmetic**: field elements split into N limbs; add/sub use +//! carry chains with boolean quotients, mul uses schoolbook column equations, +//! all verified mod the curve's base field. +//! - **Hint-verified EC ops** (N ≥ 2): point_double, point_add, and on-curve +//! checks use prover hints verified via schoolbook column equations, avoiding +//! the step-by-step MultiLimbOps chain (which requires field inversions). +//! - **FakeGLV**: same half-GCD scalar decomposition as the native path. +//! - **Signed-digit windows**: w-bit windows produce signed odd digits d ∈ {±1, +//! ±3, ..., ±(2^w-1)}, eliminating zero-digit handling and halving the lookup +//! table to 2^(w-1) entries. Skew correction applied post-loop. +//! - **Merged doublings**: all points share w doublings per window, saving w × +//! (n_points - 1) doublings per window. +//! +//! ## Phases +//! +//! 1. **Preprocessing**: per-point sanitization, limb decomposition of point +//! coordinates, on-curve checks, FakeGLV decomposition, signed-bit +//! decomposition, y-negation via `negate_mod_p_multi`. +//! 2. **Merged scalar mul**: `scalar_mul_merged_glv` runs a single windowed +//! loop with shared doublings + per-point signed table lookups. Identity +//! check: final accumulator must equal the known offset. +//! 3. **Scalar relations**: per-point verification that (-1)^neg1·|s1| + +//! (-1)^neg2·|s2|·s ≡ 0 (mod curve_order). +//! 4. **Accumulation**: adds each point's scalar-mul result (via dispatch to +//! hint-verified or generic add), subtracts offset, constrains outputs. +//! +//! ## I/O modes +//! +//! Two entry points share a single core implementation via `NonNativeIo`: +//! - `process_multi_point_non_native`: point coordinates as single field +//! elements (decomposed to limbs internally); outputs as single witnesses. +//! - `process_multi_point_non_native_limbed`: point coordinates pre-decomposed +//! as limbs; outputs constrained per-limb. + +use { + super::{ + curve, ec_points, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + sanitize::{ + emit_ec_scalar_mul_hint_and_sanitize_multi_limb, emit_fakeglv_hint, + sanitize_point_scalar_multi_limb, + }, + scalar_relation, Limbs, MsmLimbedOutputs, + }, + crate::{ + constraint_helpers::{ + add_constant_witness, constrain_equal, constrain_to_constant, select_witness, + }, + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, + provekit_common::{witness::SumTerm, FieldElement}, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// IO mode: single field-element outputs vs. per-limb outputs +// --------------------------------------------------------------------------- + +/// Distinguishes the two non-native MSM I/O modes. +/// +/// - `SingleFe`: point coordinates come as single field elements (decomposed to +/// limbs internally); outputs are single witnesses `(out_x, out_y, out_inf)`. +/// - `Limbed`: point coordinates arrive pre-decomposed as limbs (stride +/// `2*num_limbs + 1` per point); outputs are per-limb witnesses. +enum NonNativeIo<'a> { + SingleFe { outputs: (usize, usize, usize) }, + Limbed { outputs: &'a MsmLimbedOutputs }, +} + +// --------------------------------------------------------------------------- +// Public entry points (thin wrappers around the shared core) +// --------------------------------------------------------------------------- + +/// Multi-point non-native MSM with single field-element I/O. +/// +/// All points share a single set of doublings per window, saving +/// `w × (n_points - 1)` doublings per window compared to separate loops. +pub(super) fn process_multi_point_non_native<'a>( + compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + process_non_native_core( + compiler, + point_wits, + scalar_wits, + NonNativeIo::SingleFe { outputs }, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +/// Multi-point non-native MSM with per-limb I/O. +/// +/// Point coordinates are provided as limbs (stride `2*num_limbs + 1` per +/// point), avoiding the single-field-element bottleneck. Output coordinates +/// are constrained per-limb rather than recomposed. +pub(super) fn process_multi_point_non_native_limbed<'a>( + compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: &MsmLimbedOutputs, + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + process_non_native_core( + compiler, + point_wits, + scalar_wits, + NonNativeIo::Limbed { outputs }, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +// --------------------------------------------------------------------------- +// Core implementation +// --------------------------------------------------------------------------- + +/// Unified non-native MSM implementation. +/// +/// The only differences between single-FE and limbed I/O modes are: +/// 1. **Phase 1**: how point coordinates are extracted from `point_wits` +/// (decompose from single FE vs. read pre-decomposed limbs). +/// 2. **Phase 4**: how output coordinates are constrained (recompose to single +/// FE vs. constrain per-limb). +/// +/// Phases 2 (merged scalar mul) and 3 (scalar relations) are identical. +fn process_non_native_core<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + io: NonNativeIo<'_>, + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + mut range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let one = compiler.witness_one(); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Generator as limbs — avoids truncation when generator coords > BN254_Fr + let gen_x_fe_limbs = decompose_to_limbs_pub(&curve.generator.0, limb_bits, num_limbs); + let gen_y_fe_limbs = decompose_to_limbs_pub(&curve.generator.1, limb_bits, num_limbs); + let gen_x_limb_wits: Vec = gen_x_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + let gen_y_limb_wits: Vec = gen_y_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + + // Build params once for all multi-limb ops + let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); + + // Offset point as limbs for accumulation + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + + // Track all_skipped = product of all is_skip flags + let mut all_skipped: Option = None; + let mut merged_points: Vec = Vec::new(); + let mut scalar_rel_inputs: Vec<(usize, usize, usize, usize, usize, usize)> = Vec::new(); + let mut accum_inputs: Vec<(Limbs, Limbs, usize)> = Vec::new(); + + // Phase 1: Per-point preprocessing + for i in 0..n_points { + // Extract point limbs and inf flag — differs by IO mode + let (px_limbs, py_limbs, inf_flag) = match &io { + NonNativeIo::SingleFe { .. } => { + let (px, py) = decompose_point_to_limbs( + compiler, + point_wits[3 * i], + point_wits[3 * i + 1], + num_limbs, + limb_bits, + range_checks, + ); + (px, py, point_wits[3 * i + 2]) + } + NonNativeIo::Limbed { .. } => { + let stride = 2 * num_limbs + 1; + let base = i * stride; + let mut px = Limbs::new(num_limbs); + let mut py = Limbs::new(num_limbs); + for j in 0..num_limbs { + px[j] = point_wits[base + j]; + py[j] = point_wits[base + num_limbs + j]; + range_checks.entry(limb_bits).or_default().push(px[j]); + range_checks.entry(limb_bits).or_default().push(py[j]); + } + (px, py, point_wits[base + 2 * num_limbs]) + } + }; + + // Sanitize at the limb level — per-limb select between input and generator + let san = sanitize_point_scalar_multi_limb( + compiler, + px_limbs, + py_limbs, + scalar_wits[2 * i], + scalar_wits[2 * i + 1], + inf_flag, + &gen_x_limb_wits, + &gen_y_limb_wits, + zero_witness, + one, + ); + + // Track all_skipped + all_skipped = Some(match all_skipped { + None => san.is_skip, + Some(prev) => compiler.add_product(prev, san.is_skip), + }); + + // EcScalarMulHint with multi-limb inputs/outputs + let (rx, ry) = emit_ec_scalar_mul_hint_and_sanitize_multi_limb( + compiler, + &san, + &gen_x_limb_wits, + &gen_y_limb_wits, + num_limbs, + limb_bits, + range_checks, + curve, + ); + + // Sanitized px/py are already in Limbs form + let px = san.px_limbs; + let py = san.py_limbs; + + // On-curve checks: use hint-verified for multi-limb, generic for single-limb + if num_limbs >= 2 { + ec_points::verify_on_curve_non_native(compiler, range_checks, px, py, ¶ms); + ec_points::verify_on_curve_non_native(compiler, range_checks, rx, ry, ¶ms); + } else { + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + // FakeGLV, bit decomposition, y-negation (via MultiLimbOps) + let (s1_witness, s2_witness, neg1_witness, neg2_witness); + let (py_effective, ry_effective, s1_bits, s2_bits, s1_skew, s2_skew); + { + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // FakeGLVHint → |s1|, |s2|, neg1, neg2 + (s1_witness, s2_witness, neg1_witness, neg2_witness) = + emit_fakeglv_hint(ops.compiler, san.s_lo, san.s_hi, curve); + + // Signed-bit decomposition of |s1|, |s2| — produces signed digits + // d_i = 2*b_i - 1 ∈ {-1, +1} with skew correction, matching the + // native path's wNAF approach. + let half_bits = curve.glv_half_bits() as usize; + (s1_bits, s1_skew) = super::decompose_signed_bits(ops.compiler, s1_witness, half_bits); + (s2_bits, s2_skew) = super::decompose_signed_bits(ops.compiler, s2_witness, half_bits); + + // Conditionally negate y-coordinates + let neg_py = ops.negate(py); + let neg_ry = ops.negate(ry); + py_effective = ops.select(neg1_witness, py, neg_py); + ry_effective = ops.select(neg2_witness, ry, neg_ry); + + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + merged_points.push(ec_points::MergedGlvPoint { + px, + py: py_effective, + s1_bits, + s1_skew, + rx, + ry: ry_effective, + s2_bits, + s2_skew, + }); + + scalar_rel_inputs.push(( + san.s_lo, + san.s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + )); + accum_inputs.push((rx, ry, san.is_skip)); + } + + // Phase 2: Per-point scalar mul verification + // + // Each point gets its own accumulator and identity check. This ensures + // per-point soundness: b_i * (R_i - scalar_i * P_i) = O with b_i ≠ 0 + // implies R_i = scalar_i * P_i for each point independently. + let half_bits = curve.glv_half_bits() as usize; + { + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Precompute the expected accumulated offset (same for all points) + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + + for pt in &merged_points { + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_merged_glv( + &mut ops, + std::slice::from_ref(pt), + window_size, + offset_x, + offset_y, + ); + + // Per-point identity check + for j in 0..num_limbs { + constrain_to_constant(ops.compiler, glv_acc.0[j], acc_off_x_values[j]); + constrain_to_constant(ops.compiler, glv_acc.1[j], acc_off_y_values[j]); + } + } + + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + // Phase 3: Per-point scalar relations + for &(s_lo, s_hi, s1, s2, neg1, neg2) in &scalar_rel_inputs { + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1, + s2, + neg1, + neg2, + curve, + ); + } + + // Phase 4: Accumulation + output constraining + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let mut acc_x = ops.constant_limbs(&offset_x_values); + let mut acc_y = ops.constant_limbs(&offset_y_values); + + for &(rx, ry, is_skip) in &accum_inputs { + let (cand_x, cand_y) = ec_points::point_add_dispatch(&mut ops, acc_x, acc_y, rx, ry); + let (new_acc_x, new_acc_y) = + ec_points::point_select_unchecked(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); + acc_x = new_acc_x; + acc_y = new_acc_y; + } + + // Offset subtraction + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_values = curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); + + let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); + + let sub_x = { + let off_x = ops.constant_limbs(&offset_x_values); + let g_x = ops.constant_limbs(&gen_x_limb_values); + ops.select(all_skipped, off_x, g_x) + }; + let sub_y = { + let neg_off_y = ops.constant_limbs(&neg_offset_y_values); + let neg_g_y = ops.constant_limbs(&neg_gen_y_values); + ops.select(all_skipped, neg_off_y, neg_g_y) + }; + + let (result_x, result_y) = ec_points::point_add_dispatch(&mut ops, acc_x, acc_y, sub_x, sub_y); + compiler = ops.compiler; + + // Output constraining — differs by IO mode + match &io { + NonNativeIo::SingleFe { + outputs: (out_x, out_y, out_inf), + } => { + if num_limbs == 1 { + let masked_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); + let masked_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); + constrain_equal(compiler, *out_x, masked_x); + constrain_equal(compiler, *out_y, masked_y); + } else { + let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); + let masked_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); + let masked_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); + constrain_equal(compiler, *out_x, masked_x); + constrain_equal(compiler, *out_y, masked_y); + } + constrain_equal(compiler, *out_inf, all_skipped); + } + NonNativeIo::Limbed { outputs } => { + let zero_limb_wits: Vec = (0..num_limbs) + .map(|_| add_constant_witness(compiler, FieldElement::ZERO)) + .collect(); + for j in 0..num_limbs { + let masked_x = + select_witness(compiler, all_skipped, result_x[j], zero_limb_wits[j]); + let masked_y = + select_witness(compiler, all_skipped, result_y[j], zero_limb_wits[j]); + constrain_equal(compiler, outputs.out_x_limbs[j], masked_x); + constrain_equal(compiler, outputs.out_y_limbs[j], masked_y); + } + constrain_equal(compiler, outputs.out_inf, all_skipped); + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. +fn verify_on_curve( + ops: &mut MultiLimbOps, + x: Limbs, + y: Limbs, + b_limb_values: &[FieldElement], + num_limbs: usize, +) { + let y_sq = ops.mul(y, y); + let x_sq = ops.mul(x, x); + let x_cubed = ops.mul(x_sq, x); + let a = ops.curve_a(); + let ax = ops.mul(a, x); + let x3_plus_ax = ops.add(x_cubed, ax); + let b = ops.constant_limbs(b_limb_values); + let rhs = ops.add(x3_plus_ax, b); + for i in 0..num_limbs { + constrain_equal(ops.compiler, y_sq[i], rhs[i]); + } +} + +/// Decompose a point (px_witness, py_witness) into Limbs. +fn decompose_point_to_limbs( + compiler: &mut NoirToR1CSCompiler, + px_witness: usize, + py_witness: usize, + num_limbs: usize, + limb_bits: u32, + range_checks: &mut BTreeMap>, +) -> (Limbs, Limbs) { + if num_limbs == 1 { + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + let px_limbs = + decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); + let py_limbs = + decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); + (px_limbs, py_limbs) + } +} + +/// Decompose a single witness into `num_limbs` limbs using digital +/// decomposition. +fn decompose_witness_to_limbs( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + limb_bits: u32, + num_limbs: usize, + range_checks: &mut BTreeMap>, +) -> Limbs { + let log_bases = vec![limb_bits as usize; num_limbs]; + let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); + let mut limbs = Limbs::new(num_limbs); + for i in 0..num_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + // Range-check each decomposed limb to [0, 2^limb_bits). + // add_digital_decomposition constrains the recomposition but does + // NOT range-check individual digits. + range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + limbs +} + +/// Recompose limbs back into a single witness: val = Σ limb\[i\] * +/// 2^(i*limb_bits) +fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { + let terms: Vec = limbs + .iter() + .enumerate() + .map(|(i, &limb)| { + let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); + SumTerm(Some(coeff), limb) + }) + .collect(); + compiler.add_sum(terms) +} diff --git a/provekit/r1cs-compiler/src/msm/sanitize.rs b/provekit/r1cs-compiler/src/msm/sanitize.rs new file mode 100644 index 000000000..10ece47fd --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/sanitize.rs @@ -0,0 +1,213 @@ +//! Degenerate-case detection and sanitization helpers for MSM point-scalar +//! pairs. + +use { + super::{curve::CurveParams, Limbs}, + crate::{ + constraint_helpers::{compute_boolean_or, constrain_boolean, select_witness}, + msm::multi_limb_arith::compute_is_zero, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::Field, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, +}; + +/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at +/// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if +/// degenerate). +fn detect_skip( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + inf_flag: usize, +) -> usize { + constrain_boolean(compiler, inf_flag); + let is_zero_s_lo = compute_is_zero(compiler, s_lo); + let is_zero_s_hi = compute_is_zero(compiler, s_hi); + let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); + compute_boolean_or(compiler, s_is_zero, inf_flag) +} + +/// Sanitized point-scalar inputs after degenerate-case detection. +pub(super) struct SanitizedInputs { + pub px: usize, + pub py: usize, + pub s_lo: usize, + pub s_hi: usize, + pub is_skip: usize, +} + +/// Detects degenerate cases (scalar=0 or point at infinity) and replaces +/// the point with the generator G and scalar with 1 when degenerate. +pub(super) fn sanitize_point_scalar( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x: usize, + gen_y: usize, + zero: usize, + one: usize, +) -> SanitizedInputs { + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + SanitizedInputs { + px: select_witness(compiler, is_skip, px, gen_x), + py: select_witness(compiler, is_skip, py, gen_y), + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + +/// Negate a y-coordinate and conditionally select based on a sign flag. +/// Returns `(y_eff, neg_y_eff)` where: +/// - if `neg_flag=0`: `y_eff = y`, `neg_y_eff = -y` +/// - if `neg_flag=1`: `y_eff = -y`, `neg_y_eff = y` +pub(super) fn negate_y_signed_native( + compiler: &mut NoirToR1CSCompiler, + neg_flag: usize, + y: usize, +) -> (usize, usize) { + constrain_boolean(compiler, neg_flag); + let neg_y = compiler.add_sum(vec![SumTerm(Some(-FieldElement::ONE), y)]); + let y_eff = select_witness(compiler, neg_flag, y, neg_y); + let neg_y_eff = select_witness(compiler, neg_flag, neg_y, y); + (y_eff, neg_y_eff) +} + +/// Emit an `EcScalarMulHint` and sanitize the result point. +/// When `is_skip=1`, the result is swapped to the generator point. +/// Used by the native path where coordinates fit in a single field element. +pub(super) fn emit_ec_scalar_mul_hint_and_sanitize( + compiler: &mut NoirToR1CSCompiler, + san: &SanitizedInputs, + gen_x_witness: usize, + gen_y_witness: usize, + curve: &CurveParams, +) -> (usize, usize) { + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px_limbs: vec![san.px], + py_limbs: vec![san.py], + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + num_limbs: 1, + limb_bits: 0, + }); + let rx = select_witness(compiler, san.is_skip, hint_start, gen_x_witness); + let ry = select_witness(compiler, san.is_skip, hint_start + 1, gen_y_witness); + (rx, ry) +} + +/// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. +pub(super) fn emit_fakeglv_hint( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + curve: &CurveParams, +) -> (usize, usize, usize, usize) { + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) +} + +/// Multi-limb sanitized inputs for non-native MSM. +pub(super) struct SanitizedInputsMultiLimb { + pub px_limbs: Limbs, + pub py_limbs: Limbs, + pub s_lo: usize, + pub s_hi: usize, + pub is_skip: usize, +} + +/// Sanitize a non-native point-scalar pair at the limb level. +/// +/// Detects degenerate cases and replaces the point with the generator +/// (as limbs) and scalar with 1 when degenerate. This avoids truncating +/// generator coordinates that exceed BN254_Fr. +pub(super) fn sanitize_point_scalar_multi_limb( + compiler: &mut NoirToR1CSCompiler, + px_limbs: Limbs, + py_limbs: Limbs, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + zero: usize, + one: usize, +) -> SanitizedInputsMultiLimb { + let n = px_limbs.len(); + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + let mut san_px = Limbs::new(n); + let mut san_py = Limbs::new(n); + for i in 0..n { + san_px[i] = select_witness(compiler, is_skip, px_limbs[i], gen_x_limb_wits[i]); + san_py[i] = select_witness(compiler, is_skip, py_limbs[i], gen_y_limb_wits[i]); + } + + SanitizedInputsMultiLimb { + px_limbs: san_px, + py_limbs: san_py, + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + +/// Emit an `EcScalarMulHint` with multi-limb inputs/outputs and sanitize. +/// +/// When `is_skip=1`, each output limb is replaced with the corresponding +/// generator limb. Returns `(rx_limbs, ry_limbs)` as `Limbs`. +pub(super) fn emit_ec_scalar_mul_hint_and_sanitize_multi_limb( + compiler: &mut NoirToR1CSCompiler, + san: &SanitizedInputsMultiLimb, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + num_limbs: usize, + limb_bits: u32, + range_checks: &mut std::collections::BTreeMap>, + curve: &CurveParams, +) -> (Limbs, Limbs) { + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px_limbs: san.px_limbs.as_slice().to_vec(), + py_limbs: san.py_limbs.as_slice().to_vec(), + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + num_limbs: num_limbs as u32, + limb_bits, + }); + + let mut rx = Limbs::new(num_limbs); + let mut ry = Limbs::new(num_limbs); + for i in 0..num_limbs { + let rx_hint = hint_start + i; + let ry_hint = hint_start + num_limbs + i; + // Range-check hint output limbs + range_checks.entry(limb_bits).or_default().push(rx_hint); + range_checks.entry(limb_bits).or_default().push(ry_hint); + // Sanitize: select between hint output and generator + rx[i] = select_witness(compiler, san.is_skip, rx_hint, gen_x_limb_wits[i]); + ry[i] = select_witness(compiler, san.is_skip, ry_hint, gen_y_limb_wits[i]); + } + + (rx, ry) +} diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs new file mode 100644 index 000000000..0b6c1e29b --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -0,0 +1,254 @@ +//! Scalar relation verification: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +//! (mod n). +//! +//! Shared by both the native and non-native MSM paths. + +use { + super::{ + cost_model, curve, + multi_limb_arith::compute_is_zero, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + Limbs, + }, + crate::{ + constraint_helpers::constrain_zero, + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, + curve::CurveParams, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// The 256-bit scalar is split into two 128-bit halves (s_lo, s_hi) because the +/// full value doesn't fit in the native field. +const SCALAR_HALF_BITS: usize = 128; + +/// Compute digit widths for decomposing `total_bits` into chunks of at most +/// `max_width` bits. The last chunk may be smaller. +fn limb_widths(total_bits: usize, max_width: u32) -> Vec { + let n = (total_bits + max_width as usize - 1) / max_width as usize; + (0..n) + .map(|i| { + let remaining = total_bits - i * max_width as usize; + remaining.min(max_width as usize) + }) + .collect() +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n). +/// +/// Uses multi-limb arithmetic with curve_order_n as the modulus. +pub(super) fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &CurveParams, +) { + let order_bits = curve.curve_order_bits() as usize; + let limb_bits = + cost_model::scalar_relation_limb_bits(FieldElement::MODULUS_BIT_SIZE, order_bits); + let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + + let params = MultiLimbParams::for_curve_order(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + let s_limbs = decompose_scalar_from_halves(&mut ops, s_lo, s_hi, num_limbs, limb_bits); + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness, num_limbs, half_bits, limb_bits); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness, num_limbs, half_bits, limb_bits); + + let product = ops.mul(s2_limbs, s_limbs); + + // Sign handling: when signs match check s1+product=0, otherwise s1-product=0. + // XOR = neg1 + neg2 - 2*neg1*neg2 gives 0 for same signs, 1 for different. + let sum = ops.add(s1_limbs, product); + let diff = ops.sub(s1_limbs, product); + + let xor_prod = ops.compiler.add_product(neg1_witness, neg2_witness); + let xor = ops.compiler.add_sum(vec![ + SumTerm(None, neg1_witness), + SumTerm(None, neg2_witness), + SumTerm(Some(-FieldElement::from(2u64)), xor_prod), + ]); + + let effective = ops.select_unchecked(xor, sum, diff); + for i in 0..num_limbs { + constrain_zero(ops.compiler, effective[i]); + } + + // Soundness: s2 must be non-zero. If s2=0 the relation degenerates to + // s1≡0 (mod n) which is trivially satisfiable with s1=0, leaving the + // hint-supplied result point R unconstrained. + let s2_is_zero = compute_is_zero(ops.compiler, s2_witness); + constrain_zero(ops.compiler, s2_is_zero); +} + +/// Decompose a 256-bit scalar from two 128-bit halves into `num_limbs` limbs. +/// +/// When `limb_bits` divides 128 (e.g. 64), limb boundaries align with the +/// s_lo/s_hi split. Otherwise (e.g. 85-bit limbs), one limb straddles bit 128 +/// and is assembled from a partial s_lo digit and a partial s_hi digit. +/// +/// For small curves where `num_limbs * limb_bits < 256`, the digits beyond the +/// used limbs are constrained to zero. This ensures the scalar fits in the +/// representation and prevents truncation attacks. +fn decompose_scalar_from_halves( + ops: &mut MultiLimbOps, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, +) -> Limbs { + let lo_tail = SCALAR_HALF_BITS % limb_bits as usize; + + if lo_tail == 0 { + let widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let dd_lo = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_hi]); + let mut limbs = Limbs::new(num_limbs); + let from_lo = widths.len().min(num_limbs); + for (i, &w) in widths.iter().enumerate().take(from_lo) { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks.entry(w as u32).or_default().push(limbs[i]); + } + let from_hi = (num_limbs - from_lo).min(widths.len()); + for (i, &w) in widths.iter().enumerate().take(from_hi) { + limbs[from_lo + i] = dd_hi.get_digit_witness_index(i, 0); + ops.range_checks + .entry(w as u32) + .or_default() + .push(limbs[from_lo + i]); + } + + // Constrain unused dd_lo digits to zero (small curves where num_limbs + // covers fewer than 128 bits of s_lo). + for i in from_lo..widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + // Constrain unused dd_hi digits to zero (small curves where the upper + // half is partially or entirely unused). + for i in from_hi..widths.len() { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(i, 0)); + } + + limbs + } else { + // Example: 85-bit limbs, 254-bit order → + // s_lo DD [85, 43], s_hi DD [42, 86] + // L0 = s_lo[0..85), L1 = s_lo[85..128) | s_hi[0..42), L2 = s_hi[42..128) + let hi_head = limb_bits as usize - lo_tail; + let hi_rest = SCALAR_HALF_BITS - hi_head; + let lo_full = SCALAR_HALF_BITS / limb_bits as usize; + + let lo_widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let hi_widths = vec![hi_head, hi_rest]; + + let dd_lo = add_digital_decomposition(ops.compiler, lo_widths.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, hi_widths, vec![s_hi]); + let mut limbs = Limbs::new(num_limbs); + + let lo_used = lo_full.min(num_limbs); + for i in 0..lo_used { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks + .entry(limb_bits) + .or_default() + .push(limbs[i]); + } + + // Cross-boundary limb and hi_rest, only if num_limbs needs them. + let needs_cross = num_limbs > lo_full; + let needs_hi_rest = num_limbs > lo_full + 1 && hi_rest > 0; + + if needs_cross { + let shift = FieldElement::from(2u64).pow([lo_tail as u64]); + let lo_digit = dd_lo.get_digit_witness_index(lo_full, 0); + let hi_digit = dd_hi.get_digit_witness_index(0, 0); + limbs[lo_full] = ops.compiler.add_sum(vec![ + SumTerm(None, lo_digit), + SumTerm(Some(shift), hi_digit), + ]); + ops.range_checks + .entry(lo_tail as u32) + .or_default() + .push(lo_digit); + ops.range_checks + .entry(hi_head as u32) + .or_default() + .push(hi_digit); + } + + if needs_hi_rest { + limbs[lo_full + 1] = dd_hi.get_digit_witness_index(1, 0); + ops.range_checks + .entry(hi_rest as u32) + .or_default() + .push(limbs[lo_full + 1]); + } + + // Constrain unused digits to zero for small curves. + // dd_lo: digits beyond lo_used that aren't part of the cross-boundary. + if !needs_cross { + // The tail digit of dd_lo and all dd_hi digits are unused. + for i in lo_used..lo_widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(0, 0)); + if hi_rest > 0 { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + } else if !needs_hi_rest && hi_rest > 0 { + // Cross-boundary is used but hi_rest digit is unused. + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + + limbs + } +} + +/// Decompose a half-scalar witness into `num_limbs` limbs, zero-padding the +/// upper limbs beyond `half_bits`. +fn decompose_half_scalar( + ops: &mut MultiLimbOps, + witness: usize, + num_limbs: usize, + half_bits: usize, + limb_bits: u32, +) -> Limbs { + let widths = limb_widths(half_bits, limb_bits); + let dd = add_digital_decomposition(ops.compiler, widths.clone(), vec![witness]); + let mut limbs = Limbs::new(num_limbs); + + for (i, &w) in widths.iter().enumerate() { + limbs[i] = dd.get_digit_witness_index(i, 0); + ops.range_checks.entry(w as u32).or_default().push(limbs[i]); + } + + for i in widths.len()..num_limbs { + let w = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::from(0u64), + ))); + limbs[i] = w; + constrain_zero(ops.compiler, limbs[i]); + } + + limbs +} diff --git a/provekit/r1cs-compiler/src/msm/tests.rs b/provekit/r1cs-compiler/src/msm/tests.rs new file mode 100644 index 000000000..ea326ac15 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/tests.rs @@ -0,0 +1,107 @@ +use { + super::*, crate::noir_to_r1cs::NoirToR1CSCompiler, + provekit_common::witness::ConstantOrR1CSWitness, std::collections::BTreeMap, +}; + +/// Verify that the non-native (SECP256R1) single-point MSM path generates +/// constraints without panicking. This does multi-limb arithmetic, +/// range checks, and FakeGLV verification — the entire non-native code path +/// that has no Noir e2e coverage for now : ) +#[test] +fn test_secp256r1_single_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // Allocate witness slots for: px, py, inf, s_lo, s_hi, out_x, out_y, out_inf + // (witness 0 is the constant-one witness) + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(8); + let px = base; + let py = base + 1; + let inf = base + 2; + let s_lo = base + 3; + let s_hi = base + 4; + let out_x = base + 5; + let out_y = base + 6; + let out_inf = base + 7; + + let points = vec![ + ConstantOrR1CSWitness::Witness(px), + ConstantOrR1CSWitness::Witness(py), + ConstantOrR1CSWitness::Witness(inf), + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(s_lo), + ConstantOrR1CSWitness::Witness(s_hi), + ]; + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); + + assert!( + n_constraints > 100, + "expected substantial constraints for non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 100, + "expected substantial witnesses for non-native MSM, got {n_witnesses}" + ); + assert!( + !range_checks.is_empty(), + "non-native MSM should produce range checks" + ); +} + +/// Verify that the non-native multi-point MSM path (2 points, SECP256R1) +/// generates constraints. does the multi-point accumulation and offset +/// subtraction logic for the non-native path. +#[test] +fn test_secp256r1_multi_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // 2 points: px1, py1, inf1, px2, py2, inf2, s1_lo, s1_hi, s2_lo, s2_hi, + // out_x, out_y, out_inf + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(13); + + let points = vec![ + ConstantOrR1CSWitness::Witness(base), // px1 + ConstantOrR1CSWitness::Witness(base + 1), // py1 + ConstantOrR1CSWitness::Witness(base + 2), // inf1 + ConstantOrR1CSWitness::Witness(base + 3), // px2 + ConstantOrR1CSWitness::Witness(base + 4), // py2 + ConstantOrR1CSWitness::Witness(base + 5), // inf2 + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(base + 6), // s1_lo + ConstantOrR1CSWitness::Witness(base + 7), // s1_hi + ConstantOrR1CSWitness::Witness(base + 8), // s2_lo + ConstantOrR1CSWitness::Witness(base + 9), // s2_hi + ]; + let out_x = base + 10; + let out_y = base + 11; + let out_inf = base + 12; + + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); + + // Multi-point should produce more constraints than single-point + assert!( + n_constraints > 200, + "expected substantial constraints for 2-point non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 200, + "expected substantial witnesses for 2-point non-native MSM, got {n_witnesses}" + ); +} diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 189eb4693..961edb248 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -3,6 +3,7 @@ use { binops::add_combined_binop_constraints, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, memory::{add_ram_checking, add_rom_checking, MemoryBlock, MemoryOperation}, + msm::add_msm, poseidon2::add_poseidon2_permutation, range_check::add_range_checks, sha256_compression::add_sha256_compression, @@ -88,12 +89,17 @@ pub struct R1CSBreakdown { pub poseidon2_constraints: usize, /// Witnesses from Poseidon2 permutation pub poseidon2_witnesses: usize, + + /// Constraints from multi-scalar multiplication + pub msm_constraints: usize, + /// Witnesses from multi-scalar multiplication + pub msm_witnesses: usize, } /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, /// and C R1CS matrices, along with the witness vector. -pub(crate) struct NoirToR1CSCompiler { - pub(crate) r1cs: R1CS, +pub struct NoirToR1CSCompiler { + pub r1cs: R1CS, /// Indicates how to solve for each R1CS witness pub witness_builders: Vec, @@ -136,7 +142,7 @@ pub fn noir_to_r1cs_with_breakdown( } impl NoirToR1CSCompiler { - pub(crate) fn new() -> Self { + pub fn new() -> Self { let mut r1cs = R1CS::new(); // Grow the matrices to account for the constant one witness. r1cs.add_witnesses(1); @@ -457,6 +463,7 @@ impl NoirToR1CSCompiler { let mut xor_ops = vec![]; let mut sha256_compression_ops = vec![]; let mut poseidon2_ops = vec![]; + let mut msm_ops = vec![]; let mut breakdown = R1CSBreakdown::default(); @@ -627,6 +634,24 @@ impl NoirToR1CSCompiler { output_witnesses, )); } + BlackBoxFuncCall::MultiScalarMul { + points, + scalars, + outputs, + } => { + let point_wits: Vec = points + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let scalar_wits: Vec = scalars + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let out_x = self.fetch_r1cs_witness_index(outputs.0); + let out_y = self.fetch_r1cs_witness_index(outputs.1); + let out_inf = self.fetch_r1cs_witness_index(outputs.2); + msm_ops.push((point_wits, scalar_wits, (out_x, out_y, out_inf))); + } _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } @@ -718,6 +743,12 @@ impl NoirToR1CSCompiler { breakdown.poseidon2_constraints = self.r1cs.num_constraints() - constraints_before_poseidon; breakdown.poseidon2_witnesses = self.num_witnesses() - witnesses_before_poseidon; + let constraints_before_msm = self.r1cs.num_constraints(); + let witnesses_before_msm = self.num_witnesses(); + add_msm(self, msm_ops, &mut range_checks); + breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; + breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; + breakdown.range_ops_total = range_checks.values().map(|v| v.len()).sum(); let constraints_before_range = self.r1cs.num_constraints(); let witnesses_before_range = self.num_witnesses(); diff --git a/provekit/r1cs-compiler/src/range_check.rs b/provekit/r1cs-compiler/src/range_check.rs index f76fe94c3..763576ac4 100644 --- a/provekit/r1cs-compiler/src/range_check.rs +++ b/provekit/r1cs-compiler/src/range_check.rs @@ -139,13 +139,46 @@ fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 { optimal_width } +/// Estimates total witness cost for resolving range checks without +/// constructing actual R1CS constraints. +/// +/// Takes a map of `bit_width → count` (number of witnesses needing that +/// range check). Uses the same optimal-base-width search and +/// LogUp-vs-naive cost model as [`add_range_checks`], but operates on +/// aggregate counts rather than concrete witness indices. +pub(crate) fn estimate_range_check_cost(checks: &BTreeMap) -> usize { + if checks.is_empty() { + return 0; + } + + // Create synthetic RangeCheckRequests with unique dummy indices. + let mut collected: Vec = Vec::new(); + let mut dummy_idx = 0usize; + for (&bits, &count) in checks { + for _ in 0..count { + collected.push(RangeCheckRequest { + witness_idx: dummy_idx, + bits, + }); + dummy_idx += 1; + } + } + + if collected.is_empty() { + return 0; + } + + let base_width = get_optimal_base_width(&collected); + calculate_witness_cost(base_width, &collected) +} + /// Add witnesses and constraints that ensure that the values of the witness /// belong to a range 0..2^k (for some k). /// /// Uses dynamic base width optimization: all range check requests are /// collected, and the optimal decomposition base width is determined by /// minimizing the total witness count (memory cost). The search evaluates -/// every base width from [MIN_BASE_WIDTH] to [MAX_BASE_WIDTH]. For each +/// every base width from \[MIN_BASE_WIDTH\] to \[MAX_BASE_WIDTH\]. For each /// candidate, the cost model picks the cheaper of LogUp and naive for /// every atomic bucket. /// @@ -156,7 +189,7 @@ fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 { /// /// `range_checks` is a map from the number of bits k to the vector of /// witness indices that are to be constrained within the range [0..2^k]. -pub(crate) fn add_range_checks( +pub fn add_range_checks( r1cs: &mut NoirToR1CSCompiler, range_checks: BTreeMap>, ) -> Option { diff --git a/tooling/provekit-bench/Cargo.toml b/tooling/provekit-bench/Cargo.toml index 3b1993b2a..8791ea151 100644 --- a/tooling/provekit-bench/Cargo.toml +++ b/tooling/provekit-bench/Cargo.toml @@ -16,6 +16,7 @@ provekit-r1cs-compiler.workspace = true provekit-verifier.workspace = true # Noir language +acir.workspace = true nargo.workspace = true nargo_cli.workspace = true nargo_toml.workspace = true diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 4746ffe04..0f621293d 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -23,7 +23,7 @@ struct NargoTomlPackage { name: String, } -fn test_noir_compiler(test_case_path: impl AsRef) { +fn test_noir_compiler(test_case_path: impl AsRef, witness_file: &str) { let test_case_path = test_case_path.as_ref(); compile_workspace(test_case_path).expect("Compiling workspace"); @@ -36,7 +36,7 @@ fn test_noir_compiler(test_case_path: impl AsRef) { let package_name = nargo_toml.package.name; let circuit_path = test_case_path.join(format!("target/{package_name}.json")); - let witness_file_path = test_case_path.join("Prover.toml"); + let witness_file_path = test_case_path.join(witness_file); let schema = NoirCompiler::from_file(&circuit_path, provekit_common::HashConfig::default()) .expect("Reading proof scheme"); @@ -69,19 +69,57 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result Ok(workspace) } -#[test_case("../../noir-examples/noir-r1cs-test-programs/acir_assert_zero")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u8")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u16")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-write-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/conditional-write")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bin-opcode")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] -#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] -fn case_noir(path: &str) { - test_noir_compiler(path); +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/acir_assert_zero", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u8", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u16", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-write-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/conditional-write", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bin-opcode", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha", "Prover.toml")] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bounded-vec", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check", "Prover.toml"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover.toml"; "embedded_curve_msm")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_zero_scalars.toml"; "msm_zero_scalars")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_single_nonzero.toml"; "msm_single_nonzero")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_order.toml"; "msm_near_order")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_identity.toml"; "msm_near_identity")] +fn case_noir(path: &str, witness_file: &str) { + test_noir_compiler(path, witness_file); } diff --git a/tooling/provekit-bench/tests/msm_witness_solving.rs b/tooling/provekit-bench/tests/msm_witness_solving.rs new file mode 100644 index 000000000..f8787c0f5 --- /dev/null +++ b/tooling/provekit-bench/tests/msm_witness_solving.rs @@ -0,0 +1,512 @@ +//! End-to-end MSM witness solving tests for non-native curves (secp256r1). +//! +//! These tests verify that the full pipeline works correctly: +//! 1. Compile MSM circuit (R1CS + witness builders) +//! 2. Set initial witness values (point coordinates as limbs + scalar) +//! 3. Solve all derived witnesses via the witness builder layer scheduler +//! 4. Check R1CS satisfaction: A·w ⊙ B·w = C·w for all constraints +//! +//! All tests use the **limbed API** (`add_msm_with_curve_limbed`) where +//! point coordinates are multi-limb witnesses, supporting arbitrary +//! secp256r1 coordinates (including those exceeding BN254 Fr). + +use { + acir::native_types::WitnessMap, + ark_ff::{PrimeField, Zero}, + provekit_common::{ + witness::{ConstantOrR1CSWitness, LayerScheduler, WitnessBuilder}, + FieldElement, NoirElement, TranscriptSponge, + }, + provekit_prover::{bigint_mod::ec_scalar_mul, r1cs::solve_witness_vec}, + provekit_r1cs_compiler::{ + msm::{ + add_msm_with_curve_limbed, + cost_model::get_optimal_msm_params, + curve::{decompose_to_limbs, secp256r1_params}, + MsmLimbedOutputs, + }, + noir_to_r1cs::NoirToR1CSCompiler, + range_check::add_range_checks, + }, + std::collections::BTreeMap, + whir::transcript::{codecs::Empty, DomainSeparator, ProverState}, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Convert a [u64; 4] to a FieldElement. Panics if value exceeds BN254 Fr. +/// Only used for scalars (128-bit halves that always fit). +fn u256_to_fe(v: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*v)) + .unwrap_or_else(|| panic!("Value exceeds BN254 Fr: {v:?}")) +} + +/// Split a 256-bit scalar into (lo_128, hi_128) as [u64; 4] values. +fn split_scalar(s: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let lo = [s[0], s[1], 0, 0]; + let hi = [s[2], s[3], 0, 0]; + (lo, hi) +} + +/// Verify R1CS satisfaction: for each constraint row, A·w * B·w == C·w. +fn check_r1cs_satisfaction( + r1cs: &provekit_common::R1CS, + witness: &[FieldElement], +) -> anyhow::Result<()> { + use anyhow::ensure; + + ensure!( + witness.len() == r1cs.num_witnesses(), + "witness size {} != expected {}", + witness.len(), + r1cs.num_witnesses() + ); + + let a = r1cs.a() * witness; + let b = r1cs.b() * witness; + let c = r1cs.c() * witness; + for (row, ((a_val, b_val), c_val)) in a.into_iter().zip(b).zip(c).enumerate() { + ensure!( + a_val * b_val == c_val, + "Constraint {row} failed: a={a_val:?}, b={b_val:?}, a*b={:?}, c={c_val:?}", + a_val * b_val + ); + } + Ok(()) +} + +/// Create a dummy transcript for witness solving (no challenges needed). +fn dummy_transcript() -> ProverState { + let ds = DomainSeparator::protocol(&()).instance(&Empty); + ProverState::new(&ds, TranscriptSponge::default()) +} + +/// Solve all witness builders given initial witness values. +fn solve_witnesses( + builders: &[WitnessBuilder], + num_witnesses: usize, + initial_values: &[(usize, FieldElement)], +) -> Vec { + let layers = LayerScheduler::new(builders).build_layers(); + let mut witness: Vec> = vec![None; num_witnesses]; + + for &(idx, val) in initial_values { + witness[idx] = Some(val); + } + + let acir_map = WitnessMap::::new(); + let mut transcript = dummy_transcript(); + solve_witness_vec(&mut witness, layers, &acir_map, &mut transcript); + + witness + .into_iter() + .enumerate() + .map(|(i, w)| w.unwrap_or_else(|| panic!("Witness {i} was not solved"))) + .collect() +} + +/// Compute the (num_limbs, limb_bits) that the compiler will use for this +/// curve, so the test can decompose coordinates the same way. +fn msm_params_for_curve( + curve: &provekit_r1cs_compiler::msm::curve::CurveParams, + n_points: usize, +) -> (usize, u32) { + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let (limb_bits, _window_size) = + get_optimal_msm_params(native_bits, curve_bits, n_points, 256, false); + let num_limbs = (curve_bits as usize + limb_bits as usize - 1) / limb_bits as usize; + (num_limbs, limb_bits) +} + +/// Decompose a [u64; 4] value into field-element limbs. +fn u256_to_limb_fes(v: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(v, limb_bits, num_limbs) +} + +// --------------------------------------------------------------------------- +// Single-point limbed MSM test runner +// --------------------------------------------------------------------------- + +/// Compile and solve a single-point MSM circuit using the limbed API. +/// +/// When `expected_inf` is true, the expected output is point at infinity +/// (all output limbs zero, out_inf = 1). +fn run_single_point_msm_test_limbed( + px: &[u64; 4], + py: &[u64; 4], + inf: bool, + scalar: &[u64; 4], + expected_x: &[u64; 4], + expected_y: &[u64; 4], + expected_inf: bool, +) { + let curve = secp256r1_params(); + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 1); + let (s_lo, s_hi) = split_scalar(scalar); + let stride = 2 * num_limbs + 1; + + let px_fes = u256_to_limb_fes(px, limb_bits, num_limbs); + let py_fes = u256_to_limb_fes(py, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(expected_x, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(expected_y, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total_input_wits = stride + 2 + stride; + compiler.r1cs.add_witnesses(total_input_wits); + + let points: Vec = (0..stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let slo_w = base + stride; + let shi_w = base + stride + 1; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(slo_w), + ConstantOrR1CSWitness::Witness(shi_w), + ]; + + let out_base = base + stride + 2; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + // Set initial witness values + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in px_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in py_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + let inf_fe = if inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((base + 2 * num_limbs, inf_fe)); + initial_values.push((slo_w, u256_to_fe(&s_lo))); + initial_values.push((shi_w, u256_to_fe(&s_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + let out_inf_fe = if expected_inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((out_inf, out_inf_fe)); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed (limbed)"); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Single-point MSM using the secp256r1 generator directly. +/// The generator's x-coordinate exceeds BN254 Fr. +#[test] +fn test_single_point_generator() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [7, 0, 0, 0]; + let (ex, ey) = ec_scalar_mul(&gx, &gy, &scalar, &curve.curve_a, &curve.field_modulus_p); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Scalar = 1: result should equal the input point. +#[test] +fn test_scalar_one() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [1, 0, 0, 0]; + + // 1·G = G + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &gx, &gy, false); +} + +/// Large scalar spanning both lo and hi halves of the 256-bit representation. +#[test] +fn test_large_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [0xcafebabe, 0x12345678, 0x42, 0]; + let (ex, ey) = ec_scalar_mul(&gx, &gy, &scalar, &curve.curve_a, &curve.field_modulus_p); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Zero scalar: result should be point at infinity. +#[test] +fn test_zero_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let zero_scalar: [u64; 4] = [0, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed( + &gx, + &gy, + false, + &zero_scalar, + &zero_point, + &zero_point, + true, + ); +} + +/// Point at infinity as input: result should be point at infinity regardless +/// of scalar. +#[test] +fn test_point_at_infinity_input() { + let curve = secp256r1_params(); + // Use generator coords as placeholder (they're ignored due to inf=1 select) + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [42, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed(&gx, &gy, true, &scalar, &zero_point, &zero_point, true); +} + +/// Non-trivial point (2·G) with a moderate scalar, verifying the full +/// wNAF + FakeGLV pipeline. +#[test] +fn test_arbitrary_point_and_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + + // P = 2·G + let (px, py) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let scalar: [u64; 4] = [17, 0, 0, 0]; + // Expected: 17·(2G) = 34G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[34, 0, 0, 0], a, p); + + run_single_point_msm_test_limbed(&px, &py, false, &scalar, &ex, &ey, false); +} + +/// Two-point MSM: s1·P1 + s2·P2 with arbitrary coordinates. +#[test] +fn test_two_point_msm() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = 3·G, P2 = 5·G + let (p1x, p1y) = ec_scalar_mul(&gx, &gy, &[3, 0, 0, 0], a, p); + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + let s1: [u64; 4] = [2, 0, 0, 0]; + let s2: [u64; 4] = [3, 0, 0, 0]; + // Expected: 2·(3G) + 3·(5G) = 6G + 15G = 21G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[21, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&p1x, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&p1y, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM"); +} + +/// Two-point MSM where one scalar is zero — only the non-zero point +/// should contribute. +#[test] +fn test_two_point_one_zero_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = G (scalar=5), P2 = 2G (scalar=0) + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let s1: [u64; 4] = [5, 0, 0, 0]; + let s2: [u64; 4] = [0, 0, 0, 0]; + // Expected: 5·G + 0·(2G) = 5G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&gx, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&gy, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + // P1 limbs (generator) + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + // P2 limbs + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + // Scalars + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + // Expected output limbs + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM with one zero scalar"); +}