From ad0b5005b229d2d48a02fa43dc60d0f47ef9e775 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 26 Feb 2026 17:57:56 +0530 Subject: [PATCH 01/19] feat : ec_ops added with proper constraints for field element --- .../src/witness/scheduling/dependency.rs | 6 +- .../common/src/witness/scheduling/remapper.rs | 6 + .../common/src/witness/witness_builder.rs | 12 + provekit/prover/src/witness/bigint_mod.rs | 422 ++++++++++++++++++ provekit/prover/src/witness/mod.rs | 1 + .../prover/src/witness/witness_builder.rs | 18 + provekit/r1cs-compiler/src/lib.rs | 1 + provekit/r1cs-compiler/src/msm/curve.rs | 71 +++ provekit/r1cs-compiler/src/msm/ec_ops.rs | 305 +++++++++++++ provekit/r1cs-compiler/src/msm/mod.rs | 2 + provekit/r1cs-compiler/src/noir_to_r1cs.rs | 6 + 11 files changed, 849 insertions(+), 1 deletion(-) create mode 100644 provekit/prover/src/witness/bigint_mod.rs create mode 100644 provekit/r1cs-compiler/src/msm/curve.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_ops.rs create mode 100644 provekit/r1cs-compiler/src/msm/mod.rs diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index a5cbaefd6..f2d480377 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -78,7 +78,9 @@ impl DependencyInfo { WitnessBuilder::Sum(_, ops) => ops.iter().map(|SumTerm(_, idx)| *idx).collect(), WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), - WitnessBuilder::Inverse(_, x) => vec![*x], + WitnessBuilder::Inverse(_, x) + | WitnessBuilder::ModularInverse(_, x, _) + | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( _, sz, @@ -240,6 +242,8 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::ModularInverse(idx, ..) + | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) | WitnessBuilder::LogUpDenominator(idx, ..) | WitnessBuilder::LogUpInverse(idx, ..) diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 9503847a3..76c30bd23 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,12 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::ModularInverse(idx, operand, modulus) => { + WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) + } + WitnessBuilder::IntegerQuotient(idx, dividend, divisor) => { + WitnessBuilder::IntegerQuotient(self.remap(*idx), self.remap(*dividend), *divisor) + } WitnessBuilder::ProductLinearOperation( idx, ProductLinearTerm(x, a, b), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0628fc2e3..5719ee99e 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -88,6 +88,18 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// The modular inverse of the value at a specified witness index, modulo + /// a given prime modulus. Computes a^{-1} mod m using Fermat's little + /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this + /// operates as integer modular arithmetic. + /// (witness index, operand witness index, modulus) + ModularInverse(usize, usize, #[serde(with = "serde_ark")] FieldElement), + /// The integer quotient floor(dividend / divisor). Used by reduce_mod to + /// compute k = floor(v / m) so that v = k*m + result with 0 <= result < m. + /// Unlike field multiplication by the inverse, this performs true integer + /// division on the BigInteger representation. + /// (witness index, dividend witness index, divisor constant) + IntegerQuotient(usize, usize, #[serde(with = "serde_ark")] FieldElement), /// Products with linear operations on the witness indices. /// Fields are ProductLinearOperation(witness_idx, (index, a, b), (index, c, /// d)) such that we wish to compute (ax + b) * (cx + d). diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs new file mode 100644 index 000000000..3252aeea2 --- /dev/null +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -0,0 +1,422 @@ +/// BigInteger modular arithmetic on [u64; 4] limbs (256-bit). +/// +/// These helpers compute modular inverse via Fermat's little theorem: +/// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and +/// square-and-multiply exponentiation. + +/// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → +/// 512-bit). +fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let product = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = product as u64; + carry = product >> 64; + } + result[i + 4] = carry as u64; + } + result +} + +/// Compare 8-limb value with 4-limb value (zero-extended to 8 limbs). +/// Returns Ordering::Greater if wide > narrow, etc. +#[cfg(test)] +fn cmp_wide_narrow(wide: &[u64; 8], narrow: &[u64; 4]) -> std::cmp::Ordering { + // Check high limbs of wide (must all be zero for equality/less) + for i in (4..8).rev() { + if wide[i] != 0 { + return std::cmp::Ordering::Greater; + } + } + // Compare the low 4 limbs + for i in (0..4).rev() { + match wide[i].cmp(&narrow[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Modular reduction of a 512-bit value by a 256-bit modulus. +/// Uses bit-by-bit long division. +fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { + // Find the highest set bit in wide + let mut highest_bit = 0; + for i in (0..8).rev() { + if wide[i] != 0 { + highest_bit = i * 64 + (64 - wide[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [0u64; 4]; + } + + // Bit-by-bit long division + // remainder starts at 0, we shift in bits from the dividend + let mut remainder = [0u64; 4]; + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from wide + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + let bit = (wide[limb_idx] >> bit_idx) & 1; + remainder[0] |= bit; + + // If remainder >= modulus, subtract + if cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, modulus); + } + } + + remainder +} + +/// Left-shift a 4-limb number by 1 bit. Returns the carry-out bit. +fn shift_left_one(a: &mut [u64; 4]) -> u64 { + let mut carry = 0u64; + for limb in a.iter_mut() { + let new_carry = *limb >> 63; + *limb = (*limb << 1) | carry; + carry = new_carry; + } + carry +} + +/// Compare two 4-limb numbers. +fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { + for i in (0..4).rev() { + match a[i].cmp(&b[i]) { + std::cmp::Ordering::Equal => continue, + other => return other, + } + } + std::cmp::Ordering::Equal +} + +/// Subtract b from a in-place (a -= b). Assumes a >= b. +fn sub_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) { + let mut borrow = 0u64; + for i in 0..4 { + let (diff, borrow1) = a[i].overflowing_sub(b[i]); + let (diff2, borrow2) = diff.overflowing_sub(borrow); + a[i] = diff2; + borrow = (borrow1 as u64) + (borrow2 as u64); + } + debug_assert_eq!(borrow, 0, "subtraction underflow: a < b"); +} + +/// Modular multiplication: (a * b) mod m. +pub fn mul_mod(a: &[u64; 4], b: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + reduce_wide(&wide, m) +} + +/// Modular exponentiation: base^exp mod m using square-and-multiply. +pub fn mod_pow(base: &[u64; 4], exp: &[u64; 4], m: &[u64; 4]) -> [u64; 4] { + // Find highest set bit in exp + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // exp == 0 → result = 1 (for m > 1) + return [1, 0, 0, 0]; + } + + let mut result = [1u64, 0, 0, 0]; // 1 + for bit_pos in (0..highest_bit).rev() { + // Square + result = mul_mod(&result, &result, m); + // Multiply if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (exp[limb_idx] >> bit_idx) & 1 == 1 { + result = mul_mod(&result, base, m); + } + } + + result +} + +/// Integer division with remainder: dividend = quotient * divisor + remainder, +/// where 0 <= remainder < divisor. Uses bit-by-bit long division. +pub fn divmod(dividend: &[u64; 4], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + // Find the highest set bit in dividend + let mut highest_bit = 0; + for i in (0..4).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + // Left-shift remainder by 1 + let carry = shift_left_one(&mut remainder); + debug_assert_eq!(carry, 0, "remainder overflow during shift"); + + // Bring in the next bit from dividend + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If remainder >= divisor, subtract and set quotient bit + if cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + sub_4limb_inplace(&mut remainder, divisor); + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + (quotient, remainder) +} + +/// Subtract a small u64 value from a 4-limb number. Assumes a >= small. +pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { + let mut result = *a; + let (diff, borrow) = result[0].overflowing_sub(small); + result[0] = diff; + if borrow { + for limb in result[1..].iter_mut() { + let (d, b) = limb.overflowing_sub(1); + *limb = d; + if !b { + break; + } + } + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_widening_mul_small() { + // 3 * 7 = 21 + let a = [3, 0, 0, 0]; + let b = [7, 0, 0, 0]; + let result = widening_mul(&a, &b); + assert_eq!(result[0], 21); + assert_eq!(result[1..], [0; 7]); + } + + #[test] + fn test_widening_mul_overflow() { + // u64::MAX * u64::MAX = (2^64-1)^2 = 2^128 - 2^65 + 1 + let a = [u64::MAX, 0, 0, 0]; + let b = [u64::MAX, 0, 0, 0]; + let result = widening_mul(&a, &b); + // (2^64-1)^2 = 0xFFFFFFFFFFFFFFFE_0000000000000001 + assert_eq!(result[0], 1); + assert_eq!(result[1], u64::MAX - 1); + assert_eq!(result[2..], [0; 6]); + } + + #[test] + fn test_reduce_wide_no_reduction() { + // 5 mod 7 = 5 + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [5, 0, 0, 0]); + } + + #[test] + fn test_reduce_wide_basic() { + // 10 mod 7 = 3 + let wide = [10, 0, 0, 0, 0, 0, 0, 0]; + let modulus = [7, 0, 0, 0]; + assert_eq!(reduce_wide(&wide, &modulus), [3, 0, 0, 0]); + } + + #[test] + fn test_mul_mod_small() { + // (5 * 3) mod 7 = 15 mod 7 = 1 + let a = [5, 0, 0, 0]; + let b = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mul_mod(&a, &b, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_small() { + // 3^4 mod 7 = 81 mod 7 = 4 + let base = [3, 0, 0, 0]; + let exp = [4, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [4, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_small() { + // Inverse of 3 mod 7: 3^{7-2} = 3^5 mod 7 = 243 mod 7 = 5 + // Check: 3 * 5 = 15 = 2*7 + 1 ≡ 1 (mod 7) ✓ + let a = [3, 0, 0, 0]; + let m = [7, 0, 0, 0]; + let exp = sub_u64(&m, 2); // m - 2 = 5 + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [5, 0, 0, 0]); + // Verify: a * inv mod m = 1 + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_prime_23() { + // Inverse of 5 mod 23: 5^{21} mod 23 + // 5^{-1} mod 23 = 14 (because 5*14 = 70 = 3*23 + 1) + let a = [5, 0, 0, 0]; + let m = [23, 0, 0, 0]; + let exp = sub_u64(&m, 2); + let inv = mod_pow(&a, &exp, &m); + assert_eq!(inv, [14, 0, 0, 0]); + assert_eq!(mul_mod(&a, &inv, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_basic() { + assert_eq!(sub_u64(&[10, 0, 0, 0], 3), [7, 0, 0, 0]); + } + + #[test] + fn test_sub_u64_borrow() { + // [0, 1, 0, 0] = 2^64; subtract 1 → [u64::MAX, 0, 0, 0] + assert_eq!(sub_u64(&[0, 1, 0, 0], 1), [u64::MAX, 0, 0, 0]); + } + + #[test] + fn test_fermat_inverse_large_prime() { + // Use a 128-bit prime: p = 2^127 - 1 = 170141183460469231731687303715884105727 + // In limbs: [u64::MAX, 2^63 - 1, 0, 0] + let p = [u64::MAX, (1u64 << 63) - 1, 0, 0]; + + // a = 42 + let a = [42, 0, 0, 0]; + let exp = sub_u64(&p, 2); + let inv = mod_pow(&a, &exp, &p); + + // Verify: a * inv mod p = 1 + assert_eq!(mul_mod(&a, &inv, &p), [1, 0, 0, 0]); + } + + #[test] + fn test_cmp_wide_narrow() { + let wide = [5, 0, 0, 0, 0, 0, 0, 0]; + let narrow = [5, 0, 0, 0]; + assert_eq!(cmp_wide_narrow(&wide, &narrow), std::cmp::Ordering::Equal); + + let wide_greater = [0, 0, 0, 0, 1, 0, 0, 0]; + assert_eq!( + cmp_wide_narrow(&wide_greater, &narrow), + std::cmp::Ordering::Greater + ); + } + + #[test] + fn test_mod_pow_zero_exp() { + // a^0 mod m = 1 + let base = [42, 0, 0, 0]; + let exp = [0, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [1, 0, 0, 0]); + } + + #[test] + fn test_mod_pow_one_exp() { + // a^1 mod m = a mod m + let base = [10, 0, 0, 0]; + let exp = [1, 0, 0, 0]; + let m = [7, 0, 0, 0]; + assert_eq!(mod_pow(&base, &exp, &m), [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_exact() { + // 21 / 7 = 3 remainder 0 + let (q, r) = divmod(&[21, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_with_remainder() { + // 17 / 7 = 2 remainder 3 + let (q, r) = divmod(&[17, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [2, 0, 0, 0]); + assert_eq!(r, [3, 0, 0, 0]); + } + + #[test] + fn test_divmod_smaller_dividend() { + // 5 / 7 = 0 remainder 5 + let (q, r) = divmod(&[5, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_zero_dividend() { + let (q, r) = divmod(&[0, 0, 0, 0], &[7, 0, 0, 0]); + assert_eq!(q, [0, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_large() { + // 2^64 / 3 = 6148914691236517205 remainder 1 + // 2^64 in limbs: [0, 1, 0, 0] + let (q, r) = divmod(&[0, 1, 0, 0], &[3, 0, 0, 0]); + assert_eq!(q, [6148914691236517205, 0, 0, 0]); + assert_eq!(r, [1, 0, 0, 0]); + // Verify: q * 3 + 1 = 2^64 + assert_eq!(6148914691236517205u64 * 3 + 1, 0u64); // wraps to 0 in u64 = + // 2^64 + } + + #[test] + fn test_divmod_consistency() { + // Verify dividend = quotient * divisor + remainder for various inputs + let cases: Vec<([u64; 4], [u64; 4])> = vec![ + ([100, 0, 0, 0], [7, 0, 0, 0]), + ([u64::MAX, 0, 0, 0], [1000, 0, 0, 0]), + ([0, 1, 0, 0], [u64::MAX, 0, 0, 0]), // 2^64 / (2^64 - 1) + ]; + for (dividend, divisor) in cases { + let (q, r) = divmod(÷nd, &divisor); + // Verify: q * divisor + r = dividend + let product = widening_mul(&q, &divisor); + // Add remainder to product + let mut sum = product; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // sum should equal dividend (zero-extended to 8 limbs) + let mut expected = [0u64; 8]; + expected[..4].copy_from_slice(÷nd); + assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); + } + } +} diff --git a/provekit/prover/src/witness/mod.rs b/provekit/prover/src/witness/mod.rs index 5f5de8f0b..fb5072440 100644 --- a/provekit/prover/src/witness/mod.rs +++ b/provekit/prover/src/witness/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod bigint_mod; mod digits; mod ram; pub(crate) mod witness_builder; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index db91e5e0a..115637da6 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -65,6 +65,24 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { + let a = witness[*operand_idx].unwrap(); + let a_limbs = a.into_bigint().0; + let m_limbs = modulus.into_bigint().0; + // Fermat's little theorem: a^{-1} = a^{m-2} mod m + let exp = crate::witness::bigint_mod::sub_u64(&m_limbs, 2); + let result_limbs = crate::witness::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(result_limbs)).unwrap()); + } + WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { + let dividend = witness[*dividend_idx].unwrap(); + let d_limbs = dividend.into_bigint().0; + let m_limbs = divisor.into_bigint().0; + let (quotient, _remainder) = crate::witness::bigint_mod::divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = + Some(FieldElement::from_bigint(ark_ff::BigInt(quotient)).unwrap()); + } WitnessBuilder::IndexedLogUpDenominator( witness_idx, sz_challenge, diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 7de8f899b..a1ee6b1ce 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,6 +1,7 @@ mod binops; mod digits; mod memory; +mod msm; mod noir_proof_scheme; mod noir_to_r1cs; mod poseidon2; diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs new file mode 100644 index 000000000..cbfd1bf25 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -0,0 +1,71 @@ +use provekit_common::FieldElement; + +// TODO : remove Option<> form both the params if comes in use +// otherwise we delete the params from struct +pub struct CurveParams { + pub field_modulus_p: FieldElement, + pub curve_order_n: FieldElement, + pub curve_a: FieldElement, + pub curve_b: FieldElement, + pub generator: (FieldElement, FieldElement), + pub coordinate_bits: Option, +} + +pub fn secp256r1_params() -> CurveParams { + CurveParams { + field_modulus_p: FieldElement::from_sign_and_limbs( + true, + [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ] + .as_slice(), + ), + curve_order_n: FieldElement::from_sign_and_limbs( + true, + [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ] + .as_slice(), + ), + curve_a: FieldElement::from(-3), + curve_b: FieldElement::from_sign_and_limbs( + true, + [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ] + .as_slice(), + ), + generator: ( + FieldElement::from_sign_and_limbs( + true, + [ + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ] + .as_slice(), + ), + FieldElement::from_sign_and_limbs( + true, + [ + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ] + .as_slice(), + ), + ), + coordinate_bits: None, + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs new file mode 100644 index 000000000..1bd96e69f --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_ops.rs @@ -0,0 +1,305 @@ +use { + crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Reduce the value to given modulus +pub fn reduce_mod( + r1cs_compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + // Reduce mod algorithm : + // v = k * m + result, where 0 <= result < m + // k = floor(v / m) (integer division) + // result = v - k * m + + // Computing k = floor(v / m) + // ----------------------------------------------------------- + // computing m (constant witness for use in constraints) + let m = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + // computing k via integer division + let k = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + // Computing result = v - k * m + // ----------------------------------------------------------- + // computing k * m + let k_mul_m = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + // constraint: k * m = k_mul_m + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + // computing result = v - k * m + let result = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + // constraint: 1 * (k_mul_m + result) = value + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + // range check to prove 0 <= result < m + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(modulus_bits) + .or_insert_with(Vec::new) + .push(result); + + result +} + +/// a * b mod m +pub fn compute_field_mul( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + // constraint: a * b = a_mul_b + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod(r1cs_compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod m +pub fn compute_field_sub( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + // constraint: 1 * (a - b) = a_sub_b + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod(r1cs_compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod m +/// +/// CRITICAL: secp256r1's field_modulus_p (~2^256) > BN254 scalar field +/// (~2^254). Coordinates and the modulus do not fit in a single +/// FieldElement. Either use multi-limb representation or target a +/// curve that fits (e.g. Grumpkin, BabyJubJub). +pub fn compute_field_inv( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + // Computing a^(-1) mod m + // ----------------------------------------------------------- + // computing a_inv (the F_m inverse of a) via Fermat's little theorem + let a_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + // Verifying a * a_inv mod m = 1 + // ----------------------------------------------------------- + // computing a * a_inv + let product_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product(product_raw, a, a_inv)); + // constraint: a * a_inv = product_raw + r1cs_compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, a_inv)], &[ + (FieldElement::ONE, product_raw), + ]); + // reducing a * a_inv mod m — should give 1 if a_inv is correct + let reduced = reduce_mod(r1cs_compiler, product_raw, modulus, range_checks); + + // constraint: reduced = 1 + // (reduced - 1) * 1 = 0 + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, r1cs_compiler.witness_one()), + ], + &[(FieldElement::ZERO, r1cs_compiler.witness_one())], + ); + + // range check: a_inv in [0, 2^bits(m)) + let mod_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(mod_bits) + .or_insert_with(Vec::new) + .push(a_inv); + + a_inv +} + +/// Point doubling on y^2 = x^3 + ax + b (mod p) using affine lambda formula. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) (mod p) +/// x3 = lambda^2 - 2 * x1 (mod p) +/// y3 = lambda * (x1 - x3) - y1 (mod p) +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (compute_field_inv will fail to verify +/// 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double( + r1cs_compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + curve_params: &CurveParams, + range_checks: &mut BTreeMap>, +) -> (usize, usize) { + let p = curve_params.field_modulus_p; + + // Computing numerator = 3 * x1^2 + a (mod p) + // ----------------------------------------------------------- + // computing x1^2 mod p + let x1_sq = compute_field_mul(r1cs_compiler, x1, x1, p, range_checks); + // computing 3 * x1_sq + a + let a_witness = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(a_witness, curve_params.curve_a), + )); + let num_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(num_raw, vec![ + SumTerm(Some(FieldElement::from(3u64)), x1_sq), + SumTerm(Some(FieldElement::ONE), a_witness), + ])); + // constraint: 1 * (3 * x1_sq + a) = num_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::from(3u64), x1_sq), + (FieldElement::ONE, a_witness), + ], + &[(FieldElement::ONE, num_raw)], + ); + let numerator = reduce_mod(r1cs_compiler, num_raw, p, range_checks); + + // Computing denominator = 2 * y1 (mod p) + // ----------------------------------------------------------- + // computing 2 * y1 + let denom_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(denom_raw, vec![SumTerm( + Some(FieldElement::from(2u64)), + y1, + )])); + // constraint: 1 * (2 * y1) = denom_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::from(2u64), y1)], + &[(FieldElement::ONE, denom_raw)], + ); + let denominator = reduce_mod(r1cs_compiler, denom_raw, p, range_checks); + + // Computing lambda = numerator * denominator^(-1) (mod p) + // ----------------------------------------------------------- + // computing denominator^(-1) mod p + let denom_inv = compute_field_inv(r1cs_compiler, denominator, p, range_checks); + // computing lambda = numerator * denom_inv mod p + let lambda = compute_field_mul(r1cs_compiler, numerator, denom_inv, p, range_checks); + + // Computing x3 = lambda^2 - 2 * x1 (mod p) + // ----------------------------------------------------------- + // computing lambda^2 mod p + let lambda_sq = compute_field_mul(r1cs_compiler, lambda, lambda, p, range_checks); + // computing lambda^2 - 2 * x1 + let x3_raw = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(x3_raw, vec![ + SumTerm(Some(FieldElement::ONE), lambda_sq), + SumTerm(Some(-FieldElement::from(2u64)), x1), + ])); + // constraint: 1 * (lambda^2 - 2 * x1) = x3_raw + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[ + (FieldElement::ONE, lambda_sq), + (-FieldElement::from(2u64), x1), + ], + &[(FieldElement::ONE, x3_raw)], + ); + let x3 = reduce_mod(r1cs_compiler, x3_raw, p, range_checks); + + // Computing y3 = lambda * (x1 - x3) - y1 (mod p) + // ----------------------------------------------------------- + // computing x1 - x3 mod p + let x1_minus_x3 = compute_field_sub(r1cs_compiler, x1, x3, p, range_checks); + // computing lambda * (x1 - x3) mod p + let lambda_dx = compute_field_mul(r1cs_compiler, lambda, x1_minus_x3, p, range_checks); + // computing lambda * (x1 - x3) - y1 mod p + let y3 = compute_field_sub(r1cs_compiler, lambda_dx, y1, p, range_checks); + + (x3, y3) +} + +/// checks if value is zero or not +pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + // calculating v^(-1) + let value_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Inverse(value_inv, value)); + // calculating v * v^(-1) + let value_mul_value_inv = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + // calculate is_zero = 1 - (v * v^(-1)) + let is_zero = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(provekit_common::witness::WitnessBuilder::Sum( + is_zero, + vec![ + provekit_common::witness::SumTerm(Some(FieldElement::ONE), r1cs_compiler.witness_one()), + provekit_common::witness::SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ], + )); + // constraint: v × v^(-1) = 1 - is_zero + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, r1cs_compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // constraint: v × is_zero = 0 + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, r1cs_compiler.witness_one())], + ); + is_zero +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs new file mode 100644 index 000000000..3844d1466 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -0,0 +1,2 @@ +pub mod curve; +pub mod ec_ops; diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 189eb4693..2d4245636 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -16,6 +16,7 @@ use { Circuit, Opcode, }, native_types::{Expression, Witness as NoirWitness}, + BlackBoxFunc, }, anyhow::{bail, Result}, ark_ff::PrimeField, @@ -627,6 +628,11 @@ impl NoirToR1CSCompiler { output_witnesses, )); } + BlackBoxFuncCall::MultiScalarMul { + points, + scalars, + outputs, + } => {} _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } From e64cf9ed2cdd8608656874b6517c430803fda8da Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 28 Feb 2026 06:36:12 +0530 Subject: [PATCH 02/19] feat : added wide field ops for ec operations and added trait generics --- .../src/witness/scheduling/dependency.rs | 30 + .../common/src/witness/scheduling/remapper.rs | 54 ++ .../common/src/witness/witness_builder.rs | 59 ++ provekit/prover/src/witness/bigint_mod.rs | 437 +++++++++++++- .../prover/src/witness/witness_builder.rs | 194 ++++++ provekit/r1cs-compiler/src/msm/curve.rs | 165 +++-- provekit/r1cs-compiler/src/msm/ec_ops.rs | 165 ++--- provekit/r1cs-compiler/src/msm/ec_points.rs | 101 ++++ provekit/r1cs-compiler/src/msm/mod.rs | 141 +++++ provekit/r1cs-compiler/src/msm/wide_ops.rs | 563 ++++++++++++++++++ 10 files changed, 1714 insertions(+), 195 deletions(-) create mode 100644 provekit/r1cs-compiler/src/msm/ec_points.rs create mode 100644 provekit/r1cs-compiler/src/msm/wide_ops.rs diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index f2d480377..956f79b56 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -154,6 +154,28 @@ impl DependencyInfo { } v } + WitnessBuilder::MulModHint { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + WitnessBuilder::WideModularInverse { a_lo, a_hi, .. } => vec![*a_lo, *a_hi], + WitnessBuilder::WideAddQuotient { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + WitnessBuilder::WideSubBorrow { + a_lo, + a_hi, + b_lo, + b_hi, + .. + } => vec![*a_lo, *a_hi, *b_lo, *b_hi], WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -286,6 +308,14 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } + WitnessBuilder::MulModHint { output_start, .. } => { + (*output_start..*output_start + 20).collect() + } + WitnessBuilder::WideModularInverse { output_start, .. } => { + (*output_start..*output_start + 2).collect() + } + WitnessBuilder::WideAddQuotient { output, .. } => vec![*output], + WitnessBuilder::WideSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 76c30bd23..47490a6ce 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -221,6 +221,60 @@ impl WitnessIndexRemapper { .collect(), ) } + WitnessBuilder::MulModHint { + output_start, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => WitnessBuilder::MulModHint { + output_start: self.remap(*output_start), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + modulus: *modulus, + }, + WitnessBuilder::WideModularInverse { + output_start, + a_lo, + a_hi, + modulus, + } => WitnessBuilder::WideModularInverse { + output_start: self.remap(*output_start), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + modulus: *modulus, + }, + WitnessBuilder::WideAddQuotient { + output, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => WitnessBuilder::WideAddQuotient { + output: self.remap(*output), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + modulus: *modulus, + }, + WitnessBuilder::WideSubBorrow { + output, + a_lo, + a_hi, + b_lo, + b_hi, + } => WitnessBuilder::WideSubBorrow { + output: self.remap(*output), + a_lo: self.remap(*a_lo), + a_hi: self.remap(*a_hi), + b_lo: self.remap(*b_lo), + b_hi: self.remap(*b_hi), + }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), hi: self.remap(*hi), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 5719ee99e..6d11d17cf 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -201,6 +201,63 @@ pub enum WitnessBuilder { /// Inverse of combined lookup table entry denominator (constant operands). /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), + /// Prover hint for multi-limb modular multiplication: (a * b) mod p. + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// and a constant 256-bit modulus p, computes quotient q, remainder r, + /// their 86-bit decompositions, and carry witnesses. + /// + /// Outputs 20 witnesses starting at output_start: + /// [0..2) q_lo, q_hi (quotient in 128-bit limbs) + /// [2..4) r_lo, r_hi (remainder in 128-bit limbs) + /// [4..7) a_86_0, a_86_1, a_86_2 (a in 86-bit limbs) + /// [7..10) b_86_0, b_86_1, b_86_2 (b in 86-bit limbs) + /// [10..13) q_86_0, q_86_1, q_86_2 (q in 86-bit limbs) + /// [13..16) r_86_0, r_86_1, r_86_2 (r in 86-bit limbs) + /// [16..20) c0, c1, c2, c3 (carry witnesses, unsigned-offset) + MulModHint { + output_start: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide modular inverse: a^{-1} mod p. + /// Given input a = (a_lo, a_hi) as 128-bit limbs and constant modulus p, + /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). + /// + /// Outputs 2 witnesses at output_start: inv_lo, inv_hi (128-bit limbs). + WideModularInverse { + output_start: usize, + a_lo: usize, + a_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide addition quotient: q = floor((a + b) / p). + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// and a constant 256-bit modulus p, computes q ∈ {0, 1}. + /// + /// Outputs 1 witness at output: q. + WideAddQuotient { + output: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + modulus: [u64; 4], + }, + /// Prover hint for wide subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. + /// + /// Outputs 1 witness at output: q. + WideSubBorrow { + output: usize, + a_lo: usize, + a_hi: usize, + b_lo: usize, + b_hi: usize, + }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... @@ -272,6 +329,8 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, + WitnessBuilder::MulModHint { .. } => 20, + WitnessBuilder::WideModularInverse { .. } => 2, _ => 1, } diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs index 3252aeea2..a41f47ff3 100644 --- a/provekit/prover/src/witness/bigint_mod.rs +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -6,7 +6,7 @@ /// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → /// 512-bit). -fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { +pub fn widening_mul(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { let mut result = [0u64; 8]; for i in 0..4 { let mut carry = 0u128; @@ -90,7 +90,7 @@ fn shift_left_one(a: &mut [u64; 4]) -> u64 { } /// Compare two 4-limb numbers. -fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { +pub fn cmp_4limb(a: &[u64; 4], b: &[u64; 4]) -> std::cmp::Ordering { for i in (0..4).rev() { match a[i].cmp(&b[i]) { std::cmp::Ordering::Equal => continue, @@ -203,6 +203,235 @@ pub fn sub_u64(a: &[u64; 4], small: u64) -> [u64; 4] { result } +/// Add two 4-limb (256-bit) numbers, returning a 5-limb result with carry. +pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + result[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + result[4] = carry; + result +} + +/// Offset added to signed carries to make them non-negative for range checking. +/// Carries are bounded by |c| < 2^88, so adding 2^88 ensures c_unsigned >= 0. +pub const CARRY_OFFSET: u128 = 1u128 << 88; + +/// Integer division of a 512-bit dividend by a 256-bit divisor. +/// Returns (quotient, remainder) where both fit in 256 bits. +/// Panics if the quotient would exceed 256 bits. +pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let mut highest_bit = 0; + for i in (0..8).rev() { + if dividend[i] != 0 { + highest_bit = i * 64 + (64 - dividend[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return ([0u64; 4], [0u64; 4]); + } + + let mut quotient = [0u64; 4]; + let mut remainder = [0u64; 4]; + + for bit_pos in (0..highest_bit).rev() { + let shift_carry = shift_left_one(&mut remainder); + + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + remainder[0] |= (dividend[limb_idx] >> bit_idx) & 1; + + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit divisor, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, divisor) != std::cmp::Ordering::Less { + // Subtract divisor with inline borrow tracking (handles the case + // where remainder < divisor but shift_carry provides the extra bit). + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(divisor[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + // When shift_carry was set, the borrow absorbs it (they cancel out). + debug_assert_eq!( + borrow, shift_carry, + "unexpected borrow in divmod_wide at bit_pos {}", + bit_pos + ); + + assert!(bit_pos < 256, "quotient exceeds 256 bits"); + quotient[bit_pos / 64] |= 1u64 << (bit_pos % 64); + } + } + + (quotient, remainder) +} + +/// Split a 256-bit value into two 128-bit halves: (lo, hi). +pub fn decompose_128(val: &[u64; 4]) -> (u128, u128) { + let lo = val[0] as u128 | ((val[1] as u128) << 64); + let hi = val[2] as u128 | ((val[3] as u128) << 64); + (lo, hi) +} + +/// Split a 256-bit value into three 86-bit limbs: (l0, l1, l2). +/// l0 = bits [0..86), l1 = bits [86..172), l2 = bits [172..256). +pub fn decompose_86(val: &[u64; 4]) -> (u128, u128, u128) { + let mask_86: u128 = (1u128 << 86) - 1; + let lo128 = val[0] as u128 | ((val[1] as u128) << 64); + let hi128 = val[2] as u128 | ((val[3] as u128) << 64); + + let l0 = lo128 & mask_86; + // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 + let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; + // l2 = bits [172..256): 84 bits from hi128 + let l2 = hi128 >> 44; + + (l0, l1, l2) +} + +/// Compute carry values c0..c3 from the 86-bit schoolbook column equations +/// for the identity a*b = p*q + r (base W = 2^86). +/// +/// Column equations: +/// col0: a0*b0 - p0*q0 - r0 = c0*W +/// col1: a0*b1 + a1*b0 - p0*q1 - p1*q0 - r1 + c0 = c1*W +/// col2: a0*b2 + a1*b1 + a2*b0 - p0*q2 - p1*q1 - p2*q0 - r2 + c1 = c2*W +/// col3: a1*b2 + a2*b1 - p1*q2 - p2*q1 + c2 = c3*W +/// col4: a2*b2 - p2*q2 + c3 = 0 +pub fn compute_carries_86( + a: [u128; 3], + b: [u128; 3], + p: [u128; 3], + q: [u128; 3], + r: [u128; 3], +) -> [i128; 4] { + // Helper: convert u128 to [u64; 4] + fn to4(v: u128) -> [u64; 4] { + [v as u64, (v >> 64) as u64, 0, 0] + } + + // Helper: multiply two 86-bit values → [u64; 4] (result < 2^172) + fn mul86(x: u128, y: u128) -> [u64; 4] { + let w = widening_mul(&to4(x), &to4(y)); + [w[0], w[1], w[2], w[3]] + } + + // Helper: add two [u64; 4] values + fn add4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { + let mut r = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + let s = a[i] as u128 + b[i] as u128 + carry; + r[i] = s as u64; + carry = s >> 64; + } + r + } + + // Helper: subtract two [u64; 4] values (assumes a >= b) + fn sub4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { + let mut r = [0u64; 4]; + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + r[i] = d2; + borrow = b1 as u64 + b2 as u64; + } + r + } + + // Helper: right-shift [u64; 4] by 86 bits (= 64 + 22) + fn shr86(a: [u64; 4]) -> [u64; 4] { + let s = [a[1], a[2], a[3], 0u64]; + [ + (s[0] >> 22) | (s[1] << 42), + (s[1] >> 22) | (s[2] << 42), + s[2] >> 22, + 0, + ] + } + + // Positive column sums (a_i * b_j terms) + let pos = [ + mul86(a[0], b[0]), + add4(mul86(a[0], b[1]), mul86(a[1], b[0])), + add4( + add4(mul86(a[0], b[2]), mul86(a[1], b[1])), + mul86(a[2], b[0]), + ), + add4(mul86(a[1], b[2]), mul86(a[2], b[1])), + mul86(a[2], b[2]), + ]; + + // Negative column sums (p_i * q_j + r_i terms) + let neg = [ + add4(mul86(p[0], q[0]), to4(r[0])), + add4(add4(mul86(p[0], q[1]), mul86(p[1], q[0])), to4(r[1])), + add4( + add4( + add4(mul86(p[0], q[2]), mul86(p[1], q[1])), + mul86(p[2], q[0]), + ), + to4(r[2]), + ), + add4(mul86(p[1], q[2]), mul86(p[2], q[1])), + mul86(p[2], q[2]), + ]; + + let mut carries = [0i128; 4]; + let mut carry_pos = [0u64; 4]; + let mut carry_neg = [0u64; 4]; + + for col in 0..4 { + let total_pos = add4(pos[col], carry_pos); + let total_neg = add4(neg[col], carry_neg); + + let (is_neg, diff) = if cmp_4limb(&total_pos, &total_neg) != std::cmp::Ordering::Less { + (false, sub4(total_pos, total_neg)) + } else { + (true, sub4(total_neg, total_pos)) + }; + + // Lower 86 bits must be zero (divisibility check) + let mask_86 = (1u128 << 86) - 1; + let low86 = (diff[0] as u128 | ((diff[1] as u128) << 64)) & mask_86; + debug_assert_eq!(low86, 0, "column {} not divisible by W=2^86", col); + + let carry_mag = shr86(diff); + debug_assert_eq!(carry_mag[2], 0, "carry overflow in column {}", col); + debug_assert_eq!(carry_mag[3], 0, "carry overflow in column {}", col); + + let carry_val = carry_mag[0] as i128 | ((carry_mag[1] as i128) << 64); + carries[col] = if is_neg { -carry_val } else { carry_val }; + + if is_neg { + carry_pos = [0; 4]; + carry_neg = carry_mag; + } else { + carry_pos = carry_mag; + carry_neg = [0; 4]; + } + } + + // Verify column 4 balances + let final_pos = add4(pos[4], carry_pos); + let final_neg = add4(neg[4], carry_neg); + debug_assert_eq!( + final_pos, final_neg, + "column 4 should balance: a2*b2 - p2*q2 + c3 = 0" + ); + + carries +} + #[cfg(test)] mod tests { use super::*; @@ -384,8 +613,8 @@ mod tests { assert_eq!(q, [6148914691236517205, 0, 0, 0]); assert_eq!(r, [1, 0, 0, 0]); // Verify: q * 3 + 1 = 2^64 - assert_eq!(6148914691236517205u64 * 3 + 1, 0u64); // wraps to 0 in u64 = - // 2^64 + assert_eq!(6148914691236517205u64.wrapping_mul(3).wrapping_add(1), 0u64); + // wraps to 0 in u64 = 2^64 } #[test] @@ -419,4 +648,204 @@ mod tests { assert_eq!(sum, expected, "dividend={dividend:?} divisor={divisor:?}"); } } + + #[test] + fn test_divmod_wide_small() { + // 21 / 7 = 3 remainder 0 (512-bit dividend) + let dividend = [21, 0, 0, 0, 0, 0, 0, 0]; + let divisor = [7, 0, 0, 0]; + let (q, r) = divmod_wide(÷nd, &divisor); + assert_eq!(q, [3, 0, 0, 0]); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_large() { + // Compute a * b where a, b are 256-bit, then divide by a + // Should give quotient = b, remainder = 0 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; // secp256r1 p + let b = [42, 0, 0, 0]; + let product = widening_mul(&a, &b); + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [0, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_with_remainder() { + // (a * b + 5) / a = b remainder 5 + let a = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let b = [100, 0, 0, 0]; + let mut product = widening_mul(&a, &b); + // Add 5 + let (sum, overflow) = product[0].overflowing_add(5); + product[0] = sum; + if overflow { + for i in 1..8 { + let (s, o) = product[i].overflowing_add(1); + product[i] = s; + if !o { + break; + } + } + } + let (q, r) = divmod_wide(&product, &a); + assert_eq!(q, b); + assert_eq!(r, [5, 0, 0, 0]); + } + + #[test] + fn test_divmod_wide_consistency() { + // Verify: q * divisor + r = dividend + let a = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let b = [0xaabbccdd, 0x11223344, 0x55667788, 0x99001122]; + let product = widening_mul(&a, &b); + let divisor = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (q, r) = divmod_wide(&product, &divisor); + + // Verify: q * divisor + r = product + let qd = widening_mul(&q, &divisor); + let mut sum = qd; + let mut carry = 0u128; + for i in 0..4 { + let s = (sum[i] as u128) + (r[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = (sum[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + } + + #[test] + fn test_decompose_128_roundtrip() { + let val = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let (lo, hi) = decompose_128(&val); + // Roundtrip + assert_eq!(lo as u64, val[0]); + assert_eq!((lo >> 64) as u64, val[1]); + assert_eq!(hi as u64, val[2]); + assert_eq!((hi >> 64) as u64, val[3]); + } + + #[test] + fn test_decompose_86_roundtrip() { + let val = [ + 0x123456789abcdef0, + 0xfedcba9876543210, + 0x1111111111111111, + 0x2222222222222222, + ]; + let (l0, l1, l2) = decompose_86(&val); + + // Each limb should be < 2^86 + assert!(l0 < (1u128 << 86)); + assert!(l1 < (1u128 << 86)); + // l2 has at most 84 bits (256 - 172) + assert!(l2 < (1u128 << 84)); + + // Roundtrip: l0 + l1 * 2^86 + l2 * 2^172 should equal val + // Build from limbs back to [u64; 4] + let mut reconstructed = [0u128; 2]; // lo128, hi128 + reconstructed[0] = l0; + // l1 starts at bit 86 + reconstructed[0] |= (l1 & ((1u128 << 42) - 1)) << 86; // lower 42 bits of l1 into lo128 + reconstructed[1] = l1 >> 42; // upper 44 bits of l1 + // l2 starts at bit 172 = 128 + 44 + reconstructed[1] |= l2 << 44; + + assert_eq!(reconstructed[0] as u64, val[0]); + assert_eq!((reconstructed[0] >> 64) as u64, val[1]); + assert_eq!(reconstructed[1] as u64, val[2]); + assert_eq!((reconstructed[1] >> 64) as u64, val[3]); + } + + #[test] + fn test_decompose_86_secp256r1_p() { + // secp256r1 field modulus + let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let (l0, l1, l2) = decompose_86(&p); + assert!(l0 < (1u128 << 86)); + assert!(l1 < (1u128 << 86)); + assert!(l2 < (1u128 << 84)); + } + + #[test] + fn test_compute_carries_86_simple() { + // Test with small values: a=3, b=5, p=7 + // a*b = 15, 15 / 7 = 2 remainder 1 + // So q=2, r=1 + let a_val = [3u64, 0, 0, 0]; + let b_val = [5, 0, 0, 0]; + let p_val = [7, 0, 0, 0]; + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, &p_val); + assert_eq!(q_val, [2, 0, 0, 0]); + assert_eq!(r_val, [1, 0, 0, 0]); + + let (a0, a1, a2) = decompose_86(&a_val); + let (b0, b1, b2) = decompose_86(&b_val); + let (p0, p1, p2) = decompose_86(&p_val); + let (q0, q1, q2) = decompose_86(&q_val); + let (r0, r1, r2) = decompose_86(&r_val); + + let carries = compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ + r0, r1, r2, + ]); + // For small values, all carries should be 0 + assert_eq!(carries, [0, 0, 0, 0]); + } + + #[test] + fn test_compute_carries_86_secp256r1() { + // Test with secp256r1-sized values + let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; + let a_val = [0x123456789abcdef0, 0xfedcba9876543210, 0x0, 0x0]; // < p + let b_val = [0xaabbccddeeff0011, 0x1122334455667788, 0x0, 0x0]; // < p + + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, &p); + + // Verify a*b = p*q + r + let pq = widening_mul(&p, &q_val); + let mut sum = pq; + let mut carry = 0u128; + for i in 0..4 { + let s = sum[i] as u128 + r_val[i] as u128 + carry; + sum[i] = s as u64; + carry = s >> 64; + } + for i in 4..8 { + let s = sum[i] as u128 + carry; + sum[i] = s as u64; + carry = s >> 64; + } + assert_eq!(sum, product); + + // Compute 86-bit decompositions + let (a0, a1, a2) = decompose_86(&a_val); + let (b0, b1, b2) = decompose_86(&b_val); + let (p0, p1, p2) = decompose_86(&p); + let (q0, q1, q2) = decompose_86(&q_val); + let (r0, r1, r2) = decompose_86(&r_val); + + // This should not panic + let _carries = + compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ + r0, r1, r2, + ]); + } } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 115637da6..ae49cfcd5 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -337,6 +337,200 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } + WitnessBuilder::MulModHint { + output_start, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => { + use crate::witness::bigint_mod::{ + compute_carries_86, decompose_128, decompose_86, divmod_wide, widening_mul, + CARRY_OFFSET, + }; + + // Read inputs: a and b as 128-bit limb pairs + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + let b_lo_fe = witness[*b_lo].unwrap(); + let b_hi_fe = witness[*b_hi].unwrap(); + + // Reconstruct a, b as [u64; 4] + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = b_lo_fe.into_bigint().0; + let b_hi_limbs = b_hi_fe.into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + // Compute product and divmod + let product = widening_mul(&a_val, &b_val); + let (q_val, r_val) = divmod_wide(&product, modulus); + + // Decompose into 128-bit limbs + let (q_lo, q_hi) = decompose_128(&q_val); + let (r_lo, r_hi) = decompose_128(&r_val); + + // Decompose into 86-bit limbs + let (a86_0, a86_1, a86_2) = decompose_86(&a_val); + let (b86_0, b86_1, b86_2) = decompose_86(&b_val); + let (q86_0, q86_1, q86_2) = decompose_86(&q_val); + let (r86_0, r86_1, r86_2) = decompose_86(&r_val); + + // Compute carries + let carries = compute_carries_86( + [a86_0, a86_1, a86_2], + [b86_0, b86_1, b86_2], + { + let (p0, p1, p2) = decompose_86(modulus); + [p0, p1, p2] + }, + [q86_0, q86_1, q86_2], + [r86_0, r86_1, r86_2], + ); + + // Helper: convert u128 to FieldElement + let u128_to_fe = |val: u128| -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([ + val as u64, + (val >> 64) as u64, + 0, + 0, + ])) + .unwrap() + }; + + // Write outputs: [0..2) q_lo, q_hi + witness[*output_start] = Some(u128_to_fe(q_lo)); + witness[*output_start + 1] = Some(u128_to_fe(q_hi)); + // [2..4) r_lo, r_hi + witness[*output_start + 2] = Some(u128_to_fe(r_lo)); + witness[*output_start + 3] = Some(u128_to_fe(r_hi)); + // [4..7) a_86 limbs + witness[*output_start + 4] = Some(u128_to_fe(a86_0)); + witness[*output_start + 5] = Some(u128_to_fe(a86_1)); + witness[*output_start + 6] = Some(u128_to_fe(a86_2)); + // [7..10) b_86 limbs + witness[*output_start + 7] = Some(u128_to_fe(b86_0)); + witness[*output_start + 8] = Some(u128_to_fe(b86_1)); + witness[*output_start + 9] = Some(u128_to_fe(b86_2)); + // [10..13) q_86 limbs + witness[*output_start + 10] = Some(u128_to_fe(q86_0)); + witness[*output_start + 11] = Some(u128_to_fe(q86_1)); + witness[*output_start + 12] = Some(u128_to_fe(q86_2)); + // [13..16) r_86 limbs + witness[*output_start + 13] = Some(u128_to_fe(r86_0)); + witness[*output_start + 14] = Some(u128_to_fe(r86_1)); + witness[*output_start + 15] = Some(u128_to_fe(r86_2)); + // [16..20) carries (unsigned-offset) + for i in 0..4 { + let c_unsigned = (carries[i] + CARRY_OFFSET as i128) as u128; + witness[*output_start + 16 + i] = Some(u128_to_fe(c_unsigned)); + } + } + WitnessBuilder::WideModularInverse { + output_start, + a_lo, + a_hi, + modulus, + } => { + use crate::witness::bigint_mod::{decompose_128, mod_pow, sub_u64}; + + // Read input a as 128-bit limb pair + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + // Compute inverse: a^{p-2} mod p (Fermat's little theorem) + let exp = sub_u64(modulus, 2); + let inv = mod_pow(&a_val, &exp, modulus); + + // Decompose into 128-bit limbs + let (inv_lo, inv_hi) = decompose_128(&inv); + + let u128_to_fe = |val: u128| -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([ + val as u64, + (val >> 64) as u64, + 0, + 0, + ])) + .unwrap() + }; + + witness[*output_start] = Some(u128_to_fe(inv_lo)); + witness[*output_start + 1] = Some(u128_to_fe(inv_hi)); + } + WitnessBuilder::WideAddQuotient { + output, + a_lo, + a_hi, + b_lo, + b_hi, + modulus, + } => { + use crate::witness::bigint_mod::{add_4limb, cmp_4limb}; + + let a_lo_fe = witness[*a_lo].unwrap(); + let a_hi_fe = witness[*a_hi].unwrap(); + let b_lo_fe = witness[*b_lo].unwrap(); + let b_hi_fe = witness[*b_hi].unwrap(); + + let a_lo_limbs = a_lo_fe.into_bigint().0; + let a_hi_limbs = a_hi_fe.into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = b_lo_fe.into_bigint().0; + let b_hi_limbs = b_hi_fe.into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + let sum = add_4limb(&a_val, &b_val); + // q = 1 if sum >= p, else 0 + let q = if sum[4] > 0 { + // sum > 2^256 > any 256-bit modulus + 1u64 + } else { + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if cmp_4limb(&sum4, modulus) != std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + } + }; + + witness[*output] = Some(FieldElement::from(q)); + } + WitnessBuilder::WideSubBorrow { + output, + a_lo, + a_hi, + b_lo, + b_hi, + } => { + use crate::witness::bigint_mod::cmp_4limb; + + let a_lo_limbs = witness[*a_lo].unwrap().into_bigint().0; + let a_hi_limbs = witness[*a_hi].unwrap().into_bigint().0; + let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + + let b_lo_limbs = witness[*b_lo].unwrap().into_bigint().0; + let b_hi_limbs = witness[*b_hi].unwrap().into_bigint().0; + let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + + // q = 1 if a < b (need to add p to make result non-negative) + let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { + 1u64 + } else { + 0u64 + }; + + witness[*output] = Some(FieldElement::from(q)); + } WitnessBuilder::BytePartition { lo, hi, x, k } => { let x_val = witness[*x].unwrap().into_bigint().0[0]; debug_assert!(x_val < 256, "BytePartition input must be 8-bit"); diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index cbfd1bf25..d4d0d247b 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,71 +1,116 @@ -use provekit_common::FieldElement; +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + provekit_common::{ + witness::{ConstantTerm, WitnessBuilder}, + FieldElement, + }, +}; -// TODO : remove Option<> form both the params if comes in use -// otherwise we delete the params from struct pub struct CurveParams { - pub field_modulus_p: FieldElement, - pub curve_order_n: FieldElement, - pub curve_a: FieldElement, - pub curve_b: FieldElement, - pub generator: (FieldElement, FieldElement), - pub coordinate_bits: Option, + pub field_modulus_p: [u64; 4], + pub curve_order_n: [u64; 4], + pub curve_a: [u64; 4], + pub curve_b: [u64; 4], + pub generator: ([u64; 4], [u64; 4]), +} + +impl CurveParams { + pub fn p_lo_fe(&self) -> FieldElement { + decompose_128(self.field_modulus_p).0 + } + pub fn p_hi_fe(&self) -> FieldElement { + decompose_128(self.field_modulus_p).1 + } + pub fn p_86_limbs(&self) -> [FieldElement; 3] { + let mask_86: u128 = (1u128 << 86) - 1; + let lo128 = self.field_modulus_p[0] as u128 | ((self.field_modulus_p[1] as u128) << 64); + let hi128 = self.field_modulus_p[2] as u128 | ((self.field_modulus_p[3] as u128) << 64); + let l0 = lo128 & mask_86; + // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 + let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; + // l2 = bits [172..256): 84 bits from hi128 + let l2 = hi128 >> 44; + [ + FieldElement::from(l0), + FieldElement::from(l1), + FieldElement::from(l2), + ] + } + pub fn p_native_fe(&self) -> FieldElement { + curve_native_point_fe(&self.field_modulus_p) + } +} + +/// Splits a 256-bit value ([u64; 4]) into two 128-bit field elements (lo, hi). +fn decompose_128(val: [u64; 4]) -> (FieldElement, FieldElement) { + ( + FieldElement::from((val[0] as u128) | ((val[1] as u128) << 64)), + FieldElement::from((val[2] as u128) | ((val[3] as u128) << 64)), + ) +} + +/// Converts a 256-bit value ([u64; 4]) into a single native field element. +pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_sign_and_limbs(true, val) +} + +#[derive(Clone, Copy, Debug)] +pub struct Limb2 { + pub lo: usize, + pub hi: usize, +} + +pub fn limb2_constant(r1cs_compiler: &mut NoirToR1CSCompiler, value: [u64; 4]) -> Limb2 { + let (lo, hi) = decompose_128(value); + let lo_idx = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(lo_idx, lo))); + let hi_idx = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(hi_idx, hi))); + Limb2 { + lo: lo_idx, + hi: hi_idx, + } } pub fn secp256r1_params() -> CurveParams { CurveParams { - field_modulus_p: FieldElement::from_sign_and_limbs( - true, - [ - 0xffffffffffffffff_u64, - 0xffffffff_u64, - 0x0_u64, - 0xffffffff00000001_u64, - ] - .as_slice(), - ), - curve_order_n: FieldElement::from_sign_and_limbs( - true, + field_modulus_p: [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ], + curve_order_n: [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ], + curve_a: [ + 0xfffffffffffffffc_u64, + 0x00000000ffffffff_u64, + 0x0000000000000000_u64, + 0xffffffff00000001_u64, + ], + curve_b: [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ], + generator: ( [ - 0xf3b9cac2fc632551_u64, - 0xbce6faada7179e84_u64, - 0xffffffffffffffff_u64, - 0xffffffff00000000_u64, - ] - .as_slice(), - ), - curve_a: FieldElement::from(-3), - curve_b: FieldElement::from_sign_and_limbs( - true, + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ], [ - 0x3bce3c3e27d2604b_u64, - 0x651d06b0cc53b0f6_u64, - 0xb3ebbd55769886bc_u64, - 0x5ac635d8aa3a93e7_u64, - ] - .as_slice(), - ), - generator: ( - FieldElement::from_sign_and_limbs( - true, - [ - 0xf4a13945d898c296_u64, - 0x77037d812deb33a0_u64, - 0xf8bce6e563a440f2_u64, - 0x6b17d1f2e12c4247_u64, - ] - .as_slice(), - ), - FieldElement::from_sign_and_limbs( - true, - [ - 0xcbb6406837bf51f5_u64, - 0x2bce33576b315ece_u64, - 0x8ee7eb4a7c0f9e16_u64, - 0x4fe342e2fe1a7f9b_u64, - ] - .as_slice(), - ), + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ], ), - coordinate_bits: None, } } diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs index 1bd96e69f..985937821 100644 --- a/provekit/r1cs-compiler/src/msm/ec_ops.rs +++ b/provekit/r1cs-compiler/src/msm/ec_ops.rs @@ -1,5 +1,5 @@ use { - crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, provekit_common::{ witness::{SumTerm, WitnessBuilder}, @@ -9,7 +9,7 @@ use { }; /// Reduce the value to given modulus -pub fn reduce_mod( +pub fn reduce_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, value: usize, modulus: FieldElement, @@ -65,8 +65,30 @@ pub fn reduce_mod( result } -/// a * b mod m -pub fn compute_field_mul( +/// a + b mod p +pub fn add_mod_p( + r1cs_compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = r1cs_compiler.num_witnesses(); + r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + // constraint: a + b = a_add_b + r1cs_compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, r1cs_compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(r1cs_compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p +pub fn mul_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, b: usize, @@ -82,11 +104,11 @@ pub fn compute_field_mul( FieldElement::ONE, a_mul_b, )]); - reduce_mod(r1cs_compiler, a_mul_b, modulus, range_checks) + reduce_mod_p(r1cs_compiler, a_mul_b, modulus, range_checks) } -/// (a - b) mod m -pub fn compute_field_sub( +/// (a - b) mod p +pub fn sub_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, b: usize, @@ -104,16 +126,11 @@ pub fn compute_field_sub( &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], &[(FieldElement::ONE, a_sub_b)], ); - reduce_mod(r1cs_compiler, a_sub_b, modulus, range_checks) + reduce_mod_p(r1cs_compiler, a_sub_b, modulus, range_checks) } -/// a^(-1) mod m -/// -/// CRITICAL: secp256r1's field_modulus_p (~2^256) > BN254 scalar field -/// (~2^254). Coordinates and the modulus do not fit in a single -/// FieldElement. Either use multi-limb representation or target a -/// curve that fits (e.g. Grumpkin, BabyJubJub). -pub fn compute_field_inv( +/// a^(-1) mod p +pub fn inv_mod_p( r1cs_compiler: &mut NoirToR1CSCompiler, a: usize, modulus: FieldElement, @@ -127,17 +144,8 @@ pub fn compute_field_inv( // Verifying a * a_inv mod m = 1 // ----------------------------------------------------------- - // computing a * a_inv - let product_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(product_raw, a, a_inv)); - // constraint: a * a_inv = product_raw - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, a_inv)], &[ - (FieldElement::ONE, product_raw), - ]); - // reducing a * a_inv mod m — should give 1 if a_inv is correct - let reduced = reduce_mod(r1cs_compiler, product_raw, modulus, range_checks); + // computing a * a_inv mod m + let reduced = mul_mod_p(r1cs_compiler, a, a_inv, modulus, range_checks); // constraint: reduced = 1 // (reduced - 1) * 1 = 0 @@ -160,111 +168,6 @@ pub fn compute_field_inv( a_inv } -/// Point doubling on y^2 = x^3 + ax + b (mod p) using affine lambda formula. -/// -/// Given P = (x1, y1), computes 2P = (x3, y3): -/// lambda = (3 * x1^2 + a) / (2 * y1) (mod p) -/// x3 = lambda^2 - 2 * x1 (mod p) -/// y3 = lambda * (x1 - x3) - y1 (mod p) -/// -/// Edge case — y1 = 0 (point of order 2): -/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. -/// The result should be the point at infinity (identity element). -/// This function does NOT handle that case — the constraint system will -/// be unsatisfiable if y1 = 0 (compute_field_inv will fail to verify -/// 0 * inv = 1 mod p). The caller must check y1 = 0 using -/// compute_is_zero and conditionally select the point-at-infinity -/// result before calling this function. -pub fn point_double( - r1cs_compiler: &mut NoirToR1CSCompiler, - x1: usize, - y1: usize, - curve_params: &CurveParams, - range_checks: &mut BTreeMap>, -) -> (usize, usize) { - let p = curve_params.field_modulus_p; - - // Computing numerator = 3 * x1^2 + a (mod p) - // ----------------------------------------------------------- - // computing x1^2 mod p - let x1_sq = compute_field_mul(r1cs_compiler, x1, x1, p, range_checks); - // computing 3 * x1_sq + a - let a_witness = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( - provekit_common::witness::ConstantTerm(a_witness, curve_params.curve_a), - )); - let num_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(num_raw, vec![ - SumTerm(Some(FieldElement::from(3u64)), x1_sq), - SumTerm(Some(FieldElement::ONE), a_witness), - ])); - // constraint: 1 * (3 * x1_sq + a) = num_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::from(3u64), x1_sq), - (FieldElement::ONE, a_witness), - ], - &[(FieldElement::ONE, num_raw)], - ); - let numerator = reduce_mod(r1cs_compiler, num_raw, p, range_checks); - - // Computing denominator = 2 * y1 (mod p) - // ----------------------------------------------------------- - // computing 2 * y1 - let denom_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(denom_raw, vec![SumTerm( - Some(FieldElement::from(2u64)), - y1, - )])); - // constraint: 1 * (2 * y1) = denom_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::from(2u64), y1)], - &[(FieldElement::ONE, denom_raw)], - ); - let denominator = reduce_mod(r1cs_compiler, denom_raw, p, range_checks); - - // Computing lambda = numerator * denominator^(-1) (mod p) - // ----------------------------------------------------------- - // computing denominator^(-1) mod p - let denom_inv = compute_field_inv(r1cs_compiler, denominator, p, range_checks); - // computing lambda = numerator * denom_inv mod p - let lambda = compute_field_mul(r1cs_compiler, numerator, denom_inv, p, range_checks); - - // Computing x3 = lambda^2 - 2 * x1 (mod p) - // ----------------------------------------------------------- - // computing lambda^2 mod p - let lambda_sq = compute_field_mul(r1cs_compiler, lambda, lambda, p, range_checks); - // computing lambda^2 - 2 * x1 - let x3_raw = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(x3_raw, vec![ - SumTerm(Some(FieldElement::ONE), lambda_sq), - SumTerm(Some(-FieldElement::from(2u64)), x1), - ])); - // constraint: 1 * (lambda^2 - 2 * x1) = x3_raw - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::ONE, lambda_sq), - (-FieldElement::from(2u64), x1), - ], - &[(FieldElement::ONE, x3_raw)], - ); - let x3 = reduce_mod(r1cs_compiler, x3_raw, p, range_checks); - - // Computing y3 = lambda * (x1 - x3) - y1 (mod p) - // ----------------------------------------------------------- - // computing x1 - x3 mod p - let x1_minus_x3 = compute_field_sub(r1cs_compiler, x1, x3, p, range_checks); - // computing lambda * (x1 - x3) mod p - let lambda_dx = compute_field_mul(r1cs_compiler, lambda, x1_minus_x3, p, range_checks); - // computing lambda * (x1 - x3) - y1 mod p - let y3 = compute_field_sub(r1cs_compiler, lambda_dx, y1, p, range_checks); - - (x3, y3) -} - /// checks if value is zero or not pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { // calculating v^(-1) diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs new file mode 100644 index 000000000..d607d25ff --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -0,0 +1,101 @@ +use super::FieldOps; + +/// Generic point doubling on y^2 = x^3 + ax + b. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) +/// x3 = lambda^2 - 2 * x1 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (the inverse verification will fail to +/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double(ops: &mut F, x1: F::Elem, y1: F::Elem) -> (F::Elem, F::Elem) { + let a = ops.curve_a(); + + // Computing numerator = 3 * x1^2 + a + let x1_sq = ops.mul(x1, x1); + let two_x1_sq = ops.add(x1_sq, x1_sq); + let three_x1_sq = ops.add(two_x1_sq, x1_sq); + let numerator = ops.add(three_x1_sq, a); + + // Computing denominator = 2 * y1 + let denominator = ops.add(y1, y1); + + // Computing lambda = numerator * denominator^(-1) + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - 2 * x1 + let lambda_sq = ops.mul(lambda, lambda); + let two_x1 = ops.add(x1, x1); + let x3 = ops.sub(lambda_sq, two_x1); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Generic point addition on y^2 = x^3 + ax + b. +/// +/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): +/// lambda = (y2 - y1) / (x2 - x1) +/// x3 = lambda^2 - x1 - x2 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge cases — x1 = x2: +/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does +/// not exist. This covers two cases: +/// - P1 = P2 (same point): use `point_double` instead. +/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. +/// This function does NOT handle either case — the constraint system +/// will be unsatisfiable if x1 = x2. The caller must detect this +/// and branch accordingly. +pub fn point_add( + ops: &mut F, + x1: F::Elem, + y1: F::Elem, + x2: F::Elem, + y2: F::Elem, +) -> (F::Elem, F::Elem) { + // Computing lambda = (y2 - y1) / (x2 - x1) + let numerator = ops.sub(y2, y1); + let denominator = ops.sub(x2, x1); + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - x1 - x2 + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Conditional point select: returns `on_true` if `flag` is 1, `on_false` if +/// `flag` is 0. +/// +/// Constrains `flag` to be boolean (`flag * flag = flag`). +pub fn point_select( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select(flag, on_false.0, on_true.0); + let y = ops.select(flag, on_false.1, on_true.1); + (x, y) +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 3844d1466..a155a6def 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,2 +1,143 @@ pub mod curve; pub mod ec_ops; +pub mod ec_points; +pub mod wide_ops; + +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, + curve::{curve_native_point_fe, limb2_constant, CurveParams, Limb2}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +pub trait FieldOps { + type Elem: Copy; + + fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn inv(&mut self, a: Self::Elem) -> Self::Elem; + fn curve_a(&mut self) -> Self::Elem; + + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; +} + +/// Narrow field operations for curves where p fits in BN254's scalar field. +/// Operates on single witness indices (`usize`). +pub struct NarrowOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub modulus: FieldElement, + pub params: &'a CurveParams, +} + +impl FieldOps for NarrowOps<'_> { + type Elem = usize; + + fn add(&mut self, a: usize, b: usize) -> usize { + ec_ops::add_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn sub(&mut self, a: usize, b: usize) -> usize { + ec_ops::sub_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn mul(&mut self, a: usize, b: usize) -> usize { + ec_ops::mul_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + } + + fn inv(&mut self, a: usize) -> usize { + ec_ops::inv_mod_p(self.compiler, a, self.modulus, self.range_checks) + } + + fn curve_a(&mut self) -> usize { + let a_fe = curve_native_point_fe(&self.params.curve_a); + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, a_fe))); + w + } + + fn select(&mut self, flag: usize, on_false: usize, on_true: usize) -> usize { + constrain_boolean(self.compiler, flag); + select_witness(self.compiler, flag, on_false, on_true) + } +} + +/// Wide field operations for curves where p > BN254_r (e.g. secp256r1). +/// Operates on `Limb2` (two 128-bit limbs). +pub struct WideOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: &'a CurveParams, +} + +impl FieldOps for WideOps<'_> { + type Elem = Limb2; + + fn add(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::add_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn sub(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::sub_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn mul(&mut self, a: Limb2, b: Limb2) -> Limb2 { + wide_ops::mul_mod_p(self.compiler, self.range_checks, a, b, self.params) + } + + fn inv(&mut self, a: Limb2) -> Limb2 { + wide_ops::inv_mod_p(self.compiler, self.range_checks, a, self.params) + } + + fn curve_a(&mut self) -> Limb2 { + limb2_constant(self.compiler, self.params.curve_a) + } + + fn select(&mut self, flag: usize, on_false: Limb2, on_true: Limb2) -> Limb2 { + constrain_boolean(self.compiler, flag); + Limb2 { + lo: select_witness(self.compiler, flag, on_false.lo, on_true.lo), + hi: select_witness(self.compiler, flag, on_false.hi, on_true.hi), + } + } +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +/// +/// Produces 3 witnesses and 3 R1CS constraints (diff, flag*diff, out). +/// Does NOT constrain `flag` to be boolean — caller must do that separately. +fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + let diff = compiler.add_sum(vec![ + SumTerm(None, on_true), + SumTerm(Some(-FieldElement::ONE), on_false), + ]); + let flag_diff = compiler.add_product(flag, diff); + compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) +} diff --git a/provekit/r1cs-compiler/src/msm/wide_ops.rs b/provekit/r1cs-compiler/src/msm/wide_ops.rs new file mode 100644 index 000000000..167d6f986 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/wide_ops.rs @@ -0,0 +1,563 @@ +use { + crate::{ + msm::curve::{CurveParams, Limb2}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::Field, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// (a + b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Equation: a + b = q * p + r, where q ∈ {0, 1}, 0 ≤ r < p. +/// +/// Uses the offset trick to avoid negative intermediate values: +/// v_offset = a_lo + b_lo + 2^128 - q * p_lo (always ≥ 0) +/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} +/// r_lo = v_offset - carry_offset * 2^128 +/// r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi +/// +/// Less-than-p check (proves r < p): +/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) +/// +/// Constraints (7 total): +/// 1. q is boolean: q * q = q +/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * +/// 2^128 +/// 4. Column 1: r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi +/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 +/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi +/// +/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) +pub fn add_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // Witness: q = floor((a + b) / p) ∈ {0, 1} + // ----------------------------------------------------------- + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideAddQuotient { + output: q, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + modulus: params.field_modulus_p, + }); + // constraining q to be boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + // Computing r_lo: lower 128 bits of result + // ----------------------------------------------------------- + // v_offset = a_lo + b_lo + 2^128 - q * p_lo + // (2^128 offset ensures v_offset is always non-negative) + let v_offset = compiler.add_sum(vec![ + SumTerm(None, a.lo), + SumTerm(None, b.lo), + SumTerm(Some(two_128), w1), + SumTerm(Some(-p_lo_fe), q), + ]); + // computing carry_offset = floor(v_offset / 2^128) + let carry_offset = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry_offset, + v_offset, + two_128, + )); + // computing r_lo = v_offset - carry_offset * 2^128 + let r_lo = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_128), carry_offset), + ]); + + // Computing r_hi: upper 128 bits of result + // ----------------------------------------------------------- + // r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi + // (-1 compensates for the 2^128 offset added in the low column) + let r_hi = compiler.add_sum(vec![ + SumTerm(None, a.hi), + SumTerm(None, b.hi), + SumTerm(None, carry_offset), + SumTerm(Some(-FieldElement::ONE), w1), + SumTerm(Some(-p_hi_fe), q), + ]); + + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// (a - b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Equation: a - b + q * p = r, where q ∈ {0, 1}, 0 ≤ r < p. +/// q = 0 if a ≥ b (result is non-negative without correction) +/// q = 1 if a < b (add p to make result non-negative) +/// +/// Uses the offset trick to avoid negative intermediate values: +/// v_offset = a_lo - b_lo + q * p_lo + 2^128 (always ≥ 0) +/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} +/// r_lo = v_offset - carry_offset * 2^128 +/// r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 +/// +/// Less-than-p check (proves r < p): +/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) +/// +/// Constraints (7 total): +/// 1. q is boolean: q * q = q +/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * +/// 2^128 +/// 4. Column 1: r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 +/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 +/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi +/// +/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) +pub fn sub_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // Witness: q = (a < b) ? 1 : 0 + // ----------------------------------------------------------- + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideSubBorrow { + output: q, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + }); + // constraining q to be boolean + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); + + // Computing r_lo: lower 128 bits of result + // ----------------------------------------------------------- + // v_offset = a_lo - b_lo + q * p_lo + 2^128 + // (2^128 offset ensures v_offset is always non-negative) + let v_offset = compiler.add_sum(vec![ + SumTerm(None, a.lo), + SumTerm(Some(-FieldElement::ONE), b.lo), + SumTerm(Some(p_lo_fe), q), + SumTerm(Some(two_128), w1), + ]); + // computing carry_offset = floor(v_offset / 2^128) + let carry_offset = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry_offset, + v_offset, + two_128, + )); + // computing r_lo = v_offset - carry_offset * 2^128 + let r_lo = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_128), carry_offset), + ]); + + // Computing r_hi: upper 128 bits of result + // ----------------------------------------------------------- + // r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 + // (-1 compensates for the 2^128 offset added in the low column) + let r_hi = compiler.add_sum(vec![ + SumTerm(None, a.hi), + SumTerm(Some(-FieldElement::ONE), b.hi), + SumTerm(Some(p_hi_fe), q), + SumTerm(None, carry_offset), + SumTerm(Some(-FieldElement::ONE), w1), + ]); + + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// (a × b) mod p for 256-bit values in two 128-bit limbs. +/// +/// Verifies the integer identity `a * b = p * q + r` using schoolbook +/// multiplication in base W = 2^86 (86-bit limbs ensure all column +/// products < 2^172 ≪ BN254_r ≈ 2^254, so field equations = integer equations). +/// +/// Three layers of verification: +/// 1. Decomposition links: prove 86-bit witnesses match the 128-bit +/// inputs/outputs +/// 2. Column equations: prove a86 * b86 = p86 * q86 + r86 (integer) +/// 3. Less-than-p check: prove r < p +/// +/// Witness layout (MulModHint, 20 witnesses at output_start): +/// [0..2) q_lo, q_hi — quotient 128-bit limbs (unconstrained) +/// [2..4) r_lo, r_hi — remainder 128-bit limbs (OUTPUT) +/// [4..7) a86_0..2 — a in 86-bit limbs +/// [7..10) b86_0..2 — b in 86-bit limbs +/// [10..13) q86_0..2 — q in 86-bit limbs +/// [13..16) r86_0..2 — r in 86-bit limbs +/// [16..20) c0u..c3u — unsigned-offset carries (c_signed + 2^88) +/// +/// Constraints (26 total): +/// 9 decomposition links (a, b, r × 3 each) +/// 9 product witnesses (a_i × b_j) +/// 5 column equations +/// 3 less-than-p check +/// +/// Range checks (23 total): +/// 128-bit: r_lo, r_hi, d_lo, d_hi +/// 86-bit: a86_0, a86_1, b86_0, b86_1, q86_0, q86_1, r86_0, r86_1 +/// 84-bit: a86_2, b86_2, q86_2, r86_2 +/// 89-bit: c0u, c1u, c2u, c3u +/// 44-bit: carry_a, carry_b, carry_r +pub fn mul_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limb2, + b: Limb2, + params: &CurveParams, +) -> Limb2 { + let two_44 = FieldElement::from(2u64).pow([44u64]); + let two_86 = FieldElement::from(2u64).pow([86u64]); + let two_128 = FieldElement::from(2u64).pow([128u64]); + let offset_fe = FieldElement::from(2u64).pow([88u64]); // CARRY_OFFSET + let offset_w = FieldElement::from(2u64).pow([174u64]); // 2^88 * 2^86 + let offset_w_minus_1 = offset_w - offset_fe; // 2^88 * (2^86 - 1) + let [p0, p1, p2] = params.p_86_limbs(); + let w1 = compiler.witness_one(); + + // Step 1: Allocate MulModHint (20 witnesses) + // ----------------------------------------------------------- + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MulModHint { + output_start: os, + a_lo: a.lo, + a_hi: a.hi, + b_lo: b.lo, + b_hi: b.hi, + modulus: params.field_modulus_p, + }); + + // Witness indices + let r_lo = os + 2; + let r_hi = os + 3; + let a86 = [os + 4, os + 5, os + 6]; + let b86 = [os + 7, os + 8, os + 9]; + let q86 = [os + 10, os + 11, os + 12]; + let r86 = [os + 13, os + 14, os + 15]; + let cu = [os + 16, os + 17, os + 18, os + 19]; + + // Step 2: Decomposition consistency for a, b, r + // ----------------------------------------------------------- + decompose_check( + compiler, + range_checks, + a.lo, + a.hi, + a86, + two_86, + two_44, + two_128, + w1, + ); + decompose_check( + compiler, + range_checks, + b.lo, + b.hi, + b86, + two_86, + two_44, + two_128, + w1, + ); + decompose_check( + compiler, + range_checks, + r_lo, + r_hi, + r86, + two_86, + two_44, + two_128, + w1, + ); + + // Step 3: Product witnesses (9 R1CS constraints) + // ----------------------------------------------------------- + let ab00 = compiler.add_product(a86[0], b86[0]); + let ab01 = compiler.add_product(a86[0], b86[1]); + let ab10 = compiler.add_product(a86[1], b86[0]); + let ab02 = compiler.add_product(a86[0], b86[2]); + let ab11 = compiler.add_product(a86[1], b86[1]); + let ab20 = compiler.add_product(a86[2], b86[0]); + let ab12 = compiler.add_product(a86[1], b86[2]); + let ab21 = compiler.add_product(a86[2], b86[1]); + let ab22 = compiler.add_product(a86[2], b86[2]); + + // Step 4: Column equations (5 R1CS constraints) + // ----------------------------------------------------------- + // Identity: a*b = p*q + r in base W=2^86. + // Carries stored with unsigned offset: cu_i = c_i + 2^88. + // + // col0: ab00 + 2^174 = p0*q0 + r0 + W*cu0 + // col1: ab01 + ab10 + cu0 + (2^174-2^88) = p0*q1 + p1*q0 + r1 + W*cu1 + // col2: ab02+ab11+ab20 + cu1 + (2^174-2^88) = p0*q2+p1*q1+p2*q0 + r2 + W*cu2 + // col3: ab12 + ab21 + cu2 + (2^174-2^88) = p1*q2 + p2*q1 + W*cu3 + // col4: ab22 + cu3 = p2*q2 + 2^88 + + // col0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, ab00), (offset_w, w1)], + &[(FieldElement::ONE, w1)], + &[(p0, q86[0]), (FieldElement::ONE, r86[0]), (two_86, cu[0])], + ); + + // col1 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab01), + (FieldElement::ONE, ab10), + (FieldElement::ONE, cu[0]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[ + (p0, q86[1]), + (p1, q86[0]), + (FieldElement::ONE, r86[1]), + (two_86, cu[1]), + ], + ); + + // col2 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab02), + (FieldElement::ONE, ab11), + (FieldElement::ONE, ab20), + (FieldElement::ONE, cu[1]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[ + (p0, q86[2]), + (p1, q86[1]), + (p2, q86[0]), + (FieldElement::ONE, r86[2]), + (two_86, cu[2]), + ], + ); + + // col3 + compiler.r1cs.add_constraint( + &[ + (FieldElement::ONE, ab12), + (FieldElement::ONE, ab21), + (FieldElement::ONE, cu[2]), + (offset_w_minus_1, w1), + ], + &[(FieldElement::ONE, w1)], + &[(p1, q86[2]), (p2, q86[1]), (two_86, cu[3])], + ); + + // col4 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, ab22), (FieldElement::ONE, cu[3])], + &[(FieldElement::ONE, w1)], + &[(p2, q86[2]), (offset_fe, w1)], + ); + + // Step 5: Less-than-p check (r < p) + 128-bit range checks on r_lo, r_hi + // ----------------------------------------------------------- + less_than_p_check(compiler, range_checks, r_lo, r_hi, params); + + // Step 6: Range checks (mul-specific) + // ----------------------------------------------------------- + // 86-bit: limbs 0 and 1 of a, b, q, r + for &idx in &[ + a86[0], a86[1], b86[0], b86[1], q86[0], q86[1], r86[0], r86[1], + ] { + range_checks.entry(86).or_default().push(idx); + } + + // 84-bit: limb 2 of a, b, q, r (bits [172..256) = 84 bits) + for &idx in &[a86[2], b86[2], q86[2], r86[2]] { + range_checks.entry(84).or_default().push(idx); + } + + // 89-bit: unsigned-offset carries (|c_signed| < 2^88, so c_unsigned ∈ [0, + // 2^89)) + for &idx in &cu { + range_checks.entry(89).or_default().push(idx); + } + + Limb2 { lo: r_lo, hi: r_hi } +} + +/// a^(-1) mod p for 256-bit values in two 128-bit limbs. +/// +/// Hint-and-verify pattern: +/// 1. Prover computes inv = a^(p-2) mod p (Fermat's little theorem) +/// 2. Circuit verifies a * inv mod p = 1 +/// +/// Constraints: 26 from mul_mod_p + 2 equality checks = 28 total. +pub fn inv_mod_p( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + value: Limb2, + params: &CurveParams, +) -> Limb2 { + // Witness: inv = a^(-1) mod p (2 witnesses: lo, hi) + // ----------------------------------------------------------- + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::WideModularInverse { + output_start: value_inv, + a_lo: value.lo, + a_hi: value.hi, + modulus: params.field_modulus_p, + }); + let inv = Limb2 { + lo: value_inv, + hi: value_inv + 1, + }; + + // Verifying a * inv mod p = 1 + // ----------------------------------------------------------- + // computing product = value * inv mod p + let product = mul_mod_p(compiler, range_checks, value, inv, params); + // constraining product_lo = 1 (because 1 = 1 + 0 * 2^128) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product.lo)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // constraining product_hi = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product.hi)], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + + inv +} + +/// Verify that 128-bit limbs (v_lo, v_hi) decompose into 86-bit limbs (v86). +/// +/// Equations: +/// v_lo = v86_0 + v86_1 * 2^86 - carry * 2^128 +/// v_hi = carry + v86_2 * 2^44 +/// +/// All intermediate values < 2^172 ≪ BN254_r, so field equations = integer +/// equations. +/// +/// Creates: 1 intermediate witness (v_sum), 1 carry witness (IntegerQuotient). +/// Adds: 3 R1CS constraints (v_sum definition + 2 decomposition checks). +/// Range checks: carry (44-bit). +/// Proves r < p by decomposing (p - 1) - r into non-negative 128-bit limbs. +/// +/// If d_lo, d_hi >= 0 then (p - 1) - r >= 0, i.e. r <= p - 1 < p. +/// Uses the 2^128 offset trick to avoid negative intermediate values. +/// +/// Range checks r_lo, r_hi, d_lo, d_hi (128-bit each). +fn less_than_p_check( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r_lo: usize, + r_hi: usize, + params: &CurveParams, +) { + let two_128 = FieldElement::from(2u64).pow([128u64]); + let p_lo_fe = params.p_lo_fe(); + let p_hi_fe = params.p_hi_fe(); + let w1 = compiler.witness_one(); + + // v_diff = (p_lo - 1) + 2^128 - r_lo + // (2^128 offset ensures v_diff is always non-negative) + let p_lo_minus_1_plus_offset = p_lo_fe - FieldElement::ONE + two_128; + let v_diff = compiler.add_sum(vec![ + SumTerm(Some(p_lo_minus_1_plus_offset), w1), + SumTerm(Some(-FieldElement::ONE), r_lo), + ]); + // borrow_compl = floor(v_diff / 2^128) + let borrow_compl = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + borrow_compl, + v_diff, + two_128, + )); + // d_lo = v_diff - borrow_compl * 2^128 + let d_lo = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_128), borrow_compl), + ]); + // d_hi = (p_hi - 1) + borrow_compl - r_hi + let d_hi = compiler.add_sum(vec![ + SumTerm(Some(p_hi_fe - FieldElement::ONE), w1), + SumTerm(None, borrow_compl), + SumTerm(Some(-FieldElement::ONE), r_hi), + ]); + + // Range checks (128-bit) + range_checks.entry(128).or_default().push(r_lo); + range_checks.entry(128).or_default().push(r_hi); + range_checks.entry(128).or_default().push(d_lo); + range_checks.entry(128).or_default().push(d_hi); +} + +fn decompose_check( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + v_lo: usize, + v_hi: usize, + v86: [usize; 3], + two_86: FieldElement, + two_44: FieldElement, + two_128: FieldElement, + w1: usize, +) { + // v_sum = v86_0 + v86_1 * 2^86 (intermediate for IntegerQuotient) + let v_sum = compiler.add_sum(vec![SumTerm(None, v86[0]), SumTerm(Some(two_86), v86[1])]); + + // carry = floor(v_sum / 2^128) ∈ [0, 2^44) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_sum, two_128)); + + // Low check: v_sum - carry * 2^128 = v_lo + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, v_sum), (-two_128, carry)], + &[(FieldElement::ONE, w1)], + &[(FieldElement::ONE, v_lo)], + ); + + // High check: carry + v86_2 * 2^44 = v_hi + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, carry), (two_44, v86[2])], + &[(FieldElement::ONE, w1)], + &[(FieldElement::ONE, v_hi)], + ); + + // Range check carry (44-bit) + range_checks.entry(44).or_default().push(carry); +} From ac142d6ab68125bcc74b24391ce291b01d6f9301 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Tue, 3 Mar 2026 04:13:48 +0530 Subject: [PATCH 03/19] feat : added dynamic multi limb approach with cost model for msm black box --- noir-examples/embedded_curve_msm/Nargo.toml | 7 + noir-examples/embedded_curve_msm/Prover.toml | 5 + noir-examples/embedded_curve_msm/src/main.nr | 51 ++ noir-examples/native_msm/Nargo.toml | 7 + noir-examples/native_msm/Prover.toml | 5 + noir-examples/native_msm/src/main.nr | 104 +++ .../src/witness/scheduling/dependency.rs | 67 +- .../common/src/witness/scheduling/remapper.rs | 85 +-- .../common/src/witness/witness_builder.rs | 87 +-- provekit/prover/src/lib.rs | 52 ++ .../prover/src/witness/witness_builder.rs | 353 +++++++---- provekit/r1cs-compiler/src/msm/cost_model.rs | 362 +++++++++++ provekit/r1cs-compiler/src/msm/curve.rs | 163 +++-- provekit/r1cs-compiler/src/msm/ec_ops.rs | 208 ------ provekit/r1cs-compiler/src/msm/ec_points.rs | 183 ++++++ provekit/r1cs-compiler/src/msm/mod.rs | 510 ++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 591 ++++++++++++++++++ .../r1cs-compiler/src/msm/multi_limb_ops.rs | 275 ++++++++ provekit/r1cs-compiler/src/msm/wide_ops.rs | 563 ----------------- provekit/r1cs-compiler/src/noir_to_r1cs.rs | 39 +- 20 files changed, 2600 insertions(+), 1117 deletions(-) create mode 100644 noir-examples/embedded_curve_msm/Nargo.toml create mode 100644 noir-examples/embedded_curve_msm/Prover.toml create mode 100644 noir-examples/embedded_curve_msm/src/main.nr create mode 100644 noir-examples/native_msm/Nargo.toml create mode 100644 noir-examples/native_msm/Prover.toml create mode 100644 noir-examples/native_msm/src/main.nr create mode 100644 provekit/r1cs-compiler/src/msm/cost_model.rs delete mode 100644 provekit/r1cs-compiler/src/msm/ec_ops.rs create mode 100644 provekit/r1cs-compiler/src/msm/multi_limb_arith.rs create mode 100644 provekit/r1cs-compiler/src/msm/multi_limb_ops.rs delete mode 100644 provekit/r1cs-compiler/src/msm/wide_ops.rs diff --git a/noir-examples/embedded_curve_msm/Nargo.toml b/noir-examples/embedded_curve_msm/Nargo.toml new file mode 100644 index 000000000..ec9891616 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr new file mode 100644 index 000000000..cf0704211 --- /dev/null +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -0,0 +1,51 @@ +use std::embedded_curve_ops::{ + EmbeddedCurvePoint, + EmbeddedCurveScalar, + multi_scalar_mul, +}; + +/// Exercises the MultiScalarMul ACIR blackbox with 2 Grumpkin points. +/// Computes s1 * G + s2 * G where G is the Grumpkin generator. +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + // Grumpkin generator + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + + let s1 = EmbeddedCurveScalar { lo: scalar1_lo, hi: scalar1_hi }; + let s2 = EmbeddedCurveScalar { lo: scalar2_lo, hi: scalar2_hi }; + + // MSM: result = s1 * G + s2 * G + let result = multi_scalar_mul([g, g], [s1, s2]); + + // Prevent dead-code elimination - forces the blackbox to be retained + assert(!result.is_infinite); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify by computing independently: 3*G should match + let g = EmbeddedCurvePoint { + x: 1, + y: 17631683881184975370165255887551781615748388533673675138860, + is_infinite: false, + }; + let s3 = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let check = multi_scalar_mul([g], [s3]); + + assert(check.x == expected_x); + assert(check.y == expected_y); +} diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml new file mode 100644 index 000000000..5ff116db7 --- /dev/null +++ b/noir-examples/native_msm/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "native_msm" +type = "bin" +authors = [""] +compiler_version = ">=0.22.0" + +[dependencies] diff --git a/noir-examples/native_msm/Prover.toml b/noir-examples/native_msm/Prover.toml new file mode 100644 index 000000000..58c6933da --- /dev/null +++ b/noir-examples/native_msm/Prover.toml @@ -0,0 +1,5 @@ +# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +scalar1_lo = "1" +scalar1_hi = "0" +scalar2_lo = "2" +scalar2_hi = "0" diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr new file mode 100644 index 000000000..80cfd3d0f --- /dev/null +++ b/noir-examples/native_msm/src/main.nr @@ -0,0 +1,104 @@ +// Grumpkin generator y-coordinate +global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; + +struct Point { + x: Field, + y: Field, + is_infinite: bool, +} + +fn point_double(p: Point) -> Point { + if p.is_infinite | (p.y == 0) { + Point { x: 0, y: 0, is_infinite: true } + } else { + // Grumpkin has a=0, so lambda = 3*x1^2 / (2*y1) + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn point_add(p1: Point, p2: Point) -> Point { + if p1.is_infinite { + p2 + } else if p2.is_infinite { + p1 + } else if (p1.x == p2.x) & (p1.y == p2.y) { + point_double(p1) + } else if (p1.x == p2.x) & (p1.y == (0 - p2.y)) { + Point { x: 0, y: 0, is_infinite: true } + } else { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + Point { x: x3, y: y3, is_infinite: false } + } +} + +fn scalar_mul(p: Point, scalar_lo: Field, scalar_hi: Field) -> Point { + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + + // Combine into a single 256-bit array (lo first, then hi) + let mut bits: [u1; 256] = [0; 256]; + for i in 0..128 { + bits[i] = lo_bits[i]; + bits[128 + i] = hi_bits[i]; + } + + // Find the highest set bit + let mut top = 0; + for i in 0..256 { + if bits[i] == 1 { + top = i; + } + } + + // Double-and-add from MSB down to bit 0 + let mut acc = Point { x: 0, y: 0, is_infinite: true }; + for j in 0..256 { + let i = 255 - j; + acc = point_double(acc); + if bits[i] == 1 { + acc = point_add(acc, p); + } + } + + acc +} + +/// Native MSM: computes s1 * G + s2 * G using pure Noir field operations. +/// No blackbox functions -- all EC arithmetic is done natively over Grumpkin's +/// base field (= BN254 scalar field = Noir's native Field). +fn main( + scalar1_lo: Field, + scalar1_hi: Field, + scalar2_lo: Field, + scalar2_hi: Field, +) { + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + + let r1 = scalar_mul(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul(g, scalar2_lo, scalar2_hi); + let result = point_add(r1, r2); + + // Prevent dead-code elimination + assert(!result.is_infinite); +} + +#[test] +fn test_msm() { + // 3*G on Grumpkin (known coordinates) + let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; + let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + + main(1, 0, 2, 0); + + // Verify 1*G + 2*G = 3*G by computing 3*G directly + let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; + let three_g = scalar_mul(g, 3, 0); + + assert(three_g.x == expected_x); + assert(three_g.y == expected_y); +} diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 956f79b56..9f92afd75 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -79,6 +79,7 @@ impl DependencyInfo { WitnessBuilder::Product(_, a, b) => vec![*a, *b], WitnessBuilder::MultiplicitiesForRange(_, _, values) => values.clone(), WitnessBuilder::Inverse(_, x) + | WitnessBuilder::SafeInverse(_, x) | WitnessBuilder::ModularInverse(_, x, _) | WitnessBuilder::IntegerQuotient(_, x, _) => vec![*x], WitnessBuilder::IndexedLogUpDenominator( @@ -154,28 +155,34 @@ impl DependencyInfo { } v } - WitnessBuilder::MulModHint { - a_lo, - a_hi, - b_lo, - b_hi, + WitnessBuilder::MultiLimbMulModHint { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], - WitnessBuilder::WideModularInverse { a_lo, a_hi, .. } => vec![*a_lo, *a_hi], - WitnessBuilder::WideAddQuotient { - a_lo, - a_hi, - b_lo, - b_hi, + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), + WitnessBuilder::MultiLimbAddQuotient { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], - WitnessBuilder::WideSubBorrow { - a_lo, - a_hi, - b_lo, - b_hi, + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } + WitnessBuilder::MultiLimbSubBorrow { + a_limbs, + b_limbs, .. - } => vec![*a_lo, *a_hi, *b_lo, *b_hi], + } => { + let mut v = a_limbs.clone(); + v.extend(b_limbs); + v + } WitnessBuilder::BytePartition { x, .. } => vec![*x], WitnessBuilder::U32AdditionMulti(_, _, inputs) => inputs @@ -264,6 +271,7 @@ impl DependencyInfo { | WitnessBuilder::Challenge(idx) | WitnessBuilder::IndexedLogUpDenominator(idx, ..) | WitnessBuilder::Inverse(idx, _) + | WitnessBuilder::SafeInverse(idx, _) | WitnessBuilder::ModularInverse(idx, ..) | WitnessBuilder::IntegerQuotient(idx, ..) | WitnessBuilder::ProductLinearOperation(idx, ..) @@ -308,14 +316,21 @@ impl DependencyInfo { let n = 1usize << *num_bits; (*start..*start + n).collect() } - WitnessBuilder::MulModHint { output_start, .. } => { - (*output_start..*output_start + 20).collect() - } - WitnessBuilder::WideModularInverse { output_start, .. } => { - (*output_start..*output_start + 2).collect() + WitnessBuilder::MultiLimbMulModHint { + output_start, + num_limbs, + .. + } => { + let count = (4 * *num_limbs - 2) as usize; + (*output_start..*output_start + count).collect() } - WitnessBuilder::WideAddQuotient { output, .. } => vec![*output], - WitnessBuilder::WideSubBorrow { output, .. } => vec![*output], + WitnessBuilder::MultiLimbModularInverse { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], + WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { vec![*result_idx, *carry_idx] } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 47490a6ce..334b5f401 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -115,6 +115,9 @@ impl WitnessIndexRemapper { WitnessBuilder::Inverse(idx, operand) => { WitnessBuilder::Inverse(self.remap(*idx), self.remap(*operand)) } + WitnessBuilder::SafeInverse(idx, operand) => { + WitnessBuilder::SafeInverse(self.remap(*idx), self.remap(*operand)) + } WitnessBuilder::ModularInverse(idx, operand, modulus) => { WitnessBuilder::ModularInverse(self.remap(*idx), self.remap(*operand), *modulus) } @@ -221,59 +224,63 @@ impl WitnessIndexRemapper { .collect(), ) } - WitnessBuilder::MulModHint { + WitnessBuilder::MultiLimbMulModHint { output_start, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, - } => WitnessBuilder::MulModHint { + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbMulModHint { output_start: self.remap(*output_start), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideModularInverse { + WitnessBuilder::MultiLimbModularInverse { output_start, - a_lo, - a_hi, + a_limbs, modulus, - } => WitnessBuilder::WideModularInverse { + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbModularInverse { output_start: self.remap(*output_start), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideAddQuotient { + WitnessBuilder::MultiLimbAddQuotient { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, - } => WitnessBuilder::WideAddQuotient { - output: self.remap(*output), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), - modulus: *modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbAddQuotient { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, - WitnessBuilder::WideSubBorrow { + WitnessBuilder::MultiLimbSubBorrow { output, - a_lo, - a_hi, - b_lo, - b_hi, - } => WitnessBuilder::WideSubBorrow { - output: self.remap(*output), - a_lo: self.remap(*a_lo), - a_hi: self.remap(*a_hi), - b_lo: self.remap(*b_lo), - b_hi: self.remap(*b_hi), + a_limbs, + b_limbs, + modulus, + limb_bits, + num_limbs, + } => WitnessBuilder::MultiLimbSubBorrow { + output: self.remap(*output), + a_limbs: a_limbs.iter().map(|&w| self.remap(w)).collect(), + b_limbs: b_limbs.iter().map(|&w| self.remap(w)).collect(), + modulus: *modulus, + limb_bits: *limb_bits, + num_limbs: *num_limbs, }, WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { lo: self.remap(*lo), diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 6d11d17cf..28d6d775c 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -88,6 +88,11 @@ pub enum WitnessBuilder { /// The inverse of the value at a specified witness index /// (witness index, operand witness index) Inverse(usize, usize), + /// Safe inverse: like Inverse but handles zero by outputting 0. + /// Used by compute_is_zero where the input may be zero. Solved in the + /// Other layer (not batch-inverted), so zero inputs don't poison the batch. + /// (witness index, operand witness index) + SafeInverse(usize, usize), /// The modular inverse of the value at a specified witness index, modulo /// a given prime modulus. Computes a^{-1} mod m using Fermat's little /// theorem (a^{m-2} mod m). Unlike Inverse (BN254 field inverse), this @@ -202,61 +207,59 @@ pub enum WitnessBuilder { /// Computes: 1 / (sz - lhs - rs*rhs - rs²*and_out - rs³*xor_out) CombinedTableEntryInverse(CombinedTableEntryInverseData), /// Prover hint for multi-limb modular multiplication: (a * b) mod p. - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// Given inputs a and b as N-limb vectors (each limb `limb_bits` wide), /// and a constant 256-bit modulus p, computes quotient q, remainder r, - /// their 86-bit decompositions, and carry witnesses. + /// and carry witnesses for schoolbook column verification. /// - /// Outputs 20 witnesses starting at output_start: - /// [0..2) q_lo, q_hi (quotient in 128-bit limbs) - /// [2..4) r_lo, r_hi (remainder in 128-bit limbs) - /// [4..7) a_86_0, a_86_1, a_86_2 (a in 86-bit limbs) - /// [7..10) b_86_0, b_86_1, b_86_2 (b in 86-bit limbs) - /// [10..13) q_86_0, q_86_1, q_86_2 (q in 86-bit limbs) - /// [13..16) r_86_0, r_86_1, r_86_2 (r in 86-bit limbs) - /// [16..20) c0, c1, c2, c3 (carry witnesses, unsigned-offset) - MulModHint { + /// Outputs (4*num_limbs - 2) witnesses starting at output_start: + /// [0..N) q limbs (quotient) + /// [N..2N) r limbs (remainder) — OUTPUT + /// [2N..4N-2) carry witnesses (unsigned-offset) + MultiLimbMulModHint { output_start: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, + a_limbs: Vec, + b_limbs: Vec, modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide modular inverse: a^{-1} mod p. - /// Given input a = (a_lo, a_hi) as 128-bit limbs and constant modulus p, + /// Prover hint for multi-limb modular inverse: a^{-1} mod p. + /// Given input a as N-limb vector and constant modulus p, /// computes the inverse via Fermat's little theorem (a^{p-2} mod p). /// - /// Outputs 2 witnesses at output_start: inv_lo, inv_hi (128-bit limbs). - WideModularInverse { + /// Outputs num_limbs witnesses at output_start: inv limbs. + MultiLimbModularInverse { output_start: usize, - a_lo: usize, - a_hi: usize, + a_limbs: Vec, modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide addition quotient: q = floor((a + b) / p). - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, - /// and a constant 256-bit modulus p, computes q ∈ {0, 1}. + /// Prover hint for multi-limb addition quotient: q = floor((a + b) / p). + /// Given inputs a and b as N-limb vectors, and a constant modulus p, + /// computes q ∈ {0, 1}. /// /// Outputs 1 witness at output: q. - WideAddQuotient { - output: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, - modulus: [u64; 4], + MultiLimbAddQuotient { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, - /// Prover hint for wide subtraction borrow: q = (a < b) ? 1 : 0. - /// Given inputs a = (a_lo, a_hi) and b = (b_lo, b_hi) as 128-bit limbs, + /// Prover hint for multi-limb subtraction borrow: q = (a < b) ? 1 : 0. + /// Given inputs a and b as N-limb vectors, and a constant modulus p, /// computes q ∈ {0, 1} indicating whether a borrow (adding p) is needed. /// /// Outputs 1 witness at output: q. - WideSubBorrow { - output: usize, - a_lo: usize, - a_hi: usize, - b_lo: usize, - b_hi: usize, + MultiLimbSubBorrow { + output: usize, + a_limbs: Vec, + b_limbs: Vec, + modulus: [u64; 4], + limb_bits: u32, + num_limbs: u32, }, /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: @@ -329,8 +332,10 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, - WitnessBuilder::MulModHint { .. } => 20, - WitnessBuilder::WideModularInverse { .. } => 2, + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => { + (4 * *num_limbs - 2) as usize + } + WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, _ => 1, } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 44cd0ca07..7aafa0d3c 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -196,6 +196,58 @@ impl Prove for NoirProver { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; + // DEBUG: Check R1CS constraint satisfaction with ALL witnesses solved + { + use ark_ff::Zero; + let debug_r1cs = r1cs.clone(); + let interner = &debug_r1cs.interner; + let ha = debug_r1cs.a.hydrate(interner); + let hb = debug_r1cs.b.hydrate(interner); + let hc = debug_r1cs.c.hydrate(interner); + let mut fail_count = 0usize; + for row in 0..debug_r1cs.num_constraints() { + let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, r: usize| -> FieldElement { + let mut sum = FieldElement::zero(); + for (col, coeff) in hm.iter_row(r) { + sum += coeff * full_witness[col]; + } + sum + }; + let a_val = eval(&ha, row); + let b_val = eval(&hb, row); + let c_val = eval(&hc, row); + if a_val * b_val != c_val { + if fail_count < 10 { + eprintln!( + "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", + row, a_val, b_val, c_val, a_val * b_val + ); + eprint!(" A terms:"); + for (col, coeff) in ha.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" B terms:"); + for (col, coeff) in hb.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + eprint!(" C terms:"); + for (col, coeff) in hc.iter_row(row) { + eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); + } + eprintln!(); + } + fail_count += 1; + } + } + if fail_count > 0 { + eprintln!("TOTAL FAILING CONSTRAINTS: {fail_count} / {}", debug_r1cs.num_constraints()); + } else { + eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); + } + } + let whir_r1cs_proof = self .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index ae49cfcd5..d3479331b 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,7 +1,7 @@ use { crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, acir::native_types::WitnessMap, - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, provekit_common::{ utils::noir_to_native, @@ -65,6 +65,14 @@ impl WitnessBuilderSolver for WitnessBuilder { "Inverse/LogUpInverse should not be called - handled by batch inversion" ) } + WitnessBuilder::SafeInverse(witness_idx, operand_idx) => { + let val = witness[*operand_idx].unwrap(); + witness[*witness_idx] = Some(if val == FieldElement::zero() { + FieldElement::zero() + } else { + val.inverse().unwrap() + }); + } WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { let a = witness[*operand_idx].unwrap(); let a_limbs = a.into_bigint().0; @@ -337,61 +345,135 @@ impl WitnessBuilderSolver for WitnessBuilder { lh_val.into_bigint() ^ rh_val.into_bigint(), )); } - WitnessBuilder::MulModHint { + WitnessBuilder::MultiLimbMulModHint { output_start, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, + limb_bits, + num_limbs, } => { - use crate::witness::bigint_mod::{ - compute_carries_86, decompose_128, decompose_86, divmod_wide, widening_mul, - CARRY_OFFSET, + use crate::witness::bigint_mod::{divmod_wide, widening_mul}; + let n = *num_limbs as usize; + let w = *limb_bits; + let limb_mask: u128 = if w >= 128 { + u128::MAX + } else { + (1u128 << w) - 1 }; - // Read inputs: a and b as 128-bit limb pairs - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); - let b_lo_fe = witness[*b_lo].unwrap(); - let b_hi_fe = witness[*b_hi].unwrap(); - - // Reconstruct a, b as [u64; 4] - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct a, b as [u64; 4] from N limbs + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + // Place into val at bit_offset + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = b_lo_fe.into_bigint().0; - let b_hi_limbs = b_hi_fe.into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); // Compute product and divmod let product = widening_mul(&a_val, &b_val); let (q_val, r_val) = divmod_wide(&product, modulus); - // Decompose into 128-bit limbs - let (q_lo, q_hi) = decompose_128(&q_val); - let (r_lo, r_hi) = decompose_128(&r_val); + // Decompose a [u64;4] into N limbs of limb_bits width. + let decompose_n_from_u64 = |val: &[u64; 4]| -> Vec { + let mut limbs = Vec::with_capacity(n); + let mut remaining = *val; + for _ in 0..n { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & limb_mask); + // Shift right by w bits + if w >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (w / 64) as usize; + let bit_shift = w % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= + remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + limbs + }; + + let q_limbs_vals = decompose_n_from_u64(&q_val); + let r_limbs_vals = decompose_n_from_u64(&r_val); - // Decompose into 86-bit limbs - let (a86_0, a86_1, a86_2) = decompose_86(&a_val); - let (b86_0, b86_1, b86_2) = decompose_86(&b_val); - let (q86_0, q86_1, q86_2) = decompose_86(&q_val); - let (r86_0, r86_1, r86_2) = decompose_86(&r_val); + // Compute carries for schoolbook verification: + // a·b = p·q + r in base W = 2^limb_bits + // For each column k (0..2N-2): + // lhs_k = Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + // rhs_k = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let p_limbs_vals = decompose_n_from_u64(modulus); + let a_limbs_vals = decompose_n_from_u64(&a_val); + let b_limbs_vals = decompose_n_from_u64(&b_val); - // Compute carries - let carries = compute_carries_86( - [a86_0, a86_1, a86_2], - [b86_0, b86_1, b86_2], - { - let (p0, p1, p2) = decompose_86(modulus); - [p0, p1, p2] - }, - [q86_0, q86_1, q86_2], - [r86_0, r86_1, r86_2], - ); + let w_val = 1u128 << w; + let num_carries = 2 * n - 2; + let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut running: i128 = 0; + + for k in 0..(2 * n - 1) { + // Sum a[i]*b[j] for i+j=k + let mut ab_sum: i128 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + ab_sum += + a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; + } + } + // Sum p[i]*q[j] for i+j=k + let mut pq_sum: i128 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + pq_sum += + p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; + } + } + let r_k = if k < n { r_limbs_vals[k] as i128 } else { 0 }; + + // column: ab_sum + carry_prev = pq_sum + r_k + carry_next * W + // carry_next = (ab_sum + carry_prev - pq_sum - r_k) / W + running += ab_sum - pq_sum - r_k; + if k < 2 * n - 2 { + let carry = running / w_val as i128; + carries.push(carry); + running -= carry * w_val as i128; + } + } - // Helper: convert u128 to FieldElement let u128_to_fe = |val: u128| -> FieldElement { FieldElement::from_bigint(ark_ff::BigInt([ val as u64, @@ -402,57 +484,59 @@ impl WitnessBuilderSolver for WitnessBuilder { .unwrap() }; - // Write outputs: [0..2) q_lo, q_hi - witness[*output_start] = Some(u128_to_fe(q_lo)); - witness[*output_start + 1] = Some(u128_to_fe(q_hi)); - // [2..4) r_lo, r_hi - witness[*output_start + 2] = Some(u128_to_fe(r_lo)); - witness[*output_start + 3] = Some(u128_to_fe(r_hi)); - // [4..7) a_86 limbs - witness[*output_start + 4] = Some(u128_to_fe(a86_0)); - witness[*output_start + 5] = Some(u128_to_fe(a86_1)); - witness[*output_start + 6] = Some(u128_to_fe(a86_2)); - // [7..10) b_86 limbs - witness[*output_start + 7] = Some(u128_to_fe(b86_0)); - witness[*output_start + 8] = Some(u128_to_fe(b86_1)); - witness[*output_start + 9] = Some(u128_to_fe(b86_2)); - // [10..13) q_86 limbs - witness[*output_start + 10] = Some(u128_to_fe(q86_0)); - witness[*output_start + 11] = Some(u128_to_fe(q86_1)); - witness[*output_start + 12] = Some(u128_to_fe(q86_2)); - // [13..16) r_86 limbs - witness[*output_start + 13] = Some(u128_to_fe(r86_0)); - witness[*output_start + 14] = Some(u128_to_fe(r86_1)); - witness[*output_start + 15] = Some(u128_to_fe(r86_2)); - // [16..20) carries (unsigned-offset) - for i in 0..4 { - let c_unsigned = (carries[i] + CARRY_OFFSET as i128) as u128; - witness[*output_start + 16 + i] = Some(u128_to_fe(c_unsigned)); + // Write q limbs + for i in 0..n { + witness[*output_start + i] = Some(u128_to_fe(q_limbs_vals[i])); + } + // Write r limbs + for i in 0..n { + witness[*output_start + n + i] = Some(u128_to_fe(r_limbs_vals[i])); + } + // Write carries (unsigned-offset) + for i in 0..num_carries { + let c_unsigned = (carries[i] + carry_offset as i128) as u128; + witness[*output_start + 2 * n + i] = Some(u128_to_fe(c_unsigned)); } } - WitnessBuilder::WideModularInverse { + WitnessBuilder::MultiLimbModularInverse { output_start, - a_lo, - a_hi, + a_limbs, modulus, + limb_bits, + num_limbs, } => { - use crate::witness::bigint_mod::{decompose_128, mod_pow, sub_u64}; - - // Read input a as 128-bit limb pair - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); + use crate::witness::bigint_mod::{mod_pow, sub_u64}; + let n = *num_limbs as usize; + let w = *limb_bits; + let limb_mask: u128 = if w >= 128 { + u128::MAX + } else { + (1u128 << w) - 1 + }; - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct a as [u64; 4] from N limbs + let mut a_val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in a_limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + a_val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + a_val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } - // Compute inverse: a^{p-2} mod p (Fermat's little theorem) + // Compute inverse: a^{p-2} mod p let exp = sub_u64(modulus, 2); let inv = mod_pow(&a_val, &exp, modulus); - // Decompose into 128-bit limbs - let (inv_lo, inv_hi) = decompose_128(&inv); - + // Decompose into N limbs + let mut remaining = inv; let u128_to_fe = |val: u128| -> FieldElement { FieldElement::from_bigint(ark_ff::BigInt([ val as u64, @@ -462,37 +546,60 @@ impl WitnessBuilderSolver for WitnessBuilder { ])) .unwrap() }; - - witness[*output_start] = Some(u128_to_fe(inv_lo)); - witness[*output_start + 1] = Some(u128_to_fe(inv_hi)); + for i in 0..n { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + witness[*output_start + i] = Some(u128_to_fe(lo & limb_mask)); + // Shift right by w bits + let mut shifted = [0u64; 4]; + let word_shift = (w / 64) as usize; + let bit_shift = w % 64; + for j in 0..4 { + if j + word_shift < 4 { + shifted[j] = remaining[j + word_shift] >> bit_shift; + if bit_shift > 0 && j + word_shift + 1 < 4 { + shifted[j] |= remaining[j + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } } - WitnessBuilder::WideAddQuotient { + WitnessBuilder::MultiLimbAddQuotient { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, modulus, + limb_bits, + .. } => { use crate::witness::bigint_mod::{add_4limb, cmp_4limb}; + let w = *limb_bits; - let a_lo_fe = witness[*a_lo].unwrap(); - let a_hi_fe = witness[*a_hi].unwrap(); - let b_lo_fe = witness[*b_lo].unwrap(); - let b_hi_fe = witness[*b_hi].unwrap(); - - let a_lo_limbs = a_lo_fe.into_bigint().0; - let a_hi_limbs = a_hi_fe.into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + // Reconstruct from N limbs + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = b_lo_fe.into_bigint().0; - let b_hi_limbs = b_hi_fe.into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); let sum = add_4limb(&a_val, &b_val); - // q = 1 if sum >= p, else 0 let q = if sum[4] > 0 { - // sum > 2^256 > any 256-bit modulus 1u64 } else { let sum4 = [sum[0], sum[1], sum[2], sum[3]]; @@ -505,24 +612,38 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output] = Some(FieldElement::from(q)); } - WitnessBuilder::WideSubBorrow { + WitnessBuilder::MultiLimbSubBorrow { output, - a_lo, - a_hi, - b_lo, - b_hi, + a_limbs, + b_limbs, + limb_bits, + .. } => { use crate::witness::bigint_mod::cmp_4limb; + let w = *limb_bits; - let a_lo_limbs = witness[*a_lo].unwrap().into_bigint().0; - let a_hi_limbs = witness[*a_hi].unwrap().into_bigint().0; - let a_val = [a_lo_limbs[0], a_lo_limbs[1], a_hi_limbs[0], a_hi_limbs[1]]; + let reconstruct = |limbs: &[usize]| -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_idx in limbs.iter() { + let limb_val = witness[limb_idx].unwrap().into_bigint().0; + let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + } + bit_offset += w; + } + val + }; - let b_lo_limbs = witness[*b_lo].unwrap().into_bigint().0; - let b_hi_limbs = witness[*b_hi].unwrap().into_bigint().0; - let b_val = [b_lo_limbs[0], b_lo_limbs[1], b_hi_limbs[0], b_hi_limbs[1]]; + let a_val = reconstruct(a_limbs); + let b_val = reconstruct(b_limbs); - // q = 1 if a < b (need to add p to make result non-negative) let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { 1u64 } else { diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs new file mode 100644 index 000000000..234623a31 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -0,0 +1,362 @@ +//! Analytical cost model for MSM parameter optimization. +//! +//! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): +//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, window_size). + +/// Type of field operation for cost estimation. +#[derive(Clone, Copy)] +pub enum FieldOpType { + Add, + Sub, + Mul, + Inv, +} + +/// Count field ops in scalar_mul for given parameters. +/// Traces through ec_points::scalar_mul logic analytically. +/// +/// Returns (n_add, n_sub, n_mul, n_inv) per single scalar multiplication. +fn count_scalar_mul_field_ops(scalar_bits: usize, window_size: usize) -> (usize, usize, usize, usize) { + let w = window_size; + let table_size = 1 << w; + let num_windows = (scalar_bits + w - 1) / w; + + // Build point table: T[0]=P (free), T[1]=P (free), T[2]=2P (1 double), + // T[3..table_size] = point_add each + let table_doubles = if table_size > 2 { 1 } else { 0 }; + let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; + + // point_double costs: 5 mul, 4 add, 2 sub, 1 inv + let double_ops = (4usize, 2usize, 5usize, 1usize); // (add, sub, mul, inv) + // point_add costs: 2 add, 2 sub, 3 mul, 1 inv + let add_ops = (2usize, 2usize, 3usize, 1usize); + + // Table construction + let mut total_add = table_doubles * double_ops.0 + table_adds * add_ops.0; + let mut total_sub = table_doubles * double_ops.1 + table_adds * add_ops.1; + let mut total_mul = table_doubles * double_ops.2 + table_adds * add_ops.2; + let mut total_inv = table_doubles * double_ops.3 + table_adds * add_ops.3; + + // Table lookups: each uses (2^w - 1) point_selects + // point_select = 2 selects = 2 * (3 witnesses: diff, flag*diff, out) per coordinate + // But select is not a field op — it's cheaper (just `select` calls) + // We count it as 2 selects per point_select = 2 sub + 2 mul per select + // Actually select = flag*(on_true - on_false) + on_false: 1 sub, 1 mul, 1 add per elem + // Per point (x,y): 2 sub, 2 mul, 2 add for select + let selects_per_lookup = table_size - 1; // 2^w - 1 point_selects + let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); // (add, sub, mul, inv) + + // MSB window: 1 table lookup (possibly smaller table) + let msb_bits = scalar_bits - (num_windows - 1) * w; + let msb_table_size = 1 << msb_bits; + let msb_selects = msb_table_size - 1; + total_add += msb_selects * select_ops_per_point.0; + total_sub += msb_selects * select_ops_per_point.1; + total_mul += msb_selects * select_ops_per_point.2; + + // Remaining windows: for each of (num_windows - 1) windows: + // - w doublings + // - 1 pack_bits (cheap) + // - 1 is_zero (1 inv + some adds) + // - 1 table lookup + // - 1 sub (for denom) + // - 1 elem_is_zero + // - 1 point_double (for x_eq case) + // - 1 safe_point_add (like point_add but with select on denom) + // - 2 point_selects (x_eq and digit_is_zero) + let remaining = if num_windows > 1 { num_windows - 1 } else { 0 }; + + for _ in 0..remaining { + // w doublings + total_add += w * double_ops.0; + total_sub += w * double_ops.1; + total_mul += w * double_ops.2; + total_inv += w * double_ops.3; + + // table lookup + total_add += selects_per_lookup * select_ops_per_point.0; + total_sub += selects_per_lookup * select_ops_per_point.1; + total_mul += selects_per_lookup * select_ops_per_point.2; + + // denom = sub(looked_up.x, acc.x) + total_sub += 1; + + // elem_is_zero(denom) = is_zero per limb + products + // For N limbs: N * (1 inv + some arith) + (N-1) products + // Simplified: 1 inv + 3 witnesses + total_inv += 1; + total_add += 1; + total_mul += 1; + + // point_double for x_eq case + total_add += double_ops.0; + total_sub += double_ops.1; + total_mul += double_ops.2; + total_inv += double_ops.3; + + // safe_point_add: like point_add + 1 select on denom + total_add += add_ops.0 + select_ops_per_point.0 / 2; // 1 select + total_sub += add_ops.1 + select_ops_per_point.1 / 2; + total_mul += add_ops.2 + select_ops_per_point.2 / 2; + total_inv += add_ops.3; + + // 2 point_selects + total_add += 2 * select_ops_per_point.0; + total_sub += 2 * select_ops_per_point.1; + total_mul += 2 * select_ops_per_point.2; + + // is_zero(digit) + total_inv += 1; + total_add += 1; + total_mul += 1; + } + + (total_add, total_sub, total_mul, total_inv) +} + +/// Witnesses per single N-limb field operation. +fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize { + if is_native { + // Native: no range checks, just standard R1CS witnesses + match op { + FieldOpType::Add => 1, // sum witness + FieldOpType::Sub => 1, // sum witness + FieldOpType::Mul => 1, // product witness + FieldOpType::Inv => 1, // inverse witness + } + } else if num_limbs == 1 { + // Single-limb non-native: reduce_mod_p pattern + match op { + FieldOpType::Add => 5, // a+b, m const, k, k*m, result + FieldOpType::Sub => 5, // same + FieldOpType::Mul => 5, // a*b, m const, k, k*m, result + FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + } + } else { + // Multi-limb: N-limb operations + let n = num_limbs; + match op { + // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, d_limb) + FieldOpType::Add | FieldOpType::Sub => 1 + 3 * n + 3 * n, + // mul: hint(4N-2) + N² products + 2N-1 column constraints + lt_check + FieldOpType::Mul => (4 * n - 2) + n * n + 3 * n, + // inv: hint(N) + mul costs + FieldOpType::Inv => n + (4 * n - 2) + n * n + 3 * n, + } + } +} + +/// Total estimated witness cost for one scalar_mul. +pub fn calculate_msm_witness_cost( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + window_size: usize, + limb_bits: u32, +) -> usize { + let is_native = curve_modulus_bits == native_field_bits; + let num_limbs = if is_native { + 1 + } else { + ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) + }; + + let (n_add, n_sub, n_mul, n_inv) = count_scalar_mul_field_ops(scalar_bits, window_size); + + let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); + let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); + let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); + let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); + + let per_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Scalar decomposition: 256 bits (bit witnesses + digital decomposition overhead) + let scalar_decomp = 256 + 10; + + // Point accumulation: (n_points - 1) point_adds + let accum_per_point = if n_points > 1 { + let accum_adds = n_points - 1; + accum_adds * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 + + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + } else { + 0 + }; + + n_points * (per_scalarmul + scalar_decomp) + accum_per_point +} + +/// Check whether schoolbook column equation values fit in the native field. +/// +/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` via +/// column equations that include product sums, carry offsets, and outgoing carries. +/// Both sides of each column equation must evaluate to less than the native field +/// modulus as **integers** — if they overflow, the field's modular reduction makes +/// `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking soundness. +/// +/// The maximum integer value across either side of any column equation is bounded by: +/// +/// `2^(2W + ceil(log2(N)) + 3)` +/// +/// where `W = limb_bits` and `N = num_limbs`. This accounts for: +/// - Up to N cross-products per column, each < `2^(2W)` +/// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) +/// - Outgoing carry term `2^W * offset_carry` on the RHS +/// +/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, the +/// conservative soundness condition is: +/// +/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +pub fn column_equation_fits_native_field( + native_field_bits: u32, + limb_bits: u32, + num_limbs: usize, +) -> bool { + if num_limbs <= 1 { + return true; // Single-limb path has no column equations. + } + let ceil_log2_n = (num_limbs as f64).log2().ceil() as u32; + // Max column value < 2^(2*limb_bits + ceil_log2_n + 3). + // Need this < p_native >= 2^(native_field_bits - 1). + 2 * limb_bits + ceil_log2_n + 3 < native_field_bits +} + +/// Search for optimal (limb_bits, window_size) minimizing witness cost. +/// +/// Searches limb_bits ∈ [8..max] and window_size ∈ [2..8]. +/// Each candidate is checked for column equation soundness: the schoolbook +/// multiplication's intermediate values must fit in the native field without +/// modular wraparound (see [`column_equation_fits_native_field`]). +pub fn get_optimal_msm_params( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, +) -> (u32, usize) { + let is_native = curve_modulus_bits == native_field_bits; + if is_native { + // For native field, limb_bits doesn't matter (no multi-limb decomposition). + // Just optimize window_size. + let mut best_cost = usize::MAX; + let mut best_window = 4; + for ws in 2..=8 { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + native_field_bits, + ); + if cost < best_cost { + best_cost = cost; + best_window = ws; + } + } + return (native_field_bits, best_window); + } + + // Upper bound on search: even with N=2 (best case), we need + // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) / 2. + // The per-candidate soundness check below is the actual gate. + let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; + let mut best_cost = usize::MAX; + let mut best_limb_bits = max_limb_bits.min(86); + let mut best_window = 4; + + // Search space + for lb in (8..=max_limb_bits).step_by(2) { + let num_limbs = + ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { + continue; + } + for ws in 2..=8usize { + let cost = calculate_msm_witness_cost( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + ws, + lb, + ); + if cost < best_cost { + best_cost = cost; + best_limb_bits = lb; + best_window = ws; + } + } + } + + (best_limb_bits, best_window) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_optimal_params_bn254_native() { + // Grumpkin over BN254: native field + let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256); + assert_eq!(limb_bits, 254); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_secp256r1() { + // secp256r1 over BN254: 256-bit modulus, non-native + let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256); + let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_optimal_params_goldilocks() { + // Hypothetical 64-bit field over BN254 + let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64); + let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; + assert!( + column_equation_fits_native_field(254, limb_bits, num_limbs), + "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" + ); + assert!(window_size >= 2 && window_size <= 8); + } + + #[test] + fn test_count_field_ops_sanity() { + let (add, sub, mul, inv) = count_scalar_mul_field_ops(256, 4); + assert!(add > 0); + assert!(sub > 0); + assert!(mul > 0); + assert!(inv > 0); + } + + #[test] + fn test_column_equation_soundness_boundary() { + // For BN254 (254 bits) with N=3: max safe limb_bits is 124. + // 2*124 + ceil(log2(3)) + 3 = 248 + 2 + 3 = 253 < 254 ✓ + assert!(column_equation_fits_native_field(254, 124, 3)); + // 2*125 + ceil(log2(3)) + 3 = 250 + 2 + 3 = 255 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 125, 3)); + // 2*126 + ceil(log2(3)) + 3 = 252 + 2 + 3 = 257 ≥ 254 ✗ + assert!(!column_equation_fits_native_field(254, 126, 3)); + } + + #[test] + fn test_secp256r1_limb_bits_not_126() { + // Regression: limb_bits=126 with N=3 causes offset_w = 2^255 > p_BN254, + // making the schoolbook column equations unsound. + let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256); + assert!( + limb_bits <= 124, + "secp256r1 limb_bits={limb_bits} exceeds safe maximum 124" + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index d4d0d247b..53a1340f8 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,9 +1,6 @@ use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - provekit_common::{ - witness::{ConstantTerm, WitnessBuilder}, - FieldElement, - }, + ark_ff::{BigInteger, PrimeField}, + provekit_common::FieldElement, }; pub struct CurveParams { @@ -15,38 +12,91 @@ pub struct CurveParams { } impl CurveParams { - pub fn p_lo_fe(&self) -> FieldElement { - decompose_128(self.field_modulus_p).0 + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` width each. + pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) } - pub fn p_hi_fe(&self) -> FieldElement { - decompose_128(self.field_modulus_p).1 + + /// Decompose (p - 1) into `num_limbs` limbs of `limb_bits` width each. + pub fn p_minus_1_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + let p_minus_1 = sub_one_u64_4(&self.field_modulus_p); + decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) + } + + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` width. + pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) } - pub fn p_86_limbs(&self) -> [FieldElement; 3] { - let mask_86: u128 = (1u128 << 86) - 1; - let lo128 = self.field_modulus_p[0] as u128 | ((self.field_modulus_p[1] as u128) << 64); - let hi128 = self.field_modulus_p[2] as u128 | ((self.field_modulus_p[3] as u128) << 64); - let l0 = lo128 & mask_86; - // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 - let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; - // l2 = bits [172..256): 84 bits from hi128 - let l2 = hi128 >> 44; - [ - FieldElement::from(l0), - FieldElement::from(l1), - FieldElement::from(l2), - ] + + /// Number of bits in the field modulus. + pub fn modulus_bits(&self) -> u32 { + if self.is_native_field() { + // p mod p = 0 as a field element, so we use the constant directly. + FieldElement::MODULUS_BIT_SIZE + } else { + let fe = curve_native_point_fe(&self.field_modulus_p); + fe.into_bigint().num_bits() + } + } + + /// Returns true if the curve's base field modulus equals the native BN254 + /// scalar field modulus. + pub fn is_native_field(&self) -> bool { + let native_mod = FieldElement::MODULUS; + self.field_modulus_p == native_mod.0 } + + /// Convert modulus to a native field element (only valid when p < native modulus). pub fn p_native_fe(&self) -> FieldElement { curve_native_point_fe(&self.field_modulus_p) } } -/// Splits a 256-bit value ([u64; 4]) into two 128-bit field elements (lo, hi). -fn decompose_128(val: [u64; 4]) -> (FieldElement, FieldElement) { - ( - FieldElement::from((val[0] as u128) | ((val[1] as u128) << 64)), - FieldElement::from((val[2] as u128) | ((val[3] as u128) << 64)), - ) +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, +/// returned as FieldElements. +fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + let mask: u128 = if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + }; + let mut result = vec![FieldElement::from(0u64); num_limbs]; + let mut remaining = *val; + for item in result.iter_mut() { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + *item = FieldElement::from(lo & mask); + // Shift remaining right by limb_bits + if limb_bits >= 256 { + remaining = [0; 4]; + } else { + let mut shifted = [0u64; 4]; + let word_shift = (limb_bits / 64) as usize; + let bit_shift = limb_bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = remaining[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= remaining[i + word_shift + 1] << (64 - bit_shift); + } + } + } + remaining = shifted; + } + } + result +} + +/// Subtract 1 from a [u64; 4] value. +fn sub_one_u64_4(val: &[u64; 4]) -> [u64; 4] { + let mut result = *val; + for limb in result.iter_mut() { + if *limb > 0 { + *limb -= 1; + return result; + } + *limb = u64::MAX; // borrow + } + result } /// Converts a 256-bit value ([u64; 4]) into a single native field element. @@ -54,21 +104,46 @@ pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { FieldElement::from_sign_and_limbs(true, val) } -#[derive(Clone, Copy, Debug)] -pub struct Limb2 { - pub lo: usize, - pub hi: usize, -} - -pub fn limb2_constant(r1cs_compiler: &mut NoirToR1CSCompiler, value: [u64; 4]) -> Limb2 { - let (lo, hi) = decompose_128(value); - let lo_idx = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(lo_idx, lo))); - let hi_idx = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(hi_idx, hi))); - Limb2 { - lo: lo_idx, - hi: hi_idx, +/// Grumpkin curve parameters. +/// +/// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 +/// scalar field, and its order is the BN254 base field order. +/// +/// Equation: y² = x³ − 17 (a = 0, b = −17 mod p) +pub fn grumpkin_params() -> CurveParams { + CurveParams { + // BN254 scalar field modulus + field_modulus_p: [ + 0x43e1f593f0000001_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // BN254 base field modulus + curve_order_n: [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + curve_a: [0; 4], + // b = −17 mod p + curve_b: [ + 0x43e1f593effffff0_u64, + 0x2833e84879b97091_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ], + // Generator G = (1, sqrt(−16) mod p) + generator: ( + [1, 0, 0, 0], + [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ], + ), } } diff --git a/provekit/r1cs-compiler/src/msm/ec_ops.rs b/provekit/r1cs-compiler/src/msm/ec_ops.rs deleted file mode 100644 index 985937821..000000000 --- a/provekit/r1cs-compiler/src/msm/ec_ops.rs +++ /dev/null @@ -1,208 +0,0 @@ -use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, - provekit_common::{ - witness::{SumTerm, WitnessBuilder}, - FieldElement, - }, - std::collections::BTreeMap, -}; - -/// Reduce the value to given modulus -pub fn reduce_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - value: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - // Reduce mod algorithm : - // v = k * m + result, where 0 <= result < m - // k = floor(v / m) (integer division) - // result = v - k * m - - // Computing k = floor(v / m) - // ----------------------------------------------------------- - // computing m (constant witness for use in constraints) - let m = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Constant( - provekit_common::witness::ConstantTerm(m, modulus), - )); - // computing k via integer division - let k = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); - - // Computing result = v - k * m - // ----------------------------------------------------------- - // computing k * m - let k_mul_m = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); - // constraint: k * m = k_mul_m - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( - FieldElement::ONE, - k_mul_m, - )]); - // computing result = v - k * m - let result = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ - SumTerm(Some(FieldElement::ONE), value), - SumTerm(Some(-FieldElement::ONE), k_mul_m), - ])); - // constraint: 1 * (k_mul_m + result) = value - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], - &[(FieldElement::ONE, value)], - ); - // range check to prove 0 <= result < m - let modulus_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(modulus_bits) - .or_insert_with(Vec::new) - .push(result); - - result -} - -/// a + b mod p -pub fn add_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_add_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ - SumTerm(Some(FieldElement::ONE), a), - SumTerm(Some(FieldElement::ONE), b), - ])); - // constraint: a + b = a_add_b - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, a), (FieldElement::ONE, b)], - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, a_add_b)], - ); - reduce_mod_p(r1cs_compiler, a_add_b, modulus, range_checks) -} - -/// a * b mod p -pub fn mul_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_mul_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); - // constraint: a * b = a_mul_b - r1cs_compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( - FieldElement::ONE, - a_mul_b, - )]); - reduce_mod_p(r1cs_compiler, a_mul_b, modulus, range_checks) -} - -/// (a - b) mod p -pub fn sub_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - b: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - let a_sub_b = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ - SumTerm(Some(FieldElement::ONE), a), - SumTerm(Some(-FieldElement::ONE), b), - ])); - // constraint: 1 * (a - b) = a_sub_b - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], - &[(FieldElement::ONE, a_sub_b)], - ); - reduce_mod_p(r1cs_compiler, a_sub_b, modulus, range_checks) -} - -/// a^(-1) mod p -pub fn inv_mod_p( - r1cs_compiler: &mut NoirToR1CSCompiler, - a: usize, - modulus: FieldElement, - range_checks: &mut BTreeMap>, -) -> usize { - // Computing a^(-1) mod m - // ----------------------------------------------------------- - // computing a_inv (the F_m inverse of a) via Fermat's little theorem - let a_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); - - // Verifying a * a_inv mod m = 1 - // ----------------------------------------------------------- - // computing a * a_inv mod m - let reduced = mul_mod_p(r1cs_compiler, a, a_inv, modulus, range_checks); - - // constraint: reduced = 1 - // (reduced - 1) * 1 = 0 - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, r1cs_compiler.witness_one())], - &[ - (FieldElement::ONE, reduced), - (-FieldElement::ONE, r1cs_compiler.witness_one()), - ], - &[(FieldElement::ZERO, r1cs_compiler.witness_one())], - ); - - // range check: a_inv in [0, 2^bits(m)) - let mod_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(mod_bits) - .or_insert_with(Vec::new) - .push(a_inv); - - a_inv -} - -/// checks if value is zero or not -pub fn compute_is_zero(r1cs_compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { - // calculating v^(-1) - let value_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Inverse(value_inv, value)); - // calculating v * v^(-1) - let value_mul_value_inv = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(WitnessBuilder::Product( - value_mul_value_inv, - value, - value_inv, - )); - // calculate is_zero = 1 - (v * v^(-1)) - let is_zero = r1cs_compiler.num_witnesses(); - r1cs_compiler.add_witness_builder(provekit_common::witness::WitnessBuilder::Sum( - is_zero, - vec![ - provekit_common::witness::SumTerm(Some(FieldElement::ONE), r1cs_compiler.witness_one()), - provekit_common::witness::SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), - ], - )); - // constraint: v × v^(-1) = 1 - is_zero - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, value)], - &[(FieldElement::ONE, value_inv)], - &[ - (FieldElement::ONE, r1cs_compiler.witness_one()), - (-FieldElement::ONE, is_zero), - ], - ); - // constraint: v × is_zero = 0 - r1cs_compiler.r1cs.add_constraint( - &[(FieldElement::ONE, value)], - &[(FieldElement::ONE, is_zero)], - &[(FieldElement::ZERO, r1cs_compiler.witness_one())], - ); - is_zero -} diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index d607d25ff..14712c78c 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -99,3 +99,186 @@ pub fn point_select( let y = ops.select(flag, on_false.1, on_true.1); (x, y) } + +/// Point addition with safe denominator for the `x1 = x2` edge case. +/// +/// When `x_eq = 1`, the denominator `(x2 - x1)` is zero and cannot be +/// inverted. This function replaces it with 1, producing a satisfiable +/// but meaningless result. The caller MUST discard this result via +/// `point_select` when `x_eq = 1`. +/// +/// The `denom` parameter is the precomputed `x2 - x1`. +fn safe_point_add( + ops: &mut F, + x1: F::Elem, + y1: F::Elem, + x2: F::Elem, + y2: F::Elem, + denom: F::Elem, + x_eq: usize, +) -> (F::Elem, F::Elem) { + let numerator = ops.sub(y2, y1); + + // When x_eq=1 (denom=0), substitute with 1 to keep inv satisfiable + let one = ops.constant_one(); + let safe_denom = ops.select(x_eq, denom, one); + + let denom_inv = ops.inv(safe_denom); + let lambda = ops.mul(numerator, denom_inv); + + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Builds a point table for windowed scalar multiplication. +/// +/// T[0] = P (dummy entry, used when window digit = 0) +/// T[1] = P, T[2] = 2P, T[i] = T[i-1] + P for i >= 3. +fn build_point_table( + ops: &mut F, + px: F::Elem, + py: F::Elem, + table_size: usize, +) -> Vec<(F::Elem, F::Elem)> { + assert!(table_size >= 2); + let mut table = Vec::with_capacity(table_size); + table.push((px, py)); // T[0] = P (dummy) + table.push((px, py)); // T[1] = P + if table_size > 2 { + table.push(point_double(ops, px, py)); // T[2] = 2P + for i in 3..table_size { + let prev = table[i - 1]; + table.push(point_add(ops, prev.0, prev.1, px, py)); + } + } + table +} + +/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * 2^i`. +/// +/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, +/// halving the candidate set at each level. Total: `(2^w - 1)` point selects +/// for a table of `2^w` entries. +fn table_lookup( + ops: &mut F, + table: &[(F::Elem, F::Elem)], + bits: &[usize], +) -> (F::Elem, F::Elem) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select(ops, bit, current[i], current[i + half])); + } + current = next; + } + current[0] +} + +/// Windowed scalar multiplication: computes `[scalar] * P`. +/// +/// Takes pre-decomposed scalar bits (LSB first, `scalar_bits[0]` is the +/// least significant bit) and a window size `w`. Precomputes a table of +/// `2^w` point multiples and processes the scalar in `w`-bit windows from +/// MSB to LSB. +/// +/// Handles two edge cases: +/// 1. **MSB window digit = 0**: The accumulator is initialized from T[0] +/// (a dummy copy of P). An `acc_is_identity` flag tracks that no real +/// point has been accumulated yet. When the first non-zero window digit +/// is encountered, the looked-up point becomes the new accumulator. +/// 2. **x-coordinate collision** (`acc.x == looked_up.x`): Uses +/// `point_double` instead of `point_add`, with `safe_point_add` +/// guarding the zero denominator. +/// +/// The inverse-point case (`acc = -looked_up`, result is infinity) cannot +/// be represented in affine coordinates and remains unsupported — this has +/// negligible probability (~2^{-256}) for random scalars. +pub fn scalar_mul( + ops: &mut F, + px: F::Elem, + py: F::Elem, + scalar_bits: &[usize], + window_size: usize, +) -> (F::Elem, F::Elem) { + let n = scalar_bits.len(); + let w = window_size; + let table_size = 1 << w; + + // Build point table: T[i] = [i]P, with T[0] = P as dummy + let table = build_point_table(ops, px, py, table_size); + + // Number of windows (ceiling division) + let num_windows = (n + w - 1) / w; + + // Process MSB window first (may be shorter than w bits if n % w != 0) + let msb_start = (num_windows - 1) * w; + let msb_bits = &scalar_bits[msb_start..n]; + let msb_table = &table[..1 << msb_bits.len()]; + let mut acc = table_lookup(ops, msb_table, msb_bits); + + // Track whether acc represents the identity (no real point yet). + // When MSB digit = 0, T[0] = P is loaded as a dummy — we must not + // double or add it until the first non-zero window digit appears. + let msb_digit = ops.pack_bits(msb_bits); + let mut acc_is_identity = ops.is_zero(msb_digit); + + // Process remaining windows from MSB-1 down to LSB + for i in (0..num_windows - 1).rev() { + // w doublings — only meaningful when acc is a real point. + // When acc_is_identity=1, the doubling result is garbage but will + // be discarded by the point_select below. + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); + } + // If acc is identity, keep dummy; otherwise use doubled result + acc = point_select(ops, acc_is_identity, doubled_acc, acc); + + // Table lookup for this window's digit + let window_bits = &scalar_bits[i * w..(i + 1) * w]; + let digit = ops.pack_bits(window_bits); + let digit_is_zero = ops.is_zero(digit); + + let looked_up = table_lookup(ops, &table, window_bits); + + // Detect x-coordinate collision: acc.x == looked_up.x + let denom = ops.sub(looked_up.0, acc.0); + let x_eq = ops.elem_is_zero(denom); + + // point_double handles the acc == looked_up case (same point) + let doubled = point_double(ops, acc.0, acc.1); + + // Safe point_add (substitutes denominator when x_eq=1) + let added = safe_point_add( + ops, acc.0, acc.1, looked_up.0, looked_up.1, denom, x_eq, + ); + + // x_eq=0 => use add result, x_eq=1 => use double result + let combined = point_select(ops, x_eq, added, doubled); + + // Four cases based on (acc_is_identity, digit_is_zero): + // (0, 0) => combined — normal add/double + // (0, 1) => acc — keep accumulator + // (1, 0) => looked_up — first real point + // (1, 1) => acc — still identity + let normal_result = point_select(ops, digit_is_zero, combined, acc); + let identity_result = point_select(ops, digit_is_zero, looked_up, acc); + acc = point_select(ops, acc_is_identity, normal_result, identity_result); + + // Update: acc is identity only if it was identity AND digit is zero + acc_is_identity = ops.bool_and(acc_is_identity, digit_is_zero); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index a155a6def..dda1e064a 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,113 +1,166 @@ +pub mod cost_model; pub mod curve; -pub mod ec_ops; pub mod ec_points; -pub mod wide_ops; +pub mod multi_limb_arith; +pub mod multi_limb_ops; use { - crate::noir_to_r1cs::NoirToR1CSCompiler, - ark_ff::Field, - curve::{curve_native_point_fe, limb2_constant, CurveParams, Limb2}, + crate::{ + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + curve::CurveParams, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{ - witness::{ConstantTerm, SumTerm, WitnessBuilder}, + witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, FieldElement, }, std::collections::BTreeMap, }; -pub trait FieldOps { - type Elem: Copy; +// --------------------------------------------------------------------------- +// Limbs: fixed-capacity, Copy array of witness indices +// --------------------------------------------------------------------------- - fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn inv(&mut self, a: Self::Elem) -> Self::Elem; - fn curve_a(&mut self) -> Self::Elem; +/// Maximum number of limbs supported. Covers all practical field sizes +/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). +pub const MAX_LIMBS: usize = 32; - /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if - /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). - fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; -} - -/// Narrow field operations for curves where p fits in BN254's scalar field. -/// Operates on single witness indices (`usize`). -pub struct NarrowOps<'a> { - pub compiler: &'a mut NoirToR1CSCompiler, - pub range_checks: &'a mut BTreeMap>, - pub modulus: FieldElement, - pub params: &'a CurveParams, +/// A fixed-capacity array of witness indices, indexed by limb position. +/// +/// This type is `Copy`, so it can be used as `FieldOps::Elem` without +/// requiring const generics or dispatch macros. The runtime `len` field +/// tracks how many limbs are actually in use. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, } -impl FieldOps for NarrowOps<'_> { - type Elem = usize; +impl Limbs { + /// Sentinel value for uninitialized limb slots. Using `usize::MAX` + /// ensures accidental use of an unfilled slot indexes an absurdly + /// large witness, causing an immediate out-of-bounds panic. + const UNINIT: usize = usize::MAX; - fn add(&mut self, a: usize, b: usize) -> usize { - ec_ops::add_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. + pub fn new(len: usize) -> Self { + assert!( + len > 0 && len <= MAX_LIMBS, + "limb count must be 1..={MAX_LIMBS}, got {len}" + ); + Self { + data: [Self::UNINIT; MAX_LIMBS], + len, + } } - fn sub(&mut self, a: usize, b: usize) -> usize { - ec_ops::sub_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self { + data: [Self::UNINIT; MAX_LIMBS], + len: 1, + }; + l.data[0] = value; + l } - fn mul(&mut self, a: usize, b: usize) -> usize { - ec_ops::mul_mod_p(self.compiler, a, b, self.modulus, self.range_checks) + /// Create `Limbs` from a slice of witness indices. + pub fn from_slice(s: &[usize]) -> Self { + assert!( + !s.is_empty() && s.len() <= MAX_LIMBS, + "slice length must be 1..={MAX_LIMBS}, got {}", + s.len() + ); + let mut data = [Self::UNINIT; MAX_LIMBS]; + data[..s.len()].copy_from_slice(s); + Self { data, len: s.len() } } - fn inv(&mut self, a: usize) -> usize { - ec_ops::inv_mod_p(self.compiler, a, self.modulus, self.range_checks) + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] } - fn curve_a(&mut self) -> usize { - let a_fe = curve_native_point_fe(&self.params.curve_a); - let w = self.compiler.num_witnesses(); - self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, a_fe))); - w + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len } +} - fn select(&mut self, flag: usize, on_false: usize, on_true: usize) -> usize { - constrain_boolean(self.compiler, flag); - select_witness(self.compiler, flag, on_false, on_true) +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() } } -/// Wide field operations for curves where p > BN254_r (e.g. secp256r1). -/// Operates on `Limb2` (two 128-bit limbs). -pub struct WideOps<'a> { - pub compiler: &'a mut NoirToR1CSCompiler, - pub range_checks: &'a mut BTreeMap>, - pub params: &'a CurveParams, +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } } +impl Eq for Limbs {} -impl FieldOps for WideOps<'_> { - type Elem = Limb2; - - fn add(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::add_mod_p(self.compiler, self.range_checks, a, b, self.params) +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] } +} - fn sub(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::sub_mod_p(self.compiler, self.range_checks, a, b, self.params) +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] } +} - fn mul(&mut self, a: Limb2, b: Limb2) -> Limb2 { - wide_ops::mul_mod_p(self.compiler, self.range_checks, a, b, self.params) - } +// --------------------------------------------------------------------------- +// FieldOps trait +// --------------------------------------------------------------------------- - fn inv(&mut self, a: Limb2) -> Limb2 { - wide_ops::inv_mod_p(self.compiler, self.range_checks, a, self.params) - } +pub trait FieldOps { + type Elem: Copy; - fn curve_a(&mut self) -> Limb2 { - limb2_constant(self.compiler, self.params.curve_a) - } + fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; + fn inv(&mut self, a: Self::Elem) -> Self::Elem; + fn curve_a(&mut self) -> Self::Elem; - fn select(&mut self, flag: usize, on_false: Limb2, on_true: Limb2) -> Limb2 { - constrain_boolean(self.compiler, flag); - Limb2 { - lo: select_witness(self.compiler, flag, on_false.lo, on_true.lo), - hi: select_witness(self.compiler, flag, on_false.hi, on_true.hi), - } - } + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; + + /// Checks if a BN254 native witness value is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + fn is_zero(&mut self, value: usize) -> usize; + + /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. + /// Does NOT constrain bits to be boolean — caller must ensure that. + fn pack_bits(&mut self, bits: &[usize]) -> usize; + + /// Checks if a field element (in the curve's base field) is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + fn elem_is_zero(&mut self, value: Self::Elem) -> usize; + + /// Returns the constant field element 1. + fn constant_one(&mut self) -> Self::Elem; + + /// Computes a * b for two boolean (0/1) native witnesses. + /// Used for boolean AND on flags in scalar_mul. + fn bool_and(&mut self, a: usize, b: usize) -> usize; } // --------------------------------------------------------------------------- @@ -115,7 +168,7 @@ impl FieldOps for WideOps<'_> { // --------------------------------------------------------------------------- /// Constrains `flag` to be boolean: `flag * flag = flag`. -fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { compiler.r1cs.add_constraint( &[(FieldElement::ONE, flag)], &[(FieldElement::ONE, flag)], @@ -125,15 +178,18 @@ fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { /// Single-witness conditional select: `out = on_false + flag * (on_true - /// on_false)`. -/// -/// Produces 3 witnesses and 3 R1CS constraints (diff, flag*diff, out). -/// Does NOT constrain `flag` to be boolean — caller must do that separately. -fn select_witness( +pub(crate) fn select_witness( compiler: &mut NoirToR1CSCompiler, flag: usize, on_false: usize, on_true: usize, ) -> usize { + // When both branches are the same witness, result is trivially that witness. + // Avoids duplicate column indices in R1CS from `on_true - on_false` when + // both share the same witness index. + if on_false == on_true { + return on_false; + } let diff = compiler.add_sum(vec![ SumTerm(None, on_true), SumTerm(Some(-FieldElement::ONE), on_false), @@ -141,3 +197,301 @@ fn select_witness( let flag_diff = compiler.add_product(flag, diff); compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) } + +/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +// --------------------------------------------------------------------------- +// Params builder (runtime num_limbs, no const generics) +// --------------------------------------------------------------------------- + +/// Build `MultiLimbParams` for a given runtime `num_limbs`. +fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiLimbParams { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + modulus_bits: curve.modulus_bits(), + is_native, + modulus_fe, + } +} + +// --------------------------------------------------------------------------- +// MSM entry point +// --------------------------------------------------------------------------- + +/// Processes all deferred MSM operations. +/// +/// Each entry is `(points, scalars, (out_x, out_y, out_inf))` where: +/// - `points` has layout `[x1, y1, inf1, x2, y2, inf2, ...]` (3 per point) +/// - `scalars` has layout `[s1_lo, s1_hi, s2_lo, s2_hi, ...]` (2 per scalar) +/// - outputs are the R1CS witness indices for the result point +pub fn add_msm( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + for (points, scalars, outputs) in msm_ops { + add_single_msm( + compiler, + &points, + &scalars, + outputs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +/// Processes a single MSM operation. +fn add_single_msm( + compiler: &mut NoirToR1CSCompiler, + points: &[ConstantOrR1CSWitness], + scalars: &[ConstantOrR1CSWitness], + outputs: (usize, usize, usize), + limb_bits: u32, + window_size: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + assert!( + points.len() % 3 == 0, + "points length must be a multiple of 3" + ); + let n = points.len() / 3; + assert_eq!( + scalars.len(), + 2 * n, + "scalars length must be 2x the number of points" + ); + + // Resolve all inputs to witness indices + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + let is_native = curve.is_native_field(); + let num_limbs = if is_native { + 1 + } else { + (curve.modulus_bits() as usize + limb_bits as usize - 1) / limb_bits as usize + }; + + process_single_msm( + compiler, + &point_wits, + &scalar_wits, + outputs, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +/// Process a full single-MSM with runtime `num_limbs`. +/// +/// Handles coordinate decomposition, scalar_mul, accumulation, and +/// output constraining. +fn process_single_msm<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + num_limbs: usize, + limb_bits: u32, + window_size: usize, + mut range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let n_points = point_wits.len() / 3; + let mut acc: Option<(Limbs, Limbs)> = None; + + for i in 0..n_points { + let px_witness = point_wits[3 * i]; + let py_witness = point_wits[3 * i + 1]; + + let s_lo = scalar_wits[2 * i]; + let s_hi = scalar_wits[2 * i + 1]; + let scalar_bits = decompose_scalar_bits(compiler, s_lo, s_hi); + + // Build coordinates as Limbs + let (px, py) = if num_limbs == 1 { + // Single-limb: wrap witness directly + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + // Multi-limb: decompose single witness into num_limbs limbs + let px_limbs = decompose_witness_to_limbs( + compiler, + px_witness, + limb_bits, + num_limbs, + range_checks, + ); + let py_limbs = decompose_witness_to_limbs( + compiler, + py_witness, + limb_bits, + num_limbs, + range_checks, + ); + (px_limbs, py_limbs) + }; + + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let result = ec_points::scalar_mul(&mut ops, px, py, &scalar_bits, window_size); + compiler = ops.compiler; + range_checks = ops.range_checks; + + acc = Some(match acc { + None => result, + Some((ax, ay)) => { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let sum = ec_points::point_add(&mut ops, ax, ay, result.0, result.1); + compiler = ops.compiler; + range_checks = ops.range_checks; + sum + } + }); + } + + let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); + let (out_x, out_y, out_inf) = outputs; + + if num_limbs == 1 { + constrain_equal(compiler, out_x, computed_x[0]); + constrain_equal(compiler, out_y, computed_y[0]); + } else { + let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); + constrain_equal(compiler, out_x, recomposed_x); + constrain_equal(compiler, out_y, recomposed_y); + } + constrain_zero(compiler, out_inf); +} + +/// Decompose a single witness into `num_limbs` limbs using digital +/// decomposition. +fn decompose_witness_to_limbs( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + limb_bits: u32, + num_limbs: usize, + range_checks: &mut BTreeMap>, +) -> Limbs { + let log_bases = vec![limb_bits as usize; num_limbs]; + let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); + let mut limbs = Limbs::new(num_limbs); + for i in 0..num_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + // Range-check each decomposed limb to [0, 2^limb_bits). + // add_digital_decomposition constrains the recomposition but does + // NOT range-check individual digits. + range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + limbs +} + +/// Recompose limbs back into a single witness: val = Σ limb[i] * +/// 2^(i*limb_bits) +fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { + let terms: Vec = limbs + .iter() + .enumerate() + .map(|(i, &limb)| { + let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); + SumTerm(Some(coeff), limb) + }) + .collect(); + compiler.add_sum(terms) +} + +/// Resolves a `ConstantOrR1CSWitness` to a witness index. +fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { + match input { + ConstantOrR1CSWitness::Witness(idx) => *idx, + ConstantOrR1CSWitness::Constant(value) => { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, *value))); + w + } + } +} + +/// Decomposes a scalar given as two 128-bit limbs into 256 bit witnesses (LSB +/// first). +fn decompose_scalar_bits( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, +) -> Vec { + let log_bases_128 = vec![1usize; 128]; + + let dd_lo = add_digital_decomposition(compiler, log_bases_128.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(compiler, log_bases_128, vec![s_hi]); + + let mut bits = Vec::with_capacity(256); + for bit_idx in 0..128 { + bits.push(dd_lo.get_digit_witness_index(bit_idx, 0)); + } + for bit_idx in 0..128 { + bits.push(dd_hi.get_digit_witness_index(bit_idx, 0)); + } + bits +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs new file mode 100644 index 000000000..ab84fc9b7 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -0,0 +1,591 @@ +//! N-limb modular arithmetic for EC field operations. +//! +//! Replaces both `ec_ops.rs` (N=1 path) and `wide_ops.rs` (N>1 path) with +//! unified multi-limb operations using `Limbs` (runtime-sized, Copy). + +use { + super::Limbs, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +// --------------------------------------------------------------------------- +// N=1 single-limb path (moved from ec_ops.rs) +// --------------------------------------------------------------------------- + +/// Reduce the value to given modulus (N=1 path). +/// Computes v = k*m + result, where 0 <= result < m. +pub fn reduce_mod_p( + compiler: &mut NoirToR1CSCompiler, + value: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(m, modulus), + )); + let k = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(k, value, modulus)); + + let k_mul_m = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(k_mul_m, k, m)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, k)], &[(FieldElement::ONE, m)], &[( + FieldElement::ONE, + k_mul_m, + )]); + + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(result, vec![ + SumTerm(Some(FieldElement::ONE), value), + SumTerm(Some(-FieldElement::ONE), k_mul_m), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, k_mul_m), (FieldElement::ONE, result)], + &[(FieldElement::ONE, value)], + ); + + let modulus_bits = modulus.into_bigint().num_bits(); + range_checks + .entry(modulus_bits) + .or_default() + .push(result); + + result +} + +/// a + b mod p (N=1 path) +pub fn add_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_add_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_add_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a), (FieldElement::ONE, b)], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a_add_b)], + ); + reduce_mod_p(compiler, a_add_b, modulus, range_checks) +} + +/// a * b mod p (N=1 path) +pub fn mul_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_mul_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product(a_mul_b, a, b)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ONE, + a_mul_b, + )]); + reduce_mod_p(compiler, a_mul_b, modulus, range_checks) +} + +/// (a - b) mod p (N=1 path) +pub fn sub_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + b: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_sub_b = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum(a_sub_b, vec![ + SumTerm(Some(FieldElement::ONE), a), + SumTerm(Some(-FieldElement::ONE), b), + ])); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, a_sub_b)], + ); + reduce_mod_p(compiler, a_sub_b, modulus, range_checks) +} + +/// a^(-1) mod p (N=1 path) +pub fn inv_mod_p_single( + compiler: &mut NoirToR1CSCompiler, + a: usize, + modulus: FieldElement, + range_checks: &mut BTreeMap>, +) -> usize { + let a_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::ModularInverse(a_inv, a, modulus)); + + let reduced = mul_mod_p_single(compiler, a, a_inv, modulus, range_checks); + + // Constrain reduced = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[ + (FieldElement::ONE, reduced), + (-FieldElement::ONE, compiler.witness_one()), + ], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + let mod_bits = modulus.into_bigint().num_bits(); + range_checks.entry(mod_bits).or_default().push(a_inv); + + a_inv +} + +/// Checks if value is zero or not (used by all N values). +/// Returns a boolean witness: 1 if zero, 0 if non-zero. +/// +/// Uses SafeInverse (not Inverse) because the input value may be zero. +/// SafeInverse outputs 0 when the input is 0, and is solved in the Other +/// layer (not batch-inverted), so zero inputs don't poison the batch. +pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize { + let value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SafeInverse(value_inv, value)); + + let value_mul_value_inv = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Product( + value_mul_value_inv, + value, + value_inv, + )); + + let is_zero = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Sum( + is_zero, + vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ], + )); + + // v × v^(-1) = 1 - is_zero + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, value_inv)], + &[ + (FieldElement::ONE, compiler.witness_one()), + (-FieldElement::ONE, is_zero), + ], + ); + // v × is_zero = 0 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, value)], + &[(FieldElement::ONE, is_zero)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + + is_zero +} + +// --------------------------------------------------------------------------- +// N≥2 multi-limb path (generalization of wide_ops.rs) +// --------------------------------------------------------------------------- + +/// (a + b) mod p for multi-limb values. +/// +/// Per limb i: v_i = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} +/// carry_i = floor(v_i / 2^W) +/// r[i] = v_i - carry_i * 2^W +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "add_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = floor((a + b) / p) ∈ {0, 1} + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + ); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(None, b[i]), + SumTerm(Some(two_pow_w), w1), + SumTerm(Some(-p_limbs[i]), q), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + // Compensate for previous 2^W offset + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_offset = compiler.add_sum(terms); + + // carry = floor(v_offset / 2^W) + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry, v_offset, two_pow_w, + )); + // r[i] = v_offset - carry * 2^W + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + + r +} + +/// (a - b) mod p for multi-limb values. +pub fn sub_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "sub_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + // Witness: q = (a < b) ? 1 : 0 + let q = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + // q is boolean + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + &[(FieldElement::ONE, q)], + ); + + let mut r = Limbs::new(n); + let mut carry_prev: Option = None; + + for i in 0..n { + // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} + let mut terms = vec![ + SumTerm(None, a[i]), + SumTerm(Some(-FieldElement::ONE), b[i]), + SumTerm(Some(p_limbs[i]), q), + SumTerm(Some(two_pow_w), w1), + ]; + if let Some(carry) = carry_prev { + terms.push(SumTerm(None, carry)); + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_offset = compiler.add_sum(terms); + + let carry = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + carry, v_offset, two_pow_w, + )); + r[i] = compiler.add_sum(vec![ + SumTerm(None, v_offset), + SumTerm(Some(-two_pow_w), carry), + ]); + carry_prev = Some(carry); + } + + less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + + r +} + +/// (a * b) mod p for multi-limb values using schoolbook multiplication. +/// +/// Verifies: a·b = p·q + r in base W = 2^limb_bits. +/// Column k: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET +/// = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W +pub fn mul_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "mul_mod_p_multi requires n >= 2, got n={n}"); + + // Soundness check: column equation values must not overflow the native field. + // The maximum value across either side of any column equation is bounded by + // 2^(2*limb_bits + ceil(log2(n)) + 3). This must be strictly less than the + // native field modulus p >= 2^(MODULUS_BIT_SIZE - 1). + { + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let max_bits = 2 * limb_bits + ceil_log2_n + 3; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs \ + requires {max_bits} bits, but native field is only {} bits. \ + Use smaller limb_bits.", + FieldElement::MODULUS_BIT_SIZE, + ); + } + + let w1 = compiler.witness_one(); + let num_carries = 2 * n - 2; + // Carry offset: 2^(limb_bits + ceil(log2(n)) + 1) + let extra_bits = ((n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + // offset_w = carry_offset * 2^limb_bits + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits - 1) + let offset_w_minus_carry = offset_w - carry_offset_fe; + + // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { + output_start: os, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + + // q[0..n), r[n..2n), carries[2n..4n-2) + let q: Vec = (0..n).map(|i| os + i).collect(); + let r_indices: Vec = (0..n).map(|i| os + n + i).collect(); + let cu: Vec = (0..num_carries).map(|i| os + 2 * n + i).collect(); + + // Step 2: Product witnesses for a[i]*b[j] (n² R1CS constraints) + let mut ab_products = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + ab_products[i][j] = compiler.add_product(a[i], b[j]); + } + } + + // Step 3: Column equations (2n-1 R1CS constraints) + for k in 0..(2 * n - 1) { + // LHS: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((FieldElement::ONE, ab_products[i][j_val as usize])); + } + } + // Add carry_{k-1} + if k > 0 { + lhs_terms.push((FieldElement::ONE, cu[k - 1])); + // Add offset_w - carry_offset for subsequent columns + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + // First column: add offset_w + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q[j_val as usize])); + } + } + if k < n { + rhs_terms.push((FieldElement::ONE, r_indices[k])); + } + if k < 2 * n - 2 { + rhs_terms.push((two_pow_w, cu[k])); + } else { + // Last column: RHS includes offset_w to balance the LHS offset + // LHS has: carry[k-1] + offset_w_minus_carry = true_carry + offset_w + // RHS needs: sum_pq[k] + offset_w (no outgoing carry at last column) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } + + // Step 4: less-than-p check and range checks on r + let mut r_limbs = Limbs::new(n); + for (i, &ri) in r_indices.iter().enumerate() { + r_limbs[i] = ri; + } + less_than_p_check_multi(compiler, range_checks, r_limbs, p_minus_1_limbs, two_pow_w, limb_bits); + + // Step 5: Range checks for q limbs and carries + for i in 0..n { + range_checks.entry(limb_bits).or_default().push(q[i]); + } + // Carry range: limb_bits + extra_bits + 1 (carry_offset_bits + 1) + let carry_range_bits = carry_offset_bits + 1; + for &c in &cu { + range_checks.entry(carry_range_bits).or_default().push(c); + } + + r_limbs +} + +/// a^(-1) mod p for multi-limb values. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, ..., 0]. +pub fn inv_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + let n = a.len(); + assert!(n >= 2, "inv_mod_p_multi requires n >= 2, got n={n}"); + + // Hint: compute inverse + let inv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { + output_start: inv_start, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + let mut inv = Limbs::new(n); + for i in 0..n { + inv[i] = inv_start + i; + } + + // Verify: a * inv mod p = [1, 0, ..., 0] + let product = mul_mod_p_multi( + compiler, + range_checks, + a, + inv, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ); + + // Constrain product[0] = 1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[0])], + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], + ); + // Constrain product[1..n] = 0 + for i in 1..n { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, product[i])], + &[(FieldElement::ONE, compiler.witness_one())], + &[], + ); + } + + inv +} + +/// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. +/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * 2^W +fn less_than_p_check_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + r: Limbs, + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) { + let n = r.len(); + let w1 = compiler.witness_one(); + let mut borrow_prev: Option = None; + + for i in 0..n { + // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev + let p_minus_1_plus_offset = p_minus_1_limbs[i] + two_pow_w; + let mut terms = vec![ + SumTerm(Some(p_minus_1_plus_offset), w1), + SumTerm(Some(-FieldElement::ONE), r[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + terms.push(SumTerm(Some(-FieldElement::ONE), w1)); + } + let v_diff = compiler.add_sum(terms); + + // borrow = floor(v_diff / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( + borrow, v_diff, two_pow_w, + )); + // d[i] = v_diff - borrow * 2^W + let d_i = compiler.add_sum(vec![ + SumTerm(None, v_diff), + SumTerm(Some(-two_pow_w), borrow), + ]); + + // Range check r[i] and d[i] + range_checks.entry(limb_bits).or_default().push(r[i]); + range_checks.entry(limb_bits).or_default().push(d_i); + + borrow_prev = Some(borrow); + } + + // Constrain final borrow = 0: if borrow_out != 0, then r > p-1 (i.e. r >= p), + // which would mean the result is not properly reduced. + if let Some(final_borrow) = borrow_prev { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, final_borrow)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); + } +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs new file mode 100644 index 000000000..4f1c45448 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -0,0 +1,275 @@ +//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime limb count. +//! +//! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling +//! arbitrary limb counts without const generics or dispatch macros. + +use { + super::{ + multi_limb_arith, + Limbs, + FieldOps, + }, + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Parameters for multi-limb field arithmetic. +pub struct MultiLimbParams { + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + pub modulus_bits: u32, + /// p = native field → skip mod reduction + pub is_native: bool, + /// For N=1 non-native: the modulus as a single FieldElement + pub modulus_fe: Option, +} + +/// Unified field operations struct parameterized by runtime limb count. +pub struct MultiLimbOps<'a> { + pub compiler: &'a mut NoirToR1CSCompiler, + pub range_checks: &'a mut BTreeMap>, + pub params: MultiLimbParams, +} + +impl MultiLimbOps<'_> { + fn is_native_single(&self) -> bool { + self.params.num_limbs == 1 && self.params.is_native + } + + fn is_non_native_single(&self) -> bool { + self.params.num_limbs == 1 && !self.params.is_native + } + + fn n(&self) -> usize { + self.params.num_limbs + } +} + +impl FieldOps for MultiLimbOps<'_> { + type Elem = Limbs; + + fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, merge into a single + // term with coefficient 2 to avoid duplicate column indices in + // the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let r = if a[0] == b[0] { + self.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::from(2u64)), a[0]), + ]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(None, b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::add_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::add_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + // When both operands are the same witness, a - a = 0. Use a + // single zero-coefficient term to avoid duplicate column indices. + let r = if a[0] == b[0] { + self.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::ZERO), a[0]), + ]) + } else { + self.compiler.add_sum(vec![ + SumTerm(None, a[0]), + SumTerm(Some(-FieldElement::ONE), b[0]), + ]) + }; + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::sub_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::sub_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); + if self.is_native_single() { + let r = self.compiler.add_product(a[0], b[0]); + Limbs::single(r) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::mul_mod_p_single( + self.compiler, a[0], b[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::mul_mod_p_multi( + self.compiler, + self.range_checks, + a, + b, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn inv(&mut self, a: Limbs) -> Limbs { + debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); + if self.is_native_single() { + let a_inv = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Inverse(a_inv, a[0])); + // a * a_inv = 1 + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, a[0])], + &[(FieldElement::ONE, a_inv)], + &[(FieldElement::ONE, self.compiler.witness_one())], + ); + Limbs::single(a_inv) + } else if self.is_non_native_single() { + let modulus = self.params.modulus_fe.unwrap(); + let r = multi_limb_arith::inv_mod_p_single( + self.compiler, a[0], modulus, self.range_checks, + ); + Limbs::single(r) + } else { + multi_limb_arith::inv_mod_p_multi( + self.compiler, + self.range_checks, + a, + &self.params.p_limbs, + &self.params.p_minus_1_limbs, + self.params.two_pow_w, + self.params.limb_bits, + &self.params.modulus_raw, + ) + } + } + + fn curve_a(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + self.params.curve_a_limbs[i], + ))); + out[i] = w; + } + out + } + + fn select( + &mut self, + flag: usize, + on_false: Limbs, + on_true: Limbs, + ) -> Limbs { + super::constrain_boolean(self.compiler, flag); + let n = self.n(); + let mut out = Limbs::new(n); + for i in 0..n { + out[i] = super::select_witness(self.compiler, flag, on_false[i], on_true[i]); + } + out + } + + fn is_zero(&mut self, value: usize) -> usize { + multi_limb_arith::compute_is_zero(self.compiler, value) + } + + fn pack_bits(&mut self, bits: &[usize]) -> usize { + super::pack_bits_helper(self.compiler, bits) + } + + fn elem_is_zero(&mut self, value: Limbs) -> usize { + let n = self.n(); + if n == 1 { + multi_limb_arith::compute_is_zero(self.compiler, value[0]) + } else { + // Check each limb is zero and AND the results together + let mut result = multi_limb_arith::compute_is_zero(self.compiler, value[0]); + for i in 1..n { + let limb_zero = multi_limb_arith::compute_is_zero(self.compiler, value[i]); + result = self.compiler.add_product(result, limb_zero); + } + result + } + } + + fn constant_one(&mut self) -> Limbs { + let n = self.n(); + let mut out = Limbs::new(n); + // limb[0] = 1 + let w0 = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w0, FieldElement::ONE))); + out[0] = w0; + // limb[1..n] = 0 + for i in 1..n { + let w = self.compiler.num_witnesses(); + self.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::ZERO, + ))); + out[i] = w; + } + out + } + + fn bool_and(&mut self, a: usize, b: usize) -> usize { + self.compiler.add_product(a, b) + } +} diff --git a/provekit/r1cs-compiler/src/msm/wide_ops.rs b/provekit/r1cs-compiler/src/msm/wide_ops.rs deleted file mode 100644 index 167d6f986..000000000 --- a/provekit/r1cs-compiler/src/msm/wide_ops.rs +++ /dev/null @@ -1,563 +0,0 @@ -use { - crate::{ - msm::curve::{CurveParams, Limb2}, - noir_to_r1cs::NoirToR1CSCompiler, - }, - ark_ff::Field, - provekit_common::{ - witness::{SumTerm, WitnessBuilder}, - FieldElement, - }, - std::collections::BTreeMap, -}; - -/// (a + b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Equation: a + b = q * p + r, where q ∈ {0, 1}, 0 ≤ r < p. -/// -/// Uses the offset trick to avoid negative intermediate values: -/// v_offset = a_lo + b_lo + 2^128 - q * p_lo (always ≥ 0) -/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} -/// r_lo = v_offset - carry_offset * 2^128 -/// r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi -/// -/// Less-than-p check (proves r < p): -/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) -/// -/// Constraints (7 total): -/// 1. q is boolean: q * q = q -/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * -/// 2^128 -/// 4. Column 1: r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi -/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 -/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi -/// -/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) -pub fn add_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // Witness: q = floor((a + b) / p) ∈ {0, 1} - // ----------------------------------------------------------- - let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideAddQuotient { - output: q, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - modulus: params.field_modulus_p, - }); - // constraining q to be boolean - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( - FieldElement::ONE, - q, - )]); - - // Computing r_lo: lower 128 bits of result - // ----------------------------------------------------------- - // v_offset = a_lo + b_lo + 2^128 - q * p_lo - // (2^128 offset ensures v_offset is always non-negative) - let v_offset = compiler.add_sum(vec![ - SumTerm(None, a.lo), - SumTerm(None, b.lo), - SumTerm(Some(two_128), w1), - SumTerm(Some(-p_lo_fe), q), - ]); - // computing carry_offset = floor(v_offset / 2^128) - let carry_offset = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry_offset, - v_offset, - two_128, - )); - // computing r_lo = v_offset - carry_offset * 2^128 - let r_lo = compiler.add_sum(vec![ - SumTerm(None, v_offset), - SumTerm(Some(-two_128), carry_offset), - ]); - - // Computing r_hi: upper 128 bits of result - // ----------------------------------------------------------- - // r_hi = a_hi + b_hi + carry_offset - 1 - q * p_hi - // (-1 compensates for the 2^128 offset added in the low column) - let r_hi = compiler.add_sum(vec![ - SumTerm(None, a.hi), - SumTerm(None, b.hi), - SumTerm(None, carry_offset), - SumTerm(Some(-FieldElement::ONE), w1), - SumTerm(Some(-p_hi_fe), q), - ]); - - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// (a - b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Equation: a - b + q * p = r, where q ∈ {0, 1}, 0 ≤ r < p. -/// q = 0 if a ≥ b (result is non-negative without correction) -/// q = 1 if a < b (add p to make result non-negative) -/// -/// Uses the offset trick to avoid negative intermediate values: -/// v_offset = a_lo - b_lo + q * p_lo + 2^128 (always ≥ 0) -/// carry_offset = floor(v_offset / 2^128) ∈ {0, 1, 2} -/// r_lo = v_offset - carry_offset * 2^128 -/// r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 -/// -/// Less-than-p check (proves r < p): -/// d_lo + d_hi * 2^128 = (p - 1) - r (all components ≥ 0) -/// -/// Constraints (7 total): -/// 1. q is boolean: q * q = q -/// 2-3. Column 0: v_offset defined, then r_lo = v_offset - carry_offset * -/// 2^128 -/// 4. Column 1: r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 -/// 5-6. LT check: v_diff defined, then d_lo = v_diff - borrow_compl * 2^128 -/// 7. LT check: d_hi = (p_hi - 1) + borrow_compl - r_hi -/// -/// Range checks: r_lo, r_hi, d_lo, d_hi (128-bit each) -pub fn sub_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // Witness: q = (a < b) ? 1 : 0 - // ----------------------------------------------------------- - let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideSubBorrow { - output: q, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - }); - // constraining q to be boolean - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( - FieldElement::ONE, - q, - )]); - - // Computing r_lo: lower 128 bits of result - // ----------------------------------------------------------- - // v_offset = a_lo - b_lo + q * p_lo + 2^128 - // (2^128 offset ensures v_offset is always non-negative) - let v_offset = compiler.add_sum(vec![ - SumTerm(None, a.lo), - SumTerm(Some(-FieldElement::ONE), b.lo), - SumTerm(Some(p_lo_fe), q), - SumTerm(Some(two_128), w1), - ]); - // computing carry_offset = floor(v_offset / 2^128) - let carry_offset = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry_offset, - v_offset, - two_128, - )); - // computing r_lo = v_offset - carry_offset * 2^128 - let r_lo = compiler.add_sum(vec![ - SumTerm(None, v_offset), - SumTerm(Some(-two_128), carry_offset), - ]); - - // Computing r_hi: upper 128 bits of result - // ----------------------------------------------------------- - // r_hi = a_hi - b_hi + q * p_hi + carry_offset - 1 - // (-1 compensates for the 2^128 offset added in the low column) - let r_hi = compiler.add_sum(vec![ - SumTerm(None, a.hi), - SumTerm(Some(-FieldElement::ONE), b.hi), - SumTerm(Some(p_hi_fe), q), - SumTerm(None, carry_offset), - SumTerm(Some(-FieldElement::ONE), w1), - ]); - - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// (a × b) mod p for 256-bit values in two 128-bit limbs. -/// -/// Verifies the integer identity `a * b = p * q + r` using schoolbook -/// multiplication in base W = 2^86 (86-bit limbs ensure all column -/// products < 2^172 ≪ BN254_r ≈ 2^254, so field equations = integer equations). -/// -/// Three layers of verification: -/// 1. Decomposition links: prove 86-bit witnesses match the 128-bit -/// inputs/outputs -/// 2. Column equations: prove a86 * b86 = p86 * q86 + r86 (integer) -/// 3. Less-than-p check: prove r < p -/// -/// Witness layout (MulModHint, 20 witnesses at output_start): -/// [0..2) q_lo, q_hi — quotient 128-bit limbs (unconstrained) -/// [2..4) r_lo, r_hi — remainder 128-bit limbs (OUTPUT) -/// [4..7) a86_0..2 — a in 86-bit limbs -/// [7..10) b86_0..2 — b in 86-bit limbs -/// [10..13) q86_0..2 — q in 86-bit limbs -/// [13..16) r86_0..2 — r in 86-bit limbs -/// [16..20) c0u..c3u — unsigned-offset carries (c_signed + 2^88) -/// -/// Constraints (26 total): -/// 9 decomposition links (a, b, r × 3 each) -/// 9 product witnesses (a_i × b_j) -/// 5 column equations -/// 3 less-than-p check -/// -/// Range checks (23 total): -/// 128-bit: r_lo, r_hi, d_lo, d_hi -/// 86-bit: a86_0, a86_1, b86_0, b86_1, q86_0, q86_1, r86_0, r86_1 -/// 84-bit: a86_2, b86_2, q86_2, r86_2 -/// 89-bit: c0u, c1u, c2u, c3u -/// 44-bit: carry_a, carry_b, carry_r -pub fn mul_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - a: Limb2, - b: Limb2, - params: &CurveParams, -) -> Limb2 { - let two_44 = FieldElement::from(2u64).pow([44u64]); - let two_86 = FieldElement::from(2u64).pow([86u64]); - let two_128 = FieldElement::from(2u64).pow([128u64]); - let offset_fe = FieldElement::from(2u64).pow([88u64]); // CARRY_OFFSET - let offset_w = FieldElement::from(2u64).pow([174u64]); // 2^88 * 2^86 - let offset_w_minus_1 = offset_w - offset_fe; // 2^88 * (2^86 - 1) - let [p0, p1, p2] = params.p_86_limbs(); - let w1 = compiler.witness_one(); - - // Step 1: Allocate MulModHint (20 witnesses) - // ----------------------------------------------------------- - let os = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::MulModHint { - output_start: os, - a_lo: a.lo, - a_hi: a.hi, - b_lo: b.lo, - b_hi: b.hi, - modulus: params.field_modulus_p, - }); - - // Witness indices - let r_lo = os + 2; - let r_hi = os + 3; - let a86 = [os + 4, os + 5, os + 6]; - let b86 = [os + 7, os + 8, os + 9]; - let q86 = [os + 10, os + 11, os + 12]; - let r86 = [os + 13, os + 14, os + 15]; - let cu = [os + 16, os + 17, os + 18, os + 19]; - - // Step 2: Decomposition consistency for a, b, r - // ----------------------------------------------------------- - decompose_check( - compiler, - range_checks, - a.lo, - a.hi, - a86, - two_86, - two_44, - two_128, - w1, - ); - decompose_check( - compiler, - range_checks, - b.lo, - b.hi, - b86, - two_86, - two_44, - two_128, - w1, - ); - decompose_check( - compiler, - range_checks, - r_lo, - r_hi, - r86, - two_86, - two_44, - two_128, - w1, - ); - - // Step 3: Product witnesses (9 R1CS constraints) - // ----------------------------------------------------------- - let ab00 = compiler.add_product(a86[0], b86[0]); - let ab01 = compiler.add_product(a86[0], b86[1]); - let ab10 = compiler.add_product(a86[1], b86[0]); - let ab02 = compiler.add_product(a86[0], b86[2]); - let ab11 = compiler.add_product(a86[1], b86[1]); - let ab20 = compiler.add_product(a86[2], b86[0]); - let ab12 = compiler.add_product(a86[1], b86[2]); - let ab21 = compiler.add_product(a86[2], b86[1]); - let ab22 = compiler.add_product(a86[2], b86[2]); - - // Step 4: Column equations (5 R1CS constraints) - // ----------------------------------------------------------- - // Identity: a*b = p*q + r in base W=2^86. - // Carries stored with unsigned offset: cu_i = c_i + 2^88. - // - // col0: ab00 + 2^174 = p0*q0 + r0 + W*cu0 - // col1: ab01 + ab10 + cu0 + (2^174-2^88) = p0*q1 + p1*q0 + r1 + W*cu1 - // col2: ab02+ab11+ab20 + cu1 + (2^174-2^88) = p0*q2+p1*q1+p2*q0 + r2 + W*cu2 - // col3: ab12 + ab21 + cu2 + (2^174-2^88) = p1*q2 + p2*q1 + W*cu3 - // col4: ab22 + cu3 = p2*q2 + 2^88 - - // col0 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, ab00), (offset_w, w1)], - &[(FieldElement::ONE, w1)], - &[(p0, q86[0]), (FieldElement::ONE, r86[0]), (two_86, cu[0])], - ); - - // col1 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab01), - (FieldElement::ONE, ab10), - (FieldElement::ONE, cu[0]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[ - (p0, q86[1]), - (p1, q86[0]), - (FieldElement::ONE, r86[1]), - (two_86, cu[1]), - ], - ); - - // col2 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab02), - (FieldElement::ONE, ab11), - (FieldElement::ONE, ab20), - (FieldElement::ONE, cu[1]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[ - (p0, q86[2]), - (p1, q86[1]), - (p2, q86[0]), - (FieldElement::ONE, r86[2]), - (two_86, cu[2]), - ], - ); - - // col3 - compiler.r1cs.add_constraint( - &[ - (FieldElement::ONE, ab12), - (FieldElement::ONE, ab21), - (FieldElement::ONE, cu[2]), - (offset_w_minus_1, w1), - ], - &[(FieldElement::ONE, w1)], - &[(p1, q86[2]), (p2, q86[1]), (two_86, cu[3])], - ); - - // col4 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, ab22), (FieldElement::ONE, cu[3])], - &[(FieldElement::ONE, w1)], - &[(p2, q86[2]), (offset_fe, w1)], - ); - - // Step 5: Less-than-p check (r < p) + 128-bit range checks on r_lo, r_hi - // ----------------------------------------------------------- - less_than_p_check(compiler, range_checks, r_lo, r_hi, params); - - // Step 6: Range checks (mul-specific) - // ----------------------------------------------------------- - // 86-bit: limbs 0 and 1 of a, b, q, r - for &idx in &[ - a86[0], a86[1], b86[0], b86[1], q86[0], q86[1], r86[0], r86[1], - ] { - range_checks.entry(86).or_default().push(idx); - } - - // 84-bit: limb 2 of a, b, q, r (bits [172..256) = 84 bits) - for &idx in &[a86[2], b86[2], q86[2], r86[2]] { - range_checks.entry(84).or_default().push(idx); - } - - // 89-bit: unsigned-offset carries (|c_signed| < 2^88, so c_unsigned ∈ [0, - // 2^89)) - for &idx in &cu { - range_checks.entry(89).or_default().push(idx); - } - - Limb2 { lo: r_lo, hi: r_hi } -} - -/// a^(-1) mod p for 256-bit values in two 128-bit limbs. -/// -/// Hint-and-verify pattern: -/// 1. Prover computes inv = a^(p-2) mod p (Fermat's little theorem) -/// 2. Circuit verifies a * inv mod p = 1 -/// -/// Constraints: 26 from mul_mod_p + 2 equality checks = 28 total. -pub fn inv_mod_p( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - value: Limb2, - params: &CurveParams, -) -> Limb2 { - // Witness: inv = a^(-1) mod p (2 witnesses: lo, hi) - // ----------------------------------------------------------- - let value_inv = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::WideModularInverse { - output_start: value_inv, - a_lo: value.lo, - a_hi: value.hi, - modulus: params.field_modulus_p, - }); - let inv = Limb2 { - lo: value_inv, - hi: value_inv + 1, - }; - - // Verifying a * inv mod p = 1 - // ----------------------------------------------------------- - // computing product = value * inv mod p - let product = mul_mod_p(compiler, range_checks, value, inv, params); - // constraining product_lo = 1 (because 1 = 1 + 0 * 2^128) - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, product.lo)], - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, compiler.witness_one())], - ); - // constraining product_hi = 0 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, product.hi)], - &[(FieldElement::ONE, compiler.witness_one())], - &[], - ); - - inv -} - -/// Verify that 128-bit limbs (v_lo, v_hi) decompose into 86-bit limbs (v86). -/// -/// Equations: -/// v_lo = v86_0 + v86_1 * 2^86 - carry * 2^128 -/// v_hi = carry + v86_2 * 2^44 -/// -/// All intermediate values < 2^172 ≪ BN254_r, so field equations = integer -/// equations. -/// -/// Creates: 1 intermediate witness (v_sum), 1 carry witness (IntegerQuotient). -/// Adds: 3 R1CS constraints (v_sum definition + 2 decomposition checks). -/// Range checks: carry (44-bit). -/// Proves r < p by decomposing (p - 1) - r into non-negative 128-bit limbs. -/// -/// If d_lo, d_hi >= 0 then (p - 1) - r >= 0, i.e. r <= p - 1 < p. -/// Uses the 2^128 offset trick to avoid negative intermediate values. -/// -/// Range checks r_lo, r_hi, d_lo, d_hi (128-bit each). -fn less_than_p_check( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - r_lo: usize, - r_hi: usize, - params: &CurveParams, -) { - let two_128 = FieldElement::from(2u64).pow([128u64]); - let p_lo_fe = params.p_lo_fe(); - let p_hi_fe = params.p_hi_fe(); - let w1 = compiler.witness_one(); - - // v_diff = (p_lo - 1) + 2^128 - r_lo - // (2^128 offset ensures v_diff is always non-negative) - let p_lo_minus_1_plus_offset = p_lo_fe - FieldElement::ONE + two_128; - let v_diff = compiler.add_sum(vec![ - SumTerm(Some(p_lo_minus_1_plus_offset), w1), - SumTerm(Some(-FieldElement::ONE), r_lo), - ]); - // borrow_compl = floor(v_diff / 2^128) - let borrow_compl = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - borrow_compl, - v_diff, - two_128, - )); - // d_lo = v_diff - borrow_compl * 2^128 - let d_lo = compiler.add_sum(vec![ - SumTerm(None, v_diff), - SumTerm(Some(-two_128), borrow_compl), - ]); - // d_hi = (p_hi - 1) + borrow_compl - r_hi - let d_hi = compiler.add_sum(vec![ - SumTerm(Some(p_hi_fe - FieldElement::ONE), w1), - SumTerm(None, borrow_compl), - SumTerm(Some(-FieldElement::ONE), r_hi), - ]); - - // Range checks (128-bit) - range_checks.entry(128).or_default().push(r_lo); - range_checks.entry(128).or_default().push(r_hi); - range_checks.entry(128).or_default().push(d_lo); - range_checks.entry(128).or_default().push(d_hi); -} - -fn decompose_check( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - v_lo: usize, - v_hi: usize, - v86: [usize; 3], - two_86: FieldElement, - two_44: FieldElement, - two_128: FieldElement, - w1: usize, -) { - // v_sum = v86_0 + v86_1 * 2^86 (intermediate for IntegerQuotient) - let v_sum = compiler.add_sum(vec![SumTerm(None, v86[0]), SumTerm(Some(two_86), v86[1])]); - - // carry = floor(v_sum / 2^128) ∈ [0, 2^44) - let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_sum, two_128)); - - // Low check: v_sum - carry * 2^128 = v_lo - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, v_sum), (-two_128, carry)], - &[(FieldElement::ONE, w1)], - &[(FieldElement::ONE, v_lo)], - ); - - // High check: carry + v86_2 * 2^44 = v_hi - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, carry), (two_44, v86[2])], - &[(FieldElement::ONE, w1)], - &[(FieldElement::ONE, v_hi)], - ); - - // Range check carry (44-bit) - range_checks.entry(44).or_default().push(carry); -} diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 2d4245636..18bc22ddc 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -3,6 +3,7 @@ use { binops::add_combined_binop_constraints, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, memory::{add_ram_checking, add_rom_checking, MemoryBlock, MemoryOperation}, + msm::add_msm, poseidon2::add_poseidon2_permutation, range_check::add_range_checks, sha256_compression::add_sha256_compression, @@ -16,7 +17,6 @@ use { Circuit, Opcode, }, native_types::{Expression, Witness as NoirWitness}, - BlackBoxFunc, }, anyhow::{bail, Result}, ark_ff::PrimeField, @@ -89,6 +89,11 @@ pub struct R1CSBreakdown { pub poseidon2_constraints: usize, /// Witnesses from Poseidon2 permutation pub poseidon2_witnesses: usize, + + /// Constraints from multi-scalar multiplication + pub msm_constraints: usize, + /// Witnesses from multi-scalar multiplication + pub msm_witnesses: usize, } /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, @@ -458,6 +463,7 @@ impl NoirToR1CSCompiler { let mut xor_ops = vec![]; let mut sha256_compression_ops = vec![]; let mut poseidon2_ops = vec![]; + let mut msm_ops = vec![]; let mut breakdown = R1CSBreakdown::default(); @@ -632,7 +638,20 @@ impl NoirToR1CSCompiler { points, scalars, outputs, - } => {} + } => { + let point_wits: Vec = points + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let scalar_wits: Vec = scalars + .iter() + .map(|inp| self.fetch_constant_or_r1cs_witness(inp.input())) + .collect(); + let out_x = self.fetch_r1cs_witness_index(outputs.0); + let out_y = self.fetch_r1cs_witness_index(outputs.1); + let out_inf = self.fetch_r1cs_witness_index(outputs.2); + msm_ops.push((point_wits, scalar_wits, (out_x, out_y, out_inf))); + } _ => { unimplemented!("Other black box function: {:?}", black_box_func_call); } @@ -724,6 +743,22 @@ impl NoirToR1CSCompiler { breakdown.poseidon2_constraints = self.r1cs.num_constraints() - constraints_before_poseidon; breakdown.poseidon2_witnesses = self.num_witnesses() - witnesses_before_poseidon; + let constraints_before_msm = self.r1cs.num_constraints(); + let witnesses_before_msm = self.num_witnesses(); + // Cost model: pick optimal (limb_bits, window_size) for MSM + let curve = crate::msm::curve::grumpkin_params(); + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let (msm_limb_bits, msm_window_size) = if !msm_ops.is_empty() { + let n_points: usize = msm_ops.iter().map(|(pts, _, _)| pts.len() / 3).sum(); + crate::msm::cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256) + } else { + (native_bits, 4) + }; + add_msm(self, msm_ops, msm_limb_bits, msm_window_size, &mut range_checks, &curve); + breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; + breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; + breakdown.range_ops_total = range_checks.values().map(|v| v.len()).sum(); let constraints_before_range = self.r1cs.num_constraints(); let witnesses_before_range = self.num_witnesses(); From 0f769a767a39fb4f2efe9b906de0ded2ef6dfd81 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 5 Mar 2026 04:48:33 +0530 Subject: [PATCH 04/19] feat : added gnark optimisations for msm --- .../src/witness/scheduling/dependency.rs | 22 +- .../common/src/witness/scheduling/remapper.rs | 28 + .../common/src/witness/witness_builder.rs | 37 +- provekit/prover/src/lib.rs | 15 +- provekit/prover/src/witness/bigint_mod.rs | 392 ++++++++++++- .../prover/src/witness/witness_builder.rs | 66 ++- provekit/r1cs-compiler/src/digits.rs | 4 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 221 ++++---- provekit/r1cs-compiler/src/msm/curve.rs | 488 +++++++++++++++- provekit/r1cs-compiler/src/msm/ec_points.rs | 185 +++--- provekit/r1cs-compiler/src/msm/mod.rs | 529 +++++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 163 +++--- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 115 ++-- provekit/r1cs-compiler/src/noir_to_r1cs.rs | 11 +- tooling/provekit-bench/tests/compiler.rs | 1 + 15 files changed, 1796 insertions(+), 481 deletions(-) diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 9f92afd75..98ae1368f 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -156,9 +156,7 @@ impl DependencyInfo { v } WitnessBuilder::MultiLimbMulModHint { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); @@ -166,18 +164,14 @@ impl DependencyInfo { } WitnessBuilder::MultiLimbModularInverse { a_limbs, .. } => a_limbs.clone(), WitnessBuilder::MultiLimbAddQuotient { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); v } WitnessBuilder::MultiLimbSubBorrow { - a_limbs, - b_limbs, - .. + a_limbs, b_limbs, .. } => { let mut v = a_limbs.clone(); v.extend(b_limbs); @@ -229,6 +223,10 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], + WitnessBuilder::EcScalarMulHint { + px, py, s_lo, s_hi, .. + } => vec![*px, *py, *s_lo, *s_hi], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -329,6 +327,12 @@ impl DependencyInfo { num_limbs, .. } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::FakeGLVHint { + output_start, .. + } => (*output_start..*output_start + 4).collect(), + WitnessBuilder::EcScalarMulHint { + output_start, .. + } => (*output_start..*output_start + 2).collect(), WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 334b5f401..696144113 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -366,6 +366,34 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => WitnessBuilder::FakeGLVHint { + output_start: self.remap(*output_start), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_order: *curve_order, + }, + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcScalarMulHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + s_lo: self.remap(*s_lo), + s_hi: self.remap(*s_hi), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 28d6d775c..2b7cd2f30 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -270,6 +270,37 @@ pub enum WitnessBuilder { packed: usize, chunk_bits: Vec, }, + /// Prover hint for FakeGLV scalar decomposition. + /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, + /// computes half_gcd(s, n) → (|s1|, |s2|, neg1, neg2) such that: + /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) + /// + /// Outputs 4 witnesses starting at output_start: + /// [0] |s1| (128-bit field element) + /// [1] |s2| (128-bit field element) + /// [2] neg1 (boolean: 0 or 1) + /// [3] neg2 (boolean: 0 or 1) + FakeGLVHint { + output_start: usize, + s_lo: usize, + s_hi: usize, + curve_order: [u64; 4], + }, + /// Prover hint for EC scalar multiplication: computes R = [s]P. + /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, + /// computes R = [s]P on the curve with parameter `curve_a` and + /// field modulus `field_modulus_p`. + /// + /// Outputs 2 witnesses at output_start: R_x, R_y. + EcScalarMulHint { + output_start: usize, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -332,10 +363,10 @@ impl WitnessBuilder { WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, - WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => { - (4 * *num_limbs - 2) as usize - } + WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::FakeGLVHint { .. } => 4, + WitnessBuilder::EcScalarMulHint { .. } => 2, _ => 1, } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 7aafa0d3c..b78fea288 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -206,7 +206,9 @@ impl Prove for NoirProver { let hc = debug_r1cs.c.hydrate(interner); let mut fail_count = 0usize; for row in 0..debug_r1cs.num_constraints() { - let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, r: usize| -> FieldElement { + let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, + r: usize| + -> FieldElement { let mut sum = FieldElement::zero(); for (col, coeff) in hm.iter_row(r) { sum += coeff * full_witness[col]; @@ -220,7 +222,11 @@ impl Prove for NoirProver { if fail_count < 10 { eprintln!( "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", - row, a_val, b_val, c_val, a_val * b_val + row, + a_val, + b_val, + c_val, + a_val * b_val ); eprint!(" A terms:"); for (col, coeff) in ha.iter_row(row) { @@ -242,7 +248,10 @@ impl Prove for NoirProver { } } if fail_count > 0 { - eprintln!("TOTAL FAILING CONSTRAINTS: {fail_count} / {}", debug_r1cs.num_constraints()); + eprintln!( + "TOTAL FAILING CONSTRAINTS: {fail_count} / {}", + debug_r1cs.num_constraints() + ); } else { eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); } diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/witness/bigint_mod.rs index a41f47ff3..2874d49a3 100644 --- a/provekit/prover/src/witness/bigint_mod.rs +++ b/provekit/prover/src/witness/bigint_mod.rs @@ -217,9 +217,317 @@ pub fn add_4limb(a: &[u64; 4], b: &[u64; 4]) -> [u64; 5] { result } -/// Offset added to signed carries to make them non-negative for range checking. -/// Carries are bounded by |c| < 2^88, so adding 2^88 ensures c_unsigned >= 0. -pub const CARRY_OFFSET: u128 = 1u128 << 88; +/// Add two 4-limb numbers in-place: a += b. Returns the carry-out. +pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { + let mut carry = 0u64; + for i in 0..4 { + let (s1, c1) = a[i].overflowing_add(b[i]); + let (s2, c2) = s1.overflowing_add(carry); + a[i] = s2; + carry = (c1 as u64) + (c2 as u64); + } + carry +} + +/// Subtract b from a in-place, returning true if a >= b (no underflow). +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns false. +pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + a[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + borrow == 0 +} + +/// Returns true if val == 0. +pub fn is_zero(val: &[u64; 4]) -> bool { + val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 +} + +/// Compute the number of bits needed for the half-GCD sub-scalars. +/// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. +pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { + let mut order_bits = 0u32; + for i in (0..4).rev() { + if n[i] != 0 { + order_bits = (i as u32) * 64 + (64 - n[i].leading_zeros()); + break; + } + } + (order_bits + 1) / 2 +} + +/// Build the threshold value `2^half_bits` as a `[u64; 4]`. +fn build_threshold(half_bits: u32) -> [u64; 4] { + assert!(half_bits <= 255, "half_bits must be <= 255"); + let mut threshold = [0u64; 4]; + let word = (half_bits / 64) as usize; + let bit = half_bits % 64; + threshold[word] = 1u64 << bit; + threshold +} + +/// Half-GCD scalar decomposition for FakeGLV. +/// +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such that: +/// `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// +/// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below +/// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. +/// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. +pub fn half_gcd( + s: &[u64; 4], + n: &[u64; 4], +) -> ([u64; 4], [u64; 4], bool, bool) { + // Extended GCD on (n, s): + // We track: r_{i} = r_{i-2} - q_i * r_{i-1} + // t_{i} = t_{i-2} - q_i * t_{i-1} + // Starting: r_0 = n, r_1 = s, t_0 = 0, t_1 = 1 + // + // We want: t_i * s ≡ r_i (mod n) [up to sign] + // More precisely: t_i * s ≡ (-1)^{i+1} * r_i (mod n) + // + // The relation we verify is: sign_r * |r_i| + sign_t * |t_i| * s ≡ 0 (mod n) + + // Threshold: 2^half_bits where half_bits = ceil(order_bits / 2) + let half_bits = half_gcd_bits(n); + let threshold = build_threshold(half_bits); + + // r_prev = n, r_curr = s + let mut r_prev = *n; + let mut r_curr = *s; + + // t_prev = 0, t_curr = 1 + let mut t_prev = [0u64; 4]; + let mut t_curr = [1u64, 0, 0, 0]; + + // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, positive) + let mut t_prev_neg = false; + let mut t_curr_neg = false; + + let mut iteration = 0u32; + + loop { + // Check if r_curr < threshold + if cmp_4limb(&r_curr, &threshold) == std::cmp::Ordering::Less { + break; + } + + if is_zero(&r_curr) { + break; + } + + // q = r_prev / r_curr, new_r = r_prev % r_curr + let (q, new_r) = divmod(&r_prev, &r_curr); + + // new_t = t_prev + q * t_curr (in terms of absolute values and signs) + // Since the GCD recurrence is: t_{i} = t_{i-2} - q_i * t_{i-1} + // In terms of absolute values with sign tracking: + // If t_prev and q*t_curr have the same sign → subtract magnitudes + // If they have different signs → add magnitudes + // But actually: new_t = |t_prev| +/- q * |t_curr|, with sign flips each iteration. + // + // The standard extended GCD recurrence gives: + // t_i = t_{i-2} - q_i * t_{i-1} + // We track magnitudes and sign bits separately. + + // Compute q * t_curr + let qt = mul_mod_no_reduce(&q, &t_curr); + + // new_t magnitude and sign: + // In the standard recurrence: new_t_val = t_prev_val - q * t_curr_val + // where t_prev_val = (-1)^t_prev_neg * |t_prev|, etc. + // + // But it's simpler to just track: alternating signs. + // In the half-GCD: t values alternate in sign. So: + // new_t = t_prev + q * t_curr (absolute addition since signs alternate) + let mut new_t = qt; + add_4limb_inplace(&mut new_t, &t_prev); + let new_t_neg = !t_curr_neg; + + r_prev = r_curr; + r_curr = new_r; + t_prev = t_curr; + t_prev_neg = t_curr_neg; + t_curr = new_t; + t_curr_neg = new_t_neg; + iteration += 1; + } + + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD property) + // The relation is: (-1)^(iteration) * r_curr + t_curr * s ≡ 0 (mod n) + // Or equivalently: r_curr ≡ (-1)^(iteration+1) * t_curr * s (mod n) + + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| + + // Determine signs: + // We need: neg1 * val1 + neg2 * val2 * s ≡ 0 (mod n) + // From the extended GCD: r_i = (-1)^i * (... some relation with t_i * s mod n) + // The exact sign relationship: + // t_i * s ≡ (-1)^(i+1) * r_i (mod n) + // So: (-1)^(i+1) * r_i + t_i * s ≡ 0 (mod n) + // + // If iteration is even: (-1)^(even+1) = -1, so: -r_i + t_i * s ≡ 0 + // → neg1=true (negate r_i), neg2=t_curr_neg + // If iteration is odd: (-1)^(odd+1) = 1, so: r_i + t_i * s ≡ 0 + // → neg1=false, neg2=t_curr_neg + + let neg1 = iteration % 2 == 0; // negate val1 when iteration is even + let neg2 = t_curr_neg; + + (val1, val2, neg1, neg2) +} + +/// Multiply two 4-limb values without modular reduction. +/// Returns the lower 4 limbs (ignoring overflow beyond 256 bits). +/// Used internally by half_gcd for q * t_curr where the result is known to fit. +fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let wide = widening_mul(a, b); + [wide[0], wide[1], wide[2], wide[3]] +} + +// --------------------------------------------------------------------------- +// Modular arithmetic helpers for EC operations (prover-side) +// --------------------------------------------------------------------------- + +/// Modular addition: (a + b) mod p. +pub fn mod_add(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let sum = add_4limb(a, b); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + if sum[4] > 0 || cmp_4limb(&sum4, p) != std::cmp::Ordering::Less { + // sum >= p, subtract p + let mut result = sum4; + sub_4limb_inplace(&mut result, p); + result + } else { + sum4 + } +} + +/// Modular subtraction: (a - b) mod p. +pub fn mod_sub(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let mut result = *a; + let no_borrow = sub_4limb_checked(&mut result, b); + if no_borrow { + result + } else { + // a < b, add p to get (a - b + p) + add_4limb_inplace(&mut result, p); + result + } +} + +/// Modular inverse: a^{p-2} mod p (Fermat's little theorem). +pub fn mod_inverse(a: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { + let exp = sub_u64(p, 2); + mod_pow(a, &exp, p) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mul_mod(px, px, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let numerator = mod_add(&three_x_sq, a, p); + let two_y = mod_add(py, py, p); + let denom_inv = mod_inverse(&two_y, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mul_mod(&lambda, &lambda, p); + let two_x = mod_add(px, px, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(px, &x3, p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, py, p); + + (x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // lambda = (y2 - y1) / (x2 - x1) + let numerator = mod_sub(p2y, p1y, p); + let denominator = mod_sub(p2x, p1x, p); + let denom_inv = mod_inverse(&denominator, p); + let lambda = mul_mod(&numerator, &denom_inv, p); + + // x3 = lambda^2 - x1 - x2 + let lambda_sq = mul_mod(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(p1x, p2x, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + // y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = mod_sub(p1x, &x3, p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, p1y, p); + + (x3, y3) +} + +/// EC scalar multiplication via double-and-add: returns [scalar]*P. +pub fn ec_scalar_mul( + px: &[u64; 4], + py: &[u64; 4], + scalar: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + // Find highest set bit in scalar + let mut highest_bit = 0; + for i in (0..4).rev() { + if scalar[i] != 0 { + highest_bit = i * 64 + (64 - scalar[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + // scalar == 0 → point at infinity (not representable in affine) + panic!("ec_scalar_mul: scalar is zero"); + } + + // Start from the MSB-1 and double-and-add + let mut rx = *px; + let mut ry = *py; + + for bit_pos in (0..highest_bit - 1).rev() { + // Double + let (dx, dy) = ec_point_double(&rx, &ry, a, p); + rx = dx; + ry = dy; + + // Add if bit is set + let limb_idx = bit_pos / 64; + let bit_idx = bit_pos % 64; + if (scalar[limb_idx] >> bit_idx) & 1 == 1 { + let (ax, ay) = ec_point_add(&rx, &ry, px, py, p); + rx = ax; + ry = ay; + } + } + + (rx, ry) +} /// Integer division of a 512-bit dividend by a 256-bit divisor. /// Returns (quotient, remainder) where both fit in 256 bits. @@ -848,4 +1156,82 @@ mod tests { r0, r1, r2, ]); } + + #[test] + fn test_half_gcd_small() { + // s = 42, n = 101 + let s = [42, 0, 0, 0]; + let n = [101, 0, 0, 0]; + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + let sign1: i128 = if neg1 { -1 } else { 1 }; + let sign2: i128 = if neg2 { -1 } else { 1 }; + let v1 = val1[0] as i128; + let v2 = val2[0] as i128; + let s_val = s[0] as i128; + let n_val = n[0] as i128; + let lhs = ((sign1 * v1 + sign2 * v2 * s_val) % n_val + n_val) % n_val; + assert_eq!(lhs, 0, "half_gcd relation failed for small values"); + } + + #[test] + fn test_half_gcd_grumpkin_order() { + // Grumpkin curve order (BN254 base field order) + let n = [ + 0x3c208c16d87cfd47_u64, + 0x97816a916871ca8d_u64, + 0xb85045b68181585d_u64, + 0x30644e72e131a029_u64, + ]; + // Some scalar + let s = [ + 0x123456789abcdef0_u64, + 0xfedcba9876543210_u64, + 0x1111111111111111_u64, + 0x2222222222222222_u64, + ]; + + let (val1, val2, neg1, neg2) = half_gcd(&s, &n); + + // val1 and val2 should be < 2^128 + assert_eq!(val1[2], 0, "val1 should be < 2^128"); + assert_eq!(val1[3], 0, "val1 should be < 2^128"); + assert_eq!(val2[2], 0, "val2 should be < 2^128"); + assert_eq!(val2[3], 0, "val2 should be < 2^128"); + + // Verify: (-1)^neg1 * val1 + (-1)^neg2 * val2 * s ≡ 0 (mod n) + // Use big integer arithmetic + let term2_full = widening_mul(&val2, &s); + let (_, term2_mod_n) = divmod_wide(&term2_full, &n); + + // Compute: sign1 * val1 + sign2 * term2_mod_n (mod n) + let effective1 = if neg1 { + // n - val1 + let mut result = n; + sub_4limb_checked(&mut result, &val1); + result + } else { + val1 + }; + let effective2 = if neg2 { + let mut result = n; + sub_4limb_checked(&mut result, &term2_mod_n); + result + } else { + term2_mod_n + }; + + let sum = add_4limb(&effective1, &effective2); + let sum4 = [sum[0], sum[1], sum[2], sum[3]]; + // sum might be >= n, so reduce + let (_, remainder) = if sum[4] > 0 { + // Sum overflows 256 bits, need wide divmod + let wide = [sum[0], sum[1], sum[2], sum[3], sum[4], 0, 0, 0]; + divmod_wide(&wide, &n) + } else { + divmod(&sum4, &n) + }; + assert_eq!(remainder, [0, 0, 0, 0], "half_gcd relation failed for Grumpkin order"); + } } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index d3479331b..ec3d61b06 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -449,8 +449,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - ab_sum += - a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; + ab_sum += a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; } } // Sum p[i]*q[j] for i+j=k @@ -458,8 +457,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - pq_sum += - p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; + pq_sum += p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; } } let r_k = if k < n { r_limbs_vals[k] as i128 } else { 0 }; @@ -663,6 +661,66 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*lo] = Some(FieldElement::from(lo_val)); witness[*hi] = Some(FieldElement::from(hi_val)); } + WitnessBuilder::FakeGLVHint { + output_start, + s_lo, + s_hi, + curve_order, + } => { + // Reconstruct s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let s_val: [u64; 4] = [ + s_lo_val[0], + s_lo_val[1], + s_hi_val[0], + s_hi_val[1], + ]; + + let (val1, val2, neg1, neg2) = + crate::witness::bigint_mod::half_gcd(&s_val, curve_order); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val1)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(val2)).unwrap()); + witness[*output_start + 2] = + Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = + Some(FieldElement::from(neg2 as u64)); + } + WitnessBuilder::EcScalarMulHint { + output_start, + px, + py, + s_lo, + s_hi, + curve_a, + field_modulus_p, + } => { + // Reconstruct scalar s = s_lo + s_hi * 2^128 + let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; + let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; + let scalar: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; + + // Reconstruct point P + let px_val = witness[*px].unwrap().into_bigint().0; + let py_val = witness[*py].unwrap().into_bigint().0; + + // Compute R = [s]P + let (rx, ry) = crate::witness::bigint_mod::ec_scalar_mul( + &px_val, + &py_val, + &scalar, + curve_a, + field_modulus_p, + ); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(rx)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(ry)).unwrap()); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 91c4e4128..3d8917d55 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -1,5 +1,6 @@ use { crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::Field, ark_std::One, provekit_common::{ witness::{DigitalDecompositionWitnesses, WitnessBuilder}, @@ -66,7 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() * FieldElement::from(1u64 << *log_base); + let multiplier = *digit_multipliers.last().unwrap() + * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 234623a31..b896dc043 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -1,7 +1,8 @@ //! Analytical cost model for MSM parameter optimization. //! //! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): -//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, window_size). +//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, +//! window_size). /// Type of field operation for cost estimation. #[derive(Clone, Copy)] @@ -12,105 +13,75 @@ pub enum FieldOpType { Inv, } -/// Count field ops in scalar_mul for given parameters. -/// Traces through ec_points::scalar_mul logic analytically. +/// Count field ops in scalar_mul_glv for given parameters. /// -/// Returns (n_add, n_sub, n_mul, n_inv) per single scalar multiplication. -fn count_scalar_mul_field_ops(scalar_bits: usize, window_size: usize) -> (usize, usize, usize, usize) { +/// The GLV approach does interleaved two-point scalar mul with half-width scalars. +/// Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 is_zero + 2 point_selects +/// Plus: 2 table builds, on-curve check, scalar relation overhead. +fn count_glv_field_ops( + scalar_bits: usize, // half_bits = ceil(order_bits / 2) + window_size: usize, +) -> (usize, usize, usize, usize) { let w = window_size; let table_size = 1 << w; let num_windows = (scalar_bits + w - 1) / w; - // Build point table: T[0]=P (free), T[1]=P (free), T[2]=2P (1 double), - // T[3..table_size] = point_add each + let double_ops = (4usize, 2usize, 5usize, 1usize); + let add_ops = (2usize, 2usize, 3usize, 1usize); + let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); + + // Two tables (one for P, one for R) let table_doubles = if table_size > 2 { 1 } else { 0 }; let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; - // point_double costs: 5 mul, 4 add, 2 sub, 1 inv - let double_ops = (4usize, 2usize, 5usize, 1usize); // (add, sub, mul, inv) - // point_add costs: 2 add, 2 sub, 3 mul, 1 inv - let add_ops = (2usize, 2usize, 3usize, 1usize); + let mut total_add = 2 * (table_doubles * double_ops.0 + table_adds * add_ops.0); + let mut total_sub = 2 * (table_doubles * double_ops.1 + table_adds * add_ops.1); + let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); + let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); + + for win_idx in (0..num_windows).rev() { + let bit_start = win_idx * w; + let bit_end = std::cmp::min(bit_start + w, scalar_bits); + let actual_w = bit_end - bit_start; + let actual_selects = (1 << actual_w) - 1; - // Table construction - let mut total_add = table_doubles * double_ops.0 + table_adds * add_ops.0; - let mut total_sub = table_doubles * double_ops.1 + table_adds * add_ops.1; - let mut total_mul = table_doubles * double_ops.2 + table_adds * add_ops.2; - let mut total_inv = table_doubles * double_ops.3 + table_adds * add_ops.3; - - // Table lookups: each uses (2^w - 1) point_selects - // point_select = 2 selects = 2 * (3 witnesses: diff, flag*diff, out) per coordinate - // But select is not a field op — it's cheaper (just `select` calls) - // We count it as 2 selects per point_select = 2 sub + 2 mul per select - // Actually select = flag*(on_true - on_false) + on_false: 1 sub, 1 mul, 1 add per elem - // Per point (x,y): 2 sub, 2 mul, 2 add for select - let selects_per_lookup = table_size - 1; // 2^w - 1 point_selects - let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); // (add, sub, mul, inv) - - // MSB window: 1 table lookup (possibly smaller table) - let msb_bits = scalar_bits - (num_windows - 1) * w; - let msb_table_size = 1 << msb_bits; - let msb_selects = msb_table_size - 1; - total_add += msb_selects * select_ops_per_point.0; - total_sub += msb_selects * select_ops_per_point.1; - total_mul += msb_selects * select_ops_per_point.2; - - // Remaining windows: for each of (num_windows - 1) windows: - // - w doublings - // - 1 pack_bits (cheap) - // - 1 is_zero (1 inv + some adds) - // - 1 table lookup - // - 1 sub (for denom) - // - 1 elem_is_zero - // - 1 point_double (for x_eq case) - // - 1 safe_point_add (like point_add but with select on denom) - // - 2 point_selects (x_eq and digit_is_zero) - let remaining = if num_windows > 1 { num_windows - 1 } else { 0 }; - - for _ in 0..remaining { - // w doublings + // w shared doublings total_add += w * double_ops.0; total_sub += w * double_ops.1; total_mul += w * double_ops.2; total_inv += w * double_ops.3; - // table lookup - total_add += selects_per_lookup * select_ops_per_point.0; - total_sub += selects_per_lookup * select_ops_per_point.1; - total_mul += selects_per_lookup * select_ops_per_point.2; - - // denom = sub(looked_up.x, acc.x) - total_sub += 1; - - // elem_is_zero(denom) = is_zero per limb + products - // For N limbs: N * (1 inv + some arith) + (N-1) products - // Simplified: 1 inv + 3 witnesses - total_inv += 1; - total_add += 1; - total_mul += 1; - - // point_double for x_eq case - total_add += double_ops.0; - total_sub += double_ops.1; - total_mul += double_ops.2; - total_inv += double_ops.3; - - // safe_point_add: like point_add + 1 select on denom - total_add += add_ops.0 + select_ops_per_point.0 / 2; // 1 select - total_sub += add_ops.1 + select_ops_per_point.1 / 2; - total_mul += add_ops.2 + select_ops_per_point.2 / 2; - total_inv += add_ops.3; - - // 2 point_selects - total_add += 2 * select_ops_per_point.0; - total_sub += 2 * select_ops_per_point.1; - total_mul += 2 * select_ops_per_point.2; - - // is_zero(digit) - total_inv += 1; - total_add += 1; - total_mul += 1; + // Two table lookups + two point_adds + two is_zeros + two point_selects + for _ in 0..2 { + total_add += actual_selects * select_ops_per_point.0; + total_sub += actual_selects * select_ops_per_point.1; + total_mul += actual_selects * select_ops_per_point.2; + + total_add += add_ops.0; + total_sub += add_ops.1; + total_mul += add_ops.2; + total_inv += add_ops.3; + + total_inv += 1; // is_zero + total_add += 1; + total_mul += 1; + + total_add += select_ops_per_point.0; + total_sub += select_ops_per_point.1; + total_mul += select_ops_per_point.2; + } } + // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul (a*x), 2 add + total_mul += 8; + total_add += 4; + + // Conditional y-negation: 2 sub + 2 select (for P.y and R.y) + total_sub += 2; + total_add += 2 * select_ops_per_point.0; + total_sub += 2 * select_ops_per_point.1; + total_mul += 2 * select_ops_per_point.2; + (total_add, total_sub, total_mul, total_inv) } @@ -119,18 +90,18 @@ fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize if is_native { // Native: no range checks, just standard R1CS witnesses match op { - FieldOpType::Add => 1, // sum witness - FieldOpType::Sub => 1, // sum witness - FieldOpType::Mul => 1, // product witness - FieldOpType::Inv => 1, // inverse witness + FieldOpType::Add => 1, // sum witness + FieldOpType::Sub => 1, // sum witness + FieldOpType::Mul => 1, // product witness + FieldOpType::Inv => 1, // inverse witness } } else if num_limbs == 1 { // Single-limb non-native: reduce_mod_p pattern match op { - FieldOpType::Add => 5, // a+b, m const, k, k*m, result - FieldOpType::Sub => 5, // same - FieldOpType::Mul => 5, // a*b, m const, k, k*m, result - FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + FieldOpType::Add => 5, // a+b, m const, k, k*m, result + FieldOpType::Sub => 5, // same + FieldOpType::Mul => 5, // a*b, m const, k, k*m, result + FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check } } else { // Multi-limb: N-limb operations @@ -162,41 +133,53 @@ pub fn calculate_msm_witness_cost( ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) }; - let (n_add, n_sub, n_mul, n_inv) = count_scalar_mul_field_ops(scalar_bits, window_size); - let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); - let per_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + // FakeGLV path for ALL points: half-width interleaved scalar mul + let half_bits = (scalar_bits + 1) / 2; + let (n_add, n_sub, n_mul, n_inv) = count_glv_field_ops(half_bits, window_size); + let glv_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Per-point overhead: scalar decomposition (2 × half_bits for s1, s2) + + // scalar relation (~150 witnesses) + FakeGLVHint (4 witnesses) + let scalar_decomp = 2 * half_bits + 10; + let scalar_relation = 150; + let glv_hint = 4; - // Scalar decomposition: 256 bits (bit witnesses + digital decomposition overhead) - let scalar_decomp = 256 + 10; + // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) + let ec_hint = if n_points > 1 { 2 } else { 0 }; + + let per_point = glv_scalarmul + scalar_decomp + scalar_relation + glv_hint + ec_hint; // Point accumulation: (n_points - 1) point_adds - let accum_per_point = if n_points > 1 { + let accum = if n_points > 1 { let accum_adds = n_points - 1; - accum_adds * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 - + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + accum_adds + * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 + + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 + + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) } else { 0 }; - n_points * (per_scalarmul + scalar_decomp) + accum_per_point + n_points * per_point + accum } /// Check whether schoolbook column equation values fit in the native field. /// -/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` via -/// column equations that include product sums, carry offsets, and outgoing carries. -/// Both sides of each column equation must evaluate to less than the native field -/// modulus as **integers** — if they overflow, the field's modular reduction makes -/// `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking soundness. +/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` +/// via column equations that include product sums, carry offsets, and outgoing +/// carries. Both sides of each column equation must evaluate to less than the +/// native field modulus as **integers** — if they overflow, the field's modular +/// reduction makes `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking +/// soundness. /// -/// The maximum integer value across either side of any column equation is bounded by: +/// The maximum integer value across either side of any column equation is +/// bounded by: /// /// `2^(2W + ceil(log2(N)) + 3)` /// @@ -205,8 +188,8 @@ pub fn calculate_msm_witness_cost( /// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) /// - Outgoing carry term `2^W * offset_carry` on the RHS /// -/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, the -/// conservative soundness condition is: +/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, +/// the conservative soundness condition is: /// /// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` pub fn column_equation_fits_native_field( @@ -259,8 +242,8 @@ pub fn get_optimal_msm_params( } // Upper bound on search: even with N=2 (best case), we need - // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) / 2. - // The per-candidate soundness check below is the actual gate. + // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) + // / 2. The per-candidate soundness check below is the actual gate. let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; let mut best_cost = usize::MAX; let mut best_limb_bits = max_limb_bits.min(86); @@ -268,8 +251,7 @@ pub fn get_optimal_msm_params( // Search space for lb in (8..=max_limb_bits).step_by(2) { - let num_limbs = - ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { continue; } @@ -329,15 +311,6 @@ mod tests { assert!(window_size >= 2 && window_size <= 8); } - #[test] - fn test_count_field_ops_sanity() { - let (add, sub, mul, inv) = count_scalar_mul_field_ops(256, 4); - assert!(add > 0); - assert!(sub > 0); - assert!(mul > 0); - assert!(inv > 0); - } - #[test] fn test_column_equation_soundness_boundary() { // For BN254 (254 bits) with N=3: max safe limb_bits is 124. diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 53a1340f8..07c53891a 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,5 +1,5 @@ use { - ark_ff::{BigInteger, PrimeField}, + ark_ff::{BigInteger, Field, PrimeField}, provekit_common::FieldElement, }; @@ -9,10 +9,15 @@ pub struct CurveParams { pub curve_a: [u64; 4], pub curve_b: [u64; 4], pub generator: ([u64; 4], [u64; 4]), + /// A known non-identity point on the curve, used as the accumulator offset + /// in `scalar_mul_glv`. Must be deterministic and unrelated to typical + /// table entries (we use [2]G). + pub offset_point: ([u64; 4], [u64; 4]), } impl CurveParams { - /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` width each. + /// Decompose the field modulus p into `num_limbs` limbs of `limb_bits` + /// width each. pub fn p_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { decompose_to_limbs(&self.field_modulus_p, limb_bits, num_limbs) } @@ -23,7 +28,8 @@ impl CurveParams { decompose_to_limbs(&p_minus_1, limb_bits, num_limbs) } - /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` width. + /// Decompose the curve parameter `a` into `num_limbs` limbs of `limb_bits` + /// width. pub fn curve_a_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) } @@ -46,15 +52,124 @@ impl CurveParams { self.field_modulus_p == native_mod.0 } - /// Convert modulus to a native field element (only valid when p < native modulus). + /// Convert modulus to a native field element (only valid when p < native + /// modulus). pub fn p_native_fe(&self) -> FieldElement { curve_native_point_fe(&self.field_modulus_p) } + + /// Returns the curve parameter b as a native field element. + pub fn curve_b_fe(&self) -> FieldElement { + curve_native_point_fe(&self.curve_b) + } + + /// Decompose the curve order n into `num_limbs` limbs of `limb_bits` width + /// each. + pub fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_order_n, limb_bits, num_limbs) + } + + /// Decompose (curve_order_n - 1) into `num_limbs` limbs of `limb_bits` + /// width each. + pub fn curve_order_n_minus_1_limbs( + &self, + limb_bits: u32, + num_limbs: usize, + ) -> Vec { + let n_minus_1 = sub_one_u64_4(&self.curve_order_n); + decompose_to_limbs(&n_minus_1, limb_bits, num_limbs) + } + + /// Number of bits in the curve order n. + pub fn curve_order_bits(&self) -> u32 { + // Compute bit length directly from raw limbs to avoid reduction + // mod the native field (curve_order_n may exceed the native modulus). + let n = &self.curve_order_n; + for i in (0..4).rev() { + if n[i] != 0 { + return (i as u32) * 64 + (64 - n[i].leading_zeros()); + } + } + 0 + } + + /// Number of bits for the GLV half-scalar: `ceil(order_bits / 2)`. + /// This determines the bit width of the sub-scalars s1, s2 from half-GCD. + pub fn glv_half_bits(&self) -> u32 { + (self.curve_order_bits() + 1) / 2 + } + + /// Decompose the offset point x-coordinate into limbs. + pub fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.0, limb_bits, num_limbs) + } + + /// Decompose the offset point y-coordinate into limbs. + pub fn offset_y_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.offset_point.1, limb_bits, num_limbs) + } + + /// Compute `[2^n_doublings] * offset_point` on the curve (compile-time + /// only). + /// + /// Used to compute the accumulated offset after the scalar_mul_glv loop: + /// since the accumulator starts at R and gets doubled n times total, the + /// offset to subtract is `[2^n]*R`, not just `R`. + pub fn accumulated_offset(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + if self.is_native_field() { + self.accumulated_offset_native(n_doublings) + } else { + self.accumulated_offset_generic(n_doublings) + } + } + + /// Compute accumulated offset using FieldElement arithmetic (native field). + fn accumulated_offset_native(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let mut x = curve_native_point_fe(&self.offset_point.0); + let mut y = curve_native_point_fe(&self.offset_point.1); + let a = curve_native_point_fe(&self.curve_a); + + for _ in 0..n_doublings { + let x_sq = x * x; + let num = x_sq + x_sq + x_sq + a; + let denom_inv = (y + y).inverse().unwrap(); + let lambda = num * denom_inv; + let x3 = lambda * lambda - x - x; + let y3 = lambda * (x - x3) - y; + x = x3; + y = y3; + } + + (x.into_bigint().0, y.into_bigint().0) + } + + /// Compute accumulated offset using generic 256-bit arithmetic (non-native + /// field). + fn accumulated_offset_generic(&self, n_doublings: usize) -> ([u64; 4], [u64; 4]) { + let p = &self.field_modulus_p; + let mut x = self.offset_point.0; + let mut y = self.offset_point.1; + let a = &self.curve_a; + + for _ in 0..n_doublings { + let (x3, y3) = u256_arith::ec_point_double(&x, &y, a, p); + x = x3; + y = y3; + } + + (x, y) + } } /// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width each, /// returned as FieldElements. -fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { +pub fn decompose_to_limbs(val: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + // Special case: when a single limb needs > 128 bits, FieldElement::from(u128) + // would truncate. Use from_sign_and_limbs to preserve the full value. + if num_limbs == 1 && limb_bits > 128 { + return vec![curve_native_point_fe(val)]; + } + let mask: u128 = if limb_bits >= 128 { u128::MAX } else { @@ -104,6 +219,25 @@ pub fn curve_native_point_fe(val: &[u64; 4]) -> FieldElement { FieldElement::from_sign_and_limbs(true, val) } +/// Negate a field element: compute `-val mod p` (i.e., `p - val`). +/// Returns `[0; 4]` when `val` is zero. +pub fn negate_field_element(val: &[u64; 4], modulus: &[u64; 4]) -> [u64; 4] { + if *val == [0u64; 4] { + return [0u64; 4]; + } + // val is in [1, p-1], so p - val is in [1, p-1] — no borrow. + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = modulus[i].overflowing_sub(val[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + debug_assert!(!borrow, "negate_field_element: val >= modulus"); + result +} + /// Grumpkin curve parameters. /// /// Grumpkin is a cycle-companion curve for BN254: its base field is the BN254 @@ -120,33 +254,344 @@ pub fn grumpkin_params() -> CurveParams { 0x30644e72e131a029_u64, ], // BN254 base field modulus - curve_order_n: [ + curve_order_n: [ 0x3c208c16d87cfd47_u64, 0x97816a916871ca8d_u64, 0xb85045b68181585d_u64, 0x30644e72e131a029_u64, ], - curve_a: [0; 4], + curve_a: [0; 4], // b = −17 mod p - curve_b: [ + curve_b: [ 0x43e1f593effffff0_u64, 0x2833e84879b97091_u64, 0xb85045b68181585d_u64, 0x30644e72e131a029_u64, ], // Generator G = (1, sqrt(−16) mod p) - generator: ( - [1, 0, 0, 0], + generator: ([1, 0, 0, 0], [ + 0x833fc48d823f272c_u64, + 0x2d270d45f1181294_u64, + 0xcf135e7506a45d63_u64, + 0x0000000000000002_u64, + ]), + // Offset point = [2]G + offset_point: ( + [ + 0x6d8bc688cdbffffe_u64, + 0x19a74caa311e13d4_u64, + 0xddeb49cdaa36306d_u64, + 0x06ce1b0827aafa85_u64, + ], [ - 0x833fc48d823f272c_u64, - 0x2d270d45f1181294_u64, - 0xcf135e7506a45d63_u64, - 0x0000000000000002_u64, + 0x467be7e7a43f80ac_u64, + 0xc93faf6fa1a788bf_u64, + 0x909ede0ba2a6855f_u64, + 0x1c122f81a3a14964_u64, ], ), } } +/// 256-bit modular arithmetic for compile-time EC point computations. +/// Only used to precompute accumulated offset points; not performance-critical. +mod u256_arith { + type U256 = [u64; 4]; + + /// Returns true if a >= b. + fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal + } + + /// a + b, returns (result, carry). + fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) + } + + /// a - b, returns (result, borrow). + fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) + } + + /// (a + b) mod p. + pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } + } + + /// (a - b) mod p. + fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } + } + + /// Schoolbook multiplication producing 512-bit result. + fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result + } + + /// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long + /// division. + fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if a[i] != 0 { + total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit of a + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (a[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r + } + + /// (a * b) mod p. + pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = mul_wide(a, b); + mod_reduce_wide(&wide, p) + } + + /// a^exp mod p using square-and-multiply. + fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result + } + + /// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. + fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) + } + + /// EC point addition on y^2 = x^3 + ax + b. + /// Computes (x1,y1) + (x2,y2). Requires x1 != x2. + pub fn ec_point_add(x1: &U256, y1: &U256, x2: &U256, y2: &U256, p: &U256) -> (U256, U256) { + // lambda = (y2 - y1) / (x2 - x1) + let num = mod_sub(y2, y1, p); + let denom = mod_sub(x2, x1, p); + let denom_inv = mod_inv(&denom, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - x1 - x2 + let lambda_sq = mod_mul(&lambda, &lambda, p); + let x1_plus_x2 = mod_add(x1, x2, p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); + + // y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = mod_sub(x1, &x3, p); + let lambda_dx = mod_mul(&lambda, &x1_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y1, p); + + (x3, y3) + } + + /// EC point doubling on y^2 = x^3 + ax + b. + pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) + } +} + +#[cfg(test)] +mod tests { + use {super::*, ark_ff::Field}; + + #[test] + fn test_offset_point_on_curve_grumpkin() { + let c = grumpkin_params(); + let x = curve_native_point_fe(&c.offset_point.0); + let y = curve_native_point_fe(&c.offset_point.1); + let b = curve_native_point_fe(&c.curve_b); + // Grumpkin: y^2 = x^3 + b (a=0) + assert_eq!(y * y, x * x * x + b, "offset point not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_single_double_grumpkin() { + let c = grumpkin_params(); + let (x4, y4) = c.accumulated_offset(1); + let x = curve_native_point_fe(&x4); + let y = curve_native_point_fe(&y4); + let b = curve_native_point_fe(&c.curve_b); + // Should still be on curve + assert_eq!(y * y, x * x * x + b, "[4]G not on Grumpkin"); + } + + #[test] + fn test_accumulated_offset_native_vs_generic() { + let c = grumpkin_params(); + // Both paths should give the same result + let native = c.accumulated_offset_native(10); + let generic = c.accumulated_offset_generic(10); + assert_eq!(native, generic, "native vs generic mismatch for n=10"); + } + + #[test] + fn test_accumulated_offset_256_on_curve() { + let c = grumpkin_params(); + let (x, y) = c.accumulated_offset(256); + let xfe = curve_native_point_fe(&x); + let yfe = curve_native_point_fe(&y); + let b = curve_native_point_fe(&c.curve_b); + assert_eq!(yfe * yfe, xfe * xfe * xfe + b, "[2^257]G not on Grumpkin"); + } + + #[test] + fn test_offset_point_on_curve_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let x = &c.offset_point.0; + let y = &c.offset_point.1; + let a = &c.curve_a; + let b = &c.curve_b; + // y^2 = x^3 + a*x + b (mod p) + let y_sq = u256_arith::mod_mul(y, y, p); + let x_sq = u256_arith::mod_mul(x, x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, x, p); + let ax = u256_arith::mod_mul(a, x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "offset point not on secp256r1"); + } + + #[test] + fn test_accumulated_offset_secp256r1() { + let c = secp256r1_params(); + let p = &c.field_modulus_p; + let a = &c.curve_a; + let b = &c.curve_b; + let (x, y) = c.accumulated_offset(256); + // Verify the accumulated offset is on the curve + let y_sq = u256_arith::mod_mul(&y, &y, p); + let x_sq = u256_arith::mod_mul(&x, &x, p); + let x_cubed = u256_arith::mod_mul(&x_sq, &x, p); + let ax = u256_arith::mod_mul(a, &x, p); + let x3_plus_ax = u256_arith::mod_add(&x_cubed, &ax, p); + let rhs = u256_arith::mod_add(&x3_plus_ax, b, p); + assert_eq!(y_sq, rhs, "accumulated offset not on secp256r1"); + } + + #[test] + fn test_fe_roundtrip() { + // Verify from_sign_and_limbs / into_bigint roundtrip + let val: [u64; 4] = [42, 0, 0, 0]; + let fe = curve_native_point_fe(&val); + let back = fe.into_bigint().0; + assert_eq!(val, back, "roundtrip failed for small value"); + + let val2: [u64; 4] = [ + 0x6d8bc688cdbffffe, + 0x19a74caa311e13d4, + 0xddeb49cdaa36306d, + 0x06ce1b0827aafa85, + ]; + let fe2 = curve_native_point_fe(&val2); + let back2 = fe2.into_bigint().0; + assert_eq!(val2, back2, "roundtrip failed for offset x"); + } +} + pub fn secp256r1_params() -> CurveParams { CurveParams { field_modulus_p: [ @@ -187,5 +632,20 @@ pub fn secp256r1_params() -> CurveParams { 0x4fe342e2fe1a7f9b_u64, ], ), + // Offset point = [2]G + offset_point: ( + [ + 0xa60b48fc47669978_u64, + 0xc08969e277f21b35_u64, + 0x8a52380304b51ac3_u64, + 0x7cf27b188d034f7e_u64, + ], + [ + 0x9e04b79d227873d1_u64, + 0xba7dade63ce98229_u64, + 0x293d9ac69f7430db_u64, + 0x07775510db8ed040_u64, + ], + ), } } diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 14712c78c..8f7172897 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -100,43 +100,6 @@ pub fn point_select( (x, y) } -/// Point addition with safe denominator for the `x1 = x2` edge case. -/// -/// When `x_eq = 1`, the denominator `(x2 - x1)` is zero and cannot be -/// inverted. This function replaces it with 1, producing a satisfiable -/// but meaningless result. The caller MUST discard this result via -/// `point_select` when `x_eq = 1`. -/// -/// The `denom` parameter is the precomputed `x2 - x1`. -fn safe_point_add( - ops: &mut F, - x1: F::Elem, - y1: F::Elem, - x2: F::Elem, - y2: F::Elem, - denom: F::Elem, - x_eq: usize, -) -> (F::Elem, F::Elem) { - let numerator = ops.sub(y2, y1); - - // When x_eq=1 (denom=0), substitute with 1 to keep inv satisfiable - let one = ops.constant_one(); - let safe_denom = ops.select(x_eq, denom, one); - - let denom_inv = ops.inv(safe_denom); - let lambda = ops.mul(numerator, denom_inv); - - let lambda_sq = ops.mul(lambda, lambda); - let x1_plus_x2 = ops.add(x1, x2); - let x3 = ops.sub(lambda_sq, x1_plus_x2); - - let x1_minus_x3 = ops.sub(x1, x3); - let lambda_dx = ops.mul(lambda, x1_minus_x3); - let y3 = ops.sub(lambda_dx, y1); - - (x3, y3) -} - /// Builds a point table for windowed scalar multiplication. /// /// T[0] = P (dummy entry, used when window digit = 0) @@ -161,7 +124,8 @@ fn build_point_table( table } -/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * 2^i`. +/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * +/// 2^i`. /// /// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, /// halving the candidate set at each level. Total: `(2^w - 1)` point selects @@ -185,100 +149,103 @@ fn table_lookup( current[0] } -/// Windowed scalar multiplication: computes `[scalar] * P`. +/// Interleaved two-point scalar multiplication for FakeGLV. /// -/// Takes pre-decomposed scalar bits (LSB first, `scalar_bits[0]` is the -/// least significant bit) and a window size `w`. Precomputes a table of -/// `2^w` point multiples and processes the scalar in `w`-bit windows from -/// MSB to LSB. +/// Computes `[s1]P + [s2]R` using shared doublings, where s1 and s2 are +/// half-width scalars (typically ~128-bit for 256-bit curves). The +/// accumulator starts at an offset point and the caller checks equality +/// with the accumulated offset to verify the constraint `[s1]P + [s2]R = O`. /// -/// Handles two edge cases: -/// 1. **MSB window digit = 0**: The accumulator is initialized from T[0] -/// (a dummy copy of P). An `acc_is_identity` flag tracks that no real -/// point has been accumulated yet. When the first non-zero window digit -/// is encountered, the looked-up point becomes the new accumulator. -/// 2. **x-coordinate collision** (`acc.x == looked_up.x`): Uses -/// `point_double` instead of `point_add`, with `safe_point_add` -/// guarding the zero denominator. +/// Structure per window (from MSB to LSB): +/// 1. `w` shared doublings on accumulator +/// 2. Table lookup in T_P[d1] for s1's window digit +/// 3. point_add(acc, T_P[d1]) + is_zero(d1) + point_select +/// 4. Table lookup in T_R[d2] for s2's window digit +/// 5. point_add(acc, T_R[d2]) + is_zero(d2) + point_select /// -/// The inverse-point case (`acc = -looked_up`, result is infinity) cannot -/// be represented in affine coordinates and remains unsupported — this has -/// negligible probability (~2^{-256}) for random scalars. -pub fn scalar_mul( +/// Returns the final accumulator (x, y). +pub fn scalar_mul_glv( ops: &mut F, + // Point P (table 1) px: F::Elem, py: F::Elem, - scalar_bits: &[usize], + s1_bits: &[usize], // 128 bit witnesses for |s1| + // Point R (table 2) — the claimed output + rx: F::Elem, + ry: F::Elem, + s2_bits: &[usize], // 128 bit witnesses for |s2| + // Shared parameters window_size: usize, + offset_x: F::Elem, + offset_y: F::Elem, ) -> (F::Elem, F::Elem) { - let n = scalar_bits.len(); + let n1 = s1_bits.len(); + let n2 = s2_bits.len(); + assert_eq!(n1, n2, "s1 and s2 must have the same number of bits"); + let n = n1; let w = window_size; let table_size = 1 << w; - // Build point table: T[i] = [i]P, with T[0] = P as dummy - let table = build_point_table(ops, px, py, table_size); + // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R + let table_p = build_point_table(ops, px, py, table_size); + let table_r = build_point_table(ops, rx, ry, table_size); - // Number of windows (ceiling division) let num_windows = (n + w - 1) / w; - // Process MSB window first (may be shorter than w bits if n % w != 0) - let msb_start = (num_windows - 1) * w; - let msb_bits = &scalar_bits[msb_start..n]; - let msb_table = &table[..1 << msb_bits.len()]; - let mut acc = table_lookup(ops, msb_table, msb_bits); + // Initialize accumulator with the offset point + let mut acc = (offset_x, offset_y); - // Track whether acc represents the identity (no real point yet). - // When MSB digit = 0, T[0] = P is loaded as a dummy — we must not - // double or add it until the first non-zero window digit appears. - let msb_digit = ops.pack_bits(msb_bits); - let mut acc_is_identity = ops.is_zero(msb_digit); + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; - // Process remaining windows from MSB-1 down to LSB - for i in (0..num_windows - 1).rev() { - // w doublings — only meaningful when acc is a real point. - // When acc_is_identity=1, the doubling result is garbage but will - // be discarded by the point_select below. + // w shared doublings on the accumulator let mut doubled_acc = acc; for _ in 0..w { doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); } - // If acc is identity, keep dummy; otherwise use doubled result - acc = point_select(ops, acc_is_identity, doubled_acc, acc); - - // Table lookup for this window's digit - let window_bits = &scalar_bits[i * w..(i + 1) * w]; - let digit = ops.pack_bits(window_bits); - let digit_is_zero = ops.is_zero(digit); - let looked_up = table_lookup(ops, &table, window_bits); - - // Detect x-coordinate collision: acc.x == looked_up.x - let denom = ops.sub(looked_up.0, acc.0); - let x_eq = ops.elem_is_zero(denom); - - // point_double handles the acc == looked_up case (same point) - let doubled = point_double(ops, acc.0, acc.1); - - // Safe point_add (substitutes denominator when x_eq=1) - let added = safe_point_add( - ops, acc.0, acc.1, looked_up.0, looked_up.1, denom, x_eq, + // --- Process P's window digit (s1) --- + let s1_window_bits = &s1_bits[bit_start..bit_end]; + let lookup_table_p = if actual_w < w { + &table_p[..1 << actual_w] + } else { + &table_p[..] + }; + let looked_up_p = table_lookup(ops, lookup_table_p, s1_window_bits); + let added_p = point_add( + ops, + doubled_acc.0, + doubled_acc.1, + looked_up_p.0, + looked_up_p.1, ); - - // x_eq=0 => use add result, x_eq=1 => use double result - let combined = point_select(ops, x_eq, added, doubled); - - // Four cases based on (acc_is_identity, digit_is_zero): - // (0, 0) => combined — normal add/double - // (0, 1) => acc — keep accumulator - // (1, 0) => looked_up — first real point - // (1, 1) => acc — still identity - let normal_result = point_select(ops, digit_is_zero, combined, acc); - let identity_result = point_select(ops, digit_is_zero, looked_up, acc); - acc = point_select(ops, acc_is_identity, normal_result, identity_result); - - // Update: acc is identity only if it was identity AND digit is zero - acc_is_identity = ops.bool_and(acc_is_identity, digit_is_zero); + let digit_p = ops.pack_bits(s1_window_bits); + let digit_p_is_zero = ops.is_zero(digit_p); + let after_p = point_select(ops, digit_p_is_zero, added_p, doubled_acc); + + // --- Process R's window digit (s2) --- + let s2_window_bits = &s2_bits[bit_start..bit_end]; + let lookup_table_r = if actual_w < w { + &table_r[..1 << actual_w] + } else { + &table_r[..] + }; + let looked_up_r = table_lookup(ops, lookup_table_r, s2_window_bits); + let added_r = point_add( + ops, + after_p.0, + after_p.1, + looked_up_r.0, + looked_up_r.1, + ); + let digit_r = ops.pack_bits(s2_window_bits); + let digit_r_is_zero = ops.is_zero(digit_r); + acc = point_select(ops, digit_r_is_zero, added_r, after_p); } acc } + diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index dda1e064a..826d381c4 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -10,7 +10,7 @@ use { noir_to_r1cs::NoirToR1CSCompiler, }, ark_ff::{AdditiveGroup, Field}, - curve::CurveParams, + curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{ witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, @@ -151,16 +151,8 @@ pub trait FieldOps { /// Does NOT constrain bits to be boolean — caller must ensure that. fn pack_bits(&mut self, bits: &[usize]) -> usize; - /// Checks if a field element (in the curve's base field) is zero. - /// Returns a boolean witness: 1 if zero, 0 if non-zero. - fn elem_is_zero(&mut self, value: Self::Elem) -> usize; - - /// Returns the constant field element 1. - fn constant_one(&mut self) -> Self::Elem; - - /// Computes a * b for two boolean (0/1) native witnesses. - /// Used for boolean AND on flags in scalar_mul. - fn bool_and(&mut self, a: usize, b: usize) -> usize; + /// Returns a constant field element from its limb decomposition. + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Self::Elem; } // --------------------------------------------------------------------------- @@ -319,8 +311,12 @@ fn add_single_msm( /// Process a full single-MSM with runtime `num_limbs`. /// -/// Handles coordinate decomposition, scalar_mul, accumulation, and -/// output constraining. +/// Uses FakeGLV for ALL points: each point P_i with scalar s_i is verified +/// using scalar decomposition and half-width interleaved scalar mul. +/// +/// For `n_points == 1`, R = (out_x, out_y) is the ACIR output. +/// For `n_points > 1`, R_i = EcScalarMulHint witnesses, accumulated via +/// point_add and constrained against the ACIR output. fn process_single_msm<'a>( mut compiler: &'a mut NoirToR1CSCompiler, point_wits: &[usize], @@ -333,79 +329,319 @@ fn process_single_msm<'a>( curve: &CurveParams, ) { let n_points = point_wits.len() / 3; - let mut acc: Option<(Limbs, Limbs)> = None; + let (out_x, out_y, out_inf) = outputs; - for i in 0..n_points { - let px_witness = point_wits[3 * i]; - let py_witness = point_wits[3 * i + 1]; + if n_points == 1 { + // Single-point: R is the ACIR output directly + let px_witness = point_wits[0]; + let py_witness = point_wits[1]; + // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) + constrain_zero(compiler, point_wits[2]); + let s_lo = scalar_wits[0]; + let s_hi = scalar_wits[1]; + + // Decompose P into limbs + let (px, py) = decompose_point_to_limbs( + compiler, + px_witness, + py_witness, + num_limbs, + limb_bits, + range_checks, + ); + // R = ACIR output, decompose into limbs + let (rx, ry) = decompose_point_to_limbs( + compiler, out_x, out_y, num_limbs, limb_bits, range_checks, + ); - let s_lo = scalar_wits[2 * i]; - let s_hi = scalar_wits[2 * i + 1]; - let scalar_bits = decompose_scalar_bits(compiler, s_lo, s_hi); + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + s_lo, + s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); - // Build coordinates as Limbs - let (px, py) = if num_limbs == 1 { - // Single-limb: wrap witness directly - (Limbs::single(px_witness), Limbs::single(py_witness)) - } else { - // Multi-limb: decompose single witness into num_limbs limbs - let px_limbs = decompose_witness_to_limbs( + constrain_zero(compiler, out_inf); + } else { + // Multi-point: compute R_i = [s_i]P_i via hints, verify each with FakeGLV, + // then accumulate R_i's and constrain against ACIR output. + let mut acc: Option<(Limbs, Limbs)> = None; + + for i in 0..n_points { + let px_witness = point_wits[3 * i]; + let py_witness = point_wits[3 * i + 1]; + // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) + constrain_zero(compiler, point_wits[3 * i + 2]); + let s_lo = scalar_wits[2 * i]; + let s_hi = scalar_wits[2 * i + 1]; + + // Add EcScalarMulHint → R_i = [s_i]P_i + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px: px_witness, + py: py_witness, + s_lo, + s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let rx_witness = hint_start; + let ry_witness = hint_start + 1; + + // Decompose P_i into limbs + let (px, py) = decompose_point_to_limbs( compiler, px_witness, - limb_bits, + py_witness, num_limbs, + limb_bits, range_checks, ); - let py_limbs = decompose_witness_to_limbs( + // Decompose R_i into limbs + let (rx, ry) = decompose_point_to_limbs( compiler, - py_witness, - limb_bits, + rx_witness, + ry_witness, num_limbs, + limb_bits, range_checks, ); - (px_limbs, py_limbs) - }; + // Verify R_i = [s_i]P_i using FakeGLV + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + s_lo, + s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); + + // Accumulate R_i via point_add + acc = Some(match acc { + None => (rx, ry), + Some((ax, ay)) => { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + let sum = ec_points::point_add(&mut ops, ax, ay, rx, ry); + compiler = ops.compiler; + range_checks = ops.range_checks; + sum + } + }); + } + + let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); + + if num_limbs == 1 { + constrain_equal(compiler, out_x, computed_x[0]); + constrain_equal(compiler, out_y, computed_y[0]); + } else { + let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); + constrain_equal(compiler, out_x, recomposed_x); + constrain_equal(compiler, out_y, recomposed_y); + } + constrain_zero(compiler, out_inf); + } +} + +/// Decompose a point (px_witness, py_witness) into Limbs. +fn decompose_point_to_limbs( + compiler: &mut NoirToR1CSCompiler, + px_witness: usize, + py_witness: usize, + num_limbs: usize, + limb_bits: u32, + range_checks: &mut BTreeMap>, +) -> (Limbs, Limbs) { + if num_limbs == 1 { + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + let px_limbs = + decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); + let py_limbs = + decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); + (px_limbs, py_limbs) + } +} + +/// FakeGLV verification for a single point: verifies R = [s]P. +/// +/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies +/// [s1]P + [s2]R = O using interleaved windowed scalar mul with +/// half-width scalars. +/// +/// Returns the mutable references back to the caller for continued use. +fn verify_point_fakeglv<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + mut range_checks: &'a mut BTreeMap>, + px: Limbs, + py: Limbs, + rx: Limbs, + ry: Limbs, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + curve: &CurveParams, +) -> ( + &'a mut NoirToR1CSCompiler, + &'a mut BTreeMap>, +) { + // --- Step 1: On-curve checks for P and R --- + { let params = build_params(num_limbs, limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, params, }; - let result = ec_points::scalar_mul(&mut ops, px, py, &scalar_bits, window_size); + + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); + + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + compiler = ops.compiler; range_checks = ops.range_checks; + } + + // --- Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 --- + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + let s1_witness = glv_start; + let s2_witness = glv_start + 1; + let neg1_witness = glv_start + 2; + let neg2_witness = glv_start + 3; + + // neg1 and neg2 are constrained to be boolean by the `select` calls + // in Step 4 below (MultiLimbOps::select calls constrain_boolean internally). + + // --- Step 3: Decompose |s1|, |s2| into half_bits bits each --- + let half_bits = curve.glv_half_bits() as usize; + let s1_bits = decompose_half_scalar_bits(compiler, s1_witness, half_bits); + let s2_bits = decompose_half_scalar_bits(compiler, s2_witness, half_bits); + + // --- Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity check --- + { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + + // Compute negated y-coordinates: neg_y = 0 - y (mod p) + let zero_limbs = vec![FieldElement::from(0u64); num_limbs]; + let zero = ops.constant_limbs(&zero_limbs); + + let neg_py = ops.sub(zero, py); + let neg_ry = ops.sub(zero, ry); + + // Select: if neg1=1, use neg_py; else use py + let py_effective = ops.select(neg1_witness, py, neg_py); + // Select: if neg2=1, use neg_ry; else use ry + let ry_effective = ops.select(neg2_witness, ry, neg_ry); + + // GLV scalar mul + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_glv( + &mut ops, + px, + py_effective, + &s1_bits, + rx, + ry_effective, + &s2_bits, + window_size, + offset_x, + offset_y, + ); + + // Identity check: acc should equal [2^(num_windows * window_size)] * offset_point + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + let expected_x = ops.constant_limbs(&acc_off_x_values); + let expected_y = ops.constant_limbs(&acc_off_y_values); + + for i in 0..num_limbs { + constrain_equal(ops.compiler, glv_acc.0[i], expected_x[i]); + constrain_equal(ops.compiler, glv_acc.1[i], expected_y[i]); + } - acc = Some(match acc { - None => result, - Some((ax, ay)) => { - let params = build_params(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params, - }; - let sum = ec_points::point_add(&mut ops, ax, ay, result.0, result.1); - compiler = ops.compiler; - range_checks = ops.range_checks; - sum - } - }); + compiler = ops.compiler; + range_checks = ops.range_checks; } - let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); - let (out_x, out_y, out_inf) = outputs; + // --- Step 5: Scalar relation verification --- + verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + curve, + ); - if num_limbs == 1 { - constrain_equal(compiler, out_x, computed_x[0]); - constrain_equal(compiler, out_y, computed_y[0]); - } else { - let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); - let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); - constrain_equal(compiler, out_x, recomposed_x); - constrain_equal(compiler, out_y, recomposed_y); + (compiler, range_checks) +} + +/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. +fn verify_on_curve( + ops: &mut MultiLimbOps, + x: Limbs, + y: Limbs, + b_limb_values: &[FieldElement], + num_limbs: usize, +) { + let y_sq = ops.mul(y, y); + let x_sq = ops.mul(x, x); + let x_cubed = ops.mul(x_sq, x); + let a = ops.curve_a(); + let ax = ops.mul(a, x); + let x3_plus_ax = ops.add(x_cubed, ax); + let b = ops.constant_limbs(b_limb_values); + let rhs = ops.add(x3_plus_ax, b); + for i in 0..num_limbs { + constrain_equal(ops.compiler, y_sq[i], rhs[i]); } - constrain_zero(compiler, out_inf); } /// Decompose a single witness into `num_limbs` limbs using digital @@ -456,26 +692,171 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes } } -/// Decomposes a scalar given as two 128-bit limbs into 256 bit witnesses (LSB -/// first). -fn decompose_scalar_bits( +/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). +fn decompose_half_scalar_bits( compiler: &mut NoirToR1CSCompiler, - s_lo: usize, - s_hi: usize, + scalar: usize, + half_bits: usize, ) -> Vec { - let log_bases_128 = vec![1usize; 128]; + let log_bases = vec![1usize; half_bits]; + let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); + let mut bits = Vec::with_capacity(half_bits); + for bit_idx in 0..half_bits { + bits.push(dd.get_digit_witness_index(bit_idx, 0)); + } + bits +} + +/// Builds `MultiLimbParams` for scalar relation verification (mod +/// curve_order_n). +fn build_scalar_relation_params( + num_limbs: usize, + limb_bits: u32, + curve: &CurveParams, +) -> MultiLimbParams { + // Scalar relation uses curve_order_n as the modulus. + // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, + // except for Grumpkin where they're very close but still different). + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); + let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); - let dd_lo = add_digital_decomposition(compiler, log_bases_128.clone(), vec![s_lo]); - let dd_hi = add_digital_decomposition(compiler, log_bases_128, vec![s_hi]); + // For N=1 non-native, we need the modulus as a FieldElement + let modulus_fe = if num_limbs == 1 { + Some(curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; - let mut bits = Vec::with_capacity(256); - for bit_idx in 0..128 { - bits.push(dd_lo.get_digit_witness_index(bit_idx, 0)); + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: n_limbs, + p_minus_1_limbs: n_minus_1_limbs, + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + modulus_bits: curve.curve_order_bits(), + is_native: false, // always non-native + modulus_fe, } - for bit_idx in 0..128 { - bits.push(dd_hi.get_digit_witness_index(bit_idx, 0)); +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n). +/// +/// Uses multi-limb arithmetic with curve_order_n as the modulus. +/// The sub-scalars s1, s2 have `half_bits = ceil(order_bits/2)` bits; +/// the full scalar s has up to `order_bits` bits. +fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &CurveParams, +) { + // Use 64-bit limbs. Number of limbs covers the full curve order. + let sr_limb_bits: u32 = 64; + let order_bits = curve.curve_order_bits() as usize; + let sr_num_limbs = (order_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + // Number of 64-bit limbs the half-scalar occupies + let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + + let params = build_scalar_relation_params(sr_num_limbs, sr_limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params, + }; + + // Decompose s into sr_num_limbs × 64-bit limbs from (s_lo, s_hi) + // s_lo contains bits [0..128), s_hi contains bits [128..256) + let s_limbs = { + let dd_lo = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_hi]); + let mut limbs = Limbs::new(sr_num_limbs); + // s_lo provides limbs 0,1; s_hi provides limbs 2,3 (for sr_num_limbs=4) + let lo_n = 2.min(sr_num_limbs); + for i in 0..lo_n { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks.entry(64).or_default().push(limbs[i]); + } + let hi_n = sr_num_limbs - lo_n; + for i in 0..hi_n { + limbs[lo_n + i] = dd_hi.get_digit_witness_index(i, 0); + ops.range_checks + .entry(64) + .or_default() + .push(limbs[lo_n + i]); + } + limbs + }; + + // Helper: decompose a half-scalar witness into sr_num_limbs × 64-bit limbs. + // The half-scalar has `half_bits` bits → occupies `half_limbs` 64-bit limbs. + // Upper limbs (half_limbs..sr_num_limbs) are zero-padded. + let decompose_half_scalar = |ops: &mut MultiLimbOps, witness: usize| -> Limbs { + let dd_bases: Vec = (0..half_limbs) + .map(|i| { + let remaining = half_bits as u32 - (i as u32 * 64); + remaining.min(64) as usize + }) + .collect(); + let dd = add_digital_decomposition(ops.compiler, dd_bases, vec![witness]); + let mut limbs = Limbs::new(sr_num_limbs); + for i in 0..half_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + let remaining_bits = (half_bits as u32) - (i as u32 * 64); + let this_limb_bits = remaining_bits.min(64); + ops.range_checks + .entry(this_limb_bits) + .or_default() + .push(limbs[i]); + } + // Zero-pad upper limbs + for i in half_limbs..sr_num_limbs { + let w = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::from(0u64), + ))); + limbs[i] = w; + constrain_zero(ops.compiler, limbs[i]); + } + limbs + }; + + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness); + + // Compute product = s2 * s (mod n) + let product = ops.mul(s2_limbs, s_limbs); + + // Handle signs: compute effective values + // If neg2 is set: neg_product = n - product (mod n), i.e. 0 - product + let zero_limbs_vals = vec![FieldElement::from(0u64); sr_num_limbs]; + let zero = ops.constant_limbs(&zero_limbs_vals); + let neg_product = ops.sub(zero, product); + // Select: if neg2=1, use neg_product; else use product + let effective_product = ops.select(neg2_witness, product, neg_product); + + // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 + let neg_s1 = ops.sub(zero, s1_limbs); + let effective_s1 = ops.select(neg1_witness, s1_limbs, neg_s1); + + // Sum: effective_s1 + effective_product (mod n) should be 0 + let sum = ops.add(effective_s1, effective_product); + + // Constrain sum == 0: all limbs must be zero + for i in 0..sr_num_limbs { + constrain_zero(ops.compiler, sum[i]); } - bits } /// Constrains two witnesses to be equal: `a - b = 0`. diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs index ab84fc9b7..12c30b382 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -54,10 +54,7 @@ pub fn reduce_mod_p( ); let modulus_bits = modulus.into_bigint().num_bits(); - range_checks - .entry(modulus_bits) - .or_default() - .push(result); + range_checks.entry(modulus_bits).or_default().push(result); result } @@ -169,13 +166,10 @@ pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize )); let is_zero = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::Sum( - is_zero, - vec![ - SumTerm(Some(FieldElement::ONE), compiler.witness_one()), - SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), - ], - )); + compiler.add_witness_builder(WitnessBuilder::Sum(is_zero, vec![ + SumTerm(Some(FieldElement::ONE), compiler.witness_one()), + SumTerm(Some(-FieldElement::ONE), value_mul_value_inv), + ])); // v × v^(-1) = 1 - is_zero compiler.r1cs.add_constraint( @@ -223,43 +217,47 @@ pub fn add_mod_p_multi( // Witness: q = floor((a + b) / p) ∈ {0, 1} let q = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, num_limbs: n as u32, }); // q is boolean - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - ); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); let mut r = Limbs::new(n); let mut carry_prev: Option = None; for i in 0..n { // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; let mut terms = vec![ SumTerm(None, a[i]), SumTerm(None, b[i]), - SumTerm(Some(two_pow_w), w1), + SumTerm(Some(w1_coeff), w1), SumTerm(Some(-p_limbs[i]), q), ]; if let Some(carry) = carry_prev { terms.push(SumTerm(None, carry)); - // Compensate for previous 2^W offset - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_offset = compiler.add_sum(terms); // carry = floor(v_offset / 2^W) let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry, v_offset, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); // r[i] = v_offset - carry * 2^W r[i] = compiler.add_sum(vec![ SumTerm(None, v_offset), @@ -268,7 +266,14 @@ pub fn add_mod_p_multi( carry_prev = Some(carry); } - less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); r } @@ -292,41 +297,46 @@ pub fn sub_mod_p_multi( // Witness: q = (a < b) ? 1 : 0 let q = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, num_limbs: n as u32, }); // q is boolean - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - &[(FieldElement::ONE, q)], - ); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( + FieldElement::ONE, + q, + )]); let mut r = Limbs::new(n); let mut carry_prev: Option = None; for i in 0..n { // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} + // When carry_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if carry_prev.is_some() { + two_pow_w - FieldElement::ONE + } else { + two_pow_w + }; let mut terms = vec![ SumTerm(None, a[i]), SumTerm(Some(-FieldElement::ONE), b[i]), SumTerm(Some(p_limbs[i]), q), - SumTerm(Some(two_pow_w), w1), + SumTerm(Some(w1_coeff), w1), ]; if let Some(carry) = carry_prev { terms.push(SumTerm(None, carry)); - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_offset = compiler.add_sum(terms); let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - carry, v_offset, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); r[i] = compiler.add_sum(vec![ SumTerm(None, v_offset), SumTerm(Some(-two_pow_w), carry), @@ -334,7 +344,14 @@ pub fn sub_mod_p_multi( carry_prev = Some(carry); } - less_than_p_check_multi(compiler, range_checks, r, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); r } @@ -367,9 +384,8 @@ pub fn mul_mod_p_multi( let max_bits = 2 * limb_bits + ceil_log2_n + 3; assert!( max_bits < FieldElement::MODULUS_BIT_SIZE, - "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs \ - requires {max_bits} bits, but native field is only {} bits. \ - Use smaller limb_bits.", + "Schoolbook column equation overflow: limb_bits={limb_bits}, n={n} limbs requires \ + {max_bits} bits, but native field is only {} bits. Use smaller limb_bits.", FieldElement::MODULUS_BIT_SIZE, ); } @@ -382,18 +398,19 @@ pub fn mul_mod_p_multi( let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); // offset_w = carry_offset * 2^limb_bits let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); - // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits - 1) + // offset_w_minus_carry = offset_w - carry_offset = carry_offset * (2^limb_bits + // - 1) let offset_w_minus_carry = offset_w - carry_offset_fe; // Step 1: Allocate hint witnesses (q limbs, r limbs, carries) let os = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbMulModHint { output_start: os, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, - num_limbs: n as u32, + num_limbs: n as u32, }); // q[0..n), r[n..2n), carries[2n..4n-2) @@ -459,7 +476,14 @@ pub fn mul_mod_p_multi( for (i, &ri) in r_indices.iter().enumerate() { r_limbs[i] = ri; } - less_than_p_check_multi(compiler, range_checks, r_limbs, p_minus_1_limbs, two_pow_w, limb_bits); + less_than_p_check_multi( + compiler, + range_checks, + r_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + ); // Step 5: Range checks for q limbs and carries for i in 0..n { @@ -475,7 +499,8 @@ pub fn mul_mod_p_multi( } /// a^(-1) mod p for multi-limb values. -/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, ..., 0]. +/// Uses MultiLimbModularInverse hint, verifies via mul_mod_p(a, inv) = [1, 0, +/// ..., 0]. pub fn inv_mod_p_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -493,14 +518,15 @@ pub fn inv_mod_p_multi( let inv_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::MultiLimbModularInverse { output_start: inv_start, - a_limbs: a.as_slice().to_vec(), - modulus: *modulus_raw, + a_limbs: a.as_slice().to_vec(), + modulus: *modulus_raw, limb_bits, - num_limbs: n as u32, + num_limbs: n as u32, }); let mut inv = Limbs::new(n); for i in 0..n { inv[i] = inv_start + i; + range_checks.entry(limb_bits).or_default().push(inv[i]); } // Verify: a * inv mod p = [1, 0, ..., 0] @@ -535,7 +561,8 @@ pub fn inv_mod_p_multi( } /// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. -/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * 2^W +/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * +/// 2^W fn less_than_p_check_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -547,25 +574,27 @@ fn less_than_p_check_multi( let n = r.len(); let w1 = compiler.witness_one(); let mut borrow_prev: Option = None; - for i in 0..n { // v_diff = (p-1)[i] + 2^W - r[i] + borrow_prev - let p_minus_1_plus_offset = p_minus_1_limbs[i] + two_pow_w; + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). + let w1_coeff = if borrow_prev.is_some() { + p_minus_1_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_minus_1_limbs[i] + two_pow_w + }; let mut terms = vec![ - SumTerm(Some(p_minus_1_plus_offset), w1), + SumTerm(Some(w1_coeff), w1), SumTerm(Some(-FieldElement::ONE), r[i]), ]; if let Some(borrow) = borrow_prev { terms.push(SumTerm(None, borrow)); - terms.push(SumTerm(Some(-FieldElement::ONE), w1)); } let v_diff = compiler.add_sum(terms); // borrow = floor(v_diff / 2^W) let borrow = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient( - borrow, v_diff, two_pow_w, - )); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v_diff, two_pow_w)); // d[i] = v_diff - borrow * 2^W let d_i = compiler.add_sum(vec![ SumTerm(None, v_diff), @@ -579,13 +608,15 @@ fn less_than_p_check_multi( borrow_prev = Some(borrow); } - // Constrain final borrow = 0: if borrow_out != 0, then r > p-1 (i.e. r >= p), - // which would mean the result is not properly reduced. + // Constrain final carry = 1: the 2^W offset at each limb propagates + // a carry of 1 through the chain. For valid r < p, the final carry + // must be exactly 1. If r >= p, the carry chain underflows and the + // final carry is 0. if let Some(final_borrow) = borrow_prev { compiler.r1cs.add_constraint( &[(FieldElement::ONE, compiler.witness_one())], &[(FieldElement::ONE, final_borrow)], - &[(FieldElement::ZERO, compiler.witness_one())], + &[(FieldElement::ONE, compiler.witness_one())], ); } } diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 4f1c45448..9b1d9db45 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -1,14 +1,11 @@ -//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime limb count. +//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime +//! limb count. //! //! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling //! arbitrary limb counts without const generics or dispatch macros. use { - super::{ - multi_limb_arith, - Limbs, - FieldOps, - }, + super::{multi_limb_arith, FieldOps, Limbs}, crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{AdditiveGroup, Field}, provekit_common::{ @@ -20,18 +17,18 @@ use { /// Parameters for multi-limb field arithmetic. pub struct MultiLimbParams { - pub num_limbs: usize, - pub limb_bits: u32, - pub p_limbs: Vec, - pub p_minus_1_limbs: Vec, - pub two_pow_w: FieldElement, - pub modulus_raw: [u64; 4], - pub curve_a_limbs: Vec, - pub modulus_bits: u32, + pub num_limbs: usize, + pub limb_bits: u32, + pub p_limbs: Vec, + pub p_minus_1_limbs: Vec, + pub two_pow_w: FieldElement, + pub modulus_raw: [u64; 4], + pub curve_a_limbs: Vec, + pub modulus_bits: u32, /// p = native field → skip mod reduction - pub is_native: bool, + pub is_native: bool, /// For N=1 non-native: the modulus as a single FieldElement - pub modulus_fe: Option, + pub modulus_fe: Option, } /// Unified field operations struct parameterized by runtime limb count. @@ -66,20 +63,21 @@ impl FieldOps for MultiLimbOps<'_> { // term with coefficient 2 to avoid duplicate column indices in // the R1CS sparse matrix (set overwrites on duplicate (row,col)). let r = if a[0] == b[0] { - self.compiler.add_sum(vec![ - SumTerm(Some(FieldElement::from(2u64)), a[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::from(2u64)), a[0])]) } else { - self.compiler.add_sum(vec![ - SumTerm(None, a[0]), - SumTerm(None, b[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(None, a[0]), SumTerm(None, b[0])]) }; Limbs::single(r) } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::add_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -104,9 +102,8 @@ impl FieldOps for MultiLimbOps<'_> { // When both operands are the same witness, a - a = 0. Use a // single zero-coefficient term to avoid duplicate column indices. let r = if a[0] == b[0] { - self.compiler.add_sum(vec![ - SumTerm(Some(FieldElement::ZERO), a[0]), - ]) + self.compiler + .add_sum(vec![SumTerm(Some(FieldElement::ZERO), a[0])]) } else { self.compiler.add_sum(vec![ SumTerm(None, a[0]), @@ -117,7 +114,11 @@ impl FieldOps for MultiLimbOps<'_> { } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::sub_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -144,7 +145,11 @@ impl FieldOps for MultiLimbOps<'_> { } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); let r = multi_limb_arith::mul_mod_p_single( - self.compiler, a[0], b[0], modulus, self.range_checks, + self.compiler, + a[0], + b[0], + modulus, + self.range_checks, ); Limbs::single(r) } else { @@ -177,9 +182,8 @@ impl FieldOps for MultiLimbOps<'_> { Limbs::single(a_inv) } else if self.is_non_native_single() { let modulus = self.params.modulus_fe.unwrap(); - let r = multi_limb_arith::inv_mod_p_single( - self.compiler, a[0], modulus, self.range_checks, - ); + let r = + multi_limb_arith::inv_mod_p_single(self.compiler, a[0], modulus, self.range_checks); Limbs::single(r) } else { multi_limb_arith::inv_mod_p_multi( @@ -210,12 +214,7 @@ impl FieldOps for MultiLimbOps<'_> { out } - fn select( - &mut self, - flag: usize, - on_false: Limbs, - on_true: Limbs, - ) -> Limbs { + fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { super::constrain_boolean(self.compiler, flag); let n = self.n(); let mut out = Limbs::new(n); @@ -233,43 +232,21 @@ impl FieldOps for MultiLimbOps<'_> { super::pack_bits_helper(self.compiler, bits) } - fn elem_is_zero(&mut self, value: Limbs) -> usize { - let n = self.n(); - if n == 1 { - multi_limb_arith::compute_is_zero(self.compiler, value[0]) - } else { - // Check each limb is zero and AND the results together - let mut result = multi_limb_arith::compute_is_zero(self.compiler, value[0]); - for i in 1..n { - let limb_zero = multi_limb_arith::compute_is_zero(self.compiler, value[i]); - result = self.compiler.add_product(result, limb_zero); - } - result - } - } - - fn constant_one(&mut self) -> Limbs { + fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { let n = self.n(); + assert_eq!( + limbs.len(), + n, + "constant_limbs: expected {n} limbs, got {}", + limbs.len() + ); let mut out = Limbs::new(n); - // limb[0] = 1 - let w0 = self.compiler.num_witnesses(); - self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w0, FieldElement::ONE))); - out[0] = w0; - // limb[1..n] = 0 - for i in 1..n { + for i in 0..n { let w = self.compiler.num_witnesses(); self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( - w, - FieldElement::ZERO, - ))); + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, limbs[i]))); out[i] = w; } out } - - fn bool_and(&mut self, a: usize, b: usize) -> usize { - self.compiler.add_product(a, b) - } } diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 18bc22ddc..2475ddf8c 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -750,12 +750,19 @@ impl NoirToR1CSCompiler { let native_bits = FieldElement::MODULUS_BIT_SIZE; let curve_bits = curve.modulus_bits(); let (msm_limb_bits, msm_window_size) = if !msm_ops.is_empty() { - let n_points: usize = msm_ops.iter().map(|(pts, _, _)| pts.len() / 3).sum(); + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); crate::msm::cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256) } else { (native_bits, 4) }; - add_msm(self, msm_ops, msm_limb_bits, msm_window_size, &mut range_checks, &curve); + add_msm( + self, + msm_ops, + msm_limb_bits, + msm_window_size, + &mut range_checks, + &curve, + ); breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 3b74c8e99..8212290b2 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -81,6 +81,7 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result #[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] #[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] #[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm"; "embedded_curve_msm")] fn case_noir(path: &str) { test_noir_compiler(path); } From 92ab2369e97cb2542d97aed2d490b960752a40a7 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 5 Mar 2026 05:11:17 +0530 Subject: [PATCH 05/19] opt: added unchecked select for already constrained values in scalar relation verification --- provekit/r1cs-compiler/src/msm/ec_points.rs | 25 ++++++++++++++++--- provekit/r1cs-compiler/src/msm/mod.rs | 23 ++++++++++++++--- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 5 +++- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 8f7172897..1b591ed8a 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -100,6 +100,19 @@ pub fn point_select( (x, y) } +/// Conditional point select without boolean constraint on `flag`. +/// Caller must ensure `flag` is already constrained boolean. +fn point_select_unchecked( + ops: &mut F, + flag: usize, + on_false: (F::Elem, F::Elem), + on_true: (F::Elem, F::Elem), +) -> (F::Elem, F::Elem) { + let x = ops.select_unchecked(flag, on_false.0, on_true.0); + let y = ops.select_unchecked(flag, on_false.1, on_true.1); + (x, y) +} + /// Builds a point table for windowed scalar multiplication. /// /// T[0] = P (dummy entry, used when window digit = 0) @@ -130,6 +143,9 @@ fn build_point_table( /// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, /// halving the candidate set at each level. Total: `(2^w - 1)` point selects /// for a table of `2^w` entries. +/// +/// Each bit is constrained boolean exactly once, then all subsequent selects +/// on that bit use the unchecked variant. fn table_lookup( ops: &mut F, table: &[(F::Elem, F::Elem)], @@ -139,10 +155,11 @@ fn table_lookup( let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); // Process bits from MSB to LSB for &bit in bits.iter().rev() { + ops.constrain_flag(bit); // constrain boolean once per bit let half = current.len() / 2; let mut next = Vec::with_capacity(half); for i in 0..half { - next.push(point_select(ops, bit, current[i], current[i + half])); + next.push(point_select_unchecked(ops, bit, current[i], current[i + half])); } current = next; } @@ -224,7 +241,8 @@ pub fn scalar_mul_glv( ); let digit_p = ops.pack_bits(s1_window_bits); let digit_p_is_zero = ops.is_zero(digit_p); - let after_p = point_select(ops, digit_p_is_zero, added_p, doubled_acc); + // is_zero already constrains its output boolean; skip redundant check + let after_p = point_select_unchecked(ops, digit_p_is_zero, added_p, doubled_acc); // --- Process R's window digit (s2) --- let s2_window_bits = &s2_bits[bit_start..bit_end]; @@ -243,7 +261,8 @@ pub fn scalar_mul_glv( ); let digit_r = ops.pack_bits(s2_window_bits); let digit_r_is_zero = ops.is_zero(digit_r); - acc = point_select(ops, digit_r_is_zero, added_r, after_p); + // is_zero already constrains its output boolean; skip redundant check + acc = point_select_unchecked(ops, digit_r_is_zero, added_r, after_p); } acc diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 826d381c4..6bc96b8f9 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -139,9 +139,24 @@ pub trait FieldOps { fn inv(&mut self, a: Self::Elem) -> Self::Elem; fn curve_a(&mut self) -> Self::Elem; + /// Constrains `flag` to be boolean (`flag * flag = flag`). + fn constrain_flag(&mut self, flag: usize); + + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + fn select_unchecked( + &mut self, + flag: usize, + on_false: Self::Elem, + on_true: Self::Elem, + ) -> Self::Elem; + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). - fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem; + fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } /// Checks if a BN254 native witness value is zero. /// Returns a boolean witness: 1 if zero, 0 if non-zero. @@ -844,11 +859,13 @@ fn verify_scalar_relation( let zero = ops.constant_limbs(&zero_limbs_vals); let neg_product = ops.sub(zero, product); // Select: if neg2=1, use neg_product; else use product - let effective_product = ops.select(neg2_witness, product, neg_product); + // neg2 already constrained boolean in verify_point_fakeglv + let effective_product = ops.select_unchecked(neg2_witness, product, neg_product); // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 let neg_s1 = ops.sub(zero, s1_limbs); - let effective_s1 = ops.select(neg1_witness, s1_limbs, neg_s1); + // neg1 already constrained boolean in verify_point_fakeglv + let effective_s1 = ops.select_unchecked(neg1_witness, s1_limbs, neg_s1); // Sum: effective_s1 + effective_product (mod n) should be 0 let sum = ops.add(effective_s1, effective_product); diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 9b1d9db45..7ac8d78ac 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -214,8 +214,11 @@ impl FieldOps for MultiLimbOps<'_> { out } - fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + fn constrain_flag(&mut self, flag: usize) { super::constrain_boolean(self.compiler, flag); + } + + fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { let n = self.n(); let mut out = Limbs::new(n); for i in 0..n { From 405276574470e440579972cbcd66933ad60a1802 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 6 Mar 2026 02:09:22 +0530 Subject: [PATCH 06/19] feat : added scalar 0 and inf handling --- .../src/witness/scheduling/dependency.rs | 12 +- .../prover/src/{witness => }/bigint_mod.rs | 319 +----------- provekit/prover/src/lib.rs | 1 + provekit/prover/src/witness/mod.rs | 1 - .../prover/src/witness/witness_builder.rs | 32 +- provekit/r1cs-compiler/src/digits.rs | 4 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 10 +- provekit/r1cs-compiler/src/msm/curve.rs | 35 +- provekit/r1cs-compiler/src/msm/ec_points.rs | 18 +- provekit/r1cs-compiler/src/msm/mod.rs | 483 +++++++++++++----- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 16 +- provekit/r1cs-compiler/src/noir_to_r1cs.rs | 19 +- 12 files changed, 424 insertions(+), 526 deletions(-) rename provekit/prover/src/{witness => }/bigint_mod.rs (74%) diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 98ae1368f..68bc7b6e1 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -327,12 +327,12 @@ impl DependencyInfo { num_limbs, .. } => (*output_start..*output_start + *num_limbs as usize).collect(), - WitnessBuilder::FakeGLVHint { - output_start, .. - } => (*output_start..*output_start + 4).collect(), - WitnessBuilder::EcScalarMulHint { - output_start, .. - } => (*output_start..*output_start + 2).collect(), + WitnessBuilder::FakeGLVHint { output_start, .. } => { + (*output_start..*output_start + 4).collect() + } + WitnessBuilder::EcScalarMulHint { output_start, .. } => { + (*output_start..*output_start + 2).collect() + } WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { diff --git a/provekit/prover/src/witness/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs similarity index 74% rename from provekit/prover/src/witness/bigint_mod.rs rename to provekit/prover/src/bigint_mod.rs index 2874d49a3..796fd4cdf 100644 --- a/provekit/prover/src/witness/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -230,7 +230,8 @@ pub fn add_4limb_inplace(a: &mut [u64; 4], b: &[u64; 4]) -> u64 { } /// Subtract b from a in-place, returning true if a >= b (no underflow). -/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns false. +/// If a < b, the result is a += 2^256 - b (wrapping subtraction) and returns +/// false. pub fn sub_4limb_checked(a: &mut [u64; 4], b: &[u64; 4]) -> bool { let mut borrow = 0u64; for i in 0..4 { @@ -272,16 +273,13 @@ fn build_threshold(half_bits: u32) -> [u64; 4] { /// Half-GCD scalar decomposition for FakeGLV. /// -/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such that: -/// `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` +/// Given scalar `s` and curve order `n`, finds `(|s1|, |s2|, neg1, neg2)` such +/// that: `(-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n)` /// /// Uses the extended GCD on `(n, s)`, stopping when the remainder drops below /// `2^half_bits` where `half_bits = ceil(order_bits / 2)`. /// Returns `(val1, val2, neg1, neg2)` where both fit in `half_bits` bits. -pub fn half_gcd( - s: &[u64; 4], - n: &[u64; 4], -) -> ([u64; 4], [u64; 4], bool, bool) { +pub fn half_gcd(s: &[u64; 4], n: &[u64; 4]) -> ([u64; 4], [u64; 4], bool, bool) { // Extended GCD on (n, s): // We track: r_{i} = r_{i-2} - q_i * r_{i-1} // t_{i} = t_{i-2} - q_i * t_{i-1} @@ -304,7 +302,8 @@ pub fn half_gcd( let mut t_prev = [0u64; 4]; let mut t_curr = [1u64, 0, 0, 0]; - // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, positive) + // Track sign of t: t_prev_neg=false (t_0=0, positive), t_curr_neg=false (t_1=1, + // positive) let mut t_prev_neg = false; let mut t_curr_neg = false; @@ -328,7 +327,8 @@ pub fn half_gcd( // In terms of absolute values with sign tracking: // If t_prev and q*t_curr have the same sign → subtract magnitudes // If they have different signs → add magnitudes - // But actually: new_t = |t_prev| +/- q * |t_curr|, with sign flips each iteration. + // new_t = |t_prev| +/- q * |t_curr|, with sign flips each + // iteration. // // The standard extended GCD recurrence gives: // t_i = t_{i-2} - q_i * t_{i-1} @@ -357,12 +357,12 @@ pub fn half_gcd( iteration += 1; } - // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD property) - // The relation is: (-1)^(iteration) * r_curr + t_curr * s ≡ 0 (mod n) - // Or equivalently: r_curr ≡ (-1)^(iteration+1) * t_curr * s (mod n) + // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD + // property) The relation is: (-1)^(iteration) * r_curr + t_curr * s ≡ 0 + // (mod n) Or equivalently: r_curr ≡ (-1)^(iteration+1) * t_curr * s (mod n) - let val1 = r_curr; // |s1| = |r_i| - let val2 = t_curr; // |s2| = |t_i| + let val1 = r_curr; // |s1| = |r_i| + let val2 = t_curr; // |s2| = |t_i| // Determine signs: // We need: neg1 * val1 + neg2 * val2 * s ≡ 0 (mod n) @@ -376,7 +376,7 @@ pub fn half_gcd( // If iteration is odd: (-1)^(odd+1) = 1, so: r_i + t_i * s ≡ 0 // → neg1=false, neg2=t_curr_neg - let neg1 = iteration % 2 == 0; // negate val1 when iteration is even + let neg1 = iteration % 2 == 0; // negate val1 when iteration is even let neg2 = t_curr_neg; (val1, val2, neg1, neg2) @@ -581,165 +581,6 @@ pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; (quotient, remainder) } -/// Split a 256-bit value into two 128-bit halves: (lo, hi). -pub fn decompose_128(val: &[u64; 4]) -> (u128, u128) { - let lo = val[0] as u128 | ((val[1] as u128) << 64); - let hi = val[2] as u128 | ((val[3] as u128) << 64); - (lo, hi) -} - -/// Split a 256-bit value into three 86-bit limbs: (l0, l1, l2). -/// l0 = bits [0..86), l1 = bits [86..172), l2 = bits [172..256). -pub fn decompose_86(val: &[u64; 4]) -> (u128, u128, u128) { - let mask_86: u128 = (1u128 << 86) - 1; - let lo128 = val[0] as u128 | ((val[1] as u128) << 64); - let hi128 = val[2] as u128 | ((val[3] as u128) << 64); - - let l0 = lo128 & mask_86; - // l1 spans bits [86..172): 42 bits from lo128, 44 bits from hi128 - let l1 = ((lo128 >> 86) | (hi128 << 42)) & mask_86; - // l2 = bits [172..256): 84 bits from hi128 - let l2 = hi128 >> 44; - - (l0, l1, l2) -} - -/// Compute carry values c0..c3 from the 86-bit schoolbook column equations -/// for the identity a*b = p*q + r (base W = 2^86). -/// -/// Column equations: -/// col0: a0*b0 - p0*q0 - r0 = c0*W -/// col1: a0*b1 + a1*b0 - p0*q1 - p1*q0 - r1 + c0 = c1*W -/// col2: a0*b2 + a1*b1 + a2*b0 - p0*q2 - p1*q1 - p2*q0 - r2 + c1 = c2*W -/// col3: a1*b2 + a2*b1 - p1*q2 - p2*q1 + c2 = c3*W -/// col4: a2*b2 - p2*q2 + c3 = 0 -pub fn compute_carries_86( - a: [u128; 3], - b: [u128; 3], - p: [u128; 3], - q: [u128; 3], - r: [u128; 3], -) -> [i128; 4] { - // Helper: convert u128 to [u64; 4] - fn to4(v: u128) -> [u64; 4] { - [v as u64, (v >> 64) as u64, 0, 0] - } - - // Helper: multiply two 86-bit values → [u64; 4] (result < 2^172) - fn mul86(x: u128, y: u128) -> [u64; 4] { - let w = widening_mul(&to4(x), &to4(y)); - [w[0], w[1], w[2], w[3]] - } - - // Helper: add two [u64; 4] values - fn add4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { - let mut r = [0u64; 4]; - let mut carry = 0u128; - for i in 0..4 { - let s = a[i] as u128 + b[i] as u128 + carry; - r[i] = s as u64; - carry = s >> 64; - } - r - } - - // Helper: subtract two [u64; 4] values (assumes a >= b) - fn sub4(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { - let mut r = [0u64; 4]; - let mut borrow = 0u64; - for i in 0..4 { - let (d1, b1) = a[i].overflowing_sub(b[i]); - let (d2, b2) = d1.overflowing_sub(borrow); - r[i] = d2; - borrow = b1 as u64 + b2 as u64; - } - r - } - - // Helper: right-shift [u64; 4] by 86 bits (= 64 + 22) - fn shr86(a: [u64; 4]) -> [u64; 4] { - let s = [a[1], a[2], a[3], 0u64]; - [ - (s[0] >> 22) | (s[1] << 42), - (s[1] >> 22) | (s[2] << 42), - s[2] >> 22, - 0, - ] - } - - // Positive column sums (a_i * b_j terms) - let pos = [ - mul86(a[0], b[0]), - add4(mul86(a[0], b[1]), mul86(a[1], b[0])), - add4( - add4(mul86(a[0], b[2]), mul86(a[1], b[1])), - mul86(a[2], b[0]), - ), - add4(mul86(a[1], b[2]), mul86(a[2], b[1])), - mul86(a[2], b[2]), - ]; - - // Negative column sums (p_i * q_j + r_i terms) - let neg = [ - add4(mul86(p[0], q[0]), to4(r[0])), - add4(add4(mul86(p[0], q[1]), mul86(p[1], q[0])), to4(r[1])), - add4( - add4( - add4(mul86(p[0], q[2]), mul86(p[1], q[1])), - mul86(p[2], q[0]), - ), - to4(r[2]), - ), - add4(mul86(p[1], q[2]), mul86(p[2], q[1])), - mul86(p[2], q[2]), - ]; - - let mut carries = [0i128; 4]; - let mut carry_pos = [0u64; 4]; - let mut carry_neg = [0u64; 4]; - - for col in 0..4 { - let total_pos = add4(pos[col], carry_pos); - let total_neg = add4(neg[col], carry_neg); - - let (is_neg, diff) = if cmp_4limb(&total_pos, &total_neg) != std::cmp::Ordering::Less { - (false, sub4(total_pos, total_neg)) - } else { - (true, sub4(total_neg, total_pos)) - }; - - // Lower 86 bits must be zero (divisibility check) - let mask_86 = (1u128 << 86) - 1; - let low86 = (diff[0] as u128 | ((diff[1] as u128) << 64)) & mask_86; - debug_assert_eq!(low86, 0, "column {} not divisible by W=2^86", col); - - let carry_mag = shr86(diff); - debug_assert_eq!(carry_mag[2], 0, "carry overflow in column {}", col); - debug_assert_eq!(carry_mag[3], 0, "carry overflow in column {}", col); - - let carry_val = carry_mag[0] as i128 | ((carry_mag[1] as i128) << 64); - carries[col] = if is_neg { -carry_val } else { carry_val }; - - if is_neg { - carry_pos = [0; 4]; - carry_neg = carry_mag; - } else { - carry_pos = carry_mag; - carry_neg = [0; 4]; - } - } - - // Verify column 4 balances - let final_pos = add4(pos[4], carry_pos); - let final_neg = add4(neg[4], carry_neg); - debug_assert_eq!( - final_pos, final_neg, - "column 4 should balance: a2*b2 - p2*q2 + c3 = 0" - ); - - carries -} - #[cfg(test)] mod tests { use super::*; @@ -1033,130 +874,6 @@ mod tests { assert_eq!(sum, product); } - #[test] - fn test_decompose_128_roundtrip() { - let val = [ - 0x123456789abcdef0, - 0xfedcba9876543210, - 0x1111111111111111, - 0x2222222222222222, - ]; - let (lo, hi) = decompose_128(&val); - // Roundtrip - assert_eq!(lo as u64, val[0]); - assert_eq!((lo >> 64) as u64, val[1]); - assert_eq!(hi as u64, val[2]); - assert_eq!((hi >> 64) as u64, val[3]); - } - - #[test] - fn test_decompose_86_roundtrip() { - let val = [ - 0x123456789abcdef0, - 0xfedcba9876543210, - 0x1111111111111111, - 0x2222222222222222, - ]; - let (l0, l1, l2) = decompose_86(&val); - - // Each limb should be < 2^86 - assert!(l0 < (1u128 << 86)); - assert!(l1 < (1u128 << 86)); - // l2 has at most 84 bits (256 - 172) - assert!(l2 < (1u128 << 84)); - - // Roundtrip: l0 + l1 * 2^86 + l2 * 2^172 should equal val - // Build from limbs back to [u64; 4] - let mut reconstructed = [0u128; 2]; // lo128, hi128 - reconstructed[0] = l0; - // l1 starts at bit 86 - reconstructed[0] |= (l1 & ((1u128 << 42) - 1)) << 86; // lower 42 bits of l1 into lo128 - reconstructed[1] = l1 >> 42; // upper 44 bits of l1 - // l2 starts at bit 172 = 128 + 44 - reconstructed[1] |= l2 << 44; - - assert_eq!(reconstructed[0] as u64, val[0]); - assert_eq!((reconstructed[0] >> 64) as u64, val[1]); - assert_eq!(reconstructed[1] as u64, val[2]); - assert_eq!((reconstructed[1] >> 64) as u64, val[3]); - } - - #[test] - fn test_decompose_86_secp256r1_p() { - // secp256r1 field modulus - let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; - let (l0, l1, l2) = decompose_86(&p); - assert!(l0 < (1u128 << 86)); - assert!(l1 < (1u128 << 86)); - assert!(l2 < (1u128 << 84)); - } - - #[test] - fn test_compute_carries_86_simple() { - // Test with small values: a=3, b=5, p=7 - // a*b = 15, 15 / 7 = 2 remainder 1 - // So q=2, r=1 - let a_val = [3u64, 0, 0, 0]; - let b_val = [5, 0, 0, 0]; - let p_val = [7, 0, 0, 0]; - let product = widening_mul(&a_val, &b_val); - let (q_val, r_val) = divmod_wide(&product, &p_val); - assert_eq!(q_val, [2, 0, 0, 0]); - assert_eq!(r_val, [1, 0, 0, 0]); - - let (a0, a1, a2) = decompose_86(&a_val); - let (b0, b1, b2) = decompose_86(&b_val); - let (p0, p1, p2) = decompose_86(&p_val); - let (q0, q1, q2) = decompose_86(&q_val); - let (r0, r1, r2) = decompose_86(&r_val); - - let carries = compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ - r0, r1, r2, - ]); - // For small values, all carries should be 0 - assert_eq!(carries, [0, 0, 0, 0]); - } - - #[test] - fn test_compute_carries_86_secp256r1() { - // Test with secp256r1-sized values - let p = [0xffffffffffffffff, 0xffffffff, 0x0, 0xffffffff00000001]; - let a_val = [0x123456789abcdef0, 0xfedcba9876543210, 0x0, 0x0]; // < p - let b_val = [0xaabbccddeeff0011, 0x1122334455667788, 0x0, 0x0]; // < p - - let product = widening_mul(&a_val, &b_val); - let (q_val, r_val) = divmod_wide(&product, &p); - - // Verify a*b = p*q + r - let pq = widening_mul(&p, &q_val); - let mut sum = pq; - let mut carry = 0u128; - for i in 0..4 { - let s = sum[i] as u128 + r_val[i] as u128 + carry; - sum[i] = s as u64; - carry = s >> 64; - } - for i in 4..8 { - let s = sum[i] as u128 + carry; - sum[i] = s as u64; - carry = s >> 64; - } - assert_eq!(sum, product); - - // Compute 86-bit decompositions - let (a0, a1, a2) = decompose_86(&a_val); - let (b0, b1, b2) = decompose_86(&b_val); - let (p0, p1, p2) = decompose_86(&p); - let (q0, q1, q2) = decompose_86(&q_val); - let (r0, r1, r2) = decompose_86(&r_val); - - // This should not panic - let _carries = - compute_carries_86([a0, a1, a2], [b0, b1, b2], [p0, p1, p2], [q0, q1, q2], [ - r0, r1, r2, - ]); - } - #[test] fn test_half_gcd_small() { // s = 42, n = 101 @@ -1232,6 +949,10 @@ mod tests { } else { divmod(&sum4, &n) }; - assert_eq!(remainder, [0, 0, 0, 0], "half_gcd relation failed for Grumpkin order"); + assert_eq!( + remainder, + [0, 0, 0, 0], + "half_gcd relation failed for Grumpkin order" + ); } } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index b78fea288..0fa9133ad 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -21,6 +21,7 @@ use { whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, }; +pub(crate) mod bigint_mod; pub mod input_utils; mod r1cs; mod whir_r1cs; diff --git a/provekit/prover/src/witness/mod.rs b/provekit/prover/src/witness/mod.rs index fb5072440..5f5de8f0b 100644 --- a/provekit/prover/src/witness/mod.rs +++ b/provekit/prover/src/witness/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod bigint_mod; mod digits; mod ram; pub(crate) mod witness_builder; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index ec3d61b06..cee33a18d 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -78,8 +78,8 @@ impl WitnessBuilderSolver for WitnessBuilder { let a_limbs = a.into_bigint().0; let m_limbs = modulus.into_bigint().0; // Fermat's little theorem: a^{-1} = a^{m-2} mod m - let exp = crate::witness::bigint_mod::sub_u64(&m_limbs, 2); - let result_limbs = crate::witness::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); + let exp = crate::bigint_mod::sub_u64(&m_limbs, 2); + let result_limbs = crate::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); witness[*witness_idx] = Some(FieldElement::from_bigint(ark_ff::BigInt(result_limbs)).unwrap()); } @@ -87,7 +87,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let dividend = witness[*dividend_idx].unwrap(); let d_limbs = dividend.into_bigint().0; let m_limbs = divisor.into_bigint().0; - let (quotient, _remainder) = crate::witness::bigint_mod::divmod(&d_limbs, &m_limbs); + let (quotient, _remainder) = crate::bigint_mod::divmod(&d_limbs, &m_limbs); witness[*witness_idx] = Some(FieldElement::from_bigint(ark_ff::BigInt(quotient)).unwrap()); } @@ -353,7 +353,7 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::witness::bigint_mod::{divmod_wide, widening_mul}; + use crate::bigint_mod::{divmod_wide, widening_mul}; let n = *num_limbs as usize; let w = *limb_bits; let limb_mask: u128 = if w >= 128 { @@ -503,7 +503,7 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::witness::bigint_mod::{mod_pow, sub_u64}; + use crate::bigint_mod::{mod_pow, sub_u64}; let n = *num_limbs as usize; let w = *limb_bits; let limb_mask: u128 = if w >= 128 { @@ -570,7 +570,7 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, .. } => { - use crate::witness::bigint_mod::{add_4limb, cmp_4limb}; + use crate::bigint_mod::{add_4limb, cmp_4limb}; let w = *limb_bits; // Reconstruct from N limbs @@ -617,7 +617,7 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, .. } => { - use crate::witness::bigint_mod::cmp_4limb; + use crate::bigint_mod::cmp_4limb; let w = *limb_bits; let reconstruct = |limbs: &[usize]| -> [u64; 4] { @@ -670,24 +670,16 @@ impl WitnessBuilderSolver for WitnessBuilder { // Reconstruct s = s_lo + s_hi * 2^128 let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; - let s_val: [u64; 4] = [ - s_lo_val[0], - s_lo_val[1], - s_hi_val[0], - s_hi_val[1], - ]; + let s_val: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; - let (val1, val2, neg1, neg2) = - crate::witness::bigint_mod::half_gcd(&s_val, curve_order); + let (val1, val2, neg1, neg2) = crate::bigint_mod::half_gcd(&s_val, curve_order); witness[*output_start] = Some(FieldElement::from_bigint(ark_ff::BigInt(val1)).unwrap()); witness[*output_start + 1] = Some(FieldElement::from_bigint(ark_ff::BigInt(val2)).unwrap()); - witness[*output_start + 2] = - Some(FieldElement::from(neg1 as u64)); - witness[*output_start + 3] = - Some(FieldElement::from(neg2 as u64)); + witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); + witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); } WitnessBuilder::EcScalarMulHint { output_start, @@ -708,7 +700,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let py_val = witness[*py].unwrap().into_bigint().0; // Compute R = [s]P - let (rx, ry) = crate::witness::bigint_mod::ec_scalar_mul( + let (rx, ry) = crate::bigint_mod::ec_scalar_mul( &px_val, &py_val, &scalar, diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 3d8917d55..657f7bd78 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -67,8 +67,8 @@ pub(crate) fn add_digital_decomposition( // Add the constraints for the digital recomposition let mut digit_multipliers = vec![FieldElement::one()]; for log_base in log_bases[..log_bases.len() - 1].iter() { - let multiplier = *digit_multipliers.last().unwrap() - * FieldElement::from(2u64).pow([*log_base as u64]); + let multiplier = + *digit_multipliers.last().unwrap() * FieldElement::from(2u64).pow([*log_base as u64]); digit_multipliers.push(multiplier); } dd_struct diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index b896dc043..8a42735b3 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -15,9 +15,10 @@ pub enum FieldOpType { /// Count field ops in scalar_mul_glv for given parameters. /// -/// The GLV approach does interleaved two-point scalar mul with half-width scalars. -/// Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 is_zero + 2 point_selects -/// Plus: 2 table builds, on-curve check, scalar relation overhead. +/// The GLV approach does interleaved two-point scalar mul with half-width +/// scalars. Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 +/// is_zero + 2 point_selects Plus: 2 table builds, on-curve check, scalar +/// relation overhead. fn count_glv_field_ops( scalar_bits: usize, // half_bits = ceil(order_bits / 2) window_size: usize, @@ -72,7 +73,8 @@ fn count_glv_field_ops( } } - // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul (a*x), 2 add + // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul + // (a*x), 2 add total_mul += 8; total_add += 4; diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 07c53891a..455e465de 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -58,11 +58,6 @@ impl CurveParams { curve_native_point_fe(&self.field_modulus_p) } - /// Returns the curve parameter b as a native field element. - pub fn curve_b_fe(&self) -> FieldElement { - curve_native_point_fe(&self.curve_b) - } - /// Decompose the curve order n into `num_limbs` limbs of `limb_bits` width /// each. pub fn curve_order_n_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { @@ -99,6 +94,11 @@ impl CurveParams { (self.curve_order_bits() + 1) / 2 } + /// Decompose the generator x-coordinate into limbs. + pub fn generator_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.generator.0, limb_bits, num_limbs) + } + /// Decompose the offset point x-coordinate into limbs. pub fn offset_x_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { decompose_to_limbs(&self.offset_point.0, limb_bits, num_limbs) @@ -446,28 +446,6 @@ mod u256_arith { mod_pow(a, &exp, p) } - /// EC point addition on y^2 = x^3 + ax + b. - /// Computes (x1,y1) + (x2,y2). Requires x1 != x2. - pub fn ec_point_add(x1: &U256, y1: &U256, x2: &U256, y2: &U256, p: &U256) -> (U256, U256) { - // lambda = (y2 - y1) / (x2 - x1) - let num = mod_sub(y2, y1, p); - let denom = mod_sub(x2, x1, p); - let denom_inv = mod_inv(&denom, p); - let lambda = mod_mul(&num, &denom_inv, p); - - // x3 = lambda^2 - x1 - x2 - let lambda_sq = mod_mul(&lambda, &lambda, p); - let x1_plus_x2 = mod_add(x1, x2, p); - let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); - - // y3 = lambda * (x1 - x3) - y1 - let x1_minus_x3 = mod_sub(x1, &x3, p); - let lambda_dx = mod_mul(&lambda, &x1_minus_x3, p); - let y3 = mod_sub(&lambda_dx, y1, p); - - (x3, y3) - } - /// EC point doubling on y^2 = x^3 + ax + b. pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { // lambda = (3*x^2 + a) / (2*y) @@ -495,7 +473,7 @@ mod u256_arith { #[cfg(test)] mod tests { - use {super::*, ark_ff::Field}; + use super::*; #[test] fn test_offset_point_on_curve_grumpkin() { @@ -592,6 +570,7 @@ mod tests { } } +#[allow(dead_code)] pub fn secp256r1_params() -> CurveParams { CurveParams { field_modulus_p: [ diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 1b591ed8a..1bf2aa264 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -159,7 +159,12 @@ fn table_lookup( let half = current.len() / 2; let mut next = Vec::with_capacity(half); for i in 0..half { - next.push(point_select_unchecked(ops, bit, current[i], current[i + half])); + next.push(point_select_unchecked( + ops, + bit, + current[i], + current[i + half], + )); } current = next; } @@ -203,6 +208,8 @@ pub fn scalar_mul_glv( let w = window_size; let table_size = 1 << w; + // TODO : implement lazy overflow as used in gnark. + // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R let table_p = build_point_table(ops, px, py, table_size); let table_r = build_point_table(ops, rx, ry, table_size); @@ -252,13 +259,7 @@ pub fn scalar_mul_glv( &table_r[..] }; let looked_up_r = table_lookup(ops, lookup_table_r, s2_window_bits); - let added_r = point_add( - ops, - after_p.0, - after_p.1, - looked_up_r.0, - looked_up_r.1, - ); + let added_r = point_add(ops, after_p.0, after_p.1, looked_up_r.0, looked_up_r.1); let digit_r = ops.pack_bits(s2_window_bits); let digit_r_is_zero = ops.is_zero(digit_r); // is_zero already constrains its output boolean; skip redundant check @@ -267,4 +268,3 @@ pub fn scalar_mul_glv( acc } - diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 6bc96b8f9..92606bf32 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -7,9 +7,10 @@ pub mod multi_limb_ops; use { crate::{ digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + msm::multi_limb_arith::compute_is_zero, noir_to_r1cs::NoirToR1CSCompiler, }, - ark_ff::{AdditiveGroup, Field}, + ark_ff::{AdditiveGroup, Field, PrimeField}, curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{ @@ -66,18 +67,6 @@ impl Limbs { l } - /// Create `Limbs` from a slice of witness indices. - pub fn from_slice(s: &[usize]) -> Self { - assert!( - !s.is_empty() && s.len() <= MAX_LIMBS, - "slice length must be 1..={MAX_LIMBS}, got {}", - s.len() - ); - let mut data = [Self::UNINIT; MAX_LIMBS]; - data[..s.len()].copy_from_slice(s); - Self { data, len: s.len() } - } - /// View the active limbs as a slice. pub fn as_slice(&self) -> &[usize] { &self.data[..self.len] @@ -158,7 +147,7 @@ pub trait FieldOps { self.select_unchecked(flag, on_false, on_true) } - /// Checks if a BN254 native witness value is zero. + /// Checks if a native witness value is zero. /// Returns a boolean witness: 1 if zero, 0 if non-zero. fn is_zero(&mut self, value: usize) -> usize; @@ -215,6 +204,51 @@ pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize] compiler.add_sum(terms) } +/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. +/// Does NOT constrain a or b to be boolean — caller must ensure that. +fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { + let one = compiler.witness_one(); + let one_minus_a = compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), a), + ]); + let one_minus_b = compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), b), + ]); + let product = compiler.add_product(one_minus_a, one_minus_b); + compiler.add_sum(vec![ + SumTerm(None, one), + SumTerm(Some(-FieldElement::ONE), product), + ]) +} + +/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at +/// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if +/// degenerate). +fn detect_skip( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + inf_flag: usize, +) -> usize { + constrain_boolean(compiler, inf_flag); + let is_zero_s_lo = compute_is_zero(compiler, s_lo); + let is_zero_s_hi = compute_is_zero(compiler, s_hi); + let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); + compute_boolean_or(compiler, s_is_zero, inf_flag) +} + +/// Constrains `a * b = 0`. +fn constrain_product_zero(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( + FieldElement::ZERO, + compiler.witness_one(), + )]); +} + // --------------------------------------------------------------------------- // Params builder (runtime num_limbs, no const generics) // --------------------------------------------------------------------------- @@ -236,7 +270,6 @@ fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiL two_pow_w, modulus_raw: curve.field_modulus_p, curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), - modulus_bits: curve.modulus_bits(), is_native, modulus_fe, } @@ -248,10 +281,15 @@ fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiL /// Processes all deferred MSM operations. /// +/// Internally selects the optimal (limb_bits, window_size) via cost model +/// and uses Grumpkin curve parameters. +/// /// Each entry is `(points, scalars, (out_x, out_y, out_inf))` where: /// - `points` has layout `[x1, y1, inf1, x2, y2, inf2, ...]` (3 per point) /// - `scalars` has layout `[s1_lo, s1_hi, s2_lo, s2_hi, ...]` (2 per scalar) /// - outputs are the R1CS witness indices for the result point +/// Grumpkin-specific MSM entry point (used by the Noir `MultiScalarMul` black +/// box). pub fn add_msm( compiler: &mut NoirToR1CSCompiler, msm_ops: Vec<( @@ -259,11 +297,34 @@ pub fn add_msm( Vec, (usize, usize, usize), )>, - limb_bits: u32, - window_size: usize, + range_checks: &mut BTreeMap>, +) { + let curve = curve::grumpkin_params(); + add_msm_with_curve(compiler, msm_ops, range_checks, &curve); +} + +/// Curve-agnostic MSM: compiles MSM operations for any curve described by +/// `curve`. +pub fn add_msm_with_curve( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + (usize, usize, usize), + )>, range_checks: &mut BTreeMap>, curve: &CurveParams, ) { + if msm_ops.is_empty() { + return; + } + + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); + let (limb_bits, window_size) = + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256); + for (points, scalars, outputs) in msm_ops { add_single_msm( compiler, @@ -350,25 +411,52 @@ fn process_single_msm<'a>( // Single-point: R is the ACIR output directly let px_witness = point_wits[0]; let py_witness = point_wits[1]; - // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) - constrain_zero(compiler, point_wits[2]); + let inf_flag = point_wits[2]; let s_lo = scalar_wits[0]; let s_hi = scalar_wits[1]; - // Decompose P into limbs + // --- Detect degenerate case: is_skip = (scalar == 0) OR (point is infinity) + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + // --- Sanitize inputs: swap in generator G and scalar=1 when is_skip --- + let one = compiler.witness_one(); + let gen_x_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.0)); + let gen_y_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.1)); + + let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); + let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); + + // When is_skip=1, use scalar=(1, 0) so FakeGLV computes [1]*G = G + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); + let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); + + // Sanitize R (output point): when is_skip=1, R must be G (since [1]*G = G) + let sanitized_rx = select_witness(compiler, is_skip, out_x, gen_x_witness); + let sanitized_ry = select_witness(compiler, is_skip, out_y, gen_y_witness); + + // Decompose sanitized P into limbs let (px, py) = decompose_point_to_limbs( compiler, - px_witness, - py_witness, + sanitized_px, + sanitized_py, num_limbs, limb_bits, range_checks, ); - // R = ACIR output, decompose into limbs + // Decompose sanitized R into limbs let (rx, ry) = decompose_point_to_limbs( - compiler, out_x, out_y, num_limbs, limb_bits, range_checks, + compiler, + sanitized_rx, + sanitized_ry, + num_limbs, + limb_bits, + range_checks, ); + // Run FakeGLV on sanitized values (always satisfiable) (compiler, range_checks) = verify_point_fakeglv( compiler, range_checks, @@ -376,62 +464,111 @@ fn process_single_msm<'a>( py, rx, ry, - s_lo, - s_hi, + sanitized_s_lo, + sanitized_s_hi, num_limbs, limb_bits, window_size, curve, ); - constrain_zero(compiler, out_inf); + // --- Mask output: when is_skip, output must be (0, 0, 1) --- + constrain_equal(compiler, out_inf, is_skip); + constrain_product_zero(compiler, is_skip, out_x); + constrain_product_zero(compiler, is_skip, out_y); } else { // Multi-point: compute R_i = [s_i]P_i via hints, verify each with FakeGLV, - // then accumulate R_i's and constrain against ACIR output. - let mut acc: Option<(Limbs, Limbs)> = None; + // then accumulate R_i's with offset-based accumulation and skip handling. + let one = compiler.witness_one(); + + // Generator constants for sanitization + let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); + let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); + let gen_x_witness = add_constant_witness(compiler, gen_x_fe); + let gen_y_witness = add_constant_witness(compiler, gen_y_fe); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Build params once for all multi-limb ops in the multi-point path + let params = build_params(num_limbs, limb_bits, curve); + + // Offset point as limbs for accumulation + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + + // Start accumulator at offset_point + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let mut acc_x = ops.constant_limbs(&offset_x_values); + let mut acc_y = ops.constant_limbs(&offset_y_values); + compiler = ops.compiler; + range_checks = ops.range_checks; + + // Track all_skipped = product of all is_skip flags + let mut all_skipped: Option = None; for i in 0..n_points { let px_witness = point_wits[3 * i]; let py_witness = point_wits[3 * i + 1]; - // Constrain input infinity flag to 0 (affine coordinates cannot represent infinity) - constrain_zero(compiler, point_wits[3 * i + 2]); + let inf_flag = point_wits[3 * i + 2]; let s_lo = scalar_wits[2 * i]; let s_hi = scalar_wits[2 * i + 1]; - // Add EcScalarMulHint → R_i = [s_i]P_i + // --- Detect degenerate case --- + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + // Track all_skipped + all_skipped = Some(match all_skipped { + None => is_skip, + Some(prev) => compiler.add_product(prev, is_skip), + }); + + // --- Sanitize inputs --- + let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); + let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); + let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); + let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); + + // EcScalarMulHint uses sanitized inputs let hint_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { output_start: hint_start, - px: px_witness, - py: py_witness, - s_lo, - s_hi, + px: sanitized_px, + py: sanitized_py, + s_lo: sanitized_s_lo, + s_hi: sanitized_s_hi, curve_a: curve.curve_a, field_modulus_p: curve.field_modulus_p, }); let rx_witness = hint_start; let ry_witness = hint_start + 1; - // Decompose P_i into limbs + // When is_skip=1, R should be G (since [1]*G = G) + let sanitized_rx = select_witness(compiler, is_skip, rx_witness, gen_x_witness); + let sanitized_ry = select_witness(compiler, is_skip, ry_witness, gen_y_witness); + + // Decompose sanitized P_i into limbs let (px, py) = decompose_point_to_limbs( compiler, - px_witness, - py_witness, + sanitized_px, + sanitized_py, num_limbs, limb_bits, range_checks, ); - // Decompose R_i into limbs + // Decompose sanitized R_i into limbs let (rx, ry) = decompose_point_to_limbs( compiler, - rx_witness, - ry_witness, + sanitized_rx, + sanitized_ry, num_limbs, limb_bits, range_checks, ); - // Verify R_i = [s_i]P_i using FakeGLV + // Verify R_i = [s_i]P_i using FakeGLV (on sanitized values) (compiler, range_checks) = verify_point_fakeglv( compiler, range_checks, @@ -439,44 +576,87 @@ fn process_single_msm<'a>( py, rx, ry, - s_lo, - s_hi, + sanitized_s_lo, + sanitized_s_hi, num_limbs, limb_bits, window_size, curve, ); - // Accumulate R_i via point_add - acc = Some(match acc { - None => (rx, ry), - Some((ax, ay)) => { - let params = build_params(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params, - }; - let sum = ec_points::point_add(&mut ops, ax, ay, rx, ry); - compiler = ops.compiler; - range_checks = ops.range_checks; - sum - } - }); + // --- Offset-based accumulation with conditional select --- + // Compute candidate = point_add(acc, R_i) + // Then select: if is_skip, keep acc unchanged; else use candidate + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); + let (new_acc_x, new_acc_y) = + ec_points::point_select(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); + acc_x = new_acc_x; + acc_y = new_acc_y; + compiler = ops.compiler; + range_checks = ops.range_checks; } - let (computed_x, computed_y) = acc.expect("MSM must have at least one point"); + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + // Subtract offset: result = point_add(acc, -offset) + // Negated offset = (offset_x, -offset_y) + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_values = + curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); + // When all_skipped, acc == offset_point, so subtracting offset would be + // point_add(O, -O) which fails (x1 == x2). Use generator G as the + // subtraction target instead; the result won't matter since we'll mask it. + let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); + + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Select subtraction point: if all_skipped, use -G; else use -offset + let sub_x = { + let off_x = ops.constant_limbs(&offset_x_values); + let g_x = ops.constant_limbs(&gen_x_limb_values); + ops.select(all_skipped, off_x, g_x) + }; + let sub_y = { + let neg_off_y = ops.constant_limbs(&neg_offset_y_values); + let neg_g_y = ops.constant_limbs(&neg_gen_y_values); + ops.select(all_skipped, neg_off_y, neg_g_y) + }; + + let (result_x, result_y) = ec_points::point_add(&mut ops, acc_x, acc_y, sub_x, sub_y); + compiler = ops.compiler; + range_checks = ops.range_checks; + + // --- Constrain output --- + // When all_skipped: output is (0, 0, 1) + // Otherwise: output matches the computed result with inf=0 if num_limbs == 1 { - constrain_equal(compiler, out_x, computed_x[0]); - constrain_equal(compiler, out_y, computed_y[0]); + // Mask result with all_skipped: when all_skipped=1, out must be 0 + let masked_result_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); } else { - let recomposed_x = recompose_limbs(compiler, computed_x.as_slice(), limb_bits); - let recomposed_y = recompose_limbs(compiler, computed_y.as_slice(), limb_bits); - constrain_equal(compiler, out_x, recomposed_x); - constrain_equal(compiler, out_y, recomposed_y); + let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); + let masked_result_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); } - constrain_zero(compiler, out_inf); + constrain_equal(compiler, out_inf, all_skipped); } } @@ -524,62 +704,53 @@ fn verify_point_fakeglv<'a>( &'a mut NoirToR1CSCompiler, &'a mut BTreeMap>, ) { - // --- Step 1: On-curve checks for P and R --- + // --- Steps 1-4: On-curve checks, FakeGLV decomposition, and GLV scalar mul + // --- + let s1_witness; + let s2_witness; + let neg1_witness; + let neg2_witness; { let params = build_params(num_limbs, limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, - params, + params: ¶ms, }; + // Step 1: On-curve checks for P and R let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); - verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); - compiler = ops.compiler; - range_checks = ops.range_checks; - } - - // --- Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 --- - let glv_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { - output_start: glv_start, - s_lo, - s_hi, - curve_order: curve.curve_order_n, - }); - let s1_witness = glv_start; - let s2_witness = glv_start + 1; - let neg1_witness = glv_start + 2; - let neg2_witness = glv_start + 3; - - // neg1 and neg2 are constrained to be boolean by the `select` calls - // in Step 4 below (MultiLimbOps::select calls constrain_boolean internally). + // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 + let glv_start = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + s1_witness = glv_start; + s2_witness = glv_start + 1; + neg1_witness = glv_start + 2; + neg2_witness = glv_start + 3; - // --- Step 3: Decompose |s1|, |s2| into half_bits bits each --- - let half_bits = curve.glv_half_bits() as usize; - let s1_bits = decompose_half_scalar_bits(compiler, s1_witness, half_bits); - let s2_bits = decompose_half_scalar_bits(compiler, s2_witness, half_bits); + // Step 3: Decompose |s1|, |s2| into half_bits bits each + let half_bits = curve.glv_half_bits() as usize; + let s1_bits = decompose_half_scalar_bits(ops.compiler, s1_witness, half_bits); + let s2_bits = decompose_half_scalar_bits(ops.compiler, s2_witness, half_bits); - // --- Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity check --- - { - let params = build_params(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params, - }; + // Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity + // check // Compute negated y-coordinates: neg_y = 0 - y (mod p) - let zero_limbs = vec![FieldElement::from(0u64); num_limbs]; - let zero = ops.constant_limbs(&zero_limbs); - - let neg_py = ops.sub(zero, py); - let neg_ry = ops.sub(zero, ry); + let neg_py = ops.negate(py); + let neg_ry = ops.negate(ry); // Select: if neg1=1, use neg_py; else use py + // neg1 and neg2 are constrained to be boolean by ops.select internally. let py_effective = ops.select(neg1_witness, py, neg_py); // Select: if neg2=1, use neg_ry; else use ry let ry_effective = ops.select(neg2_witness, ry, neg_ry); @@ -603,7 +774,8 @@ fn verify_point_fakeglv<'a>( offset_y, ); - // Identity check: acc should equal [2^(num_windows * window_size)] * offset_point + // Identity check: acc should equal [2^(num_windows * window_size)] * + // offset_point let glv_num_windows = (half_bits + window_size - 1) / window_size; let glv_n_doublings = glv_num_windows * window_size; let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); @@ -751,12 +923,41 @@ fn build_scalar_relation_params( two_pow_w, modulus_raw: curve.curve_order_n, curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused - modulus_bits: curve.curve_order_bits(), - is_native: false, // always non-native + is_native: false, // always non-native modulus_fe, } } +/// Picks the largest limb size for the scalar-relation multi-limb arithmetic +/// that fits inside the native field without overflow. +/// +/// The schoolbook multiplication column equations require: +/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +/// +/// We start at 64 bits (the ideal case — inputs are 128-bit half-scalars) and +/// search downward until the soundness check passes. For BN254 (254-bit native +/// field) this resolves to 64; smaller fields like M31 (31 bits) will get a +/// proportionally smaller limb size. +/// +/// Panics if the native field is too small (< ~12 bits) to support any valid +/// limb decomposition. +fn scalar_relation_limb_bits(order_bits: usize) -> u32 { + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let mut limb_bits: u32 = 64.min((native_bits.saturating_sub(4)) / 2); + loop { + let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; + if cost_model::column_equation_fits_native_field(native_bits, limb_bits, num_limbs) { + break; + } + limb_bits -= 1; + assert!( + limb_bits >= 4, + "native field too small for scalar relation verification" + ); + } + limb_bits +} + /// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 /// (mod n). /// @@ -774,60 +975,70 @@ fn verify_scalar_relation( neg2_witness: usize, curve: &CurveParams, ) { - // Use 64-bit limbs. Number of limbs covers the full curve order. - let sr_limb_bits: u32 = 64; let order_bits = curve.curve_order_bits() as usize; + let sr_limb_bits = scalar_relation_limb_bits(order_bits); let sr_num_limbs = (order_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; let half_bits = curve.glv_half_bits() as usize; - // Number of 64-bit limbs the half-scalar occupies + // Number of limbs the half-scalar occupies let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; let params = build_scalar_relation_params(sr_num_limbs, sr_limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, - params, + params: ¶ms, }; - // Decompose s into sr_num_limbs × 64-bit limbs from (s_lo, s_hi) - // s_lo contains bits [0..128), s_hi contains bits [128..256) + // Decompose s into sr_num_limbs limbs from (s_lo, s_hi). + // s_lo contains bits [0..128), s_hi contains bits [128..256). let s_limbs = { - let dd_lo = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_lo]); - let dd_hi = add_digital_decomposition(ops.compiler, vec![64, 64], vec![s_hi]); + let limbs_per_half = (128 + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let dd_bases_128: Vec = (0..limbs_per_half) + .map(|i| { + let remaining = 128u32 - (i as u32 * sr_limb_bits); + remaining.min(sr_limb_bits) as usize + }) + .collect(); + let dd_lo = add_digital_decomposition(ops.compiler, dd_bases_128.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, dd_bases_128, vec![s_hi]); let mut limbs = Limbs::new(sr_num_limbs); - // s_lo provides limbs 0,1; s_hi provides limbs 2,3 (for sr_num_limbs=4) - let lo_n = 2.min(sr_num_limbs); + let lo_n = limbs_per_half.min(sr_num_limbs); for i in 0..lo_n { limbs[i] = dd_lo.get_digit_witness_index(i, 0); - ops.range_checks.entry(64).or_default().push(limbs[i]); + let remaining = 128u32 - (i as u32 * sr_limb_bits); + ops.range_checks + .entry(remaining.min(sr_limb_bits)) + .or_default() + .push(limbs[i]); } let hi_n = sr_num_limbs - lo_n; for i in 0..hi_n { limbs[lo_n + i] = dd_hi.get_digit_witness_index(i, 0); + let remaining = 128u32 - (i as u32 * sr_limb_bits); ops.range_checks - .entry(64) + .entry(remaining.min(sr_limb_bits)) .or_default() .push(limbs[lo_n + i]); } limbs }; - // Helper: decompose a half-scalar witness into sr_num_limbs × 64-bit limbs. - // The half-scalar has `half_bits` bits → occupies `half_limbs` 64-bit limbs. + // Helper: decompose a half-scalar witness into sr_num_limbs limbs. + // The half-scalar has `half_bits` bits → occupies `half_limbs` limbs. // Upper limbs (half_limbs..sr_num_limbs) are zero-padded. let decompose_half_scalar = |ops: &mut MultiLimbOps, witness: usize| -> Limbs { let dd_bases: Vec = (0..half_limbs) .map(|i| { - let remaining = half_bits as u32 - (i as u32 * 64); - remaining.min(64) as usize + let remaining = half_bits as u32 - (i as u32 * sr_limb_bits); + remaining.min(sr_limb_bits) as usize }) .collect(); let dd = add_digital_decomposition(ops.compiler, dd_bases, vec![witness]); let mut limbs = Limbs::new(sr_num_limbs); for i in 0..half_limbs { limbs[i] = dd.get_digit_witness_index(i, 0); - let remaining_bits = (half_bits as u32) - (i as u32 * 64); - let this_limb_bits = remaining_bits.min(64); + let remaining_bits = (half_bits as u32) - (i as u32 * sr_limb_bits); + let this_limb_bits = remaining_bits.min(sr_limb_bits); ops.range_checks .entry(this_limb_bits) .or_default() @@ -855,15 +1066,12 @@ fn verify_scalar_relation( // Handle signs: compute effective values // If neg2 is set: neg_product = n - product (mod n), i.e. 0 - product - let zero_limbs_vals = vec![FieldElement::from(0u64); sr_num_limbs]; - let zero = ops.constant_limbs(&zero_limbs_vals); - let neg_product = ops.sub(zero, product); - // Select: if neg2=1, use neg_product; else use product + let neg_product = ops.negate(product); // neg2 already constrained boolean in verify_point_fakeglv let effective_product = ops.select_unchecked(neg2_witness, product, neg_product); // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 - let neg_s1 = ops.sub(zero, s1_limbs); + let neg_s1 = ops.negate(s1_limbs); // neg1 already constrained boolean in verify_point_fakeglv let effective_s1 = ops.select_unchecked(neg1_witness, s1_limbs, neg_s1); @@ -876,6 +1084,13 @@ fn verify_scalar_relation( } } +/// Creates a constant witness with the given value. +fn add_constant_witness(compiler: &mut NoirToR1CSCompiler, value: FieldElement) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + w +} + /// Constrains two witnesses to be equal: `a - b = 0`. fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { compiler.r1cs.add_constraint( diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 7ac8d78ac..34275b890 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -24,7 +24,6 @@ pub struct MultiLimbParams { pub two_pow_w: FieldElement, pub modulus_raw: [u64; 4], pub curve_a_limbs: Vec, - pub modulus_bits: u32, /// p = native field → skip mod reduction pub is_native: bool, /// For N=1 non-native: the modulus as a single FieldElement @@ -32,13 +31,13 @@ pub struct MultiLimbParams { } /// Unified field operations struct parameterized by runtime limb count. -pub struct MultiLimbOps<'a> { +pub struct MultiLimbOps<'a, 'p> { pub compiler: &'a mut NoirToR1CSCompiler, pub range_checks: &'a mut BTreeMap>, - pub params: MultiLimbParams, + pub params: &'p MultiLimbParams, } -impl MultiLimbOps<'_> { +impl MultiLimbOps<'_, '_> { fn is_native_single(&self) -> bool { self.params.num_limbs == 1 && self.params.is_native } @@ -50,9 +49,16 @@ impl MultiLimbOps<'_> { fn n(&self) -> usize { self.params.num_limbs } + + /// Negate a multi-limb value: computes `0 - value (mod p)`. + pub fn negate(&mut self, value: Limbs) -> Limbs { + let zero_vals = vec![FieldElement::from(0u64); self.params.num_limbs]; + let zero = self.constant_limbs(&zero_vals); + self.sub(zero, value) + } } -impl FieldOps for MultiLimbOps<'_> { +impl FieldOps for MultiLimbOps<'_, '_> { type Elem = Limbs; fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 2475ddf8c..144730437 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -745,24 +745,7 @@ impl NoirToR1CSCompiler { let constraints_before_msm = self.r1cs.num_constraints(); let witnesses_before_msm = self.num_witnesses(); - // Cost model: pick optimal (limb_bits, window_size) for MSM - let curve = crate::msm::curve::grumpkin_params(); - let native_bits = FieldElement::MODULUS_BIT_SIZE; - let curve_bits = curve.modulus_bits(); - let (msm_limb_bits, msm_window_size) = if !msm_ops.is_empty() { - let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); - crate::msm::cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256) - } else { - (native_bits, 4) - }; - add_msm( - self, - msm_ops, - msm_limb_bits, - msm_window_size, - &mut range_checks, - &curve, - ); + add_msm(self, msm_ops, &mut range_checks); breakdown.msm_constraints = self.r1cs.num_constraints() - constraints_before_msm; breakdown.msm_witnesses = self.num_witnesses() - witnesses_before_msm; From 6e7197b4309fec1f215210369b24971d3a06ac3c Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 6 Mar 2026 04:21:46 +0530 Subject: [PATCH 07/19] feat : updated G offset for curve and refactor --- noir-examples/embedded_curve_msm/Prover.toml | 68 ++- noir-examples/embedded_curve_msm/src/main.nr | 3 +- provekit/prover/src/bigint_mod.rs | 167 +++++++- .../prover/src/witness/witness_builder.rs | 393 ++++-------------- provekit/r1cs-compiler/src/msm/curve.rs | 39 +- 5 files changed, 319 insertions(+), 351 deletions(-) diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml index 58c6933da..da0b3529b 100644 --- a/noir-examples/embedded_curve_msm/Prover.toml +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -1,5 +1,71 @@ -# MSM: result = s1 * G + s2 * G = 1*G + 2*G = 3*G +# ============================================================ +# MSM test vectors: result = s1 * G + s2 * G +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# Uncomment ONE test case at a time to run. +# ============================================================ + +# === Test 1: Small scalars (1*G + 2*G = 3*G) === scalar1_lo = "1" scalar1_hi = "0" scalar2_lo = "2" scalar2_hi = "0" + +# === Test 2: All-zero scalars (0*G + 0*G = point at infinity) === +# scalar1_lo = "0" +# scalar1_hi = "0" +# scalar2_lo = "0" +# scalar2_hi = "0" + +# === Test 3: One zero, one non-zero (0*G + 5*G = 5*G) === +# scalar1_lo = "0" +# scalar1_hi = "0" +# scalar2_lo = "5" +# scalar2_hi = "0" + +# === Test 4: Large lo, small hi (diff ≠ 2^128) === +# scalar1_lo = "64323764613183177041862057485226039389" +# scalar1_hi = "1" +# scalar2_lo = "99999999999999999999999999999999999999" +# scalar2_hi = "3" + +# === Test 5: Small lo, large hi === +# scalar1_lo = "1" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "2" +# scalar2_hi = "64323764613183177041862057485226039389" + +# === Test 6: Near-max scalars (n-10 and n-20) === +# scalar1_lo = "201385395114098847380338600778089168189" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "201385395114098847380338600778089168179" +# scalar2_hi = "64323764613183177041862057485226039389" + +# === Test 7: Powers of 2 (2^100 and 2^200) === +# scalar1_lo = "1267650600228229401496703205376" +# scalar1_hi = "0" +# scalar2_lo = "0" +# scalar2_hi = "4722366482869645213696" + +# === Test 8: Half curve order (n/2) and 1 === +# scalar1_lo = "270833881017518655421856604104928689827" +# scalar1_hi = "32161882306591588520931028742613019694" +# scalar2_lo = "1" +# scalar2_hi = "0" + +# === Test 9: Large mixed scalars === +# scalar1_lo = "340282366920938463463374607431768211455" +# scalar1_hi = "0" +# scalar2_lo = "170141183460469231731687303715884105727" +# scalar2_hi = "3" + +# === Test 10: Both scalars equal, ~2n/3 === +# scalar1_lo = "247684385716378719408017269662648849284" +# scalar1_hi = "42882509742122118027908038323484026259" +# scalar2_lo = "247684385716378719408017269662648849284" +# scalar2_hi = "42882509742122118027908038323484026259" + +# === Test 11: n - 2, n - 3 (previously failing with [2]G offset) === +# scalar1_lo = "201385395114098847380338600778089168197" +# scalar1_hi = "64323764613183177041862057485226039389" +# scalar2_lo = "201385395114098847380338600778089168196" +# scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/src/main.nr b/noir-examples/embedded_curve_msm/src/main.nr index cf0704211..19a193181 100644 --- a/noir-examples/embedded_curve_msm/src/main.nr +++ b/noir-examples/embedded_curve_msm/src/main.nr @@ -26,7 +26,8 @@ fn main( let result = multi_scalar_mul([g, g], [s1, s2]); // Prevent dead-code elimination - forces the blackbox to be retained - assert(!result.is_infinite); + // Using is_infinite as return value ensures the MSM is computed + assert(result.is_infinite == (scalar1_lo + scalar1_hi + scalar2_lo + scalar2_hi == 0)); } #[test] diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index 796fd4cdf..eaeec3fa3 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -248,6 +248,149 @@ pub fn is_zero(val: &[u64; 4]) -> bool { val[0] == 0 && val[1] == 0 && val[2] == 0 && val[3] == 0 } +/// Compute the bit mask for a limb of the given width. +pub fn limb_mask(limb_bits: u32) -> u128 { + if limb_bits >= 128 { + u128::MAX + } else { + (1u128 << limb_bits) - 1 + } +} + +/// Right-shift a 4-limb (256-bit) value by `bits` positions. +pub fn shr_256(val: &[u64; 4], bits: u32) -> [u64; 4] { + if bits >= 256 { + return [0; 4]; + } + let mut shifted = [0u64; 4]; + let word_shift = (bits / 64) as usize; + let bit_shift = bits % 64; + for i in 0..4 { + if i + word_shift < 4 { + shifted[i] = val[i + word_shift] >> bit_shift; + if bit_shift > 0 && i + word_shift + 1 < 4 { + shifted[i] |= val[i + word_shift + 1] << (64 - bit_shift); + } + } + } + shifted +} + +/// Decompose a 256-bit value into `num_limbs` limbs of `limb_bits` width. +/// Returns u128 limb values (each < 2^limb_bits). +pub fn decompose_to_u128_limbs(val: &[u64; 4], num_limbs: usize, limb_bits: u32) -> Vec { + let mask = limb_mask(limb_bits); + let mut limbs = Vec::with_capacity(num_limbs); + let mut remaining = *val; + for _ in 0..num_limbs { + let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); + limbs.push(lo & mask); + remaining = shr_256(&remaining, limb_bits); + } + limbs +} + +/// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` +/// boundaries. +pub fn reconstruct_from_u128_limbs(limb_values: &[u128], limb_bits: u32) -> [u64; 4] { + let mut val = [0u64; 4]; + let mut bit_offset = 0u32; + for &limb_u128 in limb_values.iter() { + let word_start = (bit_offset / 64) as usize; + let bit_within = bit_offset % 64; + if word_start < 4 { + val[word_start] |= (limb_u128 as u64) << bit_within; + if word_start + 1 < 4 { + val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; + } + if word_start + 2 < 4 && bit_within > 0 { + let upper = limb_u128 >> (128 - bit_within); + if upper > 0 { + val[word_start + 2] |= upper as u64; + } + } + } + bit_offset += limb_bits; + } + val +} + +/// Compute schoolbook carries for a*b = p*q + r verification in base +/// 2^limb_bits. Returns unsigned-offset carries ready to be written as +/// witnesses. +pub fn compute_mul_mod_carries( + a_limbs: &[u128], + b_limbs: &[u128], + p_limbs: &[u128], + q_limbs: &[u128], + r_limbs: &[u128], + limb_bits: u32, +) -> Vec { + let n = a_limbs.len(); + let w = limb_bits; + let num_carries = 2 * n - 2; + let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); + let mut carries = Vec::with_capacity(num_carries); + let mut carry: i128 = 0; + + for k in 0..(2 * n - 1) { + let mut ab_lo: u128 = 0; + let mut ab_hi: u64 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + let prod = a_limbs[i] * b_limbs[j as usize]; + let (new_lo, ov) = ab_lo.overflowing_add(prod); + ab_lo = new_lo; + if ov { + ab_hi += 1; + } + } + } + let mut pq_lo: u128 = 0; + let mut pq_hi: u64 = 0; + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + let prod = p_limbs[i] * q_limbs[j as usize]; + let (new_lo, ov) = pq_lo.overflowing_add(prod); + pq_lo = new_lo; + if ov { + pq_hi += 1; + } + } + } + if k < n { + let (new_lo, ov) = pq_lo.overflowing_add(r_limbs[k]); + pq_lo = new_lo; + if ov { + pq_hi += 1; + } + } + + let diff_lo = ab_lo.wrapping_sub(pq_lo); + let borrow = if ab_lo < pq_lo { 1i64 } else { 0 }; + let diff_hi = ab_hi as i64 - pq_hi as i64 - borrow; + + let carry_lo = carry as u128; + let carry_hi: i64 = if carry < 0 { -1 } else { 0 }; + let (total_lo, ov) = diff_lo.overflowing_add(carry_lo); + let total_hi = diff_hi + carry_hi + if ov { 1i64 } else { 0 }; + + if k < 2 * n - 2 { + debug_assert_eq!( + total_lo & ((1u128 << w) - 1), + 0, + "non-zero remainder at column {k}" + ); + carry = total_hi as i128 * (1i128 << (128 - w)) + (total_lo >> w) as i128; + carries.push((carry + carry_offset as i128) as u128); + } + } + + carries +} + /// Compute the number of bits needed for the half-GCD sub-scalars. /// Returns `ceil(order_bits / 2)` where `order_bits` is the bit length of `n`. pub fn half_gcd_bits(n: &[u64; 4]) -> u32 { @@ -358,25 +501,19 @@ pub fn half_gcd(s: &[u64; 4], n: &[u64; 4]) -> ([u64; 4], [u64; 4], bool, bool) } // At this point: r_curr < 2^half_bits and t_curr < ~2^half_bits (half-GCD - // property) The relation is: (-1)^(iteration) * r_curr + t_curr * s ≡ 0 - // (mod n) Or equivalently: r_curr ≡ (-1)^(iteration+1) * t_curr * s (mod n) + // property). + // + // From the extended GCD identity: t_i * s ≡ r_i (mod n) + // Rearranging: -r_i + t_i * s ≡ 0 (mod n) + // + // The circuit checks: (-1)^neg1 * |r_i| + (-1)^neg2 * |t_i| * s ≡ 0 (mod n) + // Since r_i is always non-negative, neg1 must always be true (negate r_i). + // neg2 must match the actual sign of t_i so that (-1)^neg2 * |t_i| = t_i. let val1 = r_curr; // |s1| = |r_i| let val2 = t_curr; // |s2| = |t_i| - // Determine signs: - // We need: neg1 * val1 + neg2 * val2 * s ≡ 0 (mod n) - // From the extended GCD: r_i = (-1)^i * (... some relation with t_i * s mod n) - // The exact sign relationship: - // t_i * s ≡ (-1)^(i+1) * r_i (mod n) - // So: (-1)^(i+1) * r_i + t_i * s ≡ 0 (mod n) - // - // If iteration is even: (-1)^(even+1) = -1, so: -r_i + t_i * s ≡ 0 - // → neg1=true (negate r_i), neg2=t_curr_neg - // If iteration is odd: (-1)^(odd+1) = 1, so: r_i + t_i * s ≡ 0 - // → neg1=false, neg2=t_curr_neg - - let neg1 = iteration % 2 == 0; // negate val1 when iteration is even + let neg1 = true; // always negate r_i: -r_i + t_i * s ≡ 0 (mod n) let neg2 = t_curr_neg; (val1, val2, neg1, neg2) diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index cee33a18d..87b1105e6 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -23,6 +23,42 @@ pub trait WitnessBuilderSolver { ); } +/// Resolve a ConstantOrR1CSWitness to its FieldElement value. +fn resolve(witness: &[Option], v: &ConstantOrR1CSWitness) -> FieldElement { + match v { + ConstantOrR1CSWitness::Constant(c) => *c, + ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), + } +} + +/// Convert a u128 value to a FieldElement. +fn u128_to_fe(val: u128) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt([val as u64, (val >> 64) as u64, 0, 0])).unwrap() +} + +/// Read witness limbs and reconstruct as [u64; 4]. +fn read_witness_limbs( + witness: &[Option], + indices: &[usize], + limb_bits: u32, +) -> [u64; 4] { + let limb_values: Vec = indices + .iter() + .map(|&idx| { + let bigint = witness[idx].unwrap().into_bigint().0; + bigint[0] as u128 | ((bigint[1] as u128) << 64) + }) + .collect(); + crate::bigint_mod::reconstruct_from_u128_limbs(&limb_values, limb_bits) +} + +/// Write u128 limb values as FieldElement witnesses starting at `start`. +fn write_limbs(witness: &mut [Option], start: usize, vals: &[u128]) { + for (i, &val) in vals.iter().enumerate() { + witness[start + i] = Some(u128_to_fe(val)); + } +} + impl WitnessBuilderSolver for WitnessBuilder { fn solve( &self, @@ -171,18 +207,9 @@ impl WitnessBuilderSolver for WitnessBuilder { rhs, output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let output = match output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let output = resolve(witness, output); witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() - (lhs @@ -201,22 +228,10 @@ impl WitnessBuilderSolver for WitnessBuilder { and_output, xor_output, ) => { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let and_out = match and_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let xor_out = match xor_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); + let and_out = resolve(witness, and_output); + let xor_out = resolve(witness, xor_output); // Encoding: sz - (lhs + rs*rhs + rs²*and_out + rs³*xor_out) witness[*witness_idx] = Some( witness[*sz_challenge].unwrap() @@ -229,18 +244,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::MultiplicitiesForBinOp(witness_idx, atomic_bits, operands) => { let mut multiplicities = vec![0u32; 2usize.pow(2 * *atomic_bits)]; for (lhs, rhs) in operands { - let lhs = match lhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; - let rhs = match rhs { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => { - witness[*witness_idx].unwrap() - } - }; + let lhs = resolve(witness, lhs); + let rhs = resolve(witness, rhs); let index = (lhs.into_bigint().0[0] << *atomic_bits) + rhs.into_bigint().0[0]; multiplicities[index as usize] += 1; } @@ -249,14 +254,8 @@ impl WitnessBuilderSolver for WitnessBuilder { } } WitnessBuilder::U32Addition(result_witness_idx, carry_witness_idx, a, b) => { - let a_val = match a { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; - let b_val = match b { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(idx) => witness[*idx].unwrap(), - }; + let a_val = resolve(witness, a); + let b_val = resolve(witness, b); assert!( a_val.into_bigint().num_bits() <= 32, "a_val must be less than or equal to 32 bits, got {}", @@ -284,12 +283,7 @@ impl WitnessBuilderSolver for WitnessBuilder { // Sum all inputs as u64 to handle overflow. let mut sum: u64 = 0; for input in inputs { - let val = match input { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(idx) => { - witness[*idx].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, input).into_bigint().0[0]; assert!(val < (1u64 << 32), "input must be 32-bit"); sum += val; } @@ -300,14 +294,8 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*carry_witness_idx] = Some(FieldElement::from(quotient)); } WitnessBuilder::And(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -323,14 +311,8 @@ impl WitnessBuilderSolver for WitnessBuilder { )); } WitnessBuilder::Xor(result_witness_idx, lh, rh) => { - let lh_val = match lh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; - let rh_val = match rh { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(witness_idx) => witness[*witness_idx].unwrap(), - }; + let lh_val = resolve(witness, lh); + let rh_val = resolve(witness, rh); assert!( lh_val.into_bigint().num_bits() <= 32, "lh_val must be less than or equal to 32 bits, got {}", @@ -353,148 +335,33 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::bigint_mod::{divmod_wide, widening_mul}; + use crate::bigint_mod::{ + compute_mul_mod_carries, decompose_to_u128_limbs, divmod_wide, widening_mul, + }; let n = *num_limbs as usize; let w = *limb_bits; - let limb_mask: u128 = if w >= 128 { - u128::MAX - } else { - (1u128 << w) - 1 - }; - - // Reconstruct a, b as [u64; 4] from N limbs - let reconstruct = |limbs: &[usize]| -> [u64; 4] { - let mut val = [0u64; 4]; - let mut bit_offset = 0u32; - for &limb_idx in limbs.iter() { - let limb_val = witness[limb_idx].unwrap().into_bigint().0; - let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); - // Place into val at bit_offset - let word_start = (bit_offset / 64) as usize; - let bit_within = bit_offset % 64; - if word_start < 4 { - val[word_start] |= (limb_u128 as u64) << bit_within; - if word_start + 1 < 4 { - val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; - } - if word_start + 2 < 4 && bit_within > 0 { - let upper = limb_u128 >> (128 - bit_within); - if upper > 0 { - val[word_start + 2] |= upper as u64; - } - } - } - bit_offset += w; - } - val - }; - let a_val = reconstruct(a_limbs); - let b_val = reconstruct(b_limbs); + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); - // Compute product and divmod let product = widening_mul(&a_val, &b_val); let (q_val, r_val) = divmod_wide(&product, modulus); - // Decompose a [u64;4] into N limbs of limb_bits width. - let decompose_n_from_u64 = |val: &[u64; 4]| -> Vec { - let mut limbs = Vec::with_capacity(n); - let mut remaining = *val; - for _ in 0..n { - let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); - limbs.push(lo & limb_mask); - // Shift right by w bits - if w >= 256 { - remaining = [0; 4]; - } else { - let mut shifted = [0u64; 4]; - let word_shift = (w / 64) as usize; - let bit_shift = w % 64; - for i in 0..4 { - if i + word_shift < 4 { - shifted[i] = remaining[i + word_shift] >> bit_shift; - if bit_shift > 0 && i + word_shift + 1 < 4 { - shifted[i] |= - remaining[i + word_shift + 1] << (64 - bit_shift); - } - } - } - remaining = shifted; - } - } - limbs - }; - - let q_limbs_vals = decompose_n_from_u64(&q_val); - let r_limbs_vals = decompose_n_from_u64(&r_val); + let q_limbs_vals = decompose_to_u128_limbs(&q_val, n, w); + let r_limbs_vals = decompose_to_u128_limbs(&r_val, n, w); - // Compute carries for schoolbook verification: - // a·b = p·q + r in base W = 2^limb_bits - // For each column k (0..2N-2): - // lhs_k = Σ_{i+j=k} a[i]*b[j] + carry_{k-1} - // rhs_k = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W - let p_limbs_vals = decompose_n_from_u64(modulus); - let a_limbs_vals = decompose_n_from_u64(&a_val); - let b_limbs_vals = decompose_n_from_u64(&b_val); - - let w_val = 1u128 << w; - let num_carries = 2 * n - 2; - let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); - let mut carries = Vec::with_capacity(num_carries); - let mut running: i128 = 0; - - for k in 0..(2 * n - 1) { - // Sum a[i]*b[j] for i+j=k - let mut ab_sum: i128 = 0; - for i in 0..n { - let j = k as isize - i as isize; - if j >= 0 && (j as usize) < n { - ab_sum += a_limbs_vals[i] as i128 * b_limbs_vals[j as usize] as i128; - } - } - // Sum p[i]*q[j] for i+j=k - let mut pq_sum: i128 = 0; - for i in 0..n { - let j = k as isize - i as isize; - if j >= 0 && (j as usize) < n { - pq_sum += p_limbs_vals[i] as i128 * q_limbs_vals[j as usize] as i128; - } - } - let r_k = if k < n { r_limbs_vals[k] as i128 } else { 0 }; - - // column: ab_sum + carry_prev = pq_sum + r_k + carry_next * W - // carry_next = (ab_sum + carry_prev - pq_sum - r_k) / W - running += ab_sum - pq_sum - r_k; - if k < 2 * n - 2 { - let carry = running / w_val as i128; - carries.push(carry); - running -= carry * w_val as i128; - } - } - - let u128_to_fe = |val: u128| -> FieldElement { - FieldElement::from_bigint(ark_ff::BigInt([ - val as u64, - (val >> 64) as u64, - 0, - 0, - ])) - .unwrap() - }; + let carries = compute_mul_mod_carries( + &decompose_to_u128_limbs(&a_val, n, w), + &decompose_to_u128_limbs(&b_val, n, w), + &decompose_to_u128_limbs(modulus, n, w), + &q_limbs_vals, + &r_limbs_vals, + w, + ); - // Write q limbs - for i in 0..n { - witness[*output_start + i] = Some(u128_to_fe(q_limbs_vals[i])); - } - // Write r limbs - for i in 0..n { - witness[*output_start + n + i] = Some(u128_to_fe(r_limbs_vals[i])); - } - // Write carries (unsigned-offset) - for i in 0..num_carries { - let c_unsigned = (carries[i] + carry_offset as i128) as u128; - witness[*output_start + 2 * n + i] = Some(u128_to_fe(c_unsigned)); - } + write_limbs(witness, *output_start, &q_limbs_vals); + write_limbs(witness, *output_start + n, &r_limbs_vals); + write_limbs(witness, *output_start + 2 * n, &carries); } WitnessBuilder::MultiLimbModularInverse { output_start, @@ -503,64 +370,14 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::bigint_mod::{mod_pow, sub_u64}; + use crate::bigint_mod::{decompose_to_u128_limbs, mod_pow, sub_u64}; let n = *num_limbs as usize; let w = *limb_bits; - let limb_mask: u128 = if w >= 128 { - u128::MAX - } else { - (1u128 << w) - 1 - }; - - // Reconstruct a as [u64; 4] from N limbs - let mut a_val = [0u64; 4]; - let mut bit_offset = 0u32; - for &limb_idx in a_limbs.iter() { - let limb_val = witness[limb_idx].unwrap().into_bigint().0; - let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); - let word_start = (bit_offset / 64) as usize; - let bit_within = bit_offset % 64; - if word_start < 4 { - a_val[word_start] |= (limb_u128 as u64) << bit_within; - if word_start + 1 < 4 { - a_val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; - } - } - bit_offset += w; - } - // Compute inverse: a^{p-2} mod p + let a_val = read_witness_limbs(witness, a_limbs, w); let exp = sub_u64(modulus, 2); let inv = mod_pow(&a_val, &exp, modulus); - - // Decompose into N limbs - let mut remaining = inv; - let u128_to_fe = |val: u128| -> FieldElement { - FieldElement::from_bigint(ark_ff::BigInt([ - val as u64, - (val >> 64) as u64, - 0, - 0, - ])) - .unwrap() - }; - for i in 0..n { - let lo = remaining[0] as u128 | ((remaining[1] as u128) << 64); - witness[*output_start + i] = Some(u128_to_fe(lo & limb_mask)); - // Shift right by w bits - let mut shifted = [0u64; 4]; - let word_shift = (w / 64) as usize; - let bit_shift = w % 64; - for j in 0..4 { - if j + word_shift < 4 { - shifted[j] = remaining[j + word_shift] >> bit_shift; - if bit_shift > 0 && j + word_shift + 1 < 4 { - shifted[j] |= remaining[j + word_shift + 1] << (64 - bit_shift); - } - } - } - remaining = shifted; - } + write_limbs(witness, *output_start, &decompose_to_u128_limbs(&inv, n, w)); } WitnessBuilder::MultiLimbAddQuotient { output, @@ -573,28 +390,8 @@ impl WitnessBuilderSolver for WitnessBuilder { use crate::bigint_mod::{add_4limb, cmp_4limb}; let w = *limb_bits; - // Reconstruct from N limbs - let reconstruct = |limbs: &[usize]| -> [u64; 4] { - let mut val = [0u64; 4]; - let mut bit_offset = 0u32; - for &limb_idx in limbs.iter() { - let limb_val = witness[limb_idx].unwrap().into_bigint().0; - let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); - let word_start = (bit_offset / 64) as usize; - let bit_within = bit_offset % 64; - if word_start < 4 { - val[word_start] |= (limb_u128 as u64) << bit_within; - if word_start + 1 < 4 { - val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; - } - } - bit_offset += w; - } - val - }; - - let a_val = reconstruct(a_limbs); - let b_val = reconstruct(b_limbs); + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); let sum = add_4limb(&a_val, &b_val); let q = if sum[4] > 0 { @@ -620,27 +417,8 @@ impl WitnessBuilderSolver for WitnessBuilder { use crate::bigint_mod::cmp_4limb; let w = *limb_bits; - let reconstruct = |limbs: &[usize]| -> [u64; 4] { - let mut val = [0u64; 4]; - let mut bit_offset = 0u32; - for &limb_idx in limbs.iter() { - let limb_val = witness[limb_idx].unwrap().into_bigint().0; - let limb_u128 = limb_val[0] as u128 | ((limb_val[1] as u128) << 64); - let word_start = (bit_offset / 64) as usize; - let bit_within = bit_offset % 64; - if word_start < 4 { - val[word_start] |= (limb_u128 as u64) << bit_within; - if word_start + 1 < 4 { - val[word_start + 1] |= (limb_u128 >> (64 - bit_within)) as u64; - } - } - bit_offset += w; - } - val - }; - - let a_val = reconstruct(a_limbs); - let b_val = reconstruct(b_limbs); + let a_val = read_witness_limbs(witness, a_limbs, w); + let b_val = read_witness_limbs(witness, b_limbs, w); let q = if cmp_4limb(&a_val, &b_val) == std::cmp::Ordering::Less { 1u64 @@ -776,12 +554,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let table_size = 1usize << *num_bits; let mut multiplicities = vec![0u32; table_size]; for query in queries { - let val = match query { - ConstantOrR1CSWitness::Constant(c) => c.into_bigint().0[0], - ConstantOrR1CSWitness::Witness(w) => { - witness[*w].unwrap().into_bigint().0[0] - } - }; + let val = resolve(witness, query).into_bigint().0[0]; multiplicities[val as usize] += 1; } for (i, count) in multiplicities.iter().enumerate() { @@ -791,14 +564,8 @@ impl WitnessBuilderSolver for WitnessBuilder { WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { let sz_val = witness[*sz].unwrap(); let rs_val = witness[*rs].unwrap(); - let input_val = match input { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; - let spread_val = match spread_output { - ConstantOrR1CSWitness::Constant(c) => *c, - ConstantOrR1CSWitness::Witness(w) => witness[*w].unwrap(), - }; + let input_val = resolve(witness, input); + let spread_val = resolve(witness, spread_output); // sz - (input + rs * spread_output) witness[*idx] = Some(sz_val - (input_val + rs_val * spread_val)); } diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 455e465de..0876bfadb 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -9,9 +9,6 @@ pub struct CurveParams { pub curve_a: [u64; 4], pub curve_b: [u64; 4], pub generator: ([u64; 4], [u64; 4]), - /// A known non-identity point on the curve, used as the accumulator offset - /// in `scalar_mul_glv`. Must be deterministic and unrelated to typical - /// table entries (we use [2]G). pub offset_point: ([u64; 4], [u64; 4]), } @@ -275,19 +272,19 @@ pub fn grumpkin_params() -> CurveParams { 0xcf135e7506a45d63_u64, 0x0000000000000002_u64, ]), - // Offset point = [2]G + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) offset_point: ( [ - 0x6d8bc688cdbffffe_u64, - 0x19a74caa311e13d4_u64, - 0xddeb49cdaa36306d_u64, - 0x06ce1b0827aafa85_u64, + 0x626578b496650e95_u64, + 0x8678dcf264df6c01_u64, + 0xf0b3eb7e6d02aba8_u64, + 0x223748a4c4edde75_u64, ], [ - 0x467be7e7a43f80ac_u64, - 0xc93faf6fa1a788bf_u64, - 0x909ede0ba2a6855f_u64, - 0x1c122f81a3a14964_u64, + 0xb75fb4c26bcd4f35_u64, + 0x4d4ba4d97d5f99d9_u64, + 0xccab35fdbf52368a_u64, + 0x25b41c5f56f8472b_u64, ], ), } @@ -611,19 +608,19 @@ pub fn secp256r1_params() -> CurveParams { 0x4fe342e2fe1a7f9b_u64, ], ), - // Offset point = [2]G + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) offset_point: ( [ - 0xa60b48fc47669978_u64, - 0xc08969e277f21b35_u64, - 0x8a52380304b51ac3_u64, - 0x7cf27b188d034f7e_u64, + 0x57c84fc9d789bd85_u64, + 0xfc35ff7dc297eac3_u64, + 0xfb982fd588c6766e_u64, + 0x447d739beedb5e67_u64, ], [ - 0x9e04b79d227873d1_u64, - 0xba7dade63ce98229_u64, - 0x293d9ac69f7430db_u64, - 0x07775510db8ed040_u64, + 0x0c7e33c972e25b32_u64, + 0x3d349b95a7fae500_u64, + 0xe12e9d953a4aaff7_u64, + 0x2d4825ab834131ee_u64, ], ), } From 64111f9d2b37f3a5155037e03230ddb6ab593c3b Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 6 Mar 2026 10:27:46 +0530 Subject: [PATCH 08/19] fix : private document items ci --- provekit/common/src/witness/witness_builder.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 2b7cd2f30..353f2575e 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -276,19 +276,19 @@ pub enum WitnessBuilder { /// (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 (mod n) /// /// Outputs 4 witnesses starting at output_start: - /// [0] |s1| (128-bit field element) - /// [1] |s2| (128-bit field element) - /// [2] neg1 (boolean: 0 or 1) - /// [3] neg2 (boolean: 0 or 1) + /// \[0\] |s1| (128-bit field element) + /// \[1\] |s2| (128-bit field element) + /// \[2\] neg1 (boolean: 0 or 1) + /// \[3\] neg2 (boolean: 0 or 1) FakeGLVHint { output_start: usize, s_lo: usize, s_hi: usize, curve_order: [u64; 4], }, - /// Prover hint for EC scalar multiplication: computes R = [s]P. + /// Prover hint for EC scalar multiplication: computes R = \[s\]P. /// Given point P = (px, py) and scalar s = s_lo + s_hi * 2^128, - /// computes R = [s]P on the curve with parameter `curve_a` and + /// computes R = \[s\]P on the curve with parameter `curve_a` and /// field modulus `field_modulus_p`. /// /// Outputs 2 witnesses at output_start: R_x, R_y. From 76ec4393213a604e228c5d512c738063c796d0d3 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 6 Mar 2026 11:05:18 +0530 Subject: [PATCH 09/19] fix : document-private-items ci --- provekit/prover/src/bigint_mod.rs | 2 +- provekit/r1cs-compiler/src/msm/ec_points.rs | 16 ++++++++-------- provekit/r1cs-compiler/src/msm/mod.rs | 6 +++--- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 12 ++++++------ 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index eaeec3fa3..e4ea1fea8 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -622,7 +622,7 @@ pub fn ec_point_add( (x3, y3) } -/// EC scalar multiplication via double-and-add: returns [scalar]*P. +/// EC scalar multiplication via double-and-add: returns \[scalar\]*P. pub fn ec_scalar_mul( px: &[u64; 4], py: &[u64; 4], diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 1bf2aa264..0138a571b 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -115,8 +115,8 @@ fn point_select_unchecked( /// Builds a point table for windowed scalar multiplication. /// -/// T[0] = P (dummy entry, used when window digit = 0) -/// T[1] = P, T[2] = 2P, T[i] = T[i-1] + P for i >= 3. +/// T\[0\] = P (dummy entry, used when window digit = 0) +/// T\[1\] = P, T\[2\] = 2P, T\[i\] = T\[i-1\] + P for i >= 3. fn build_point_table( ops: &mut F, px: F::Elem, @@ -137,8 +137,8 @@ fn build_point_table( table } -/// Selects T[d] from a point table using bit witnesses, where `d = Σ bits[i] * -/// 2^i`. +/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ +/// bits\[i\] * 2^i`. /// /// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, /// halving the candidate set at each level. Total: `(2^w - 1)` point selects @@ -180,10 +180,10 @@ fn table_lookup( /// /// Structure per window (from MSB to LSB): /// 1. `w` shared doublings on accumulator -/// 2. Table lookup in T_P[d1] for s1's window digit -/// 3. point_add(acc, T_P[d1]) + is_zero(d1) + point_select -/// 4. Table lookup in T_R[d2] for s2's window digit -/// 5. point_add(acc, T_R[d2]) + is_zero(d2) + point_select +/// 2. Table lookup in T_P\[d1\] for s1's window digit +/// 3. point_add(acc, T_P\[d1\]) + is_zero(d1) + point_select +/// 4. Table lookup in T_R\[d2\] for s2's window digit +/// 5. point_add(acc, T_R\[d2\]) + is_zero(d2) + point_select /// /// Returns the final accumulator (x, y). pub fn scalar_mul_glv( diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 92606bf32..0c24e4f7a 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -680,10 +680,10 @@ fn decompose_point_to_limbs( } } -/// FakeGLV verification for a single point: verifies R = [s]P. +/// FakeGLV verification for a single point: verifies R = \[s\]P. /// /// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies -/// [s1]P + [s2]R = O using interleaved windowed scalar mul with +/// \[s1\]P + \[s2\]R = O using interleaved windowed scalar mul with /// half-width scalars. /// /// Returns the mutable references back to the caller for continued use. @@ -853,7 +853,7 @@ fn decompose_witness_to_limbs( limbs } -/// Recompose limbs back into a single witness: val = Σ limb[i] * +/// Recompose limbs back into a single witness: val = Σ limb\[i\] * /// 2^(i*limb_bits) fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { let terms: Vec = limbs diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs index 12c30b382..840f8081a 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -196,9 +196,9 @@ pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize /// (a + b) mod p for multi-limb values. /// -/// Per limb i: v_i = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} +/// Per limb i: v_i = a\[i\] + b\[i\] + 2^W - q*p\[i\] + carry_{i-1} /// carry_i = floor(v_i / 2^W) -/// r[i] = v_i - carry_i * 2^W +/// r\[i\] = v_i - carry_i * 2^W pub fn add_mod_p_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -359,8 +359,8 @@ pub fn sub_mod_p_multi( /// (a * b) mod p for multi-limb values using schoolbook multiplication. /// /// Verifies: a·b = p·q + r in base W = 2^limb_bits. -/// Column k: Σ_{i+j=k} a[i]*b[j] + carry_{k-1} + OFFSET -/// = Σ_{i+j=k} p[i]*q[j] + r[k] + carry_k * W +/// Column k: Σ_{i+j=k} a\[i\]*b\[j\] + carry_{k-1} + OFFSET +/// = Σ_{i+j=k} p\[i\]*q\[j\] + r\[k\] + carry_k * W pub fn mul_mod_p_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, @@ -561,8 +561,8 @@ pub fn inv_mod_p_multi( } /// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. -/// Uses borrow propagation: d[i] = (p-1)[i] - r[i] + borrow_in - borrow_out * -/// 2^W +/// Uses borrow propagation: d\[i\] = (p-1)\[i\] - r\[i\] + borrow_in - +/// borrow_out * 2^W fn less_than_p_check_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, From b308656e30fdf86d478db2c4a9e2070613c9156a Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 6 Mar 2026 20:51:48 +0530 Subject: [PATCH 10/19] fix : bug in curve params handling for bigger field curves --- noir-examples/embedded_curve_msm/Prover.toml | 16 +-- provekit/r1cs-compiler/src/msm/cost_model.rs | 137 ++++++++++++++++--- provekit/r1cs-compiler/src/msm/curve.rs | 12 +- provekit/r1cs-compiler/src/msm/ec_points.rs | 2 +- provekit/r1cs-compiler/src/msm/mod.rs | 11 +- 5 files changed, 145 insertions(+), 33 deletions(-) diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml index da0b3529b..d36dddbd7 100644 --- a/noir-examples/embedded_curve_msm/Prover.toml +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -5,10 +5,10 @@ # ============================================================ # === Test 1: Small scalars (1*G + 2*G = 3*G) === -scalar1_lo = "1" -scalar1_hi = "0" -scalar2_lo = "2" -scalar2_hi = "0" +# scalar1_lo = "1" +# scalar1_hi = "0" +# scalar2_lo = "2" +# scalar2_hi = "0" # === Test 2: All-zero scalars (0*G + 0*G = point at infinity) === # scalar1_lo = "0" @@ -59,10 +59,10 @@ scalar2_hi = "0" # scalar2_hi = "3" # === Test 10: Both scalars equal, ~2n/3 === -# scalar1_lo = "247684385716378719408017269662648849284" -# scalar1_hi = "42882509742122118027908038323484026259" -# scalar2_lo = "247684385716378719408017269662648849284" -# scalar2_hi = "42882509742122118027908038323484026259" +scalar1_lo = "247684385716378719408017269662648849284" +scalar1_hi = "42882509742122118027908038323484026259" +scalar2_lo = "247684385716378719408017269662648849284" +scalar2_hi = "42882509742122118027908038323484026259" # === Test 11: n - 2, n - 3 (previously failing with [2]G offset) === # scalar1_lo = "201385395114098847380338600778089168197" diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 8a42735b3..daa89be87 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -15,14 +15,20 @@ pub enum FieldOpType { /// Count field ops in scalar_mul_glv for given parameters. /// +/// Returns `(n_add, n_sub, n_mul, n_inv, n_is_zero)`. +/// /// The GLV approach does interleaved two-point scalar mul with half-width /// scalars. Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 /// is_zero + 2 point_selects Plus: 2 table builds, on-curve check, scalar /// relation overhead. +/// +/// `is_zero` is counted separately because `compute_is_zero` always creates +/// exactly 3 native witnesses (SafeInverse + Product + Sum) regardless of +/// num_limbs — it operates on the `pack_bits` result, not on multi-limb values. fn count_glv_field_ops( scalar_bits: usize, // half_bits = ceil(order_bits / 2) window_size: usize, -) -> (usize, usize, usize, usize) { +) -> (usize, usize, usize, usize, usize) { let w = window_size; let table_size = 1 << w; let num_windows = (scalar_bits + w - 1) / w; @@ -39,6 +45,7 @@ fn count_glv_field_ops( let mut total_sub = 2 * (table_doubles * double_ops.1 + table_adds * add_ops.1); let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); + let mut total_is_zero = 0usize; for win_idx in (0..num_windows).rev() { let bit_start = win_idx * w; @@ -63,9 +70,8 @@ fn count_glv_field_ops( total_mul += add_ops.2; total_inv += add_ops.3; - total_inv += 1; // is_zero - total_add += 1; - total_mul += 1; + // is_zero: counted separately (3 fixed native witnesses each) + total_is_zero += 1; total_add += select_ops_per_point.0; total_sub += select_ops_per_point.1; @@ -84,7 +90,7 @@ fn count_glv_field_ops( total_sub += 2 * select_ops_per_point.1; total_mul += 2 * select_ops_per_point.2; - (total_add, total_sub, total_mul, total_inv) + (total_add, total_sub, total_mul, total_inv, total_is_zero) } /// Witnesses per single N-limb field operation. @@ -119,6 +125,66 @@ fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize } } +/// Count witnesses for scalar relation verification. +/// +/// The scalar relation verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod n)` +/// using multi-limb arithmetic with the curve order as modulus. This is always +/// non-native (curve_order_n ≠ native field modulus). +fn count_scalar_relation_witnesses(native_field_bits: u32, scalar_bits: usize) -> usize { + // Find sr_limb_bits (mirrors scalar_relation_limb_bits in mod.rs) + let mut sr_limb_bits: u32 = 64.min((native_field_bits.saturating_sub(4)) / 2); + loop { + let n = (scalar_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + if column_equation_fits_native_field(native_field_bits, sr_limb_bits, n) { + break; + } + sr_limb_bits -= 1; + assert!( + sr_limb_bits >= 4, + "native field too small for scalar relation cost estimation" + ); + } + + let sr_n = (scalar_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let half_bits = (scalar_bits + 1) / 2; + let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let limbs_per_128 = (128 + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + + // Scalar relation always uses non-native multi-limb arithmetic + let wit_add = witnesses_per_op(sr_n, FieldOpType::Add, false); + let wit_sub = witnesses_per_op(sr_n, FieldOpType::Sub, false); + let wit_mul = witnesses_per_op(sr_n, FieldOpType::Mul, false); + + let mut total = 0; + + // Digital decompositions for s_lo and s_hi (128 bits each) + total += 2 * limbs_per_128; + + // decompose_half_scalar for s1 and s2: + // Each: half_limbs DD witnesses + (sr_n - half_limbs) zero-pad constants + total += 2 * sr_n; + + // ops.mul(s2_limbs, s_limbs) + total += wit_mul; + + // ops.negate(product) = constant_limbs(sr_n) + sub + total += sr_n + wit_sub; + + // ops.select_unchecked(neg2, ...) = sr_n select witnesses + total += sr_n; + + // ops.negate(s1_limbs) = constant_limbs(sr_n) + sub + total += sr_n + wit_sub; + + // ops.select_unchecked(neg1, ...) = sr_n select witnesses + total += sr_n; + + // ops.add(effective_s1, effective_product) + total += wit_add; + + total +} + /// Total estimated witness cost for one scalar_mul. pub fn calculate_msm_witness_cost( native_field_bits: u32, @@ -127,8 +193,8 @@ pub fn calculate_msm_witness_cost( scalar_bits: usize, window_size: usize, limb_bits: u32, + is_native: bool, ) -> usize { - let is_native = curve_modulus_bits == native_field_bits; let num_limbs = if is_native { 1 } else { @@ -142,13 +208,14 @@ pub fn calculate_msm_witness_cost( // FakeGLV path for ALL points: half-width interleaved scalar mul let half_bits = (scalar_bits + 1) / 2; - let (n_add, n_sub, n_mul, n_inv) = count_glv_field_ops(half_bits, window_size); - let glv_scalarmul = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + let (n_add, n_sub, n_mul, n_inv, n_is_zero) = count_glv_field_ops(half_bits, window_size); + let glv_scalarmul = + n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv + n_is_zero * 3; // is_zero: 3 fixed native witnesses each // Per-point overhead: scalar decomposition (2 × half_bits for s1, s2) + - // scalar relation (~150 witnesses) + FakeGLVHint (4 witnesses) + // scalar relation (analytical) + FakeGLVHint (4 witnesses) let scalar_decomp = 2 * half_bits + 10; - let scalar_relation = 150; + let scalar_relation = count_scalar_relation_witnesses(native_field_bits, scalar_bits); let glv_hint = 4; // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) @@ -214,13 +281,16 @@ pub fn column_equation_fits_native_field( /// Each candidate is checked for column equation soundness: the schoolbook /// multiplication's intermediate values must fit in the native field without /// modular wraparound (see [`column_equation_fits_native_field`]). +/// +/// `is_native` should come from `CurveParams::is_native_field()` which +/// compares actual modulus values, not just bit widths. pub fn get_optimal_msm_params( native_field_bits: u32, curve_modulus_bits: u32, n_points: usize, scalar_bits: usize, + is_native: bool, ) -> (u32, usize) { - let is_native = curve_modulus_bits == native_field_bits; if is_native { // For native field, limb_bits doesn't matter (no multi-limb decomposition). // Just optimize window_size. @@ -234,6 +304,7 @@ pub fn get_optimal_msm_params( scalar_bits, ws, native_field_bits, + true, ); if cost < best_cost { best_cost = cost; @@ -251,8 +322,9 @@ pub fn get_optimal_msm_params( let mut best_limb_bits = max_limb_bits.min(86); let mut best_window = 4; - // Search space - for lb in (8..=max_limb_bits).step_by(2) { + // Search space: test every limb_bits value (not step_by(2)) to avoid + // missing optimal values at num_limbs transition boundaries. + for lb in 8..=max_limb_bits { let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { continue; @@ -265,6 +337,7 @@ pub fn get_optimal_msm_params( scalar_bits, ws, lb, + false, ); if cost < best_cost { best_cost = cost; @@ -284,7 +357,7 @@ mod tests { #[test] fn test_optimal_params_bn254_native() { // Grumpkin over BN254: native field - let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256); + let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256, true); assert_eq!(limb_bits, 254); assert!(window_size >= 2 && window_size <= 8); } @@ -292,7 +365,7 @@ mod tests { #[test] fn test_optimal_params_secp256r1() { // secp256r1 over BN254: 256-bit modulus, non-native - let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256); + let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256, false); let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; assert!( column_equation_fits_native_field(254, limb_bits, num_limbs), @@ -304,7 +377,7 @@ mod tests { #[test] fn test_optimal_params_goldilocks() { // Hypothetical 64-bit field over BN254 - let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64); + let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64, false); let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; assert!( column_equation_fits_native_field(254, limb_bits, num_limbs), @@ -328,10 +401,40 @@ mod tests { fn test_secp256r1_limb_bits_not_126() { // Regression: limb_bits=126 with N=3 causes offset_w = 2^255 > p_BN254, // making the schoolbook column equations unsound. - let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256); + let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256, false); assert!( limb_bits <= 124, "secp256r1 limb_bits={limb_bits} exceeds safe maximum 124" ); } + + #[test] + fn test_scalar_relation_witnesses_grumpkin() { + // Grumpkin: scalar_bits=256, sr_limb_bits=64, sr_n=4 + let sr = count_scalar_relation_witnesses(254, 256); + // Should be ~145 (not the old hardcoded 150) + assert!(sr > 100 && sr < 200, "unexpected scalar_relation={sr}"); + } + + #[test] + fn test_scalar_relation_witnesses_small_curve() { + // 64-bit curve: scalar_bits=64, should be much smaller than 150 + let sr = count_scalar_relation_witnesses(254, 64); + assert!( + sr < 100, + "64-bit curve scalar_relation={sr} should be < 100" + ); + } + + #[test] + fn test_is_zero_cost_independent_of_num_limbs() { + // Verify that is_zero doesn't scale with num_limbs in the cost model. + // For the same window parameters, changing num_limbs should only affect + // field ops, not is_zero cost. + let (_, _, _, _, n_is_zero_w4) = count_glv_field_ops(128, 4); + let (_, _, _, _, n_is_zero_w3) = count_glv_field_ops(128, 3); + // is_zero count depends on num_windows, not num_limbs + assert!(n_is_zero_w4 > 0); + assert!(n_is_zero_w3 > 0); + } } diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 0876bfadb..d4b3bff34 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,5 +1,5 @@ use { - ark_ff::{BigInteger, Field, PrimeField}, + ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, provekit_common::FieldElement, }; @@ -34,11 +34,15 @@ impl CurveParams { /// Number of bits in the field modulus. pub fn modulus_bits(&self) -> u32 { if self.is_native_field() { - // p mod p = 0 as a field element, so we use the constant directly. FieldElement::MODULUS_BIT_SIZE } else { - let fe = curve_native_point_fe(&self.field_modulus_p); - fe.into_bigint().num_bits() + let p = &self.field_modulus_p; + for i in (0..4).rev() { + if p[i] != 0 { + return (i as u32) * 64 + (64 - p[i].leading_zeros()); + } + } + 0 } } diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 0138a571b..5e096df59 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -102,7 +102,7 @@ pub fn point_select( /// Conditional point select without boolean constraint on `flag`. /// Caller must ensure `flag` is already constrained boolean. -fn point_select_unchecked( +pub fn point_select_unchecked( ops: &mut F, flag: usize, on_false: (F::Elem, F::Elem), diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 0c24e4f7a..1f6aa782e 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -321,9 +321,10 @@ pub fn add_msm_with_curve( let native_bits = FieldElement::MODULUS_BIT_SIZE; let curve_bits = curve.modulus_bits(); + let is_native = curve.is_native_field(); let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); let (limb_bits, window_size) = - cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256); + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256, is_native); for (points, scalars, outputs) in msm_ops { add_single_msm( @@ -593,8 +594,12 @@ fn process_single_msm<'a>( params: ¶ms, }; let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); - let (new_acc_x, new_acc_y) = - ec_points::point_select(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); + let (new_acc_x, new_acc_y) = ec_points::point_select_unchecked( + &mut ops, + is_skip, + (cand_x, cand_y), + (acc_x, acc_y), + ); acc_x = new_acc_x; acc_y = new_acc_y; compiler = ops.compiler; From e0a3d1dcfbfe7747ad09b9135771a71effedece6 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 11 Mar 2026 06:34:00 +0530 Subject: [PATCH 11/19] feat : updated native and non native curve implementations --- .../src/witness/scheduling/dependency.rs | 25 + .../common/src/witness/scheduling/remapper.rs | 53 + .../common/src/witness/witness_builder.rs | 62 ++ .../prover/src/witness/witness_builder.rs | 102 ++ provekit/r1cs-compiler/src/msm/cost_model.rs | 632 +++++++++--- provekit/r1cs-compiler/src/msm/curve.rs | 2 +- provekit/r1cs-compiler/src/msm/ec_points.rs | 179 +++- provekit/r1cs-compiler/src/msm/mod.rs | 947 +++++------------- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 18 +- provekit/r1cs-compiler/src/msm/native.rs | 413 ++++++++ provekit/r1cs-compiler/src/msm/non_native.rs | 424 ++++++++ .../r1cs-compiler/src/msm/scalar_relation.rs | 226 +++++ provekit/r1cs-compiler/src/range_check.rs | 33 + 13 files changed, 2273 insertions(+), 843 deletions(-) create mode 100644 provekit/r1cs-compiler/src/msm/native.rs create mode 100644 provekit/r1cs-compiler/src/msm/non_native.rs create mode 100644 provekit/r1cs-compiler/src/msm/scalar_relation.rs diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 68bc7b6e1..ba63359c1 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -223,10 +223,22 @@ impl DependencyInfo { data.rs_cubed, ] } + WitnessBuilder::EcDoubleHint { px, py, .. } => vec![*px, *py], + WitnessBuilder::EcAddHint { + x1, y1, x2, y2, .. + } => vec![*x1, *y1, *x2, *y2], WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], WitnessBuilder::EcScalarMulHint { px, py, s_lo, s_hi, .. } => vec![*px, *py, *s_lo, *s_hi], + WitnessBuilder::SelectWitness { + flag, + on_false, + on_true, + .. + } => vec![*flag, *on_false, *on_true], + WitnessBuilder::BooleanOr { a, b, .. } => vec![*a, *b], + WitnessBuilder::SignedBitHint { scalar, .. } => vec![*scalar], WitnessBuilder::ChunkDecompose { packed, .. } => vec![*packed], WitnessBuilder::SpreadWitness(_, input) => vec![*input], WitnessBuilder::SpreadBitExtract { sum_terms, .. } => { @@ -286,6 +298,13 @@ impl DependencyInfo { | WitnessBuilder::SpreadWitness(idx, ..) | WitnessBuilder::SpreadLookupDenominator(idx, ..) | WitnessBuilder::SpreadTableQuotient { idx, .. } => vec![*idx], + WitnessBuilder::SelectWitness { output, .. } + | WitnessBuilder::BooleanOr { output, .. } => vec![*output], + WitnessBuilder::SignedBitHint { + output_start, + num_bits, + .. + } => (*output_start..*output_start + *num_bits + 1).collect(), WitnessBuilder::MultiplicitiesForRange(start, range, _) => { (*start..*start + *range).collect() @@ -327,6 +346,12 @@ impl DependencyInfo { num_limbs, .. } => (*output_start..*output_start + *num_limbs as usize).collect(), + WitnessBuilder::EcDoubleHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } + WitnessBuilder::EcAddHint { output_start, .. } => { + (*output_start..*output_start + 3).collect() + } WitnessBuilder::FakeGLVHint { output_start, .. } => { (*output_start..*output_start + 4).collect() } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 696144113..0e039a1ba 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -366,6 +366,34 @@ impl WitnessIndexRemapper { }, ) } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => WitnessBuilder::EcDoubleHint { + output_start: self.remap(*output_start), + px: self.remap(*px), + py: self.remap(*py), + curve_a: *curve_a, + field_modulus_p: *field_modulus_p, + }, + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => WitnessBuilder::EcAddHint { + output_start: self.remap(*output_start), + x1: self.remap(*x1), + y1: self.remap(*y1), + x2: self.remap(*x2), + y2: self.remap(*y2), + field_modulus_p: *field_modulus_p, + }, WitnessBuilder::FakeGLVHint { output_start, s_lo, @@ -394,6 +422,31 @@ impl WitnessIndexRemapper { curve_a: *curve_a, field_modulus_p: *field_modulus_p, }, + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => WitnessBuilder::SelectWitness { + output: self.remap(*output), + flag: self.remap(*flag), + on_false: self.remap(*on_false), + on_true: self.remap(*on_true), + }, + WitnessBuilder::BooleanOr { output, a, b } => WitnessBuilder::BooleanOr { + output: self.remap(*output), + a: self.remap(*a), + b: self.remap(*b), + }, + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => WitnessBuilder::SignedBitHint { + output_start: self.remap(*output_start), + scalar: self.remap(*scalar), + num_bits: *num_bits, + }, WitnessBuilder::ChunkDecompose { output_start, packed, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 353f2575e..0b368f966 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -301,6 +301,65 @@ pub enum WitnessBuilder { curve_a: [u64; 4], field_modulus_p: [u64; 4], }, + /// Prover hint for EC point doubling on native field. + /// Given P = (px, py) and curve parameter `a`, computes: + /// lambda = (3*px^2 + a) / (2*py) mod p + /// x3 = lambda^2 - 2*px mod p + /// y3 = lambda * (px - x3) - py mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcDoubleHint { + output_start: usize, + px: usize, + py: usize, + curve_a: [u64; 4], + field_modulus_p: [u64; 4], + }, + /// Prover hint for EC point addition on native field. + /// Given P1 = (x1, y1) and P2 = (x2, y2), computes: + /// lambda = (y2 - y1) / (x2 - x1) mod p + /// x3 = lambda^2 - x1 - x2 mod p + /// y3 = lambda * (x1 - x3) - y1 mod p + /// + /// Outputs 3 witnesses at output_start: lambda, x3, y3. + EcAddHint { + output_start: usize, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + field_modulus_p: [u64; 4], + }, + /// Conditional select: output = on_false + flag * (on_true - on_false). + /// When flag=0, output=on_false; when flag=1, output=on_true. + /// (output, flag, on_false, on_true) + SelectWitness { + output: usize, + flag: usize, + on_false: usize, + on_true: usize, + }, + /// Boolean OR: output = a + b - a*b = 1 - (1-a)*(1-b). + /// (output, a, b) + BooleanOr { + output: usize, + a: usize, + b: usize, + }, + /// Signed-bit decomposition hint for wNAF scalar multiplication. + /// Given scalar s with num_bits bits, computes sign-bits b_0..b_{n-1} + /// and skew ∈ {0,1} such that: + /// s + skew + (2^n - 1) = Σ b_i * 2^{i+1} + /// where d_i = 2*b_i - 1 ∈ {-1, +1}. + /// + /// Outputs (num_bits + 1) witnesses at output_start: + /// [0..num_bits) b_i sign bits + /// [num_bits] skew (0 if s is odd, 1 if s is even) + SignedBitHint { + output_start: usize, + scalar: usize, + num_bits: usize, + }, /// Computes spread(input): interleave bits with zeros. /// Output: 0 b_{n-1} 0 b_{n-2} ... 0 b_1 0 b_0 /// (witness index of output, witness index of input) @@ -365,6 +424,9 @@ impl WitnessBuilder { WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, + WitnessBuilder::SignedBitHint { num_bits, .. } => *num_bits + 1, + WitnessBuilder::EcDoubleHint { .. } => 3, + WitnessBuilder::EcAddHint { .. } => 3, WitnessBuilder::FakeGLVHint { .. } => 4, WitnessBuilder::EcScalarMulHint { .. } => 2, diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 87b1105e6..9d98e7cb6 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -459,6 +459,75 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); } + WitnessBuilder::EcDoubleHint { + output_start, + px, + py, + curve_a, + field_modulus_p, + } => { + let px_val = witness[*px].unwrap().into_bigint().0; + let py_val = witness[*py].unwrap().into_bigint().0; + + // Compute lambda, x3, y3 using bigint_mod helpers + use crate::bigint_mod::{mod_add, mod_inverse, mod_sub, mul_mod}; + let x_sq = mul_mod(&px_val, &px_val, field_modulus_p); + let two_x_sq = mod_add(&x_sq, &x_sq, field_modulus_p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, field_modulus_p); + let numerator = mod_add(&three_x_sq, curve_a, field_modulus_p); + let two_y = mod_add(&py_val, &py_val, field_modulus_p); + let denom_inv = mod_inverse(&two_y, field_modulus_p); + let lambda = mul_mod(&numerator, &denom_inv, field_modulus_p); + + let lambda_sq = mul_mod(&lambda, &lambda, field_modulus_p); + let two_x = mod_add(&px_val, &px_val, field_modulus_p); + let x3 = mod_sub(&lambda_sq, &two_x, field_modulus_p); + + let x_minus_x3 = mod_sub(&px_val, &x3, field_modulus_p); + let lambda_dx = mul_mod(&lambda, &x_minus_x3, field_modulus_p); + let y3 = mod_sub(&lambda_dx, &py_val, field_modulus_p); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(lambda)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(x3)).unwrap()); + witness[*output_start + 2] = + Some(FieldElement::from_bigint(ark_ff::BigInt(y3)).unwrap()); + } + WitnessBuilder::EcAddHint { + output_start, + x1, + y1, + x2, + y2, + field_modulus_p, + } => { + let x1_val = witness[*x1].unwrap().into_bigint().0; + let y1_val = witness[*y1].unwrap().into_bigint().0; + let x2_val = witness[*x2].unwrap().into_bigint().0; + let y2_val = witness[*y2].unwrap().into_bigint().0; + + use crate::bigint_mod::{mod_inverse, mod_sub, mul_mod, mod_add}; + let numerator = mod_sub(&y2_val, &y1_val, field_modulus_p); + let denominator = mod_sub(&x2_val, &x1_val, field_modulus_p); + let denom_inv = mod_inverse(&denominator, field_modulus_p); + let lambda = mul_mod(&numerator, &denom_inv, field_modulus_p); + + let lambda_sq = mul_mod(&lambda, &lambda, field_modulus_p); + let x1_plus_x2 = mod_add(&x1_val, &x2_val, field_modulus_p); + let x3 = mod_sub(&lambda_sq, &x1_plus_x2, field_modulus_p); + + let x1_minus_x3 = mod_sub(&x1_val, &x3, field_modulus_p); + let lambda_dx = mul_mod(&lambda, &x1_minus_x3, field_modulus_p); + let y3 = mod_sub(&lambda_dx, &y1_val, field_modulus_p); + + witness[*output_start] = + Some(FieldElement::from_bigint(ark_ff::BigInt(lambda)).unwrap()); + witness[*output_start + 1] = + Some(FieldElement::from_bigint(ark_ff::BigInt(x3)).unwrap()); + witness[*output_start + 2] = + Some(FieldElement::from_bigint(ark_ff::BigInt(y3)).unwrap()); + } WitnessBuilder::EcScalarMulHint { output_start, px, @@ -491,6 +560,39 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output_start + 1] = Some(FieldElement::from_bigint(ark_ff::BigInt(ry)).unwrap()); } + WitnessBuilder::SelectWitness { + output, + flag, + on_false, + on_true, + } => { + let f = witness[*flag].unwrap(); + let a = witness[*on_false].unwrap(); + let b = witness[*on_true].unwrap(); + witness[*output] = Some(a + f * (b - a)); + } + WitnessBuilder::BooleanOr { output, a, b } => { + let a_val = witness[*a].unwrap(); + let b_val = witness[*b].unwrap(); + witness[*output] = Some(a_val + b_val - a_val * b_val); + } + WitnessBuilder::SignedBitHint { + output_start, + scalar, + num_bits, + } => { + let s_fe = witness[*scalar].unwrap(); + let s_big = s_fe.into_bigint().0; + let s_val: u128 = s_big[0] as u128 | ((s_big[1] as u128) << 64); + let n = *num_bits; + let skew: u128 = if s_val & 1 == 0 { 1 } else { 0 }; + let s_adj = s_val + skew; + let t = (s_adj + ((1u128 << n) - 1)) / 2; + for i in 0..n { + witness[*output_start + i] = Some(FieldElement::from(((t >> i) & 1) as u64)); + } + witness[*output_start + n] = Some(FieldElement::from(skew as u64)); + } WitnessBuilder::CombinedTableEntryInverse(..) => { unreachable!( "CombinedTableEntryInverse should not be called - handled by batch inversion" diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index daa89be87..79ce03bfb 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -4,6 +4,13 @@ //! pure analytical estimator → exhaustive search → pick optimal (limb_bits, //! window_size). +use std::collections::BTreeMap; + +/// The 256-bit scalar is split into two halves (s_lo, s_hi) because it doesn't +/// fit in the native field. This constant is used throughout the scalar relation +/// cost model. +const SCALAR_HALF_BITS: usize = 128; + /// Type of field operation for cost estimation. #[derive(Clone, Copy)] pub enum FieldOpType { @@ -13,29 +20,33 @@ pub enum FieldOpType { Inv, } -/// Count field ops in scalar_mul_glv for given parameters. +/// Count field ops and selects in scalar_mul_glv for given parameters. /// -/// Returns `(n_add, n_sub, n_mul, n_inv, n_is_zero)`. +/// Returns `(n_add, n_sub, n_mul, n_inv, n_is_zero, n_point_selects, +/// n_coord_selects)`. /// -/// The GLV approach does interleaved two-point scalar mul with half-width -/// scalars. Per window: w shared doubles + 2 table lookups + 2 point_adds + 2 -/// is_zero + 2 point_selects Plus: 2 table builds, on-curve check, scalar -/// relation overhead. +/// Field ops (add/sub/mul/inv) come from point_double, point_add, and +/// on-curve checks. Selects are counted separately because they create +/// `num_limbs` witnesses per coordinate (via `select_witness`), not +/// multi-limb field op witnesses. /// -/// `is_zero` is counted separately because `compute_is_zero` always creates -/// exactly 3 native witnesses (SafeInverse + Product + Sum) regardless of -/// num_limbs — it operates on the `pack_bits` result, not on multi-limb values. +/// - `n_point_selects`: selects on EcPoint (2 coordinates), from table +/// lookups and conditional skip after point_add. +/// - `n_coord_selects`: selects on single Limbs coordinate, from +/// y-negation. +/// - `n_is_zero`: `compute_is_zero` calls, each creating exactly 3 native +/// witnesses regardless of num_limbs. fn count_glv_field_ops( scalar_bits: usize, // half_bits = ceil(order_bits / 2) window_size: usize, -) -> (usize, usize, usize, usize, usize) { +) -> (usize, usize, usize, usize, usize, usize, usize) { let w = window_size; let table_size = 1 << w; let num_windows = (scalar_bits + w - 1) / w; + // Field ops per primitive EC operation (add, sub, mul, inv): let double_ops = (4usize, 2usize, 5usize, 1usize); let add_ops = (2usize, 2usize, 3usize, 1usize); - let select_ops_per_point = (2usize, 2usize, 2usize, 0usize); // Two tables (one for P, one for R) let table_doubles = if table_size > 2 { 1 } else { 0 }; @@ -46,12 +57,13 @@ fn count_glv_field_ops( let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); let mut total_is_zero = 0usize; + let mut total_point_selects = 0usize; for win_idx in (0..num_windows).rev() { let bit_start = win_idx * w; let bit_end = std::cmp::min(bit_start + w, scalar_bits); let actual_w = bit_end - bit_start; - let actual_selects = (1 << actual_w) - 1; + let actual_table_selects = (1 << actual_w) - 1; // w shared doublings total_add += w * double_ops.0; @@ -59,65 +71,86 @@ fn count_glv_field_ops( total_mul += w * double_ops.2; total_inv += w * double_ops.3; - // Two table lookups + two point_adds + two is_zeros + two point_selects + // Two table lookups + two point_adds + two is_zeros + two conditional + // skips for _ in 0..2 { - total_add += actual_selects * select_ops_per_point.0; - total_sub += actual_selects * select_ops_per_point.1; - total_mul += actual_selects * select_ops_per_point.2; + // Table lookup: (2^actual_w - 1) point selects + total_point_selects += actual_table_selects; + // Point add total_add += add_ops.0; total_sub += add_ops.1; total_mul += add_ops.2; total_inv += add_ops.3; - // is_zero: counted separately (3 fixed native witnesses each) + // is_zero: 3 fixed native witnesses each total_is_zero += 1; - total_add += select_ops_per_point.0; - total_sub += select_ops_per_point.1; - total_mul += select_ops_per_point.2; + // Conditional skip: 1 point select + total_point_selects += 1; } } - // On-curve checks for P and R: each needs 1 mul (y^2), 2 mul (x^2, x^3), 1 mul - // (a*x), 2 add + // On-curve checks for P and R: each needs mul(y²), mul(x²), mul(x³), + // mul(a·x), add(x³+ax), add(x³+ax+b) = 4 mul + 2 add per point total_mul += 8; total_add += 4; - // Conditional y-negation: 2 sub + 2 select (for P.y and R.y) + // Conditional y-negation: 2 negate (= 2 sub) + 2 Limbs selects (1 coord + // each) total_sub += 2; - total_add += 2 * select_ops_per_point.0; - total_sub += 2 * select_ops_per_point.1; - total_mul += 2 * select_ops_per_point.2; + let total_coord_selects = 2usize; + + ( + total_add, + total_sub, + total_mul, + total_inv, + total_is_zero, + total_point_selects, + total_coord_selects, + ) +} - (total_add, total_sub, total_mul, total_inv, total_is_zero) +/// Count only range-check-producing field ops in scalar_mul_glv. +/// +/// Returns `(n_add, n_sub, n_mul, n_inv)` excluding selects and is_zero, +/// which generate 0 range checks (selects are native `select_witness` calls, +/// is_zero operates on `pack_bits` results). +fn count_glv_real_field_ops( + scalar_bits: usize, + window_size: usize, +) -> (usize, usize, usize, usize) { + let (n_add, n_sub, n_mul, n_inv, _, _, _) = + count_glv_field_ops(scalar_bits, window_size); + (n_add, n_sub, n_mul, n_inv) } /// Witnesses per single N-limb field operation. fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize { if is_native { - // Native: no range checks, just standard R1CS witnesses match op { - FieldOpType::Add => 1, // sum witness - FieldOpType::Sub => 1, // sum witness - FieldOpType::Mul => 1, // product witness - FieldOpType::Inv => 1, // inverse witness + FieldOpType::Add => 1, + FieldOpType::Sub => 1, + FieldOpType::Mul => 1, + FieldOpType::Inv => 1, } } else if num_limbs == 1 { // Single-limb non-native: reduce_mod_p pattern match op { FieldOpType::Add => 5, // a+b, m const, k, k*m, result - FieldOpType::Sub => 5, // same + FieldOpType::Sub => 5, FieldOpType::Mul => 5, // a*b, m const, k, k*m, result - FieldOpType::Inv => 7, // a_inv + mul_mod_p(5) + range_check + FieldOpType::Inv => 6, // a_inv(1) + mul_mod_p_single(5) } } else { - // Multi-limb: N-limb operations let n = num_limbs; match op { - // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, d_limb) + // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, + // d_limb) FieldOpType::Add | FieldOpType::Sub => 1 + 3 * n + 3 * n, - // mul: hint(4N-2) + N² products + 2N-1 column constraints + lt_check + // mul: hint(4N-2) + N² products + 2N-1 column constraints + + // lt_check(3N) FieldOpType::Mul => (4 * n - 2) + n * n + 3 * n, // inv: hint(N) + mul costs FieldOpType::Inv => n + (4 * n - 2) + n * n + 3 * n, @@ -127,65 +160,112 @@ fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize /// Count witnesses for scalar relation verification. /// -/// The scalar relation verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod n)` -/// using multi-limb arithmetic with the curve order as modulus. This is always -/// non-native (curve_order_n ≠ native field modulus). +/// The scalar relation verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod +/// n)` using multi-limb arithmetic with the curve order as modulus. fn count_scalar_relation_witnesses(native_field_bits: u32, scalar_bits: usize) -> usize { - // Find sr_limb_bits (mirrors scalar_relation_limb_bits in mod.rs) - let mut sr_limb_bits: u32 = 64.min((native_field_bits.saturating_sub(4)) / 2); - loop { - let n = (scalar_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - if column_equation_fits_native_field(native_field_bits, sr_limb_bits, n) { - break; - } - sr_limb_bits -= 1; - assert!( - sr_limb_bits >= 4, - "native field too small for scalar relation cost estimation" - ); - } - - let sr_n = (scalar_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - let half_bits = (scalar_bits + 1) / 2; - let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - let limbs_per_128 = (128 + sr_limb_bits as usize - 1) / sr_limb_bits as usize; + let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); + let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; + let scalar_half_limbs = + (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; - // Scalar relation always uses non-native multi-limb arithmetic - let wit_add = witnesses_per_op(sr_n, FieldOpType::Add, false); - let wit_sub = witnesses_per_op(sr_n, FieldOpType::Sub, false); - let wit_mul = witnesses_per_op(sr_n, FieldOpType::Mul, false); + let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, false); + let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, false); + let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, false); - let mut total = 0; + // Scalar decomposition: DD digits for s_lo + s_hi, plus cross-boundary + // witness when limb boundaries don't align with the 128-bit split + let has_cross_boundary = + num_limbs > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; + let scalar_decomp = 2 * scalar_half_limbs + has_cross_boundary as usize; - // Digital decompositions for s_lo and s_hi (128 bits each) - total += 2 * limbs_per_128; + // Half-scalar decomposition: DD digits + zero-pad constants for s1, s2 + let half_scalar_decomp = 2 * num_limbs; - // decompose_half_scalar for s1 and s2: - // Each: half_limbs DD witnesses + (sr_n - half_limbs) zero-pad constants - total += 2 * sr_n; + // Sign handling: sum + diff + XOR (2 native witnesses) + select + let sign_handling = wit_add + wit_sub + 2 + num_limbs; - // ops.mul(s2_limbs, s_limbs) - total += wit_mul; + scalar_decomp + half_scalar_decomp + wit_mul + sign_handling +} - // ops.negate(product) = constant_limbs(sr_n) + sub - total += sr_n + wit_sub; +/// Range checks generated by a single N-limb field operation. +/// +/// Returns entries as `(bit_width, count)` pairs. Native ops produce no +/// range checks. Single-limb non-native uses `reduce_mod_p` (1 check at +/// `curve_modulus_bits`). Multi-limb ops produce checks at `limb_bits` +/// and `carry_bits = limb_bits + ceil(log2(N)) + 2`. +fn range_checks_per_op( + num_limbs: usize, + op: FieldOpType, + is_native: bool, + limb_bits: u32, + curve_modulus_bits: u32, +) -> Vec<(u32, usize)> { + if is_native { + return vec![]; + } + if num_limbs == 1 { + let bits = curve_modulus_bits; + return match op { + FieldOpType::Add | FieldOpType::Sub | FieldOpType::Mul => vec![(bits, 1)], + FieldOpType::Inv => vec![(bits, 2)], + }; + } + let n = num_limbs; + let ceil_log2_n = if n <= 1 { + 0u32 + } else { + (n as f64).log2().ceil() as u32 + }; + let carry_bits = limb_bits + ceil_log2_n + 2; + match op { + // add/sub: 2N from less_than_p_check_multi + FieldOpType::Add | FieldOpType::Sub => vec![(limb_bits, 2 * n)], + // mul: N q-limbs + 2N from less_than_p at limb_bits, (2N-2) carries + // at carry_bits + FieldOpType::Mul => vec![(limb_bits, 3 * n), (carry_bits, 2 * n - 2)], + // inv: N inv-limbs + mul's checks + FieldOpType::Inv => vec![(limb_bits, 4 * n), (carry_bits, 2 * n - 2)], + } +} - // ops.select_unchecked(neg2, ...) = sr_n select witnesses - total += sr_n; +/// Count range checks for scalar relation verification. +/// +/// Sources: DD digits (scalar + half-scalar decompositions) and multi-limb +/// field ops (1 mul + 1 add + 1 sub for XOR-based sign handling). +fn count_scalar_relation_range_checks( + native_field_bits: u32, + scalar_bits: usize, +) -> BTreeMap { + let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); + let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; + let half_bits = (scalar_bits + 1) / 2; + let half_limbs = (half_bits + limb_bits as usize - 1) / limb_bits as usize; + let scalar_half_limbs = + (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; - // ops.negate(s1_limbs) = constant_limbs(sr_n) + sub - total += sr_n + wit_sub; + let mut rc_map: BTreeMap = BTreeMap::new(); - // ops.select_unchecked(neg1, ...) = sr_n select witnesses - total += sr_n; + // DD digits: s_lo + s_hi (2 × scalar_half_limbs) + s1 + s2 (2 × half_limbs) + *rc_map.entry(limb_bits).or_default() += 2 * scalar_half_limbs + 2 * half_limbs; - // ops.add(effective_s1, effective_product) - total += wit_add; + // Multi-limb field ops: mul + add + sub + let modulus_bits = scalar_bits as u32; + for op in [FieldOpType::Mul, FieldOpType::Add, FieldOpType::Sub] { + for (bits, count) in range_checks_per_op(num_limbs, op, false, limb_bits, modulus_bits) { + *rc_map.entry(bits).or_default() += count; + } + } - total + rc_map } -/// Total estimated witness cost for one scalar_mul. +/// Total estimated witness cost for an MSM. +/// +/// Accounts for three categories of witnesses: +/// 1. **Inline witnesses** — field ops, selects, is_zero, hints, DDs +/// 2. **Range check resolution** — LogUp/naive cost for all range checks +/// 3. **Per-point overhead** — detect_skip, sanitization, point +/// decomposition pub fn calculate_msm_witness_cost( native_field_bits: u32, curve_modulus_bits: u32, @@ -195,47 +275,274 @@ pub fn calculate_msm_witness_cost( limb_bits: u32, is_native: bool, ) -> usize { - let num_limbs = if is_native { - 1 - } else { - ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize) - }; + if is_native { + return calculate_msm_witness_cost_native( + native_field_bits, + n_points, + scalar_bits, + window_size, + ); + } + + let num_limbs = + ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize); - let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, is_native); - let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, is_native); - let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, is_native); - let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, is_native); + let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, false); + let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, false); + let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, false); + let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, false); - // FakeGLV path for ALL points: half-width interleaved scalar mul + // === GLV scalar mul witnesses === let half_bits = (scalar_bits + 1) / 2; - let (n_add, n_sub, n_mul, n_inv, n_is_zero) = count_glv_field_ops(half_bits, window_size); - let glv_scalarmul = - n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv + n_is_zero * 3; // is_zero: 3 fixed native witnesses each + let (n_add, n_sub, n_mul, n_inv, n_is_zero, n_point_selects, n_coord_selects) = + count_glv_field_ops(half_bits, window_size); + + // Field ops: priced at full multi-limb cost + let field_op_cost = + n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + + // Selects: each select_witness creates 1 witness per limb (inlined). + // Point select = 2 coords × num_limbs × 1. + // Coord select = 1 coord × num_limbs × 1. + let select_cost = + n_point_selects * 2 * num_limbs + n_coord_selects * num_limbs; + + // is_zero: 3 fixed native witnesses each (SafeInverse + Product + Sum) + let is_zero_cost = n_is_zero * 3; + + let glv_scalarmul = field_op_cost + select_cost + is_zero_cost; + + // === Per-point overhead === + // Scalar bit decomposition: 2 DDs of half_bits 1-bit digits + let scalar_bit_decomp = 2 * half_bits; + + // detect_skip: 2×is_zero(3) + product(1) + boolean_or(1) = 8 + let detect_skip_cost = 8; + + // Sanitization: 3 constants (gen_x, gen_y, zero) + 6 select_witness × 1 + // For multi-point, constants are shared but impact is negligible. + let sanitize_cost = 3 + 6; - // Per-point overhead: scalar decomposition (2 × half_bits for s1, s2) + - // scalar relation (analytical) + FakeGLVHint (4 witnesses) - let scalar_decomp = 2 * half_bits + 10; + // Point decomposition digit witnesses (add_digital_decomposition creates + // num_limbs digit witnesses per coordinate; 2 coords × 2 points = 4). + // Only applies when num_limbs > 1 (decompose_point_to_limbs is a no-op + // for num_limbs == 1). + let point_decomp_digits = if num_limbs > 1 { 4 * num_limbs } else { 0 }; + + // Scalar relation (analytical) let scalar_relation = count_scalar_relation_witnesses(native_field_bits, scalar_bits); + + // FakeGLVHint: 4 witnesses (s1, s2, neg1, neg2) let glv_hint = 4; // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) let ec_hint = if n_points > 1 { 2 } else { 0 }; - let per_point = glv_scalarmul + scalar_decomp + scalar_relation + glv_hint + ec_hint; + let per_point = glv_scalarmul + + scalar_bit_decomp + + detect_skip_cost + + sanitize_cost + + point_decomp_digits + + scalar_relation + + glv_hint + + ec_hint; + + // === Point accumulation (multi-point only) === + // Each point gets: point_add(acc, R_i) + point_select_unchecked(skip). + // Plus final offset subtraction: 1 point_add + constants + 2 Limbs + // selects. + let point_add_cost = 2 * wit_add + 2 * wit_sub + 3 * wit_mul + wit_inv; + let accum = if n_points > 1 { + let accum_point_adds = n_points * point_add_cost; + let accum_point_selects = n_points * 2 * num_limbs; + // all_skipped tracking: (n_points - 1) product witnesses + let all_skipped_products = n_points - 1; + // Offset subtraction: point_add + 4×constant_limbs + 2 Limbs selects + // + 2×constant_limbs for initial acc + let offset_sub = point_add_cost + 6 * num_limbs + 2 * num_limbs; + + accum_point_adds + accum_point_selects + all_skipped_products + offset_sub + } else { + 0 + }; + + // === Range check resolution cost === + // All points' range checks share the same LogUp tables, so we aggregate + // across n_points before computing resolution cost (table amortizes). + let mut rc_map: BTreeMap = BTreeMap::new(); + + // 1. Range checks from GLV field ops (selects generate 0 range checks) + let (rc_n_add, rc_n_sub, rc_n_mul, rc_n_inv) = + count_glv_real_field_ops(half_bits, window_size); + for &(op, n_ops) in &[ + (FieldOpType::Add, rc_n_add), + (FieldOpType::Sub, rc_n_sub), + (FieldOpType::Mul, rc_n_mul), + (FieldOpType::Inv, rc_n_inv), + ] { + for (bits, count) in + range_checks_per_op(num_limbs, op, false, limb_bits, curve_modulus_bits) + { + *rc_map.entry(bits).or_default() += n_points * n_ops * count; + } + } + + // 2. Point decomposition range checks (num_limbs > 1 only). + // 4 coordinates: px, py, rx, ry. + if num_limbs > 1 { + *rc_map.entry(limb_bits).or_default() += n_points * 4 * num_limbs; + } + + // 3. Scalar relation range checks (always non-native, per point) + let sr_checks = count_scalar_relation_range_checks(native_field_bits, scalar_bits); + for (bits, count) in &sr_checks { + *rc_map.entry(*bits).or_default() += n_points * count; + } + + // 4. Accumulation range checks: n_points point_adds + 1 offset + // subtraction point_add (multi-point only) + if n_points > 1 { + let accum_point_adds = n_points + 1; // loop + offset subtraction + for &(op, n_ops) in &[ + (FieldOpType::Add, 2usize), + (FieldOpType::Sub, 2usize), + (FieldOpType::Mul, 3usize), + (FieldOpType::Inv, 1usize), + ] { + for (bits, count) in + range_checks_per_op(num_limbs, op, false, limb_bits, curve_modulus_bits) + { + *rc_map.entry(bits).or_default() += accum_point_adds * n_ops * count; + } + } + } + + // 5. Compute resolution cost + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + n_points * per_point + accum + range_check_cost +} + +/// Total estimated witness cost for a native-field MSM using hint-verified EC +/// ops with signed-bit wNAF (w=1). +/// +/// The native path replaces expensive field inversions with prover hints +/// verified via raw R1CS constraints: +/// - `point_double_verified_native`: 4W (3 hint + 1 product) vs 12W generic +/// - `point_add_verified_native`: 3W (3 hint) vs 8W generic +/// - `verify_on_curve_native`: 2W (2 products) vs 6W generic +/// - No multi-limb arithmetic for EC ops → zero EC-related range checks +/// +/// Uses signed-bit wNAF (w=1): every digit is non-zero (±1), so we always +/// add — no conditional skip selects. +/// +/// For n_points >= 2, uses merged-loop optimization: all points share a +/// single doubling per bit, saving 4W × (n-1) per bit. +/// Per bit (merged): 4W (shared double) + n × 8W (2×(1W select + 3W add)). +/// Skew correction: n × 10W. +fn calculate_msm_witness_cost_native( + native_field_bits: u32, + n_points: usize, + scalar_bits: usize, + _window_size: usize, +) -> usize { + let half_bits = (scalar_bits + 1) / 2; - // Point accumulation: (n_points - 1) point_adds + // === Costs that are always per-point === + let on_curve = 2 * 2; // 2 × verify_on_curve_native (2W each) + let glv_hint = 4; // FakeGLVHint (s1, s2, neg1, neg2) + let scalar_bit_decomp = 2 * (half_bits + 1); // signed-bit hint witnesses + let y_negate = 2 + 2 + 2; // 2 neg_y + 2 py_eff + 2 neg_py_eff + let detect_skip_cost = 8; // 2×is_zero(3) + product(1) + boolean_or(1) + let sanitize_cost = 3 + 6; // 3 constants + 6 selects + let ec_hint = if n_points > 1 { 2 } else { 0 }; // EcScalarMulHint + let scalar_relation = count_scalar_relation_witnesses(native_field_bits, scalar_bits); + + let per_point_fixed = on_curve + + glv_hint + + scalar_bit_decomp + + y_negate + + detect_skip_cost + + sanitize_cost + + scalar_relation + + ec_hint; + + // === EC loop + skew + constants === + let inline_total = if n_points == 1 { + // Single-point: separate loop (unchanged path) + let ec_wit = half_bits * 12; + let skew_correction = 10; + let offset_const = 2; + let identity_const = 2; + per_point_fixed + ec_wit + skew_correction + offset_const + identity_const + } else { + // Multi-point: merged loop with shared doubling + // Per bit: 4W (shared double) + n_points × 8W (2×(1W select + 3W add)) + let ec_wit = half_bits * (4 + 8 * n_points); + // Skew correction: 10W per point + let skew_correction = n_points * 10; + // Offset and identity constants are shared (not per-point) + let offset_const = 2; + let identity_const = 2; + n_points * per_point_fixed + ec_wit + skew_correction + offset_const + identity_const + }; + + // === Point accumulation (multi-point only) === let accum = if n_points > 1 { - let accum_adds = n_points - 1; - accum_adds - * (witnesses_per_op(num_limbs, FieldOpType::Add, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Sub, is_native) * 2 - + witnesses_per_op(num_limbs, FieldOpType::Mul, is_native) * 3 - + witnesses_per_op(num_limbs, FieldOpType::Inv, is_native)) + // Initial accumulator: 2W (constant witnesses for offset x,y) + let acc_init = 2; + // Per point: point_add_verified_native (3W) + 2 skip selects (2W) + let per_point_accum = n_points * (3 + 2); + // all_skipped tracking: (n_points - 1) product witnesses + let all_skipped = n_points - 1; + // Offset subtraction: 3 constants + 2 selects + point_add (3W) + 2 mask selects + let offset_sub = 3 + 2 + 3 + 2; + + acc_init + per_point_accum + all_skipped + offset_sub } else { 0 }; - n_points * per_point + accum + // === Range check cost === + // Native EC ops produce NO range checks (no multi-limb arithmetic). + // Only scalar relation produces range checks. + let mut rc_map: BTreeMap = BTreeMap::new(); + let sr_checks = count_scalar_relation_range_checks(native_field_bits, scalar_bits); + for (bits, count) in &sr_checks { + *rc_map.entry(*bits).or_default() += n_points * count; + } + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + inline_total + accum + range_check_cost +} + +/// Picks the widest limb size for scalar-relation multi-limb arithmetic that +/// fits inside the native field without overflow. +/// +/// Searches for the minimum number of limbs N (starting from 1) such that +/// the schoolbook column equations don't overflow the native field. Fewer +/// limbs means wider limbs, which means fewer witnesses and range checks. +/// +/// For BN254 (254-bit native field, ~254-bit order): N=3 @ 85-bit limbs. +/// For small curves where half_scalar × full_scalar fits natively: N=1. +pub(super) fn scalar_relation_limb_bits(native_field_bits: u32, order_bits: usize) -> u32 { + let half_bits = (order_bits + 1) / 2; + + // N=1 is valid only if the mul product (half_scalar * full_scalar) + // fits in the native field without wrapping. + if half_bits + order_bits < native_field_bits as usize { + return order_bits as u32; + } + + // For N>=2: find minimum N where schoolbook column equations fit. + for n in 2..=super::MAX_LIMBS { + let lb = ((order_bits + n - 1) / n) as u32; + if column_equation_fits_native_field(native_field_bits, lb, n) { + return lb; + } + } + + panic!("native field too small for scalar relation verification"); } /// Check whether schoolbook column equation values fit in the native field. @@ -267,11 +574,9 @@ pub fn column_equation_fits_native_field( num_limbs: usize, ) -> bool { if num_limbs <= 1 { - return true; // Single-limb path has no column equations. + return true; } let ceil_log2_n = (num_limbs as f64).log2().ceil() as u32; - // Max column value < 2^(2*limb_bits + ceil_log2_n + 3). - // Need this < p_native >= 2^(native_field_bits - 1). 2 * limb_bits + ceil_log2_n + 3 < native_field_bits } @@ -292,8 +597,6 @@ pub fn get_optimal_msm_params( is_native: bool, ) -> (u32, usize) { if is_native { - // For native field, limb_bits doesn't matter (no multi-limb decomposition). - // Just optimize window_size. let mut best_cost = usize::MAX; let mut best_window = 4; for ws in 2..=8 { @@ -314,16 +617,11 @@ pub fn get_optimal_msm_params( return (native_field_bits, best_window); } - // Upper bound on search: even with N=2 (best case), we need - // 2*lb + ceil(log2(2)) + 3 < native_field_bits => lb < (native_field_bits - 4) - // / 2. The per-candidate soundness check below is the actual gate. let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; let mut best_cost = usize::MAX; let mut best_limb_bits = max_limb_bits.min(86); let mut best_window = 4; - // Search space: test every limb_bits value (not step_by(2)) to avoid - // missing optimal values at num_limbs transition boundaries. for lb in 8..=max_limb_bits { let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { @@ -356,7 +654,6 @@ mod tests { #[test] fn test_optimal_params_bn254_native() { - // Grumpkin over BN254: native field let (limb_bits, window_size) = get_optimal_msm_params(254, 254, 1, 256, true); assert_eq!(limb_bits, 254); assert!(window_size >= 2 && window_size <= 8); @@ -364,7 +661,6 @@ mod tests { #[test] fn test_optimal_params_secp256r1() { - // secp256r1 over BN254: 256-bit modulus, non-native let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256, false); let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; assert!( @@ -376,7 +672,6 @@ mod tests { #[test] fn test_optimal_params_goldilocks() { - // Hypothetical 64-bit field over BN254 let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64, false); let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; assert!( @@ -388,19 +683,13 @@ mod tests { #[test] fn test_column_equation_soundness_boundary() { - // For BN254 (254 bits) with N=3: max safe limb_bits is 124. - // 2*124 + ceil(log2(3)) + 3 = 248 + 2 + 3 = 253 < 254 ✓ assert!(column_equation_fits_native_field(254, 124, 3)); - // 2*125 + ceil(log2(3)) + 3 = 250 + 2 + 3 = 255 ≥ 254 ✗ assert!(!column_equation_fits_native_field(254, 125, 3)); - // 2*126 + ceil(log2(3)) + 3 = 252 + 2 + 3 = 257 ≥ 254 ✗ assert!(!column_equation_fits_native_field(254, 126, 3)); } #[test] fn test_secp256r1_limb_bits_not_126() { - // Regression: limb_bits=126 with N=3 causes offset_w = 2^255 > p_BN254, - // making the schoolbook column equations unsound. let (limb_bits, _) = get_optimal_msm_params(254, 256, 1, 256, false); assert!( limb_bits <= 124, @@ -410,31 +699,94 @@ mod tests { #[test] fn test_scalar_relation_witnesses_grumpkin() { - // Grumpkin: scalar_bits=256, sr_limb_bits=64, sr_n=4 let sr = count_scalar_relation_witnesses(254, 256); - // Should be ~145 (not the old hardcoded 150) - assert!(sr > 100 && sr < 200, "unexpected scalar_relation={sr}"); + assert!(sr > 50 && sr < 200, "unexpected scalar_relation={sr}"); } #[test] fn test_scalar_relation_witnesses_small_curve() { - // 64-bit curve: scalar_bits=64, should be much smaller than 150 let sr = count_scalar_relation_witnesses(254, 64); - assert!( - sr < 100, - "64-bit curve scalar_relation={sr} should be < 100" - ); + assert!(sr < 100, "64-bit curve scalar_relation={sr} should be < 100"); } #[test] fn test_is_zero_cost_independent_of_num_limbs() { - // Verify that is_zero doesn't scale with num_limbs in the cost model. - // For the same window parameters, changing num_limbs should only affect - // field ops, not is_zero cost. - let (_, _, _, _, n_is_zero_w4) = count_glv_field_ops(128, 4); - let (_, _, _, _, n_is_zero_w3) = count_glv_field_ops(128, 3); - // is_zero count depends on num_windows, not num_limbs + let (_, _, _, _, n_is_zero_w4, _, _) = count_glv_field_ops(128, 4); + let (_, _, _, _, n_is_zero_w3, _, _) = count_glv_field_ops(128, 3); assert!(n_is_zero_w4 > 0); assert!(n_is_zero_w3 > 0); } + + #[test] + fn test_inv_single_limb_witness_count() { + // inv_mod_p_single: a_inv(1) + mul_mod_p_single(5) = 6 + assert_eq!(witnesses_per_op(1, FieldOpType::Inv, false), 6); + } + + #[test] + fn test_selects_counted_separately() { + // Verify selects are returned as separate counts, not mixed into + // field ops. + let (_, _, _, _, _, pt_sel, coord_sel) = count_glv_field_ops(128, 4); + assert!(pt_sel > 0, "expected point selects > 0"); + assert_eq!(coord_sel, 2, "expected 2 coord selects (y-negation)"); + } + + #[test] + fn test_select_cost_scales_with_num_limbs() { + // For N=3, select cost should be 2*N per point select (1 witness + // per limb per coordinate, inlined select_witness). + let half_bits = 129; + let (_, _, _, _, _, n_pt_sel, n_coord_sel) = count_glv_field_ops(half_bits, 4); + let select_cost_n1 = n_pt_sel * 2 * 1 + n_coord_sel * 1; + let select_cost_n3 = n_pt_sel * 2 * 3 + n_coord_sel * 3; + // N=3 should be exactly 3× N=1 for selects (linear in num_limbs) + assert_eq!(select_cost_n3, select_cost_n1 * 3); + } + + #[test] + fn test_range_checks_per_op_native() { + assert!(range_checks_per_op(1, FieldOpType::Add, true, 254, 254).is_empty()); + assert!(range_checks_per_op(1, FieldOpType::Mul, true, 254, 254).is_empty()); + assert!(range_checks_per_op(1, FieldOpType::Inv, true, 254, 254).is_empty()); + } + + #[test] + fn test_range_checks_per_op_single_limb() { + let rc = range_checks_per_op(1, FieldOpType::Add, false, 64, 64); + assert_eq!(rc, vec![(64, 1)]); + let rc = range_checks_per_op(1, FieldOpType::Inv, false, 64, 64); + assert_eq!(rc, vec![(64, 2)]); + } + + #[test] + fn test_range_checks_per_op_multi_limb() { + // N=3, limb_bits=86: carry_bits = 86 + ceil(log2(3)) + 2 = 90 + let rc = range_checks_per_op(3, FieldOpType::Add, false, 86, 256); + assert_eq!(rc, vec![(86, 6)]); + let rc = range_checks_per_op(3, FieldOpType::Mul, false, 86, 256); + assert_eq!(rc, vec![(86, 9), (90, 4)]); + let rc = range_checks_per_op(3, FieldOpType::Inv, false, 86, 256); + assert_eq!(rc, vec![(86, 12), (90, 4)]); + } + + #[test] + fn test_scalar_relation_range_checks_256bit() { + let rc = count_scalar_relation_range_checks(254, 256); + let total: usize = rc.values().sum(); + assert!(total > 30, "too few range checks: {total}"); + assert!(total < 200, "too many range checks: {total}"); + } + + #[test] + fn test_estimate_range_check_cost_basic() { + use crate::range_check::estimate_range_check_cost; + + assert_eq!(estimate_range_check_cost(&BTreeMap::new()), 0); + + let mut checks = BTreeMap::new(); + checks.insert(8u32, 100usize); + let cost = estimate_range_check_cost(&checks); + assert!(cost > 0, "expected nonzero cost for 100 8-bit checks"); + } } diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index d4b3bff34..ba70b8574 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -494,7 +494,7 @@ mod tests { let y = curve_native_point_fe(&y4); let b = curve_native_point_fe(&c.curve_b); // Should still be on curve - assert_eq!(y * y, x * x * x + b, "[4]G not on Grumpkin"); + assert_eq!(y * y, x * x * x + b, "[2]*offset not on Grumpkin"); } #[test] diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 5e096df59..b5348a4c8 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -1,4 +1,8 @@ -use super::FieldOps; +use { + super::{select_witness, FieldOps}, + crate::noir_to_r1cs::NoirToR1CSCompiler, + provekit_common::{witness::WitnessBuilder, FieldElement}, +}; /// Generic point doubling on y^2 = x^3 + ax + b. /// @@ -85,21 +89,6 @@ pub fn point_add( (x3, y3) } -/// Conditional point select: returns `on_true` if `flag` is 1, `on_false` if -/// `flag` is 0. -/// -/// Constrains `flag` to be boolean (`flag * flag = flag`). -pub fn point_select( - ops: &mut F, - flag: usize, - on_false: (F::Elem, F::Elem), - on_true: (F::Elem, F::Elem), -) -> (F::Elem, F::Elem) { - let x = ops.select(flag, on_false.0, on_true.0); - let y = ops.select(flag, on_false.1, on_true.1); - (x, y) -} - /// Conditional point select without boolean constraint on `flag`. /// Caller must ensure `flag` is already constrained boolean. pub fn point_select_unchecked( @@ -268,3 +257,161 @@ pub fn scalar_mul_glv( acc } + +// =========================================================================== +// Native-field hint-verified EC operations +// =========================================================================== +// These operate on single native field element witnesses (no multi-limb). +// Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw +// R1CS constraints, eliminating expensive field inversions from the circuit. + +use super::curve::CurveParams; +use ark_ff::{Field, PrimeField}; + +/// Hint-verified point doubling for native field. +/// +/// Allocates EcDoubleHint → (lambda, x3, y3) = 3W. +/// Verification constraints (4C): +/// 1. x_sq = px * px (1C via add_product) +/// 2. lambda * 2*py = 3*x_sq + a (1C raw) +/// 3. lambda * lambda = x3 + 2*px (1C raw) +/// 4. lambda * (px - x3) = y3 + py (1C raw) +/// +/// Total: 4W + 4C (1W for x_sq via add_product, 3W from hint). +pub fn point_double_verified_native( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { + output_start: hint_start, + px, + py, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // x_sq = px * px (1W + 1C) + let x_sq = compiler.add_product(px, px); + + // Constraint: lambda * (2 * py) = 3 * x_sq + a + // A = [lambda], B = [2*py], C = [3*x_sq + a_const] + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let three = FieldElement::from(3u64); + let two = FieldElement::from(2u64); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(two, py)], + &[(three, x_sq), (a_fe, compiler.witness_one())], + ); + + // Constraint: lambda^2 = x3 + 2*px + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x3), (two, px)], + ); + + // Constraint: lambda * (px - x3) = y3 + py + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, px), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, py)], + ); + + (x3, y3) +} + +/// Hint-verified point addition for native field. +/// +/// Allocates EcAddHint → (lambda, x3, y3) = 3W. +/// Verification constraints (3C): +/// 1. lambda * (x2 - x1) = y2 - y1 (1C raw) +/// 2. lambda^2 = x3 + x1 + x2 (1C raw) +/// 3. lambda * (x1 - x3) = y3 + y1 (1C raw) +/// +/// Total: 3W + 3C. +pub fn point_add_verified_native( + compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcAddHint { + output_start: hint_start, + x1, + y1, + x2, + y2, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // Constraint: lambda * (x2 - x1) = y2 - y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x2), (-FieldElement::ONE, x1)], + &[(FieldElement::ONE, y2), (-FieldElement::ONE, y1)], + ); + + // Constraint: lambda^2 = x3 + x1 + x2 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x3), (FieldElement::ONE, x1), (FieldElement::ONE, x2)], + ); + + // Constraint: lambda * (x1 - x3) = y3 + y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x1), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, y1)], + ); + + (x3, y3) +} + +/// On-curve check for native field: y^2 = x^3 + a*x + b. +/// +/// Constraints (3C, 2W): +/// 1. x_sq = x * x (1C via add_product) +/// 2. x_cu = x_sq * x (1C via add_product) +/// 3. y * y = x_cu + a*x + b (1C raw) +/// +/// Total: 2W + 3C. +pub fn verify_on_curve_native( + compiler: &mut NoirToR1CSCompiler, + x: usize, + y: usize, + curve: &CurveParams, +) { + let x_sq = compiler.add_product(x, x); + let x_cu = compiler.add_product(x_sq, x); + + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let b_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_b)).unwrap(); + + // y * y = x_cu + a*x + b + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, y)], + &[(FieldElement::ONE, y)], + &[ + (FieldElement::ONE, x_cu), + (a_fe, x), + (b_fe, compiler.witness_one()), + ], + ); +} + diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 1f6aa782e..a953d9774 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -3,16 +3,14 @@ pub mod curve; pub mod ec_points; pub mod multi_limb_arith; pub mod multi_limb_ops; +mod native; +mod non_native; +mod scalar_relation; use { - crate::{ - digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, - msm::multi_limb_arith::compute_is_zero, - noir_to_r1cs::NoirToR1CSCompiler, - }, + crate::{msm::multi_limb_arith::compute_is_zero, noir_to_r1cs::NoirToR1CSCompiler}, ark_ff::{AdditiveGroup, Field, PrimeField}, - curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, - multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + curve::CurveParams, provekit_common::{ witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, FieldElement, @@ -174,6 +172,9 @@ pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) /// Single-witness conditional select: `out = on_false + flag * (on_true - /// on_false)`. +/// +/// Uses a single witness + single R1CS constraint: +/// flag * (on_true - on_false) = result - on_false pub(crate) fn select_witness( compiler: &mut NoirToR1CSCompiler, flag: usize, @@ -181,17 +182,23 @@ pub(crate) fn select_witness( on_true: usize, ) -> usize { // When both branches are the same witness, result is trivially that witness. - // Avoids duplicate column indices in R1CS from `on_true - on_false` when - // both share the same witness index. if on_false == on_true { return on_false; } - let diff = compiler.add_sum(vec![ - SumTerm(None, on_true), - SumTerm(Some(-FieldElement::ONE), on_false), - ]); - let flag_diff = compiler.add_product(flag, diff); - compiler.add_sum(vec![SumTerm(None, on_false), SumTerm(None, flag_diff)]) + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SelectWitness { + output: result, + flag, + on_false, + on_true, + }); + // flag * (on_true - on_false) = result - on_false + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, on_true), (-FieldElement::ONE, on_false)], + &[(FieldElement::ONE, result), (-FieldElement::ONE, on_false)], + ); + result } /// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. @@ -206,21 +213,24 @@ pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize] /// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. /// Does NOT constrain a or b to be boolean — caller must ensure that. +/// +/// Uses a single witness + single R1CS constraint: +/// (1 - a) * (1 - b) = 1 - result fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { let one = compiler.witness_one(); - let one_minus_a = compiler.add_sum(vec![ - SumTerm(None, one), - SumTerm(Some(-FieldElement::ONE), a), - ]); - let one_minus_b = compiler.add_sum(vec![ - SumTerm(None, one), - SumTerm(Some(-FieldElement::ONE), b), - ]); - let product = compiler.add_product(one_minus_a, one_minus_b); - compiler.add_sum(vec![ - SumTerm(None, one), - SumTerm(Some(-FieldElement::ONE), product), - ]) + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::BooleanOr { + output: result, + a, + b, + }); + // (1 - a) * (1 - b) = 1 - result + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, one), (-FieldElement::ONE, a)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, result)], + ); + result } /// Detects whether a point-scalar pair is degenerate (scalar=0 or point at @@ -239,6 +249,39 @@ fn detect_skip( compute_boolean_or(compiler, s_is_zero, inf_flag) } +/// Sanitized point-scalar inputs after degenerate-case detection. +struct SanitizedInputs { + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + is_skip: usize, +} + +/// Detects degenerate cases (scalar=0 or point at infinity) and replaces +/// the point with the generator G and scalar with 1 when degenerate. +fn sanitize_point_scalar( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x: usize, + gen_y: usize, + zero: usize, + one: usize, +) -> SanitizedInputs { + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + SanitizedInputs { + px: select_witness(compiler, is_skip, px, gen_x), + py: select_witness(compiler, is_skip, py, gen_y), + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + /// Constrains `a * b = 0`. fn constrain_product_zero(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { compiler @@ -249,30 +292,44 @@ fn constrain_product_zero(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) )]); } -// --------------------------------------------------------------------------- -// Params builder (runtime num_limbs, no const generics) -// --------------------------------------------------------------------------- +/// Negate a y-coordinate and conditionally select based on a sign flag. +/// Returns `(y_eff, neg_y_eff)` where: +/// - if `neg_flag=0`: `y_eff = y`, `neg_y_eff = -y` +/// - if `neg_flag=1`: `y_eff = -y`, `neg_y_eff = y` +fn negate_y_signed_native( + compiler: &mut NoirToR1CSCompiler, + neg_flag: usize, + y: usize, +) -> (usize, usize) { + constrain_boolean(compiler, neg_flag); + let neg_y = compiler.add_sum(vec![SumTerm(Some(-FieldElement::ONE), y)]); + let y_eff = select_witness(compiler, neg_flag, y, neg_y); + let neg_y_eff = select_witness(compiler, neg_flag, neg_y, y); + (y_eff, neg_y_eff) +} -/// Build `MultiLimbParams` for a given runtime `num_limbs`. -fn build_params(num_limbs: usize, limb_bits: u32, curve: &CurveParams) -> MultiLimbParams { - let is_native = curve.is_native_field(); - let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); - let modulus_fe = if !is_native { - Some(curve.p_native_fe()) - } else { - None - }; - MultiLimbParams { - num_limbs, - limb_bits, - p_limbs: curve.p_limbs(limb_bits, num_limbs), - p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), - two_pow_w, - modulus_raw: curve.field_modulus_p, - curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), - is_native, - modulus_fe, - } +/// Emit an `EcScalarMulHint` and sanitize the result point. +/// When `is_skip=1`, the result is swapped to the generator point. +fn emit_ec_scalar_mul_hint_and_sanitize( + compiler: &mut NoirToR1CSCompiler, + san: &SanitizedInputs, + gen_x_witness: usize, + gen_y_witness: usize, + curve: &CurveParams, +) -> (usize, usize) { + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px: san.px, + py: san.py, + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let rx = select_witness(compiler, san.is_skip, hint_start, gen_x_witness); + let ry = select_witness(compiler, san.is_skip, hint_start + 1, gen_y_witness); + (rx, ry) } // --------------------------------------------------------------------------- @@ -388,67 +445,110 @@ fn add_single_msm( /// Process a full single-MSM with runtime `num_limbs`. /// -/// Uses FakeGLV for ALL points: each point P_i with scalar s_i is verified -/// using scalar decomposition and half-width interleaved scalar mul. -/// -/// For `n_points == 1`, R = (out_x, out_y) is the ACIR output. -/// For `n_points > 1`, R_i = EcScalarMulHint witnesses, accumulated via -/// point_add and constrained against the ACIR output. -fn process_single_msm<'a>( - mut compiler: &'a mut NoirToR1CSCompiler, +/// Dispatches to single-point or multi-point path based on the number of +/// input points. +fn process_single_msm( + compiler: &mut NoirToR1CSCompiler, point_wits: &[usize], scalar_wits: &[usize], outputs: (usize, usize, usize), num_limbs: usize, limb_bits: u32, window_size: usize, - mut range_checks: &'a mut BTreeMap>, + range_checks: &mut BTreeMap>, curve: &CurveParams, ) { let n_points = point_wits.len() / 3; - let (out_x, out_y, out_inf) = outputs; - if n_points == 1 { - // Single-point: R is the ACIR output directly - let px_witness = point_wits[0]; - let py_witness = point_wits[1]; - let inf_flag = point_wits[2]; - let s_lo = scalar_wits[0]; - let s_hi = scalar_wits[1]; - - // --- Detect degenerate case: is_skip = (scalar == 0) OR (point is infinity) - let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); - - // --- Sanitize inputs: swap in generator G and scalar=1 when is_skip --- - let one = compiler.witness_one(); - let gen_x_witness = - add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.0)); - let gen_y_witness = - add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.1)); + process_single_point_msm( + compiler, + point_wits, + scalar_wits, + outputs, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); + } else { + process_multi_point_msm( + compiler, + point_wits, + scalar_wits, + outputs, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} - let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); - let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); +/// Single-point MSM: R = [s]P with degenerate-case handling. +/// +/// The ACIR output (out_x, out_y) is the result directly. Sanitizes inputs +/// to handle scalar=0 and point-at-infinity, then verifies via FakeGLV. +fn process_single_point_msm<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + num_limbs: usize, + limb_bits: u32, + window_size: usize, + range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let (out_x, out_y, out_inf) = outputs; - // When is_skip=1, use scalar=(1, 0) so FakeGLV computes [1]*G = G - let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); - let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); - let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); + // Allocate constants + let one = compiler.witness_one(); + let gen_x_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.0)); + let gen_y_witness = + add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.1)); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Sanitize inputs: swap in generator G and scalar=1 when degenerate + let san = sanitize_point_scalar( + compiler, + point_wits[0], + point_wits[1], + scalar_wits[0], + scalar_wits[1], + point_wits[2], + gen_x_witness, + gen_y_witness, + zero_witness, + one, + ); - // Sanitize R (output point): when is_skip=1, R must be G (since [1]*G = G) - let sanitized_rx = select_witness(compiler, is_skip, out_x, gen_x_witness); - let sanitized_ry = select_witness(compiler, is_skip, out_y, gen_y_witness); + // Sanitize R (output point): when is_skip=1, R must be G (since [1]*G = G) + let sanitized_rx = select_witness(compiler, san.is_skip, out_x, gen_x_witness); + let sanitized_ry = select_witness(compiler, san.is_skip, out_y, gen_y_witness); - // Decompose sanitized P into limbs - let (px, py) = decompose_point_to_limbs( + if curve.is_native_field() { + // Native-field optimized path: hint-verified EC + wNAF + native::verify_point_fakeglv_native( compiler, - sanitized_px, - sanitized_py, - num_limbs, - limb_bits, range_checks, + san.px, + san.py, + sanitized_rx, + sanitized_ry, + san.s_lo, + san.s_hi, + curve, ); - // Decompose sanitized R into limbs - let (rx, ry) = decompose_point_to_limbs( + } else { + // Generic multi-limb path + let (px, py) = non_native::decompose_point_to_limbs( + compiler, san.px, san.py, num_limbs, limb_bits, range_checks, + ); + let (rx, ry) = non_native::decompose_point_to_limbs( compiler, sanitized_rx, sanitized_ry, @@ -456,420 +556,89 @@ fn process_single_msm<'a>( limb_bits, range_checks, ); - - // Run FakeGLV on sanitized values (always satisfiable) - (compiler, range_checks) = verify_point_fakeglv( + (compiler, _) = non_native::verify_point_fakeglv( compiler, range_checks, px, py, rx, ry, - sanitized_s_lo, - sanitized_s_hi, + san.s_lo, + san.s_hi, num_limbs, limb_bits, window_size, curve, ); - - // --- Mask output: when is_skip, output must be (0, 0, 1) --- - constrain_equal(compiler, out_inf, is_skip); - constrain_product_zero(compiler, is_skip, out_x); - constrain_product_zero(compiler, is_skip, out_y); - } else { - // Multi-point: compute R_i = [s_i]P_i via hints, verify each with FakeGLV, - // then accumulate R_i's with offset-based accumulation and skip handling. - let one = compiler.witness_one(); - - // Generator constants for sanitization - let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); - let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); - let gen_x_witness = add_constant_witness(compiler, gen_x_fe); - let gen_y_witness = add_constant_witness(compiler, gen_y_fe); - let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); - - // Build params once for all multi-limb ops in the multi-point path - let params = build_params(num_limbs, limb_bits, curve); - - // Offset point as limbs for accumulation - let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); - let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); - - // Start accumulator at offset_point - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - let mut acc_x = ops.constant_limbs(&offset_x_values); - let mut acc_y = ops.constant_limbs(&offset_y_values); - compiler = ops.compiler; - range_checks = ops.range_checks; - - // Track all_skipped = product of all is_skip flags - let mut all_skipped: Option = None; - - for i in 0..n_points { - let px_witness = point_wits[3 * i]; - let py_witness = point_wits[3 * i + 1]; - let inf_flag = point_wits[3 * i + 2]; - let s_lo = scalar_wits[2 * i]; - let s_hi = scalar_wits[2 * i + 1]; - - // --- Detect degenerate case --- - let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); - - // Track all_skipped - all_skipped = Some(match all_skipped { - None => is_skip, - Some(prev) => compiler.add_product(prev, is_skip), - }); - - // --- Sanitize inputs --- - let sanitized_px = select_witness(compiler, is_skip, px_witness, gen_x_witness); - let sanitized_py = select_witness(compiler, is_skip, py_witness, gen_y_witness); - let sanitized_s_lo = select_witness(compiler, is_skip, s_lo, one); - let sanitized_s_hi = select_witness(compiler, is_skip, s_hi, zero_witness); - - // EcScalarMulHint uses sanitized inputs - let hint_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { - output_start: hint_start, - px: sanitized_px, - py: sanitized_py, - s_lo: sanitized_s_lo, - s_hi: sanitized_s_hi, - curve_a: curve.curve_a, - field_modulus_p: curve.field_modulus_p, - }); - let rx_witness = hint_start; - let ry_witness = hint_start + 1; - - // When is_skip=1, R should be G (since [1]*G = G) - let sanitized_rx = select_witness(compiler, is_skip, rx_witness, gen_x_witness); - let sanitized_ry = select_witness(compiler, is_skip, ry_witness, gen_y_witness); - - // Decompose sanitized P_i into limbs - let (px, py) = decompose_point_to_limbs( - compiler, - sanitized_px, - sanitized_py, - num_limbs, - limb_bits, - range_checks, - ); - // Decompose sanitized R_i into limbs - let (rx, ry) = decompose_point_to_limbs( - compiler, - sanitized_rx, - sanitized_ry, - num_limbs, - limb_bits, - range_checks, - ); - - // Verify R_i = [s_i]P_i using FakeGLV (on sanitized values) - (compiler, range_checks) = verify_point_fakeglv( - compiler, - range_checks, - px, - py, - rx, - ry, - sanitized_s_lo, - sanitized_s_hi, - num_limbs, - limb_bits, - window_size, - curve, - ); - - // --- Offset-based accumulation with conditional select --- - // Compute candidate = point_add(acc, R_i) - // Then select: if is_skip, keep acc unchanged; else use candidate - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); - let (new_acc_x, new_acc_y) = ec_points::point_select_unchecked( - &mut ops, - is_skip, - (cand_x, cand_y), - (acc_x, acc_y), - ); - acc_x = new_acc_x; - acc_y = new_acc_y; - compiler = ops.compiler; - range_checks = ops.range_checks; - } - - let all_skipped = all_skipped.expect("MSM must have at least one point"); - - // Subtract offset: result = point_add(acc, -offset) - // Negated offset = (offset_x, -offset_y) - let neg_offset_y_raw = - curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); - let neg_offset_y_values = - curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); - - // When all_skipped, acc == offset_point, so subtracting offset would be - // point_add(O, -O) which fails (x1 == x2). Use generator G as the - // subtraction target instead; the result won't matter since we'll mask it. - let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); - let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); - let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); - - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - - // Select subtraction point: if all_skipped, use -G; else use -offset - let sub_x = { - let off_x = ops.constant_limbs(&offset_x_values); - let g_x = ops.constant_limbs(&gen_x_limb_values); - ops.select(all_skipped, off_x, g_x) - }; - let sub_y = { - let neg_off_y = ops.constant_limbs(&neg_offset_y_values); - let neg_g_y = ops.constant_limbs(&neg_gen_y_values); - ops.select(all_skipped, neg_off_y, neg_g_y) - }; - - let (result_x, result_y) = ec_points::point_add(&mut ops, acc_x, acc_y, sub_x, sub_y); - compiler = ops.compiler; - range_checks = ops.range_checks; - - // --- Constrain output --- - // When all_skipped: output is (0, 0, 1) - // Otherwise: output matches the computed result with inf=0 - if num_limbs == 1 { - // Mask result with all_skipped: when all_skipped=1, out must be 0 - let masked_result_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); - let masked_result_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); - constrain_equal(compiler, out_x, masked_result_x); - constrain_equal(compiler, out_y, masked_result_y); - } else { - let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); - let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); - let masked_result_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); - let masked_result_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); - constrain_equal(compiler, out_x, masked_result_x); - constrain_equal(compiler, out_y, masked_result_y); - } - constrain_equal(compiler, out_inf, all_skipped); } -} -/// Decompose a point (px_witness, py_witness) into Limbs. -fn decompose_point_to_limbs( - compiler: &mut NoirToR1CSCompiler, - px_witness: usize, - py_witness: usize, - num_limbs: usize, - limb_bits: u32, - range_checks: &mut BTreeMap>, -) -> (Limbs, Limbs) { - if num_limbs == 1 { - (Limbs::single(px_witness), Limbs::single(py_witness)) - } else { - let px_limbs = - decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); - let py_limbs = - decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); - (px_limbs, py_limbs) - } + // Mask output: when is_skip, output must be (0, 0, 1) + constrain_equal(compiler, out_inf, san.is_skip); + constrain_product_zero(compiler, san.is_skip, out_x); + constrain_product_zero(compiler, san.is_skip, out_y); } -/// FakeGLV verification for a single point: verifies R = \[s\]P. -/// -/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies -/// \[s1\]P + \[s2\]R = O using interleaved windowed scalar mul with -/// half-width scalars. +/// Multi-point MSM: computes R_i = [s_i]P_i via hints, verifies each with +/// FakeGLV, then accumulates R_i's with offset-based accumulation and skip +/// handling. /// -/// Returns the mutable references back to the caller for continued use. -fn verify_point_fakeglv<'a>( - mut compiler: &'a mut NoirToR1CSCompiler, - mut range_checks: &'a mut BTreeMap>, - px: Limbs, - py: Limbs, - rx: Limbs, - ry: Limbs, - s_lo: usize, - s_hi: usize, +/// When `curve.is_native_field()`, uses a merged-loop optimization: all +/// points share a single doubling per bit, saving 4*(n-1) constraints per +/// bit of the half-scalar (≈512 for 2 points on Grumpkin). +fn process_multi_point_msm( + compiler: &mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + n_points: usize, num_limbs: usize, limb_bits: u32, window_size: usize, + range_checks: &mut BTreeMap>, curve: &CurveParams, -) -> ( - &'a mut NoirToR1CSCompiler, - &'a mut BTreeMap>, ) { - // --- Steps 1-4: On-curve checks, FakeGLV decomposition, and GLV scalar mul - // --- - let s1_witness; - let s2_witness; - let neg1_witness; - let neg2_witness; - { - let params = build_params(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { + if curve.is_native_field() { + native::process_multi_point_native( compiler, + point_wits, + scalar_wits, + outputs, + n_points, range_checks, - params: ¶ms, - }; - - // Step 1: On-curve checks for P and R - let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); - verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); - verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); - - // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 - let glv_start = ops.compiler.num_witnesses(); - ops.compiler - .add_witness_builder(WitnessBuilder::FakeGLVHint { - output_start: glv_start, - s_lo, - s_hi, - curve_order: curve.curve_order_n, - }); - s1_witness = glv_start; - s2_witness = glv_start + 1; - neg1_witness = glv_start + 2; - neg2_witness = glv_start + 3; - - // Step 3: Decompose |s1|, |s2| into half_bits bits each - let half_bits = curve.glv_half_bits() as usize; - let s1_bits = decompose_half_scalar_bits(ops.compiler, s1_witness, half_bits); - let s2_bits = decompose_half_scalar_bits(ops.compiler, s2_witness, half_bits); - - // Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity - // check - - // Compute negated y-coordinates: neg_y = 0 - y (mod p) - let neg_py = ops.negate(py); - let neg_ry = ops.negate(ry); - - // Select: if neg1=1, use neg_py; else use py - // neg1 and neg2 are constrained to be boolean by ops.select internally. - let py_effective = ops.select(neg1_witness, py, neg_py); - // Select: if neg2=1, use neg_ry; else use ry - let ry_effective = ops.select(neg2_witness, ry, neg_ry); - - // GLV scalar mul - let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); - let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); - let offset_x = ops.constant_limbs(&offset_x_values); - let offset_y = ops.constant_limbs(&offset_y_values); - - let glv_acc = ec_points::scalar_mul_glv( - &mut ops, - px, - py_effective, - &s1_bits, - rx, - ry_effective, - &s2_bits, - window_size, - offset_x, - offset_y, + curve, ); - - // Identity check: acc should equal [2^(num_windows * window_size)] * - // offset_point - let glv_num_windows = (half_bits + window_size - 1) / window_size; - let glv_n_doublings = glv_num_windows * window_size; - let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); - - let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); - let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); - let expected_x = ops.constant_limbs(&acc_off_x_values); - let expected_y = ops.constant_limbs(&acc_off_y_values); - - for i in 0..num_limbs { - constrain_equal(ops.compiler, glv_acc.0[i], expected_x[i]); - constrain_equal(ops.compiler, glv_acc.1[i], expected_y[i]); - } - - compiler = ops.compiler; - range_checks = ops.range_checks; + return; } - // --- Step 5: Scalar relation verification --- - verify_scalar_relation( + non_native::process_multi_point_non_native( compiler, + point_wits, + scalar_wits, + outputs, + n_points, + num_limbs, + limb_bits, + window_size, range_checks, - s_lo, - s_hi, - s1_witness, - s2_witness, - neg1_witness, - neg2_witness, curve, ); - - (compiler, range_checks) } -/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. -fn verify_on_curve( - ops: &mut MultiLimbOps, - x: Limbs, - y: Limbs, - b_limb_values: &[FieldElement], - num_limbs: usize, -) { - let y_sq = ops.mul(y, y); - let x_sq = ops.mul(x, x); - let x_cubed = ops.mul(x_sq, x); - let a = ops.curve_a(); - let ax = ops.mul(a, x); - let x3_plus_ax = ops.add(x_cubed, ax); - let b = ops.constant_limbs(b_limb_values); - let rhs = ops.add(x3_plus_ax, b); - for i in 0..num_limbs { - constrain_equal(ops.compiler, y_sq[i], rhs[i]); - } -} - -/// Decompose a single witness into `num_limbs` limbs using digital -/// decomposition. -fn decompose_witness_to_limbs( +/// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. +fn emit_fakeglv_hint( compiler: &mut NoirToR1CSCompiler, - witness: usize, - limb_bits: u32, - num_limbs: usize, - range_checks: &mut BTreeMap>, -) -> Limbs { - let log_bases = vec![limb_bits as usize; num_limbs]; - let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); - let mut limbs = Limbs::new(num_limbs); - for i in 0..num_limbs { - limbs[i] = dd.get_digit_witness_index(i, 0); - // Range-check each decomposed limb to [0, 2^limb_bits). - // add_digital_decomposition constrains the recomposition but does - // NOT range-check individual digits. - range_checks.entry(limb_bits).or_default().push(limbs[i]); - } - limbs -} - -/// Recompose limbs back into a single witness: val = Σ limb\[i\] * -/// 2^(i*limb_bits) -fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { - let terms: Vec = limbs - .iter() - .enumerate() - .map(|(i, &limb)| { - let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); - SumTerm(Some(coeff), limb) - }) - .collect(); - compiler.add_sum(terms) + s_lo: usize, + s_hi: usize, + curve: &CurveParams, +) -> (usize, usize, usize, usize) { + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) } /// Resolves a `ConstantOrR1CSWitness` to a witness index. @@ -884,218 +653,32 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes } } -/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). -fn decompose_half_scalar_bits( - compiler: &mut NoirToR1CSCompiler, - scalar: usize, - half_bits: usize, -) -> Vec { - let log_bases = vec![1usize; half_bits]; - let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); - let mut bits = Vec::with_capacity(half_bits); - for bit_idx in 0..half_bits { - bits.push(dd.get_digit_witness_index(bit_idx, 0)); - } - bits -} - -/// Builds `MultiLimbParams` for scalar relation verification (mod -/// curve_order_n). -fn build_scalar_relation_params( - num_limbs: usize, - limb_bits: u32, - curve: &CurveParams, -) -> MultiLimbParams { - // Scalar relation uses curve_order_n as the modulus. - // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, - // except for Grumpkin where they're very close but still different). - let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); - let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); - let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); - - // For N=1 non-native, we need the modulus as a FieldElement - let modulus_fe = if num_limbs == 1 { - Some(curve::curve_native_point_fe(&curve.curve_order_n)) - } else { - None - }; - - MultiLimbParams { - num_limbs, - limb_bits, - p_limbs: n_limbs, - p_minus_1_limbs: n_minus_1_limbs, - two_pow_w, - modulus_raw: curve.curve_order_n, - curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused - is_native: false, // always non-native - modulus_fe, - } -} - -/// Picks the largest limb size for the scalar-relation multi-limb arithmetic -/// that fits inside the native field without overflow. -/// -/// The schoolbook multiplication column equations require: -/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` -/// -/// We start at 64 bits (the ideal case — inputs are 128-bit half-scalars) and -/// search downward until the soundness check passes. For BN254 (254-bit native -/// field) this resolves to 64; smaller fields like M31 (31 bits) will get a -/// proportionally smaller limb size. -/// -/// Panics if the native field is too small (< ~12 bits) to support any valid -/// limb decomposition. -fn scalar_relation_limb_bits(order_bits: usize) -> u32 { - let native_bits = FieldElement::MODULUS_BIT_SIZE; - let mut limb_bits: u32 = 64.min((native_bits.saturating_sub(4)) / 2); - loop { - let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; - if cost_model::column_equation_fits_native_field(native_bits, limb_bits, num_limbs) { - break; - } - limb_bits -= 1; - assert!( - limb_bits >= 4, - "native field too small for scalar relation verification" - ); - } - limb_bits -} - -/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 -/// (mod n). -/// -/// Uses multi-limb arithmetic with curve_order_n as the modulus. -/// The sub-scalars s1, s2 have `half_bits = ceil(order_bits/2)` bits; -/// the full scalar s has up to `order_bits` bits. -fn verify_scalar_relation( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - s_lo: usize, - s_hi: usize, - s1_witness: usize, - s2_witness: usize, - neg1_witness: usize, - neg2_witness: usize, - curve: &CurveParams, -) { - let order_bits = curve.curve_order_bits() as usize; - let sr_limb_bits = scalar_relation_limb_bits(order_bits); - let sr_num_limbs = (order_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - let half_bits = curve.glv_half_bits() as usize; - // Number of limbs the half-scalar occupies - let half_limbs = (half_bits + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - - let params = build_scalar_relation_params(sr_num_limbs, sr_limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - - // Decompose s into sr_num_limbs limbs from (s_lo, s_hi). - // s_lo contains bits [0..128), s_hi contains bits [128..256). - let s_limbs = { - let limbs_per_half = (128 + sr_limb_bits as usize - 1) / sr_limb_bits as usize; - let dd_bases_128: Vec = (0..limbs_per_half) - .map(|i| { - let remaining = 128u32 - (i as u32 * sr_limb_bits); - remaining.min(sr_limb_bits) as usize - }) - .collect(); - let dd_lo = add_digital_decomposition(ops.compiler, dd_bases_128.clone(), vec![s_lo]); - let dd_hi = add_digital_decomposition(ops.compiler, dd_bases_128, vec![s_hi]); - let mut limbs = Limbs::new(sr_num_limbs); - let lo_n = limbs_per_half.min(sr_num_limbs); - for i in 0..lo_n { - limbs[i] = dd_lo.get_digit_witness_index(i, 0); - let remaining = 128u32 - (i as u32 * sr_limb_bits); - ops.range_checks - .entry(remaining.min(sr_limb_bits)) - .or_default() - .push(limbs[i]); - } - let hi_n = sr_num_limbs - lo_n; - for i in 0..hi_n { - limbs[lo_n + i] = dd_hi.get_digit_witness_index(i, 0); - let remaining = 128u32 - (i as u32 * sr_limb_bits); - ops.range_checks - .entry(remaining.min(sr_limb_bits)) - .or_default() - .push(limbs[lo_n + i]); - } - limbs - }; - - // Helper: decompose a half-scalar witness into sr_num_limbs limbs. - // The half-scalar has `half_bits` bits → occupies `half_limbs` limbs. - // Upper limbs (half_limbs..sr_num_limbs) are zero-padded. - let decompose_half_scalar = |ops: &mut MultiLimbOps, witness: usize| -> Limbs { - let dd_bases: Vec = (0..half_limbs) - .map(|i| { - let remaining = half_bits as u32 - (i as u32 * sr_limb_bits); - remaining.min(sr_limb_bits) as usize - }) - .collect(); - let dd = add_digital_decomposition(ops.compiler, dd_bases, vec![witness]); - let mut limbs = Limbs::new(sr_num_limbs); - for i in 0..half_limbs { - limbs[i] = dd.get_digit_witness_index(i, 0); - let remaining_bits = (half_bits as u32) - (i as u32 * sr_limb_bits); - let this_limb_bits = remaining_bits.min(sr_limb_bits); - ops.range_checks - .entry(this_limb_bits) - .or_default() - .push(limbs[i]); - } - // Zero-pad upper limbs - for i in half_limbs..sr_num_limbs { - let w = ops.compiler.num_witnesses(); - ops.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( - w, - FieldElement::from(0u64), - ))); - limbs[i] = w; - constrain_zero(ops.compiler, limbs[i]); - } - limbs - }; - - let s1_limbs = decompose_half_scalar(&mut ops, s1_witness); - let s2_limbs = decompose_half_scalar(&mut ops, s2_witness); - - // Compute product = s2 * s (mod n) - let product = ops.mul(s2_limbs, s_limbs); - - // Handle signs: compute effective values - // If neg2 is set: neg_product = n - product (mod n), i.e. 0 - product - let neg_product = ops.negate(product); - // neg2 already constrained boolean in verify_point_fakeglv - let effective_product = ops.select_unchecked(neg2_witness, product, neg_product); - - // If neg1 is set: neg_s1 = n - s1 (mod n), i.e. 0 - s1 - let neg_s1 = ops.negate(s1_limbs); - // neg1 already constrained boolean in verify_point_fakeglv - let effective_s1 = ops.select_unchecked(neg1_witness, s1_limbs, neg_s1); - - // Sum: effective_s1 + effective_product (mod n) should be 0 - let sum = ops.add(effective_s1, effective_product); - - // Constrain sum == 0: all limbs must be zero - for i in 0..sr_num_limbs { - constrain_zero(ops.compiler, sum[i]); - } -} - -/// Creates a constant witness with the given value. +/// Creates a constant witness with the given value, pinned by an R1CS +/// constraint so that a malicious prover cannot set it to an arbitrary value. fn add_constant_witness(compiler: &mut NoirToR1CSCompiler, value: FieldElement) -> usize { let w = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: 1 * w = value * 1 (embeds the constant into the constraint matrix) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, compiler.witness_one())], + ); w } +/// Constrains a witness to equal a known constant value. +/// Uses the constant as an R1CS coefficient — no witness needed for the +/// expected value. Use this for identity checks where the witness must equal +/// a compile-time-known value. +fn constrain_to_constant(compiler: &mut NoirToR1CSCompiler, witness: usize, value: FieldElement) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, witness)], + &[(value, compiler.witness_one())], + ); +} + /// Constrains two witnesses to be equal: `a - b = 0`. fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { compiler.r1cs.add_constraint( diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 34275b890..d16f26c13 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -210,11 +210,15 @@ impl FieldOps for MultiLimbOps<'_, '_> { let mut out = Limbs::new(n); for i in 0..n { let w = self.compiler.num_witnesses(); + let value = self.params.curve_a_limbs[i]; self.compiler - .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( - w, - self.params.curve_a_limbs[i], - ))); + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: prevent malicious prover from choosing a different curve_a + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, self.compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, self.compiler.witness_one())], + ); out[i] = w; } out @@ -254,6 +258,12 @@ impl FieldOps for MultiLimbOps<'_, '_> { let w = self.compiler.num_witnesses(); self.compiler .add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, limbs[i]))); + // Pin: prevent malicious prover from altering constant values + self.compiler.r1cs.add_constraint( + &[(FieldElement::ONE, self.compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(limbs[i], self.compiler.witness_one())], + ); out[i] = w; } out diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs new file mode 100644 index 000000000..b8e2f7b52 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -0,0 +1,413 @@ +//! Native-field MSM path: hint-verified EC ops with signed-bit wNAF. +//! +//! Used when `curve.is_native_field()` — replaces expensive field inversions +//! with prover hints verified via raw R1CS constraints. + +use { + super::{ + add_constant_witness, constrain_boolean, constrain_equal, constrain_to_constant, curve, + ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, + negate_y_signed_native, sanitize_point_scalar, scalar_relation, select_witness, + }, + ark_ff::{AdditiveGroup, Field}, + crate::noir_to_r1cs::NoirToR1CSCompiler, + curve::CurveParams, + provekit_common::{witness::WitnessBuilder, FieldElement}, + std::collections::BTreeMap, +}; + +/// Per-point preprocessed data for the merged native scalar mul loop. +/// +/// Holds the inputs needed by `scalar_mul_merged_native_wnaf` to process +/// one point's P and R branches inside the shared-doubling loop. +pub(super) struct NativePointData { + px: usize, + py_eff: usize, + neg_py_eff: usize, + s1_bits: Vec, + s1_skew: usize, + rx: usize, + ry_eff: usize, + neg_ry_eff: usize, + s2_bits: Vec, + s2_skew: usize, +} + +/// Native-field FakeGLV verification using hint-verified EC ops. +/// +/// This path is used when `curve.is_native_field()` and replaces +/// `verify_point_fakeglv` for significant constraint savings. +/// +/// Key differences from the generic path: +/// - EC ops use hint-verified formulas (4W+4C per double vs ~12, 3W+3C per add +/// vs ~8) +/// - On-curve checks use raw constraints (2W+3C vs ~6W+6C) +/// - Still uses binary bit decomposition + windowed scalar mul (same table +/// structure) +pub(super) fn verify_point_fakeglv_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: usize, + py: usize, + rx: usize, + ry: usize, + s_lo: usize, + s_hi: usize, + curve: &CurveParams, +) { + // Step 1: On-curve checks for P and R (native) + ec_points::verify_on_curve_native(compiler, px, py, curve); + ec_points::verify_on_curve_native(compiler, rx, ry, curve); + + // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 + let (s1_witness, s2_witness, neg1_witness, neg2_witness) = + emit_fakeglv_hint(compiler, s_lo, s_hi, curve); + + // Step 3: Signed-bit decomposition + let half_bits = curve.glv_half_bits() as usize; + let (s1_bits, s1_skew) = decompose_signed_bits(compiler, s1_witness, half_bits); + let (s2_bits, s2_skew) = decompose_signed_bits(compiler, s2_witness, half_bits); + + // Step 4: Conditionally negate y-coordinates + let (py_effective, neg_py_effective) = negate_y_signed_native(compiler, neg1_witness, py); + let (ry_effective, neg_ry_effective) = negate_y_signed_native(compiler, neg2_witness, ry); + + // Step 5: Scalar mul via merged loop (single-point = one-element slice) + let point_data = NativePointData { + px, + py_eff: py_effective, + neg_py_eff: neg_py_effective, + s1_bits, + s1_skew, + rx, + ry_eff: ry_effective, + neg_ry_eff: neg_ry_effective, + s2_bits, + s2_skew, + }; + let offset_x_fe = curve::curve_native_point_fe(&curve.offset_point.0); + let offset_y_fe = curve::curve_native_point_fe(&curve.offset_point.1); + let offset_x = add_constant_witness(compiler, offset_x_fe); + let offset_y = add_constant_witness(compiler, offset_y_fe); + + let (acc_x, acc_y) = + scalar_mul_merged_native_wnaf(compiler, &[point_data], offset_x, offset_y, curve); + + // Step 6: Identity check — acc should equal accumulated offset + // (hardcoded into constraint matrix, not a prover-controlled witness) + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); + constrain_to_constant( + compiler, + acc_x, + curve::curve_native_point_fe(&acc_off_x_raw), + ); + constrain_to_constant( + compiler, + acc_y, + curve::curve_native_point_fe(&acc_off_y_raw), + ); + + // Step 7: Scalar relation verification (unchanged) + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + curve, + ); +} + +/// Multi-point native-field MSM with merged-loop optimization. +/// +/// All points share a single doubling per bit, saving 4*(n-1) constraints +/// per bit of the half-scalar. +pub(super) fn process_multi_point_native( + compiler: &mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + n_points: usize, + range_checks: &mut BTreeMap>, + curve: &CurveParams, +) { + let (out_x, out_y, out_inf) = outputs; + let one = compiler.witness_one(); + + // Generator constants for sanitization + let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); + let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); + let gen_x_witness = add_constant_witness(compiler, gen_x_fe); + let gen_y_witness = add_constant_witness(compiler, gen_y_fe); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + let mut all_skipped: Option = None; + let mut native_points: Vec = Vec::new(); + let mut scalar_rel_inputs: Vec<(usize, usize, usize, usize, usize, usize)> = Vec::new(); + let mut accum_inputs: Vec<(usize, usize, usize)> = Vec::new(); + + // Phase 1: Per-point preprocessing + for i in 0..n_points { + let san = sanitize_point_scalar( + compiler, + point_wits[3 * i], + point_wits[3 * i + 1], + scalar_wits[2 * i], + scalar_wits[2 * i + 1], + point_wits[3 * i + 2], + gen_x_witness, + gen_y_witness, + zero_witness, + one, + ); + + all_skipped = Some(match all_skipped { + None => san.is_skip, + Some(prev) => compiler.add_product(prev, san.is_skip), + }); + + let (sanitized_rx, sanitized_ry) = emit_ec_scalar_mul_hint_and_sanitize( + compiler, + &san, + gen_x_witness, + gen_y_witness, + curve, + ); + + // On-curve checks + ec_points::verify_on_curve_native(compiler, san.px, san.py, curve); + ec_points::verify_on_curve_native(compiler, sanitized_rx, sanitized_ry, curve); + + // FakeGLV decomposition + signed-bit decomposition + let (s1, s2, neg1, neg2) = emit_fakeglv_hint(compiler, san.s_lo, san.s_hi, curve); + let half_bits = curve.glv_half_bits() as usize; + let (s1_bits, s1_skew) = decompose_signed_bits(compiler, s1, half_bits); + let (s2_bits, s2_skew) = decompose_signed_bits(compiler, s2, half_bits); + + // Y-negation + let (py_eff, neg_py_eff) = negate_y_signed_native(compiler, neg1, san.py); + let (ry_eff, neg_ry_eff) = negate_y_signed_native(compiler, neg2, sanitized_ry); + + native_points.push(NativePointData { + px: san.px, + py_eff, + neg_py_eff, + s1_bits, + s1_skew, + rx: sanitized_rx, + ry_eff, + neg_ry_eff, + s2_bits, + s2_skew, + }); + + scalar_rel_inputs.push((san.s_lo, san.s_hi, s1, s2, neg1, neg2)); + accum_inputs.push((sanitized_rx, sanitized_ry, san.is_skip)); + } + + // Phase 2: Merged scalar mul verification (shared doubling) + let half_bits = curve.glv_half_bits() as usize; + let offset_x_fe = curve::curve_native_point_fe(&curve.offset_point.0); + let offset_y_fe = curve::curve_native_point_fe(&curve.offset_point.1); + let offset_x = add_constant_witness(compiler, offset_x_fe); + let offset_y = add_constant_witness(compiler, offset_y_fe); + + let (ver_acc_x, ver_acc_y) = + scalar_mul_merged_native_wnaf(compiler, &native_points, offset_x, offset_y, curve); + + // Identity check: acc should equal accumulated offset (hardcoded into + // constraint matrix — not a witness the prover can manipulate) + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); + constrain_to_constant( + compiler, + ver_acc_x, + curve::curve_native_point_fe(&acc_off_x_raw), + ); + constrain_to_constant( + compiler, + ver_acc_y, + curve::curve_native_point_fe(&acc_off_y_raw), + ); + + // Phase 3: Per-point scalar relations + for &(s_lo, s_hi, s1, s2, neg1, neg2) in &scalar_rel_inputs { + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1, + s2, + neg1, + neg2, + curve, + ); + } + + // Phase 4: Accumulation (same offset-based logic) + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + let mut acc_x = add_constant_witness(compiler, offset_x_fe); + let mut acc_y = add_constant_witness(compiler, offset_y_fe); + + for &(sanitized_rx, sanitized_ry, is_skip) in &accum_inputs { + let (cand_x, cand_y) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + sanitized_rx, + sanitized_ry, + curve, + ); + acc_x = select_witness(compiler, is_skip, cand_x, acc_x); + acc_y = select_witness(compiler, is_skip, cand_y, acc_y); + } + + // Offset subtraction and output constraining + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_fe = curve::curve_native_point_fe(&neg_offset_y_raw); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_fe = curve::curve_native_point_fe(&neg_gen_y_raw); + + let sub_x_off = add_constant_witness(compiler, offset_x_fe); + let sub_x = select_witness(compiler, all_skipped, sub_x_off, gen_x_witness); + + let neg_off_y_w = add_constant_witness(compiler, neg_offset_y_fe); + let neg_g_y_w = add_constant_witness(compiler, neg_gen_y_fe); + let sub_y = select_witness(compiler, all_skipped, neg_off_y_w, neg_g_y_w); + + let (result_x, result_y) = + ec_points::point_add_verified_native(compiler, acc_x, acc_y, sub_x, sub_y, curve); + + let masked_result_x = select_witness(compiler, all_skipped, result_x, zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, result_y, zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + constrain_equal(compiler, out_inf, all_skipped); +} + +/// Merged multi-point scalar multiplication for native field using +/// signed-bit wNAF (w=1) with shared doubling across all points. +/// +/// Instead of running separate 128-iteration loops per point (each with +/// its own doubling), this merges all points into a single loop with one +/// shared doubling per bit. Each bit costs: +/// 4C (shared double) + n_points × 8C (2×(1C select + 3C add)) +/// +/// Savings: 4C × (n_points - 1) per bit ≈ 512C for 2 points on Grumpkin. +fn scalar_mul_merged_native_wnaf( + compiler: &mut NoirToR1CSCompiler, + points: &[NativePointData], + offset_x: usize, + offset_y: usize, + curve: &CurveParams, +) -> (usize, usize) { + let n = points[0].s1_bits.len(); + let mut acc_x = offset_x; + let mut acc_y = offset_y; + + // wNAF loop: MSB to LSB, shared doubling + for i in (0..n).rev() { + // Single shared double + let (dx, dy) = ec_points::point_double_verified_native(compiler, acc_x, acc_y, curve); + let mut cur_x = dx; + let mut cur_y = dy; + + // For each point: P branch + R branch + for pt in points { + let sel_py = select_witness(compiler, pt.s1_bits[i], pt.neg_py_eff, pt.py_eff); + let (ax, ay) = + ec_points::point_add_verified_native(compiler, cur_x, cur_y, pt.px, sel_py, curve); + + let sel_ry = select_witness(compiler, pt.s2_bits[i], pt.neg_ry_eff, pt.ry_eff); + (cur_x, cur_y) = + ec_points::point_add_verified_native(compiler, ax, ay, pt.rx, sel_ry, curve); + } + + acc_x = cur_x; + acc_y = cur_y; + } + + // Skew corrections for all points + for pt in points { + let (sub_px, sub_py) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + pt.px, + pt.neg_py_eff, + curve, + ); + acc_x = select_witness(compiler, pt.s1_skew, acc_x, sub_px); + acc_y = select_witness(compiler, pt.s1_skew, acc_y, sub_py); + + let (sub_rx, sub_ry) = ec_points::point_add_verified_native( + compiler, + acc_x, + acc_y, + pt.rx, + pt.neg_ry_eff, + curve, + ); + acc_x = select_witness(compiler, pt.s2_skew, acc_x, sub_rx); + acc_y = select_witness(compiler, pt.s2_skew, acc_y, sub_ry); + } + + (acc_x, acc_y) +} + +/// Signed-bit decomposition for wNAF scalar multiplication. +/// +/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} +/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: +/// scalar = Σ d_i * 2^i - skew +/// +/// Reconstruction constraint (1 linear R1CS): +/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// +/// All bits and skew are boolean-constrained. +pub(super) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let constant = FieldElement::from(1u128 << num_bits) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-FieldElement::from(1u128 << (i + 1)), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs new file mode 100644 index 000000000..9c03d72cc --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -0,0 +1,424 @@ +//! Non-native (generic multi-limb) MSM path. +//! +//! Used when `!curve.is_native_field()` — uses `MultiLimbOps` for all EC +//! arithmetic with configurable limb width. + +use { + super::{ + add_constant_witness, constrain_equal, constrain_to_constant, curve, ec_points, + emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, sanitize_point_scalar, + scalar_relation, select_witness, FieldOps, Limbs, + }, + crate::{ + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + provekit_common::{ + witness::SumTerm, + FieldElement, + }, + std::collections::BTreeMap, + super::multi_limb_ops, +}; + +/// Build `MultiLimbParams` for a given runtime `num_limbs`. +pub(super) fn build_params( + num_limbs: usize, + limb_bits: u32, + curve: &CurveParams, +) -> MultiLimbParams { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + is_native, + modulus_fe, + } +} + +/// FakeGLV verification for a single point: verifies R = \[s\]P. +/// +/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies +/// \[s1\]P + \[s2\]R = O using interleaved windowed scalar mul with +/// half-width scalars. +/// +/// Returns the mutable references back to the caller for continued use. +pub(super) fn verify_point_fakeglv<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + mut range_checks: &'a mut BTreeMap>, + px: Limbs, + py: Limbs, + rx: Limbs, + ry: Limbs, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + curve: &CurveParams, +) -> ( + &'a mut NoirToR1CSCompiler, + &'a mut BTreeMap>, +) { + // --- Steps 1-4: On-curve checks, FakeGLV decomposition, and GLV scalar mul + // --- + let (s1_witness, s2_witness, neg1_witness, neg2_witness); + { + let params = build_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // Step 1: On-curve checks for P and R + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + + // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 + (s1_witness, s2_witness, neg1_witness, neg2_witness) = + emit_fakeglv_hint(ops.compiler, s_lo, s_hi, curve); + + // Step 3: Decompose |s1|, |s2| into half_bits bits each + let half_bits = curve.glv_half_bits() as usize; + let s1_bits = decompose_half_scalar_bits(ops.compiler, s1_witness, half_bits); + let s2_bits = decompose_half_scalar_bits(ops.compiler, s2_witness, half_bits); + + // Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity + // check + + // Compute negated y-coordinates: neg_y = 0 - y (mod p) + let neg_py = ops.negate(py); + let neg_ry = ops.negate(ry); + + // Select: if neg1=1, use neg_py; else use py + // neg1 and neg2 are constrained to be boolean by ops.select internally. + let py_effective = ops.select(neg1_witness, py, neg_py); + // Select: if neg2=1, use neg_ry; else use ry + let ry_effective = ops.select(neg2_witness, ry, neg_ry); + + // GLV scalar mul + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_glv( + &mut ops, + px, + py_effective, + &s1_bits, + rx, + ry_effective, + &s2_bits, + window_size, + offset_x, + offset_y, + ); + + // Identity check: acc should equal [2^(num_windows * window_size)] * + // offset_point + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + + // Identity check: hardcode expected limb values as R1CS coefficients + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + for i in 0..num_limbs { + constrain_to_constant(ops.compiler, glv_acc.0[i], acc_off_x_values[i]); + constrain_to_constant(ops.compiler, glv_acc.1[i], acc_off_y_values[i]); + } + + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + // --- Step 5: Scalar relation verification --- + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + curve, + ); + + (compiler, range_checks) +} + +/// Multi-point non-native MSM with offset-based accumulation. +pub(super) fn process_multi_point_non_native<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: (usize, usize, usize), + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + mut range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + let (out_x, out_y, out_inf) = outputs; + let one = compiler.witness_one(); + + // Generator constants for sanitization + let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); + let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); + let gen_x_witness = add_constant_witness(compiler, gen_x_fe); + let gen_y_witness = add_constant_witness(compiler, gen_y_fe); + let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + + // Build params once for all multi-limb ops in the multi-point path + let params = build_params(num_limbs, limb_bits, curve); + + // Offset point as limbs for accumulation + let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); + let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); + + // Start accumulator at offset_point + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let mut acc_x = ops.constant_limbs(&offset_x_values); + let mut acc_y = ops.constant_limbs(&offset_y_values); + compiler = ops.compiler; + range_checks = ops.range_checks; + + // Track all_skipped = product of all is_skip flags + let mut all_skipped: Option = None; + + for i in 0..n_points { + let san = sanitize_point_scalar( + compiler, + point_wits[3 * i], + point_wits[3 * i + 1], + scalar_wits[2 * i], + scalar_wits[2 * i + 1], + point_wits[3 * i + 2], + gen_x_witness, + gen_y_witness, + zero_witness, + one, + ); + + // Track all_skipped + all_skipped = Some(match all_skipped { + None => san.is_skip, + Some(prev) => compiler.add_product(prev, san.is_skip), + }); + + let (sanitized_rx, sanitized_ry) = emit_ec_scalar_mul_hint_and_sanitize( + compiler, + &san, + gen_x_witness, + gen_y_witness, + curve, + ); + + // Generic multi-limb path + let (px, py) = + decompose_point_to_limbs(compiler, san.px, san.py, num_limbs, limb_bits, range_checks); + let (rx, ry) = decompose_point_to_limbs( + compiler, + sanitized_rx, + sanitized_ry, + num_limbs, + limb_bits, + range_checks, + ); + + // Verify R_i = [s_i]P_i using FakeGLV (on sanitized values) + (compiler, range_checks) = verify_point_fakeglv( + compiler, + range_checks, + px, + py, + rx, + ry, + san.s_lo, + san.s_hi, + num_limbs, + limb_bits, + window_size, + curve, + ); + + // Offset-based accumulation with conditional select + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); + let (new_acc_x, new_acc_y) = ec_points::point_select_unchecked( + &mut ops, + san.is_skip, + (cand_x, cand_y), + (acc_x, acc_y), + ); + acc_x = new_acc_x; + acc_y = new_acc_y; + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + let all_skipped = all_skipped.expect("MSM must have at least one point"); + + // Generic multi-limb offset subtraction + let neg_offset_y_raw = + curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); + let neg_offset_y_values = curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); + + let gen_x_limb_values = curve.generator_x_limbs(limb_bits, num_limbs); + let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); + let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); + + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + let sub_x = { + let off_x = ops.constant_limbs(&offset_x_values); + let g_x = ops.constant_limbs(&gen_x_limb_values); + ops.select(all_skipped, off_x, g_x) + }; + let sub_y = { + let neg_off_y = ops.constant_limbs(&neg_offset_y_values); + let neg_g_y = ops.constant_limbs(&neg_gen_y_values); + ops.select(all_skipped, neg_off_y, neg_g_y) + }; + + let (result_x, result_y) = ec_points::point_add(&mut ops, acc_x, acc_y, sub_x, sub_y); + compiler = ops.compiler; + + if num_limbs == 1 { + let masked_result_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + } else { + let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); + let masked_result_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); + let masked_result_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); + constrain_equal(compiler, out_x, masked_result_x); + constrain_equal(compiler, out_y, masked_result_y); + } + constrain_equal(compiler, out_inf, all_skipped); +} + +/// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. +fn verify_on_curve( + ops: &mut MultiLimbOps, + x: Limbs, + y: Limbs, + b_limb_values: &[FieldElement], + num_limbs: usize, +) { + let y_sq = ops.mul(y, y); + let x_sq = ops.mul(x, x); + let x_cubed = ops.mul(x_sq, x); + let a = ops.curve_a(); + let ax = ops.mul(a, x); + let x3_plus_ax = ops.add(x_cubed, ax); + let b = ops.constant_limbs(b_limb_values); + let rhs = ops.add(x3_plus_ax, b); + for i in 0..num_limbs { + constrain_equal(ops.compiler, y_sq[i], rhs[i]); + } +} + +/// Decompose a point (px_witness, py_witness) into Limbs. +pub(super) fn decompose_point_to_limbs( + compiler: &mut NoirToR1CSCompiler, + px_witness: usize, + py_witness: usize, + num_limbs: usize, + limb_bits: u32, + range_checks: &mut BTreeMap>, +) -> (Limbs, Limbs) { + if num_limbs == 1 { + (Limbs::single(px_witness), Limbs::single(py_witness)) + } else { + let px_limbs = + decompose_witness_to_limbs(compiler, px_witness, limb_bits, num_limbs, range_checks); + let py_limbs = + decompose_witness_to_limbs(compiler, py_witness, limb_bits, num_limbs, range_checks); + (px_limbs, py_limbs) + } +} + +/// Decompose a single witness into `num_limbs` limbs using digital +/// decomposition. +fn decompose_witness_to_limbs( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + limb_bits: u32, + num_limbs: usize, + range_checks: &mut BTreeMap>, +) -> Limbs { + let log_bases = vec![limb_bits as usize; num_limbs]; + let dd = add_digital_decomposition(compiler, log_bases, vec![witness]); + let mut limbs = Limbs::new(num_limbs); + for i in 0..num_limbs { + limbs[i] = dd.get_digit_witness_index(i, 0); + // Range-check each decomposed limb to [0, 2^limb_bits). + // add_digital_decomposition constrains the recomposition but does + // NOT range-check individual digits. + range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + limbs +} + +/// Recompose limbs back into a single witness: val = Σ limb\[i\] * +/// 2^(i*limb_bits) +fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits: u32) -> usize { + let terms: Vec = limbs + .iter() + .enumerate() + .map(|(i, &limb)| { + let coeff = FieldElement::from(2u64).pow([(i as u64) * (limb_bits as u64)]); + SumTerm(Some(coeff), limb) + }) + .collect(); + compiler.add_sum(terms) +} + +/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). +fn decompose_half_scalar_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + half_bits: usize, +) -> Vec { + let log_bases = vec![1usize; half_bits]; + let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); + let mut bits = Vec::with_capacity(half_bits); + for bit_idx in 0..half_bits { + bits.push(dd.get_digit_witness_index(bit_idx, 0)); + } + bits +} diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs new file mode 100644 index 000000000..ff307abf3 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -0,0 +1,226 @@ +//! Scalar relation verification: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +//! (mod n). +//! +//! Shared by both the native and non-native MSM paths. + +use { + super::{ + constrain_zero, cost_model, curve, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + FieldOps, Limbs, + }, + crate::{ + digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, + curve::CurveParams, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// The 256-bit scalar is split into two 128-bit halves (s_lo, s_hi) because the +/// full value doesn't fit in the native field. +const SCALAR_HALF_BITS: usize = 128; + +/// Compute digit widths for decomposing `total_bits` into chunks of at most +/// `max_width` bits. The last chunk may be smaller. +fn limb_widths(total_bits: usize, max_width: u32) -> Vec { + let n = (total_bits + max_width as usize - 1) / max_width as usize; + (0..n) + .map(|i| { + let remaining = total_bits - i * max_width as usize; + remaining.min(max_width as usize) + }) + .collect() +} + +/// Builds `MultiLimbParams` for scalar relation verification (mod +/// curve_order_n). +fn build_scalar_relation_params( + num_limbs: usize, + limb_bits: u32, + curve: &CurveParams, +) -> MultiLimbParams { + // Scalar relation uses curve_order_n as the modulus. + // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, + // except for Grumpkin where they're very close but still different). + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); + let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); + + // For N=1 non-native, we need the modulus as a FieldElement + let modulus_fe = if num_limbs == 1 { + Some(curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; + + MultiLimbParams { + num_limbs, + limb_bits, + p_limbs: n_limbs, + p_minus_1_limbs: n_minus_1_limbs, + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + is_native: false, // always non-native + modulus_fe, + } +} + +/// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 +/// (mod n). +/// +/// Uses multi-limb arithmetic with curve_order_n as the modulus. +pub(super) fn verify_scalar_relation( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + s_lo: usize, + s_hi: usize, + s1_witness: usize, + s2_witness: usize, + neg1_witness: usize, + neg2_witness: usize, + curve: &CurveParams, +) { + let order_bits = curve.curve_order_bits() as usize; + let limb_bits = + cost_model::scalar_relation_limb_bits(FieldElement::MODULUS_BIT_SIZE, order_bits); + let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; + let half_bits = curve.glv_half_bits() as usize; + + let params = build_scalar_relation_params(num_limbs, limb_bits, curve); + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + let s_limbs = decompose_scalar_from_halves(&mut ops, s_lo, s_hi, num_limbs, limb_bits); + let s1_limbs = decompose_half_scalar(&mut ops, s1_witness, num_limbs, half_bits, limb_bits); + let s2_limbs = decompose_half_scalar(&mut ops, s2_witness, num_limbs, half_bits, limb_bits); + + let product = ops.mul(s2_limbs, s_limbs); + + // Sign handling: when signs match check s1+product=0, otherwise s1-product=0. + // XOR = neg1 + neg2 - 2*neg1*neg2 gives 0 for same signs, 1 for different. + let sum = ops.add(s1_limbs, product); + let diff = ops.sub(s1_limbs, product); + + let xor_prod = ops.compiler.add_product(neg1_witness, neg2_witness); + let xor = ops.compiler.add_sum(vec![ + SumTerm(None, neg1_witness), + SumTerm(None, neg2_witness), + SumTerm(Some(-FieldElement::from(2u64)), xor_prod), + ]); + + let effective = ops.select_unchecked(xor, sum, diff); + for i in 0..num_limbs { + constrain_zero(ops.compiler, effective[i]); + } +} + +/// Decompose a 256-bit scalar from two 128-bit halves into `num_limbs` limbs. +/// +/// When `limb_bits` divides 128 (e.g. 64), limb boundaries align with the +/// s_lo/s_hi split. Otherwise (e.g. 85-bit limbs), one limb straddles bit 128 +/// and is assembled from a partial s_lo digit and a partial s_hi digit. +fn decompose_scalar_from_halves( + ops: &mut MultiLimbOps, + s_lo: usize, + s_hi: usize, + num_limbs: usize, + limb_bits: u32, +) -> Limbs { + let lo_tail = SCALAR_HALF_BITS % limb_bits as usize; + + if lo_tail == 0 { + let widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let dd_lo = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, widths.clone(), vec![s_hi]); + let mut limbs = Limbs::new(num_limbs); + let from_lo = widths.len().min(num_limbs); + for (i, &w) in widths.iter().enumerate().take(from_lo) { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks.entry(w as u32).or_default().push(limbs[i]); + } + for (i, &w) in widths.iter().enumerate().take(num_limbs - from_lo) { + limbs[from_lo + i] = dd_hi.get_digit_witness_index(i, 0); + ops.range_checks.entry(w as u32).or_default().push(limbs[from_lo + i]); + } + limbs + } else { + // Example: 85-bit limbs, 254-bit order → + // s_lo DD [85, 43], s_hi DD [42, 86] + // L0 = s_lo[0..85), L1 = s_lo[85..128) | s_hi[0..42), L2 = s_hi[42..128) + let hi_head = limb_bits as usize - lo_tail; + let hi_rest = SCALAR_HALF_BITS - hi_head; + let lo_full = SCALAR_HALF_BITS / limb_bits as usize; + + let lo_widths = limb_widths(SCALAR_HALF_BITS, limb_bits); + let hi_widths = vec![hi_head, hi_rest]; + + let dd_lo = add_digital_decomposition(ops.compiler, lo_widths, vec![s_lo]); + let dd_hi = add_digital_decomposition(ops.compiler, hi_widths, vec![s_hi]); + let mut limbs = Limbs::new(num_limbs); + + for i in 0..lo_full { + limbs[i] = dd_lo.get_digit_witness_index(i, 0); + ops.range_checks.entry(limb_bits).or_default().push(limbs[i]); + } + + // Cross-boundary limb: lo_tail bits from s_lo + hi_head bits from s_hi + let shift = FieldElement::from(2u64).pow([lo_tail as u64]); + let lo_digit = dd_lo.get_digit_witness_index(lo_full, 0); + let hi_digit = dd_hi.get_digit_witness_index(0, 0); + limbs[lo_full] = ops.compiler.add_sum(vec![ + SumTerm(None, lo_digit), + SumTerm(Some(shift), hi_digit), + ]); + ops.range_checks.entry(lo_tail as u32).or_default().push(lo_digit); + ops.range_checks.entry(hi_head as u32).or_default().push(hi_digit); + + if hi_rest > 0 { + limbs[lo_full + 1] = dd_hi.get_digit_witness_index(1, 0); + ops.range_checks.entry(hi_rest as u32).or_default().push(limbs[lo_full + 1]); + } + + limbs + } +} + +/// Decompose a half-scalar witness into `num_limbs` limbs, zero-padding the +/// upper limbs beyond `half_bits`. +fn decompose_half_scalar( + ops: &mut MultiLimbOps, + witness: usize, + num_limbs: usize, + half_bits: usize, + limb_bits: u32, +) -> Limbs { + let widths = limb_widths(half_bits, limb_bits); + let dd = add_digital_decomposition(ops.compiler, widths.clone(), vec![witness]); + let mut limbs = Limbs::new(num_limbs); + + for (i, &w) in widths.iter().enumerate() { + limbs[i] = dd.get_digit_witness_index(i, 0); + ops.range_checks.entry(w as u32).or_default().push(limbs[i]); + } + + for i in widths.len()..num_limbs { + let w = ops.compiler.num_witnesses(); + ops.compiler + .add_witness_builder(WitnessBuilder::Constant(ConstantTerm( + w, + FieldElement::from(0u64), + ))); + limbs[i] = w; + constrain_zero(ops.compiler, limbs[i]); + } + + limbs +} diff --git a/provekit/r1cs-compiler/src/range_check.rs b/provekit/r1cs-compiler/src/range_check.rs index f76fe94c3..936a33240 100644 --- a/provekit/r1cs-compiler/src/range_check.rs +++ b/provekit/r1cs-compiler/src/range_check.rs @@ -139,6 +139,39 @@ fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 { optimal_width } +/// Estimates total witness cost for resolving range checks without +/// constructing actual R1CS constraints. +/// +/// Takes a map of `bit_width → count` (number of witnesses needing that +/// range check). Uses the same optimal-base-width search and +/// LogUp-vs-naive cost model as [`add_range_checks`], but operates on +/// aggregate counts rather than concrete witness indices. +pub(crate) fn estimate_range_check_cost(checks: &BTreeMap) -> usize { + if checks.is_empty() { + return 0; + } + + // Create synthetic RangeCheckRequests with unique dummy indices. + let mut collected: Vec = Vec::new(); + let mut dummy_idx = 0usize; + for (&bits, &count) in checks { + for _ in 0..count { + collected.push(RangeCheckRequest { + witness_idx: dummy_idx, + bits, + }); + dummy_idx += 1; + } + } + + if collected.is_empty() { + return 0; + } + + let base_width = get_optimal_base_width(&collected); + calculate_witness_cost(base_width, &collected) +} + /// Add witnesses and constraints that ensure that the values of the witness /// belong to a range 0..2^k (for some k). /// From 5c5f1f1c0a97491a77dcbabcff5bf1c9ec249cbc Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 11 Mar 2026 06:56:30 +0530 Subject: [PATCH 12/19] fix : lint --- .../src/witness/scheduling/dependency.rs | 4 +- provekit/prover/src/lib.rs | 61 ------------------- .../prover/src/witness/witness_builder.rs | 2 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 48 +++++++-------- provekit/r1cs-compiler/src/msm/ec_points.rs | 40 ++++++------ provekit/r1cs-compiler/src/msm/mod.rs | 7 ++- provekit/r1cs-compiler/src/msm/native.rs | 6 +- provekit/r1cs-compiler/src/msm/non_native.rs | 10 +-- .../r1cs-compiler/src/msm/scalar_relation.rs | 25 ++++++-- 9 files changed, 77 insertions(+), 126 deletions(-) diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index ba63359c1..6f06268ea 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -224,9 +224,7 @@ impl DependencyInfo { ] } WitnessBuilder::EcDoubleHint { px, py, .. } => vec![*px, *py], - WitnessBuilder::EcAddHint { - x1, y1, x2, y2, .. - } => vec![*x1, *y1, *x2, *y2], + WitnessBuilder::EcAddHint { x1, y1, x2, y2, .. } => vec![*x1, *y1, *x2, *y2], WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], WitnessBuilder::EcScalarMulHint { px, py, s_lo, s_hi, .. diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 0fa9133ad..d87e636e7 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -197,67 +197,6 @@ impl Prove for NoirProver { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; - // DEBUG: Check R1CS constraint satisfaction with ALL witnesses solved - { - use ark_ff::Zero; - let debug_r1cs = r1cs.clone(); - let interner = &debug_r1cs.interner; - let ha = debug_r1cs.a.hydrate(interner); - let hb = debug_r1cs.b.hydrate(interner); - let hc = debug_r1cs.c.hydrate(interner); - let mut fail_count = 0usize; - for row in 0..debug_r1cs.num_constraints() { - let eval = |hm: &provekit_common::sparse_matrix::HydratedSparseMatrix, - r: usize| - -> FieldElement { - let mut sum = FieldElement::zero(); - for (col, coeff) in hm.iter_row(r) { - sum += coeff * full_witness[col]; - } - sum - }; - let a_val = eval(&ha, row); - let b_val = eval(&hb, row); - let c_val = eval(&hc, row); - if a_val * b_val != c_val { - if fail_count < 10 { - eprintln!( - "CONSTRAINT {} FAILED: A={:?} B={:?} C={:?} A*B={:?}", - row, - a_val, - b_val, - c_val, - a_val * b_val - ); - eprint!(" A terms:"); - for (col, coeff) in ha.iter_row(row) { - eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); - } - eprintln!(); - eprint!(" B terms:"); - for (col, coeff) in hb.iter_row(row) { - eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); - } - eprintln!(); - eprint!(" C terms:"); - for (col, coeff) in hc.iter_row(row) { - eprint!(" w{}={:?}*{:?}", col, full_witness[col], coeff); - } - eprintln!(); - } - fail_count += 1; - } - } - if fail_count > 0 { - eprintln!( - "TOTAL FAILING CONSTRAINTS: {fail_count} / {}", - debug_r1cs.num_constraints() - ); - } else { - eprintln!("ALL {} CONSTRAINTS SATISFIED", debug_r1cs.num_constraints()); - } - } - let whir_r1cs_proof = self .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 9d98e7cb6..9cc8d8335 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -507,7 +507,7 @@ impl WitnessBuilderSolver for WitnessBuilder { let x2_val = witness[*x2].unwrap().into_bigint().0; let y2_val = witness[*y2].unwrap().into_bigint().0; - use crate::bigint_mod::{mod_inverse, mod_sub, mul_mod, mod_add}; + use crate::bigint_mod::{mod_add, mod_inverse, mod_sub, mul_mod}; let numerator = mod_sub(&y2_val, &y1_val, field_modulus_p); let denominator = mod_sub(&x2_val, &x1_val, field_modulus_p); let denom_inv = mod_inverse(&denominator, field_modulus_p); diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 79ce03bfb..5992953e8 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -7,8 +7,8 @@ use std::collections::BTreeMap; /// The 256-bit scalar is split into two halves (s_lo, s_hi) because it doesn't -/// fit in the native field. This constant is used throughout the scalar relation -/// cost model. +/// fit in the native field. This constant is used throughout the scalar +/// relation cost model. const SCALAR_HALF_BITS: usize = 128; /// Type of field operation for cost estimation. @@ -30,10 +30,9 @@ pub enum FieldOpType { /// `num_limbs` witnesses per coordinate (via `select_witness`), not /// multi-limb field op witnesses. /// -/// - `n_point_selects`: selects on EcPoint (2 coordinates), from table -/// lookups and conditional skip after point_add. -/// - `n_coord_selects`: selects on single Limbs coordinate, from -/// y-negation. +/// - `n_point_selects`: selects on EcPoint (2 coordinates), from table lookups +/// and conditional skip after point_add. +/// - `n_coord_selects`: selects on single Limbs coordinate, from y-negation. /// - `n_is_zero`: `compute_is_zero` calls, each creating exactly 3 native /// witnesses regardless of num_limbs. fn count_glv_field_ops( @@ -121,8 +120,7 @@ fn count_glv_real_field_ops( scalar_bits: usize, window_size: usize, ) -> (usize, usize, usize, usize) { - let (n_add, n_sub, n_mul, n_inv, _, _, _) = - count_glv_field_ops(scalar_bits, window_size); + let (n_add, n_sub, n_mul, n_inv, ..) = count_glv_field_ops(scalar_bits, window_size); (n_add, n_sub, n_mul, n_inv) } @@ -165,8 +163,7 @@ fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize fn count_scalar_relation_witnesses(native_field_bits: u32, scalar_bits: usize) -> usize { let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; - let scalar_half_limbs = - (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; + let scalar_half_limbs = (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, false); let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, false); @@ -174,8 +171,7 @@ fn count_scalar_relation_witnesses(native_field_bits: u32, scalar_bits: usize) - // Scalar decomposition: DD digits for s_lo + s_hi, plus cross-boundary // witness when limb boundaries don't align with the 128-bit split - let has_cross_boundary = - num_limbs > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; + let has_cross_boundary = num_limbs > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; let scalar_decomp = 2 * scalar_half_limbs + has_cross_boundary as usize; // Half-scalar decomposition: DD digits + zero-pad constants for s1, s2 @@ -240,8 +236,7 @@ fn count_scalar_relation_range_checks( let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; let half_bits = (scalar_bits + 1) / 2; let half_limbs = (half_bits + limb_bits as usize - 1) / limb_bits as usize; - let scalar_half_limbs = - (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; + let scalar_half_limbs = (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; let mut rc_map: BTreeMap = BTreeMap::new(); @@ -264,8 +259,7 @@ fn count_scalar_relation_range_checks( /// Accounts for three categories of witnesses: /// 1. **Inline witnesses** — field ops, selects, is_zero, hints, DDs /// 2. **Range check resolution** — LogUp/naive cost for all range checks -/// 3. **Per-point overhead** — detect_skip, sanitization, point -/// decomposition +/// 3. **Per-point overhead** — detect_skip, sanitization, point decomposition pub fn calculate_msm_witness_cost( native_field_bits: u32, curve_modulus_bits: u32, @@ -298,14 +292,12 @@ pub fn calculate_msm_witness_cost( count_glv_field_ops(half_bits, window_size); // Field ops: priced at full multi-limb cost - let field_op_cost = - n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; + let field_op_cost = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; // Selects: each select_witness creates 1 witness per limb (inlined). // Point select = 2 coords × num_limbs × 1. // Coord select = 1 coord × num_limbs × 1. - let select_cost = - n_point_selects * 2 * num_limbs + n_coord_selects * num_limbs; + let select_cost = n_point_selects * 2 * num_limbs + n_coord_selects * num_limbs; // is_zero: 3 fixed native witnesses each (SafeInverse + Product + Sum) let is_zero_cost = n_is_zero * 3; @@ -372,8 +364,7 @@ pub fn calculate_msm_witness_cost( let mut rc_map: BTreeMap = BTreeMap::new(); // 1. Range checks from GLV field ops (selects generate 0 range checks) - let (rc_n_add, rc_n_sub, rc_n_mul, rc_n_inv) = - count_glv_real_field_ops(half_bits, window_size); + let (rc_n_add, rc_n_sub, rc_n_mul, rc_n_inv) = count_glv_real_field_ops(half_bits, window_size); for &(op, n_ops) in &[ (FieldOpType::Add, rc_n_add), (FieldOpType::Sub, rc_n_sub), @@ -399,8 +390,8 @@ pub fn calculate_msm_witness_cost( *rc_map.entry(*bits).or_default() += n_points * count; } - // 4. Accumulation range checks: n_points point_adds + 1 offset - // subtraction point_add (multi-point only) + // 4. Accumulation range checks: n_points point_adds + 1 offset subtraction + // point_add (multi-point only) if n_points > 1 { let accum_point_adds = n_points + 1; // loop + offset subtraction for &(op, n_ops) in &[ @@ -706,13 +697,16 @@ mod tests { #[test] fn test_scalar_relation_witnesses_small_curve() { let sr = count_scalar_relation_witnesses(254, 64); - assert!(sr < 100, "64-bit curve scalar_relation={sr} should be < 100"); + assert!( + sr < 100, + "64-bit curve scalar_relation={sr} should be < 100" + ); } #[test] fn test_is_zero_cost_independent_of_num_limbs() { - let (_, _, _, _, n_is_zero_w4, _, _) = count_glv_field_ops(128, 4); - let (_, _, _, _, n_is_zero_w3, _, _) = count_glv_field_ops(128, 3); + let (_, _, _, _, n_is_zero_w4, ..) = count_glv_field_ops(128, 4); + let (_, _, _, _, n_is_zero_w3, ..) = count_glv_field_ops(128, 3); assert!(n_is_zero_w4 > 0); assert!(n_is_zero_w3 > 0); } diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index b5348a4c8..975963ed6 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -265,8 +265,10 @@ pub fn scalar_mul_glv( // Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw // R1CS constraints, eliminating expensive field inversions from the circuit. -use super::curve::CurveParams; -use ark_ff::{Field, PrimeField}; +use { + super::curve::CurveParams, + ark_ff::{Field, PrimeField}, +}; /// Hint-verified point doubling for native field. /// @@ -287,10 +289,10 @@ pub fn point_double_verified_native( // Allocate hint witnesses let hint_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { - output_start: hint_start, + output_start: hint_start, px, py, - curve_a: curve.curve_a, + curve_a: curve.curve_a, field_modulus_p: curve.field_modulus_p, }); let lambda = hint_start; @@ -305,11 +307,12 @@ pub fn point_double_verified_native( let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); let three = FieldElement::from(3u64); let two = FieldElement::from(2u64); - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(two, py)], - &[(three, x_sq), (a_fe, compiler.witness_one())], - ); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, lambda)], &[(two, py)], &[ + (three, x_sq), + (a_fe, compiler.witness_one()), + ]); // Constraint: lambda^2 = x3 + 2*px compiler.r1cs.add_constraint( @@ -348,7 +351,7 @@ pub fn point_add_verified_native( // Allocate hint witnesses let hint_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::EcAddHint { - output_start: hint_start, + output_start: hint_start, x1, y1, x2, @@ -370,7 +373,11 @@ pub fn point_add_verified_native( compiler.r1cs.add_constraint( &[(FieldElement::ONE, lambda)], &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, x3), (FieldElement::ONE, x1), (FieldElement::ONE, x2)], + &[ + (FieldElement::ONE, x3), + (FieldElement::ONE, x1), + (FieldElement::ONE, x2), + ], ); // Constraint: lambda * (x1 - x3) = y3 + y1 @@ -404,14 +411,11 @@ pub fn verify_on_curve_native( let b_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_b)).unwrap(); // y * y = x_cu + a*x + b - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, y)], - &[(FieldElement::ONE, y)], - &[ + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, y)], &[(FieldElement::ONE, y)], &[ (FieldElement::ONE, x_cu), (a_fe, x), (b_fe, compiler.witness_one()), - ], - ); + ]); } - diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index a953d9774..e55e717d1 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -546,7 +546,12 @@ fn process_single_point_msm<'a>( } else { // Generic multi-limb path let (px, py) = non_native::decompose_point_to_limbs( - compiler, san.px, san.py, num_limbs, limb_bits, range_checks, + compiler, + san.px, + san.py, + num_limbs, + limb_bits, + range_checks, ); let (rx, ry) = non_native::decompose_point_to_limbs( compiler, diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs index b8e2f7b52..fb97884f4 100644 --- a/provekit/r1cs-compiler/src/msm/native.rs +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -6,11 +6,11 @@ use { super::{ add_constant_witness, constrain_boolean, constrain_equal, constrain_to_constant, curve, - ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, - negate_y_signed_native, sanitize_point_scalar, scalar_relation, select_witness, + ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, negate_y_signed_native, + sanitize_point_scalar, scalar_relation, select_witness, }, - ark_ff::{AdditiveGroup, Field}, crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, curve::CurveParams, provekit_common::{witness::WitnessBuilder, FieldElement}, std::collections::BTreeMap, diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index 9c03d72cc..4cb186a49 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -6,8 +6,8 @@ use { super::{ add_constant_witness, constrain_equal, constrain_to_constant, curve, ec_points, - emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, sanitize_point_scalar, - scalar_relation, select_witness, FieldOps, Limbs, + emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, multi_limb_ops, + sanitize_point_scalar, scalar_relation, select_witness, FieldOps, Limbs, }, crate::{ digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, @@ -16,12 +16,8 @@ use { ark_ff::{AdditiveGroup, Field}, curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, - provekit_common::{ - witness::SumTerm, - FieldElement, - }, + provekit_common::{witness::SumTerm, FieldElement}, std::collections::BTreeMap, - super::multi_limb_ops, }; /// Build `MultiLimbParams` for a given runtime `num_limbs`. diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs index ff307abf3..b437c309e 100644 --- a/provekit/r1cs-compiler/src/msm/scalar_relation.rs +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -150,7 +150,10 @@ fn decompose_scalar_from_halves( } for (i, &w) in widths.iter().enumerate().take(num_limbs - from_lo) { limbs[from_lo + i] = dd_hi.get_digit_witness_index(i, 0); - ops.range_checks.entry(w as u32).or_default().push(limbs[from_lo + i]); + ops.range_checks + .entry(w as u32) + .or_default() + .push(limbs[from_lo + i]); } limbs } else { @@ -170,7 +173,10 @@ fn decompose_scalar_from_halves( for i in 0..lo_full { limbs[i] = dd_lo.get_digit_witness_index(i, 0); - ops.range_checks.entry(limb_bits).or_default().push(limbs[i]); + ops.range_checks + .entry(limb_bits) + .or_default() + .push(limbs[i]); } // Cross-boundary limb: lo_tail bits from s_lo + hi_head bits from s_hi @@ -181,12 +187,21 @@ fn decompose_scalar_from_halves( SumTerm(None, lo_digit), SumTerm(Some(shift), hi_digit), ]); - ops.range_checks.entry(lo_tail as u32).or_default().push(lo_digit); - ops.range_checks.entry(hi_head as u32).or_default().push(hi_digit); + ops.range_checks + .entry(lo_tail as u32) + .or_default() + .push(lo_digit); + ops.range_checks + .entry(hi_head as u32) + .or_default() + .push(hi_digit); if hi_rest > 0 { limbs[lo_full + 1] = dd_hi.get_digit_witness_index(1, 0); - ops.range_checks.entry(hi_rest as u32).or_default().push(limbs[lo_full + 1]); + ops.range_checks + .entry(hi_rest as u32) + .or_default() + .push(limbs[lo_full + 1]); } limbs From ad1f6904586eddf58ff2d56582db4db18adc3f8b Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 11 Mar 2026 07:08:19 +0530 Subject: [PATCH 13/19] feat : added wNAF implementation in native msm noir circuit --- noir-examples/native_msm/Nargo.toml | 1 - noir-examples/native_msm/src/main.nr | 405 ++++++++++++++++++++++----- 2 files changed, 329 insertions(+), 77 deletions(-) diff --git a/noir-examples/native_msm/Nargo.toml b/noir-examples/native_msm/Nargo.toml index 5ff116db7..6b16fd3ae 100644 --- a/noir-examples/native_msm/Nargo.toml +++ b/noir-examples/native_msm/Nargo.toml @@ -2,6 +2,5 @@ name = "native_msm" type = "bin" authors = [""] -compiler_version = ">=0.22.0" [dependencies] diff --git a/noir-examples/native_msm/src/main.nr b/noir-examples/native_msm/src/main.nr index 80cfd3d0f..901722a4e 100644 --- a/noir-examples/native_msm/src/main.nr +++ b/noir-examples/native_msm/src/main.nr @@ -1,104 +1,357 @@ -// Grumpkin generator y-coordinate global GRUMPKIN_GEN_Y: Field = 17631683881184975370165255887551781615748388533673675138860; -struct Point { - x: Field, - y: Field, - is_infinite: bool, +// Hardcoded offset generators: offset = 5*G, offset_final = 2^252 * 5*G +// These are compile-time constants -- no unconstrained computation, no runtime verification needed. +global OFFSET_X: Field = 12229279139087521908560794489267966517139449915173592433539394009359081620359; +global OFFSET_Y: Field = 12096995292699515952722386974733884667125946823386040531322131902193094989869; +global OFFSET_FINAL_X: Field = 17097678145015848904467691187715743297134903912023447344174597163323183228319; +global OFFSET_FINAL_Y: Field = 14560299638432262069836824301755786319891239433999125203465365199349384123743; + +// BN254 scalar field modulus as wNAF slices (MSB first). +// Used for lexicographic range check: ensures wNAF integer < p, preventing mod-p ambiguity. +global MODULUS_SLICES: [u8; 64] = [ + 9, 8, 3, 2, 2, 7, 3, 9, 7, 0, 9, 8, 13, 0, 1, 4, + 13, 12, 2, 8, 2, 2, 13, 11, 4, 0, 12, 0, 10, 12, 2, 14, + 9, 4, 1, 9, 15, 4, 2, 4, 3, 12, 13, 12, 11, 8, 4, 8, + 10, 1, 15, 0, 15, 10, 12, 9, 15, 8, 0, 0, 0, 0, 0, 0, +]; + +struct GPoint { x: Field, y: Field } +struct GPointResult { x: Field, y: Field, is_infinity: bool } +struct Hint { lambda: Field, x3: Field, y3: Field } + +// ====== Constrained EC verification ====== + +// ~4 constraints: verify point doubling +fn c_double(p: GPoint, h: Hint) -> GPoint { + let xx = p.x * p.x; + assert(h.lambda * (p.y + p.y) == 3 * xx); // a=0 for Grumpkin + assert(h.lambda * h.lambda == h.x3 + p.x + p.x); + assert(h.lambda * (p.x - h.x3) == h.y3 + p.y); + GPoint { x: h.x3, y: h.y3 } } -fn point_double(p: Point) -> Point { - if p.is_infinite | (p.y == 0) { - Point { x: 0, y: 0, is_infinite: true } +// ~4 constraints: verify point addition (incomplete -- offset generator prevents edge cases) +fn c_add(p1: GPoint, p2: GPoint, h: Hint) -> GPoint { + assert(p1.x != p2.x); + assert(h.lambda * (p2.x - p1.x) == p2.y - p1.y); + assert(h.lambda * h.lambda == h.x3 + p1.x + p2.x); + assert(h.lambda * (p1.x - h.x3) == h.y3 + p1.y); + GPoint { x: h.x3, y: h.y3 } +} + +// ====== Unconstrained witness generation ====== + +unconstrained fn u_double(p: GPoint) -> (GPoint, Hint) { + let lambda = (3 * p.x * p.x) / (2 * p.y); + let x3 = lambda * lambda - 2 * p.x; + let y3 = lambda * (p.x - x3) - p.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +unconstrained fn u_add(p1: GPoint, p2: GPoint) -> (GPoint, Hint) { + let lambda = (p2.y - p1.y) / (p2.x - p1.x); + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + (GPoint { x: x3, y: y3 }, Hint { lambda, x3, y3 }) +} + +// Unconstrained complete addition hint (handles add, double, and inverse cases) +unconstrained fn u_complete_add_hint(p1: GPoint, p2: GPoint) -> Hint { + if p1.x == p2.x { + if p1.y == p2.y { + let (_, hint) = u_double(p1); + hint + } else { + Hint { lambda: 0, x3: 0, y3: 0 } + } } else { - // Grumpkin has a=0, so lambda = 3*x1^2 / (2*y1) - let lambda = (3 * p.x * p.x) / (2 * p.y); - let x3 = lambda * lambda - 2 * p.x; - let y3 = lambda * (p.x - x3) - p.y; - Point { x: x3, y: y3, is_infinite: false } + let (_, hint) = u_add(p1, p2); + hint } } -fn point_add(p1: Point, p2: Point) -> Point { - if p1.is_infinite { - p2 - } else if p2.is_infinite { - p1 - } else if (p1.x == p2.x) & (p1.y == p2.y) { - point_double(p1) - } else if (p1.x == p2.x) & (p1.y == (0 - p2.y)) { - Point { x: 0, y: 0, is_infinite: true } +// Unconstrained full addition hint (handles infinity inputs) +unconstrained fn u_full_add_hint( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, +) -> Hint { + if p1_inf | p2_inf { + Hint { lambda: 0, x3: 0, y3: 0 } } else { - let lambda = (p2.y - p1.y) / (p2.x - p1.x); - let x3 = lambda * lambda - p1.x - p2.x; - let y3 = lambda * (p1.x - x3) - p1.y; - Point { x: x3, y: y3, is_infinite: false } + u_complete_add_hint(p1, p2) } } -fn scalar_mul(p: Point, scalar_lo: Field, scalar_hi: Field) -> Point { - let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); - let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); +// Constrained complete addition: handles add, double, and inverse-point cases. +// Both inputs must be valid on-curve points (not identity). +// Uses `active * constraint == 0` pattern so constraints are trivially satisfied +// when the result is the identity (inverse-point case). +fn c_complete_add(p1: GPoint, p2: GPoint, h: Hint) -> GPointResult { + let x_eq: bool = p1.x == p2.x; + let y_eq: bool = p1.y == p2.y; + let is_infinity: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let active: Field = (!is_infinity) as Field; - // Combine into a single 256-bit array (lo first, then hi) - let mut bits: [u1; 256] = [0; 256]; - for i in 0..128 { - bits[i] = lo_bits[i]; - bits[128 + i] = hi_bits[i]; - } + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(active * (h.lambda * lambda_lhs - lambda_rhs) == 0); - // Find the highest set bit - let mut top = 0; - for i in 0..256 { - if bits[i] == 1 { - top = i; + // x3 verification: lambda^2 = x3 + x1 + x2 (same for add and double since x2=x1 when doubling) + assert(active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + + // y3 verification: lambda * (x1 - x3) = y3 + y1 + assert(active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + GPointResult { x: h.x3 * active, y: h.y3 * active, is_infinity } +} + +// Constrained full addition: handles all cases including identity inputs. +// This is used for the final MSM sum where either operand may be the identity. +fn c_full_add( + p1: GPoint, p1_inf: bool, + p2: GPoint, p2_inf: bool, + h: Hint, +) -> GPointResult { + let neither_inf = !p1_inf & !p2_inf; + let both_inf = p1_inf & p2_inf; + let only_p1_inf = p1_inf & !p2_inf; + + // EC constraints are only active when neither input is identity. + let ec_active: Field = neither_inf as Field; + + // Determine add/double/inverse case (only meaningful when neither is identity). + // Guard with neither_inf so garbage coordinates from identity points don't affect predicates. + let x_eq: bool = (p1.x == p2.x) & neither_inf; + let y_eq: bool = (p1.y == p2.y) & neither_inf; + let is_inf_from_add: bool = x_eq & !y_eq; + let is_double: bool = x_eq & y_eq; + let arith_active: Field = ec_active * (!is_inf_from_add as Field); + + // Lambda, x3, y3 constraints (zeroed when inactive) + let lambda_lhs = if is_double { p1.y + p1.y } else { p2.x - p1.x }; + let lambda_rhs = if is_double { 3 * p1.x * p1.x } else { p2.y - p1.y }; + assert(arith_active * (h.lambda * lambda_lhs - lambda_rhs) == 0); + assert(arith_active * (h.lambda * h.lambda - h.x3 - p1.x - p2.x) == 0); + assert(arith_active * (h.lambda * (p1.x - h.x3) - h.y3 - p1.y) == 0); + + // Output selection + let result_is_inf: bool = both_inf | is_inf_from_add; + + let out_x = if result_is_inf { 0 } + else if only_p1_inf { p2.x } + else if p2_inf { p1.x } + else { h.x3 }; + let out_y = if result_is_inf { 0 } + else if only_p1_inf { p2.y } + else if p2_inf { p1.y } + else { h.y3 }; + + GPointResult { x: out_x, y: out_y, is_infinity: result_is_inf } +} + +unconstrained fn decompose_wnaf(scalar_lo: Field, scalar_hi: Field) -> ([u8; 64], bool) { + let lo_bytes = scalar_lo.to_le_bytes::<16>(); + let hi_bytes = scalar_hi.to_le_bytes::<16>(); + let mut nibbles: [u8; 64] = [0; 64]; + for i in 0..16 { + nibbles[2 * i] = lo_bytes[i] & 0x0F; + nibbles[2 * i + 1] = lo_bytes[i] >> 4; + } + for i in 0..16 { + nibbles[32 + 2 * i] = hi_bytes[i] & 0x0F; + nibbles[32 + 2 * i + 1] = hi_bytes[i] >> 4; + } + let skew: bool = (nibbles[0] & 1) == 0; + nibbles[0] = nibbles[0] + (skew as u8); + let mut slices: [u8; 64] = [0; 64]; + slices[63] = (nibbles[0] + 15) / 2; + for i in 1..64 { + let nibble = nibbles[i]; + slices[63 - i] = (nibble + 15) / 2; + if (nibble & 1) == 0 { + slices[63 - i] += 1; + slices[64 - i] -= 8; } } + (slices, skew) +} + +// 326 hints per scalar mul: 8 table + 1 init + 63*5 loop + 1 skew + 1 final +unconstrained fn compute_transcript( + P: GPoint, slices: [u8; 64], skew: bool, + offset: GPoint, offset_final: GPoint, +) -> [Hint; 326] { + let mut h: [Hint; 326] = [Hint { lambda: 0, x3: 0, y3: 0 }; 326]; + let mut p: u32 = 0; + + // Table: 2P, then P+2P, 3P+2P, ... + let (d2, d2h) = u_double(P); + h[p] = d2h; p += 1; + let mut table: [GPoint; 16] = [GPoint { x: 0, y: 0 }; 16]; + table[8] = P; + table[7] = GPoint { x: P.x, y: 0 - P.y }; + let mut A = P; + for i in 1..8 { + let (s, sh) = u_add(A, d2); h[p] = sh; p += 1; + A = s; + table[8 + i] = A; + table[7 - i] = GPoint { x: A.x, y: 0 - A.y }; + } + + // Init: offset + T[slices[0]] + let (ir, ih) = u_add(offset, table[slices[0] as u32]); + h[p] = ih; p += 1; + let mut acc = ir; + + // 63 windows: 4 doubles + 1 add each + for _w in 1..64 { + for _ in 0..4 { let (d, dh) = u_double(acc); h[p] = dh; p += 1; acc = d; } + let tp = table[slices[_w] as u32]; + let (s, sh) = u_add(acc, tp); h[p] = sh; p += 1; acc = s; + } + + // Skew correction (always compute valid hint even if unused) + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let (sr, sh) = u_add(acc, neg_P); + h[p] = sh; p += 1; + if skew { acc = sr; } - // Double-and-add from MSB down to bit 0 - let mut acc = Point { x: 0, y: 0, is_infinite: true }; - for j in 0..256 { - let i = 255 - j; - acc = point_double(acc); - if bits[i] == 1 { - acc = point_add(acc, p); + // Final offset subtraction (complete -- handles identity result when scalar = 0) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + h[p] = u_complete_add_hint(acc, neg_off); + + h +} + +// ====== Scalar range check ====== +// Lexicographic comparison: ensures wNAF slices represent an integer < field modulus. +// Without this, a prover could encode scalar + k*p (for k != 0) using valid 4-bit slices, +// since the Horner reconstruction only checks equality mod p. +// Mirrors noir_bigcurve's `compare_scalar_field_to_bignum`. +fn assert_slices_less_than_modulus(slices: [u8; 64]) { + let mut found_strictly_less: bool = false; + for i in 0..64 { + if !found_strictly_less { + let s = slices[i]; + let m = MODULUS_SLICES[i]; + // If we find a digit strictly less than modulus digit, scalar < modulus -- done. + if s as u8 < m { + found_strictly_less = true; + } else { + // If strictly greater at any position (without prior strictly-less), scalar >= modulus. + assert(s == m, "wNAF scalar exceeds field modulus"); + } } } + // If all digits equal, scalar == modulus, which is also invalid (must be strictly less). + assert(found_strictly_less, "wNAF scalar equals field modulus"); +} + +// ====== Main scalar multiplication ====== + +fn scalar_mul_wnaf(P: GPoint, scalar_lo: Field, scalar_hi: Field) -> GPointResult { + // 1. Decompose scalar into wNAF slices + // Safety: slices and skew are fully constrained below (range, reconstruction, and modulus bound) + let (slices, skew) = unsafe { decompose_wnaf(scalar_lo, scalar_hi) }; + + // Range check: each slice fits in 4 bits + for i in 0..64 { (slices[i] as Field).assert_max_bit_size::<4>(); } - acc + // Soundness fix #1: scalar range check -- slices represent integer < field modulus + assert_slices_less_than_modulus(slices); + + // Reconstruction check: wNAF Horner evaluation == scalar_lo + scalar_hi * 2^128 + let mut r: Field = 0; + for i in 0..64 { r = r * 16; r += (slices[i] as Field) * 2 - 15; } + r -= skew as Field; + + let lo_bits: [u1; 128] = scalar_lo.to_le_bits(); + let hi_bits: [u1; 128] = scalar_hi.to_le_bits(); + let mut expected: Field = 0; + let mut pow: Field = 1; + for i in 0..128 { expected += (lo_bits[i] as Field) * pow; pow *= 2; } + let two_128: Field = pow; + let mut hi_val: Field = 0; + pow = 1; + for i in 0..128 { hi_val += (hi_bits[i] as Field) * pow; pow *= 2; } + expected += hi_val * two_128; + assert(r == expected); + + // 2. Offset generators -- hardcoded compile-time constants (soundness fix #2) + let offset = GPoint { x: OFFSET_X, y: OFFSET_Y }; + let offset_final = GPoint { x: OFFSET_FINAL_X, y: OFFSET_FINAL_Y }; + + // 3. Transcript of EC operation hints + // Safety: every hint is verified by a constrained c_double or c_add call below + let hints = unsafe { compute_transcript(P, slices, skew, offset, offset_final) }; + let mut hp: u32 = 0; + + // 4. Build 16-entry lookup table: T[8]=P, T[9]=3P, ..., T[15]=15P, T[7]=-P, ..., T[0]=-15P + let d2 = c_double(P, hints[hp]); hp += 1; + let mut tx: [Field; 16] = [0; 16]; + let mut ty: [Field; 16] = [0; 16]; + tx[8] = P.x; ty[8] = P.y; + tx[7] = P.x; ty[7] = 0 - P.y; + let mut A = P; + for i in 1..8 { + A = c_add(A, d2, hints[hp]); hp += 1; + tx[8 + i] = A.x; ty[8 + i] = A.y; + tx[7 - i] = A.x; ty[7 - i] = 0 - A.y; + } + + // 5. Init accumulator: offset + T[slices[0]] + let first = GPoint { x: tx[slices[0] as u32], y: ty[slices[0] as u32] }; + let mut acc = c_add(offset, first, hints[hp]); hp += 1; + + // 6. Main wNAF loop: 63 windows * (4 doublings + 1 table add) + for _w in 1..64 { + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + acc = c_double(acc, hints[hp]); hp += 1; + let tp = GPoint { x: tx[slices[_w] as u32], y: ty[slices[_w] as u32] }; + acc = c_add(acc, tp, hints[hp]); hp += 1; + } + + // 7. Skew correction: if scalar was even, subtract P + let neg_P = GPoint { x: P.x, y: 0 - P.y }; + let skew_r = c_add(acc, neg_P, hints[hp]); hp += 1; + acc = if skew { skew_r } else { acc }; + + // 8. Subtract accumulated offset: result = acc - 2^252 * offset + // Uses complete addition to handle the identity result (scalar = 0 mod group_order) + let neg_off = GPoint { x: offset_final.x, y: 0 - offset_final.y }; + c_complete_add(acc, neg_off, hints[hp]) } -/// Native MSM: computes s1 * G + s2 * G using pure Noir field operations. -/// No blackbox functions -- all EC arithmetic is done natively over Grumpkin's -/// base field (= BN254 scalar field = Noir's native Field). +/// 2-point MSM on Grumpkin: s1*G + s2*G fn main( - scalar1_lo: Field, - scalar1_hi: Field, - scalar2_lo: Field, - scalar2_hi: Field, -) { - let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; - - let r1 = scalar_mul(g, scalar1_lo, scalar1_hi); - let r2 = scalar_mul(g, scalar2_lo, scalar2_hi); - let result = point_add(r1, r2); - - // Prevent dead-code elimination - assert(!result.is_infinite); -} + scalar1_lo: pub Field, scalar1_hi: pub Field, + scalar2_lo: pub Field, scalar2_hi: pub Field, +) -> pub (Field, Field, bool) { + let g = GPoint { x: 1, y: GRUMPKIN_GEN_Y }; -#[test] -fn test_msm() { - // 3*G on Grumpkin (known coordinates) - let expected_x = 18660890509582237958343981571981920822503400000196279471655180441138020044621; - let expected_y = 8902249110305491597038405103722863701255802573786510474664632793109847672620; + let r1 = scalar_mul_wnaf(g, scalar1_lo, scalar1_hi); + let r2 = scalar_mul_wnaf(g, scalar2_lo, scalar2_hi); - main(1, 0, 2, 0); + // Full addition: handles r1 == r2 (doubling), r1 == -r2 (identity), and identity inputs + let add_hint = unsafe { + u_full_add_hint( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + ) + }; + let result = c_full_add( + GPoint { x: r1.x, y: r1.y }, r1.is_infinity, + GPoint { x: r2.x, y: r2.y }, r2.is_infinity, + add_hint, + ); - // Verify 1*G + 2*G = 3*G by computing 3*G directly - let g = Point { x: 1, y: GRUMPKIN_GEN_Y, is_infinite: false }; - let three_g = scalar_mul(g, 3, 0); + // Verify result is on Grumpkin (skip for identity) + let on_curve = result.y * result.y - (result.x * result.x * result.x - 17); + assert((!result.is_infinity as Field) * on_curve == 0); - assert(three_g.x == expected_x); - assert(three_g.y == expected_y); + (result.x, result.y, result.is_infinity) } From 089e4008d3e8ef223cf9f063c1e3a824265f4d97 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 11 Mar 2026 08:25:03 +0530 Subject: [PATCH 14/19] lint : document-private-items --- provekit/common/src/witness/witness_builder.rs | 2 +- provekit/r1cs-compiler/src/msm/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0b368f966..22ecee690 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -354,7 +354,7 @@ pub enum WitnessBuilder { /// /// Outputs (num_bits + 1) witnesses at output_start: /// [0..num_bits) b_i sign bits - /// [num_bits] skew (0 if s is odd, 1 if s is even) + /// \[num_bits\] skew (0 if s is odd, 1 if s is even) SignedBitHint { output_start: usize, scalar: usize, diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 678123da2..c52a604db 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -487,7 +487,7 @@ fn process_single_msm( } } -/// Single-point MSM: R = [s]P with degenerate-case handling. +/// Single-point MSM: R = \[s\]P with degenerate-case handling. /// /// The ACIR output (out_x, out_y) is the result directly. Sanitizes inputs /// to handle scalar=0 and point-at-infinity, then verifies via FakeGLV. @@ -583,7 +583,7 @@ fn process_single_point_msm<'a>( constrain_product_zero(compiler, san.is_skip, out_y); } -/// Multi-point MSM: computes R_i = [s_i]P_i via hints, verifies each with +/// Multi-point MSM: computes R_i = \[s_i\]P_i via hints, verifies each with /// FakeGLV, then accumulates R_i's with offset-based accumulation and skip /// handling. /// From 665c17f32444ebe080bc91551c7db2ced1d2a6e4 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 11 Mar 2026 12:42:25 +0530 Subject: [PATCH 15/19] feat : added tests for proper coverage of logic --- noir-examples/embedded_curve_msm/Prover.toml | 72 +- .../Prover_near_identity.toml | 6 + .../embedded_curve_msm/Prover_near_order.toml | 6 + .../Prover_single_nonzero.toml | 5 + .../Prover_zero_scalars.toml | 5 + .../r1cs-compiler/src/constraint_helpers.rs | 133 +++ provekit/r1cs-compiler/src/lib.rs | 1 + provekit/r1cs-compiler/src/msm/cost_model.rs | 803 ++++++------------ provekit/r1cs-compiler/src/msm/curve.rs | 3 +- provekit/r1cs-compiler/src/msm/ec_points.rs | 68 +- provekit/r1cs-compiler/src/msm/mod.rs | 488 +++-------- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 113 ++- provekit/r1cs-compiler/src/msm/native.rs | 107 +-- provekit/r1cs-compiler/src/msm/non_native.rs | 44 +- .../r1cs-compiler/src/msm/scalar_relation.rs | 41 +- tooling/provekit-bench/tests/compiler.rs | 73 +- 16 files changed, 758 insertions(+), 1210 deletions(-) create mode 100644 noir-examples/embedded_curve_msm/Prover_near_identity.toml create mode 100644 noir-examples/embedded_curve_msm/Prover_near_order.toml create mode 100644 noir-examples/embedded_curve_msm/Prover_single_nonzero.toml create mode 100644 noir-examples/embedded_curve_msm/Prover_zero_scalars.toml create mode 100644 provekit/r1cs-compiler/src/constraint_helpers.rs diff --git a/noir-examples/embedded_curve_msm/Prover.toml b/noir-examples/embedded_curve_msm/Prover.toml index d36dddbd7..edf585681 100644 --- a/noir-examples/embedded_curve_msm/Prover.toml +++ b/noir-examples/embedded_curve_msm/Prover.toml @@ -1,71 +1,9 @@ # ============================================================ # MSM test vectors: result = s1 * G + s2 * G # Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -# Uncomment ONE test case at a time to run. # ============================================================ - -# === Test 1: Small scalars (1*G + 2*G = 3*G) === -# scalar1_lo = "1" -# scalar1_hi = "0" -# scalar2_lo = "2" -# scalar2_hi = "0" - -# === Test 2: All-zero scalars (0*G + 0*G = point at infinity) === -# scalar1_lo = "0" -# scalar1_hi = "0" -# scalar2_lo = "0" -# scalar2_hi = "0" - -# === Test 3: One zero, one non-zero (0*G + 5*G = 5*G) === -# scalar1_lo = "0" -# scalar1_hi = "0" -# scalar2_lo = "5" -# scalar2_hi = "0" - -# === Test 4: Large lo, small hi (diff ≠ 2^128) === -# scalar1_lo = "64323764613183177041862057485226039389" -# scalar1_hi = "1" -# scalar2_lo = "99999999999999999999999999999999999999" -# scalar2_hi = "3" - -# === Test 5: Small lo, large hi === -# scalar1_lo = "1" -# scalar1_hi = "64323764613183177041862057485226039389" -# scalar2_lo = "2" -# scalar2_hi = "64323764613183177041862057485226039389" - -# === Test 6: Near-max scalars (n-10 and n-20) === -# scalar1_lo = "201385395114098847380338600778089168189" -# scalar1_hi = "64323764613183177041862057485226039389" -# scalar2_lo = "201385395114098847380338600778089168179" -# scalar2_hi = "64323764613183177041862057485226039389" - -# === Test 7: Powers of 2 (2^100 and 2^200) === -# scalar1_lo = "1267650600228229401496703205376" -# scalar1_hi = "0" -# scalar2_lo = "0" -# scalar2_hi = "4722366482869645213696" - -# === Test 8: Half curve order (n/2) and 1 === -# scalar1_lo = "270833881017518655421856604104928689827" -# scalar1_hi = "32161882306591588520931028742613019694" -# scalar2_lo = "1" -# scalar2_hi = "0" - -# === Test 9: Large mixed scalars === -# scalar1_lo = "340282366920938463463374607431768211455" -# scalar1_hi = "0" -# scalar2_lo = "170141183460469231731687303715884105727" -# scalar2_hi = "3" - -# === Test 10: Both scalars equal, ~2n/3 === -scalar1_lo = "247684385716378719408017269662648849284" -scalar1_hi = "42882509742122118027908038323484026259" -scalar2_lo = "247684385716378719408017269662648849284" -scalar2_hi = "42882509742122118027908038323484026259" - -# === Test 11: n - 2, n - 3 (previously failing with [2]G offset) === -# scalar1_lo = "201385395114098847380338600778089168197" -# scalar1_hi = "64323764613183177041862057485226039389" -# scalar2_lo = "201385395114098847380338600778089168196" -# scalar2_hi = "64323764613183177041862057485226039389" +# n - 2, n - 3 (previously failing with [2]G offset) +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_identity.toml b/noir-examples/embedded_curve_msm/Prover_near_identity.toml new file mode 100644 index 000000000..a156f96ee --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_identity.toml @@ -0,0 +1,6 @@ +# MSM edge case: n-2 and n-3 (previously failing with [2]G offset) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168197" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168196" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_near_order.toml b/noir-examples/embedded_curve_msm/Prover_near_order.toml new file mode 100644 index 000000000..d8ae04eb5 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_near_order.toml @@ -0,0 +1,6 @@ +# MSM edge case: near-max scalars (n-10 and n-20) +# Grumpkin curve order n = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +scalar1_lo = "201385395114098847380338600778089168189" +scalar1_hi = "64323764613183177041862057485226039389" +scalar2_lo = "201385395114098847380338600778089168179" +scalar2_hi = "64323764613183177041862057485226039389" diff --git a/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml new file mode 100644 index 000000000..8455db356 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_single_nonzero.toml @@ -0,0 +1,5 @@ +# MSM edge case: one zero scalar, one non-zero (0*G + 5*G = 5*G) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "5" +scalar2_hi = "0" diff --git a/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml new file mode 100644 index 000000000..0bd8866c7 --- /dev/null +++ b/noir-examples/embedded_curve_msm/Prover_zero_scalars.toml @@ -0,0 +1,5 @@ +# MSM edge case: all-zero scalars (0*G + 0*G = point at infinity) +scalar1_lo = "0" +scalar1_hi = "0" +scalar2_lo = "0" +scalar2_hi = "0" diff --git a/provekit/r1cs-compiler/src/constraint_helpers.rs b/provekit/r1cs-compiler/src/constraint_helpers.rs new file mode 100644 index 000000000..2bd033bc2 --- /dev/null +++ b/provekit/r1cs-compiler/src/constraint_helpers.rs @@ -0,0 +1,133 @@ +//! General-purpose R1CS constraint helpers. + +use { + crate::noir_to_r1cs::NoirToR1CSCompiler, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{ConstantTerm, SumTerm, WitnessBuilder}, + FieldElement, + }, +}; + +/// Constrains `flag` to be boolean: `flag * flag = flag`. +pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, flag)], + ); +} + +/// Single-witness conditional select: `out = on_false + flag * (on_true - +/// on_false)`. +/// +/// Uses a single witness + single R1CS constraint: +/// flag * (on_true - on_false) = result - on_false +pub(crate) fn select_witness( + compiler: &mut NoirToR1CSCompiler, + flag: usize, + on_false: usize, + on_true: usize, +) -> usize { + // When both branches are the same witness, result is trivially that witness. + if on_false == on_true { + return on_false; + } + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SelectWitness { + output: result, + flag, + on_false, + on_true, + }); + // flag * (on_true - on_false) = result - on_false + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, flag)], + &[(FieldElement::ONE, on_true), (-FieldElement::ONE, on_false)], + &[(FieldElement::ONE, result), (-FieldElement::ONE, on_false)], + ); + result +} + +/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. +pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { + let terms: Vec = bits + .iter() + .enumerate() + .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) + .collect(); + compiler.add_sum(terms) +} + +/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. +/// Does NOT constrain a or b to be boolean — caller must ensure that. +/// +/// Uses a single witness + single R1CS constraint: +/// (1 - a) * (1 - b) = 1 - result +pub(crate) fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { + let one = compiler.witness_one(); + let result = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::BooleanOr { + output: result, + a, + b, + }); + // (1 - a) * (1 - b) = 1 - result + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, one), (-FieldElement::ONE, a)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, b)], + &[(FieldElement::ONE, one), (-FieldElement::ONE, result)], + ); + result +} + +/// Creates a constant witness with the given value, pinned by an R1CS +/// constraint so that a malicious prover cannot set it to an arbitrary value. +pub(crate) fn add_constant_witness( + compiler: &mut NoirToR1CSCompiler, + value: FieldElement, +) -> usize { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); + // Pin: 1 * w = value * 1 (embeds the constant into the constraint matrix) + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(value, compiler.witness_one())], + ); + w +} + +/// Constrains a witness to equal a known constant value. +/// Uses the constant as an R1CS coefficient — no witness needed for the +/// expected value. Use this for identity checks where the witness must equal +/// a compile-time-known value. +pub(crate) fn constrain_to_constant( + compiler: &mut NoirToR1CSCompiler, + witness: usize, + value: FieldElement, +) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, witness)], + &[(value, compiler.witness_one())], + ); +} + +/// Constrains two witnesses to be equal: `a - b = 0`. +pub(crate) fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} + +/// Constrains a witness to be zero: `w = 0`. +pub(crate) fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(FieldElement::ZERO, compiler.witness_one())], + ); +} diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index a1ee6b1ce..64e6eb46d 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -1,4 +1,5 @@ mod binops; +mod constraint_helpers; mod digits; mod memory; mod msm; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 5992953e8..39d2ed7f3 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -1,257 +1,122 @@ //! Analytical cost model for MSM parameter optimization. //! //! Follows the SHA256 pattern (`spread.rs:get_optimal_spread_width`): -//! pure analytical estimator → exhaustive search → pick optimal (limb_bits, -//! window_size). +//! `calculate_msm_witness_cost` estimates total cost, `get_optimal_msm_params` +//! searches the parameter space for the minimum. use std::collections::BTreeMap; /// The 256-bit scalar is split into two halves (s_lo, s_hi) because it doesn't -/// fit in the native field. This constant is used throughout the scalar -/// relation cost model. +/// fit in the native field. const SCALAR_HALF_BITS: usize = 128; -/// Type of field operation for cost estimation. -#[derive(Clone, Copy)] -pub enum FieldOpType { - Add, - Sub, - Mul, - Inv, +fn ceil_div(a: usize, b: usize) -> usize { + (a + b - 1) / b } -/// Count field ops and selects in scalar_mul_glv for given parameters. +/// Total witnesses produced by N-limb field operations. /// -/// Returns `(n_add, n_sub, n_mul, n_inv, n_is_zero, n_point_selects, -/// n_coord_selects)`. -/// -/// Field ops (add/sub/mul/inv) come from point_double, point_add, and -/// on-curve checks. Selects are counted separately because they create -/// `num_limbs` witnesses per coordinate (via `select_witness`), not -/// multi-limb field op witnesses. -/// -/// - `n_point_selects`: selects on EcPoint (2 coordinates), from table lookups -/// and conditional skip after point_add. -/// - `n_coord_selects`: selects on single Limbs coordinate, from y-negation. -/// - `n_is_zero`: `compute_is_zero` calls, each creating exactly 3 native -/// witnesses regardless of num_limbs. -fn count_glv_field_ops( - scalar_bits: usize, // half_bits = ceil(order_bits / 2) - window_size: usize, -) -> (usize, usize, usize, usize, usize, usize, usize) { - let w = window_size; - let table_size = 1 << w; - let num_windows = (scalar_bits + w - 1) / w; - - // Field ops per primitive EC operation (add, sub, mul, inv): - let double_ops = (4usize, 2usize, 5usize, 1usize); - let add_ops = (2usize, 2usize, 3usize, 1usize); - - // Two tables (one for P, one for R) - let table_doubles = if table_size > 2 { 1 } else { 0 }; - let table_adds = if table_size > 2 { table_size - 3 } else { 0 }; - - let mut total_add = 2 * (table_doubles * double_ops.0 + table_adds * add_ops.0); - let mut total_sub = 2 * (table_doubles * double_ops.1 + table_adds * add_ops.1); - let mut total_mul = 2 * (table_doubles * double_ops.2 + table_adds * add_ops.2); - let mut total_inv = 2 * (table_doubles * double_ops.3 + table_adds * add_ops.3); - let mut total_is_zero = 0usize; - let mut total_point_selects = 0usize; - - for win_idx in (0..num_windows).rev() { - let bit_start = win_idx * w; - let bit_end = std::cmp::min(bit_start + w, scalar_bits); - let actual_w = bit_end - bit_start; - let actual_table_selects = (1 << actual_w) - 1; - - // w shared doublings - total_add += w * double_ops.0; - total_sub += w * double_ops.1; - total_mul += w * double_ops.2; - total_inv += w * double_ops.3; - - // Two table lookups + two point_adds + two is_zeros + two conditional - // skips - for _ in 0..2 { - // Table lookup: (2^actual_w - 1) point selects - total_point_selects += actual_table_selects; - - // Point add - total_add += add_ops.0; - total_sub += add_ops.1; - total_mul += add_ops.2; - total_inv += add_ops.3; - - // is_zero: 3 fixed native witnesses each - total_is_zero += 1; - - // Conditional skip: 1 point select - total_point_selects += 1; - } - } - - // On-curve checks for P and R: each needs mul(y²), mul(x²), mul(x³), - // mul(a·x), add(x³+ax), add(x³+ax+b) = 4 mul + 2 add per point - total_mul += 8; - total_add += 4; - - // Conditional y-negation: 2 negate (= 2 sub) + 2 Limbs selects (1 coord - // each) - total_sub += 2; - let total_coord_selects = 2usize; - - ( - total_add, - total_sub, - total_mul, - total_inv, - total_is_zero, - total_point_selects, - total_coord_selects, - ) -} - -/// Count only range-check-producing field ops in scalar_mul_glv. -/// -/// Returns `(n_add, n_sub, n_mul, n_inv)` excluding selects and is_zero, -/// which generate 0 range checks (selects are native `select_witness` calls, -/// is_zero operates on `pack_bits` results). -fn count_glv_real_field_ops( - scalar_bits: usize, - window_size: usize, -) -> (usize, usize, usize, usize) { - let (n_add, n_sub, n_mul, n_inv, ..) = count_glv_field_ops(scalar_bits, window_size); - (n_add, n_sub, n_mul, n_inv) -} - -/// Witnesses per single N-limb field operation. -fn witnesses_per_op(num_limbs: usize, op: FieldOpType, is_native: bool) -> usize { +/// Per-op witness counts by configuration: +/// - Native (N=1, same field): 1 per op (direct R1CS) +/// - Single-limb non-native (N=1): 5 per add/sub/mul, 6 per inv (reduce_mod_p +/// pattern) +/// - Multi-limb (N>1): add/sub = 1+6N, mul = N²+7N-2, inv = N²+8N-2 (schoolbook +/// multiplication + quotient/remainder) +fn field_op_witnesses( + n_add: usize, + n_sub: usize, + n_mul: usize, + n_inv: usize, + num_limbs: usize, + is_native: bool, +) -> usize { if is_native { - match op { - FieldOpType::Add => 1, - FieldOpType::Sub => 1, - FieldOpType::Mul => 1, - FieldOpType::Inv => 1, - } + n_add + n_sub + n_mul + n_inv } else if num_limbs == 1 { - // Single-limb non-native: reduce_mod_p pattern - match op { - FieldOpType::Add => 5, // a+b, m const, k, k*m, result - FieldOpType::Sub => 5, - FieldOpType::Mul => 5, // a*b, m const, k, k*m, result - FieldOpType::Inv => 6, // a_inv(1) + mul_mod_p_single(5) - } + (n_add + n_sub + n_mul) * 5 + n_inv * 6 } else { let n = num_limbs; - match op { - // add/sub: q + N*(v_offset, carry, r_limb) + N*(v_diff, borrow, - // d_limb) - FieldOpType::Add | FieldOpType::Sub => 1 + 3 * n + 3 * n, - // mul: hint(4N-2) + N² products + 2N-1 column constraints + - // lt_check(3N) - FieldOpType::Mul => (4 * n - 2) + n * n + 3 * n, - // inv: hint(N) + mul costs - FieldOpType::Inv => n + (4 * n - 2) + n * n + 3 * n, - } + let w_as = 1 + 6 * n; + let w_m = n * n + 7 * n - 2; + let w_i = n * n + 8 * n - 2; + (n_add + n_sub) * w_as + n_mul * w_m + n_inv * w_i } } -/// Count witnesses for scalar relation verification. -/// -/// The scalar relation verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod -/// n)` using multi-limb arithmetic with the curve order as modulus. -fn count_scalar_relation_witnesses(native_field_bits: u32, scalar_bits: usize) -> usize { - let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); - let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; - let scalar_half_limbs = (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; - - let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, false); - let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, false); - let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, false); - - // Scalar decomposition: DD digits for s_lo + s_hi, plus cross-boundary - // witness when limb boundaries don't align with the 128-bit split - let has_cross_boundary = num_limbs > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; - let scalar_decomp = 2 * scalar_half_limbs + has_cross_boundary as usize; - - // Half-scalar decomposition: DD digits + zero-pad constants for s1, s2 - let half_scalar_decomp = 2 * num_limbs; - - // Sign handling: sum + diff + XOR (2 native witnesses) + select - let sign_handling = wit_add + wit_sub + 2 + num_limbs; - - scalar_decomp + half_scalar_decomp + wit_mul + sign_handling -} - -/// Range checks generated by a single N-limb field operation. +/// Aggregate range checks from field ops into a map. /// -/// Returns entries as `(bit_width, count)` pairs. Native ops produce no -/// range checks. Single-limb non-native uses `reduce_mod_p` (1 check at -/// `curve_modulus_bits`). Multi-limb ops produce checks at `limb_bits` -/// and `carry_bits = limb_bits + ceil(log2(N)) + 2`. -fn range_checks_per_op( +/// - Native: no range checks +/// - Single-limb non-native: 1 check at `modulus_bits` per add/sub/mul, 2 per +/// inv +/// - Multi-limb: `limb_bits`-wide checks from less_than_p, plus +/// `carry_bits`-wide checks from schoolbook column carries +fn add_field_op_range_checks( + n_add: usize, + n_sub: usize, + n_mul: usize, + n_inv: usize, num_limbs: usize, - op: FieldOpType, - is_native: bool, limb_bits: u32, - curve_modulus_bits: u32, -) -> Vec<(u32, usize)> { + modulus_bits: u32, + is_native: bool, + rc_map: &mut BTreeMap, +) { if is_native { - return vec![]; + return; } if num_limbs == 1 { - let bits = curve_modulus_bits; - return match op { - FieldOpType::Add | FieldOpType::Sub | FieldOpType::Mul => vec![(bits, 1)], - FieldOpType::Inv => vec![(bits, 2)], - }; - } - let n = num_limbs; - let ceil_log2_n = if n <= 1 { - 0u32 + *rc_map.entry(modulus_bits).or_default() += n_add + n_sub + n_mul + 2 * n_inv; } else { - (n as f64).log2().ceil() as u32 - }; - let carry_bits = limb_bits + ceil_log2_n + 2; - match op { - // add/sub: 2N from less_than_p_check_multi - FieldOpType::Add | FieldOpType::Sub => vec![(limb_bits, 2 * n)], - // mul: N q-limbs + 2N from less_than_p at limb_bits, (2N-2) carries - // at carry_bits - FieldOpType::Mul => vec![(limb_bits, 3 * n), (carry_bits, 2 * n - 2)], - // inv: N inv-limbs + mul's checks - FieldOpType::Inv => vec![(limb_bits, 4 * n), (carry_bits, 2 * n - 2)], + let n = num_limbs; + let ceil_log2_n = (n as f64).log2().ceil() as u32; + let carry_bits = limb_bits + ceil_log2_n + 2; + *rc_map.entry(limb_bits).or_default() += + (n_add + n_sub) * 2 * n + n_mul * 3 * n + n_inv * 4 * n; + *rc_map.entry(carry_bits).or_default() += (n_mul + n_inv) * (2 * n - 2); } } -/// Count range checks for scalar relation verification. +/// Witnesses and range checks for scalar relation verification. /// -/// Sources: DD digits (scalar + half-scalar decompositions) and multi-limb -/// field ops (1 mul + 1 add + 1 sub for XOR-based sign handling). -fn count_scalar_relation_range_checks( +/// Verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod n)` using multi-limb +/// arithmetic with the curve order as modulus. Components: +/// - Scalar decomposition (DD digits for s_lo, s_hi) +/// - Half-scalar decomposition (DD digits for s1, s2) +/// - One mul + one add + one sub for sign handling +/// - XOR witnesses (2) + select (num_limbs) +fn scalar_relation_cost( native_field_bits: u32, scalar_bits: usize, -) -> BTreeMap { +) -> (usize, BTreeMap) { let limb_bits = scalar_relation_limb_bits(native_field_bits, scalar_bits); - let num_limbs = (scalar_bits + limb_bits as usize - 1) / limb_bits as usize; + let n = ceil_div(scalar_bits, limb_bits as usize); let half_bits = (scalar_bits + 1) / 2; - let half_limbs = (half_bits + limb_bits as usize - 1) / limb_bits as usize; - let scalar_half_limbs = (SCALAR_HALF_BITS + limb_bits as usize - 1) / limb_bits as usize; - - let mut rc_map: BTreeMap = BTreeMap::new(); - - // DD digits: s_lo + s_hi (2 × scalar_half_limbs) + s1 + s2 (2 × half_limbs) + let half_limbs = ceil_div(half_bits, limb_bits as usize); + let scalar_half_limbs = ceil_div(SCALAR_HALF_BITS, limb_bits as usize); + + let has_cross = n > 1 && SCALAR_HALF_BITS % limb_bits as usize != 0; + let witnesses = 2 * scalar_half_limbs + + has_cross as usize + + 2 * n + + field_op_witnesses(1, 1, 1, 0, n, false) + + 2 + + n; + + let mut rc_map = BTreeMap::new(); *rc_map.entry(limb_bits).or_default() += 2 * scalar_half_limbs + 2 * half_limbs; - - // Multi-limb field ops: mul + add + sub - let modulus_bits = scalar_bits as u32; - for op in [FieldOpType::Mul, FieldOpType::Add, FieldOpType::Sub] { - for (bits, count) in range_checks_per_op(num_limbs, op, false, limb_bits, modulus_bits) { - *rc_map.entry(bits).or_default() += count; - } - } - - rc_map + add_field_op_range_checks( + 1, + 1, + 1, + 0, + n, + limb_bits, + scalar_bits as u32, + false, + &mut rc_map, + ); + + (witnesses, rc_map) } /// Total estimated witness cost for an MSM. @@ -270,262 +135,212 @@ pub fn calculate_msm_witness_cost( is_native: bool, ) -> usize { if is_native { - return calculate_msm_witness_cost_native( - native_field_bits, - n_points, - scalar_bits, - window_size, - ); + return calculate_msm_witness_cost_native(native_field_bits, n_points, scalar_bits); } - let num_limbs = - ((curve_modulus_bits as usize) + (limb_bits as usize) - 1) / (limb_bits as usize); - - let wit_add = witnesses_per_op(num_limbs, FieldOpType::Add, false); - let wit_sub = witnesses_per_op(num_limbs, FieldOpType::Sub, false); - let wit_mul = witnesses_per_op(num_limbs, FieldOpType::Mul, false); - let wit_inv = witnesses_per_op(num_limbs, FieldOpType::Inv, false); - - // === GLV scalar mul witnesses === + let n = ceil_div(curve_modulus_bits as usize, limb_bits as usize); let half_bits = (scalar_bits + 1) / 2; - let (n_add, n_sub, n_mul, n_inv, n_is_zero, n_point_selects, n_coord_selects) = - count_glv_field_ops(half_bits, window_size); - - // Field ops: priced at full multi-limb cost - let field_op_cost = n_add * wit_add + n_sub * wit_sub + n_mul * wit_mul + n_inv * wit_inv; - - // Selects: each select_witness creates 1 witness per limb (inlined). - // Point select = 2 coords × num_limbs × 1. - // Coord select = 1 coord × num_limbs × 1. - let select_cost = n_point_selects * 2 * num_limbs + n_coord_selects * num_limbs; + let w = window_size; + let table_size = 1usize << w; + let num_windows = ceil_div(half_bits, w); - // is_zero: 3 fixed native witnesses each (SafeInverse + Product + Sum) - let is_zero_cost = n_is_zero * 3; + // === GLV scalar mul field op counts === + // point_double: (5 add, 3 sub, 4 mul, 1 inv) + N constant witnesses (curve_a) + // point_add: (1 add, 5 sub, 3 mul, 1 inv) - let glv_scalarmul = field_op_cost + select_cost + is_zero_cost; + // Table building (2 tables for P and R) + let (tbl_d, tbl_a) = if table_size > 2 { + (1, table_size - 3) + } else { + (0, 0) + }; + let mut n_add = 2 * (tbl_d * 5 + tbl_a * 1); + let mut n_sub = 2 * (tbl_d * 3 + tbl_a * 5); + let mut n_mul = 2 * (tbl_d * 4 + tbl_a * 3); + let mut n_inv = 2 * (tbl_d + tbl_a); + + // Main loop: w shared doublings + 2 point_adds per window + n_add += num_windows * (w * 5 + 2 * 1); + n_sub += num_windows * (w * 3 + 2 * 5); + n_mul += num_windows * (w * 4 + 2 * 3); + n_inv += num_windows * (w + 2); + + // On-curve checks (P and R): 2 × (4 mul + 2 add) + n_mul += 8; + n_add += 4; + + // Y-negation: 2 negate = 2 sub (negate calls sub(zero, value)) + n_sub += 2; + + let glv_field_ops = field_op_witnesses(n_add, n_sub, n_mul, n_inv, n, false); + + // Constant witness allocations not captured by field ops: + // - curve_a() in each point_double: N per call + // - on-curve: 2 × (curve_a + curve_b) = 4N + // - negate: 2 × constant_limbs(zero) = 2N + // - offset point in verify_point_fakeglv: 2N + let n_doubles = 2 * tbl_d + num_windows * w; + let glv_constants = n_doubles * n + 4 * n + 2 * n + 2 * n; + + // Selects + is_zero (not field ops, priced separately) + let table_selects = num_windows * 2 * ((1 << w) - 1) * 2 * n; + let skip_selects = num_windows * 2 * 2 * n; + let y_negate_selects = 2 * n; + let is_zero_cost = num_windows * 2 * 3; // 3 native witnesses each + + let glv_cost = glv_field_ops + + glv_constants + + table_selects + + skip_selects + + y_negate_selects + + is_zero_cost; // === Per-point overhead === - // Scalar bit decomposition: 2 DDs of half_bits 1-bit digits let scalar_bit_decomp = 2 * half_bits; - - // detect_skip: 2×is_zero(3) + product(1) + boolean_or(1) = 8 - let detect_skip_cost = 8; - - // Sanitization: 3 constants (gen_x, gen_y, zero) + 6 select_witness × 1 - // For multi-point, constants are shared but impact is negligible. - let sanitize_cost = 3 + 6; - - // Point decomposition digit witnesses (add_digital_decomposition creates - // num_limbs digit witnesses per coordinate; 2 coords × 2 points = 4). - // Only applies when num_limbs > 1 (decompose_point_to_limbs is a no-op - // for num_limbs == 1). - let point_decomp_digits = if num_limbs > 1 { 4 * num_limbs } else { 0 }; - - // Scalar relation (analytical) - let scalar_relation = count_scalar_relation_witnesses(native_field_bits, scalar_bits); - - // FakeGLVHint: 4 witnesses (s1, s2, neg1, neg2) - let glv_hint = 4; - - // EcScalarMulHint: 2 witnesses per point (only for n_points > 1) - let ec_hint = if n_points > 1 { 2 } else { 0 }; - - let per_point = glv_scalarmul + let detect_skip = 8; // 2×is_zero(3W) + product(1W) + or(1W) + let sanitize = 4; // 4 select_witness + let ec_hint = 4; // 2W hint + 2W selects + let point_decomp = if n > 1 { 4 * n } else { 0 }; + let glv_hint = 4; // s1, s2, neg1, neg2 + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + + let per_point = glv_cost + scalar_bit_decomp - + detect_skip_cost - + sanitize_cost - + point_decomp_digits - + scalar_relation + + detect_skip + + sanitize + + ec_hint + + point_decomp + glv_hint - + ec_hint; - - // === Point accumulation (multi-point only) === - // Each point gets: point_add(acc, R_i) + point_select_unchecked(skip). - // Plus final offset subtraction: 1 point_add + constants + 2 Limbs - // selects. - let point_add_cost = 2 * wit_add + 2 * wit_sub + 3 * wit_mul + wit_inv; - let accum = if n_points > 1 { - let accum_point_adds = n_points * point_add_cost; - let accum_point_selects = n_points * 2 * num_limbs; - // all_skipped tracking: (n_points - 1) product witnesses - let all_skipped_products = n_points - 1; - // Offset subtraction: point_add + 4×constant_limbs + 2 Limbs selects - // + 2×constant_limbs for initial acc - let offset_sub = point_add_cost + 6 * num_limbs + 2 * num_limbs; - - accum_point_adds + accum_point_selects + all_skipped_products + offset_sub - } else { - 0 - }; + + sr_witnesses; - // === Range check resolution cost === - // All points' range checks share the same LogUp tables, so we aggregate - // across n_points before computing resolution cost (table amortizes). - let mut rc_map: BTreeMap = BTreeMap::new(); + // === Shared constants (allocated once) === + // gen_x, gen_y, zero (3W) + offset_{x,y} (2×num_limbs W via constant_limbs) + let shared_constants = 3 + 2 * n; - // 1. Range checks from GLV field ops (selects generate 0 range checks) - let (rc_n_add, rc_n_sub, rc_n_mul, rc_n_inv) = count_glv_real_field_ops(half_bits, window_size); - for &(op, n_ops) in &[ - (FieldOpType::Add, rc_n_add), - (FieldOpType::Sub, rc_n_sub), - (FieldOpType::Mul, rc_n_mul), - (FieldOpType::Inv, rc_n_inv), - ] { - for (bits, count) in - range_checks_per_op(num_limbs, op, false, limb_bits, curve_modulus_bits) - { - *rc_map.entry(bits).or_default() += n_points * n_ops * count; - } - } + // === Point accumulation === + let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); // point_add + let accum = n_points * (pa_cost + 2 * n) // per-point add + skip select + + n_points.saturating_sub(1) // all_skipped products + + pa_cost + 4 * n + 2 * n // offset subtraction + constants + selects + + 2 + if n > 1 { 2 } else { 0 }; // mask + recompose - // 2. Point decomposition range checks (num_limbs > 1 only). - // 4 coordinates: px, py, rx, ry. - if num_limbs > 1 { - *rc_map.entry(limb_bits).or_default() += n_points * 4 * num_limbs; - } + // === Range check resolution === + let mut rc_map: BTreeMap = BTreeMap::new(); - // 3. Scalar relation range checks (always non-native, per point) - let sr_checks = count_scalar_relation_range_checks(native_field_bits, scalar_bits); - for (bits, count) in &sr_checks { + // GLV field ops (per point) + add_field_op_range_checks( + n_points * n_add, + n_points * n_sub, + n_points * n_mul, + n_points * n_inv, + n, + limb_bits, + curve_modulus_bits, + false, + &mut rc_map, + ); + + // Point decomposition (per point, N>1 only) + if n > 1 { + *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; + } + + // Scalar relation (per point) + for (bits, count) in &sr_range_checks { *rc_map.entry(*bits).or_default() += n_points * count; } - // 4. Accumulation range checks: n_points point_adds + 1 offset subtraction - // point_add (multi-point only) - if n_points > 1 { - let accum_point_adds = n_points + 1; // loop + offset subtraction - for &(op, n_ops) in &[ - (FieldOpType::Add, 2usize), - (FieldOpType::Sub, 2usize), - (FieldOpType::Mul, 3usize), - (FieldOpType::Inv, 1usize), - ] { - for (bits, count) in - range_checks_per_op(num_limbs, op, false, limb_bits, curve_modulus_bits) - { - *rc_map.entry(bits).or_default() += accum_point_adds * n_ops * count; - } - } - } + // Accumulation: (n_points + 1) point_adds (1 add, 5 sub, 3 mul, 1 inv each) + add_field_op_range_checks( + (n_points + 1) * 1, + (n_points + 1) * 5, + (n_points + 1) * 3, + (n_points + 1) * 1, + n, + limb_bits, + curve_modulus_bits, + false, + &mut rc_map, + ); - // 5. Compute resolution cost let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - n_points * per_point + accum + range_check_cost + n_points * per_point + shared_constants + accum + range_check_cost } -/// Total estimated witness cost for a native-field MSM using hint-verified EC -/// ops with signed-bit wNAF (w=1). -/// -/// The native path replaces expensive field inversions with prover hints -/// verified via raw R1CS constraints: -/// - `point_double_verified_native`: 4W (3 hint + 1 product) vs 12W generic -/// - `point_add_verified_native`: 3W (3 hint) vs 8W generic -/// - `verify_on_curve_native`: 2W (2 products) vs 6W generic -/// - No multi-limb arithmetic for EC ops → zero EC-related range checks +/// Native-field MSM cost: hint-verified EC ops with signed-bit wNAF (w=1). /// -/// Uses signed-bit wNAF (w=1): every digit is non-zero (±1), so we always -/// add — no conditional skip selects. +/// The native path uses prover hints verified via raw R1CS constraints: +/// - `point_double_verified_native`: 4W (3 hint + 1 product) +/// - `point_add_verified_native`: 3W (3 hint) +/// - `verify_on_curve_native`: 2W (2 products) +/// - No multi-limb arithmetic → zero EC-related range checks /// -/// For n_points >= 2, uses merged-loop optimization: all points share a -/// single doubling per bit, saving 4W × (n-1) per bit. -/// Per bit (merged): 4W (shared double) + n × 8W (2×(1W select + 3W add)). -/// Skew correction: n × 10W. +/// Uses merged-loop optimization: all points share a single doubling per bit. fn calculate_msm_witness_cost_native( native_field_bits: u32, n_points: usize, scalar_bits: usize, - _window_size: usize, ) -> usize { let half_bits = (scalar_bits + 1) / 2; - // === Costs that are always per-point === - let on_curve = 2 * 2; // 2 × verify_on_curve_native (2W each) - let glv_hint = 4; // FakeGLVHint (s1, s2, neg1, neg2) - let scalar_bit_decomp = 2 * (half_bits + 1); // signed-bit hint witnesses - let y_negate = 2 + 2 + 2; // 2 neg_y + 2 py_eff + 2 neg_py_eff - let detect_skip_cost = 8; // 2×is_zero(3) + product(1) + boolean_or(1) - let sanitize_cost = 3 + 6; // 3 constants + 6 selects - let ec_hint = if n_points > 1 { 2 } else { 0 }; // EcScalarMulHint - let scalar_relation = count_scalar_relation_witnesses(native_field_bits, scalar_bits); - - let per_point_fixed = on_curve + // === Per-point fixed costs === + let on_curve = 4; // 2 × verify_on_curve_native (2W each) + let glv_hint = 4; // s1, s2, neg1, neg2 + let scalar_bits_cost = 2 * (half_bits + 1); // 2 × (half_bits + skew) + let y_negate = 6; // 2 × 3W (neg_y, y_eff, neg_y_eff) + let detect_skip = 8; // 2×is_zero(3W) + product(1W) + or(1W) + let sanitize = 4; // 4 select_witness + let ec_hint = 4; // 2W hint + 2W selects + let (sr_wit, sr_rc) = scalar_relation_cost(native_field_bits, scalar_bits); + + let per_point = on_curve + glv_hint - + scalar_bit_decomp + + scalar_bits_cost + y_negate - + detect_skip_cost - + sanitize_cost - + scalar_relation - + ec_hint; - - // === EC loop + skew + constants === - let inline_total = if n_points == 1 { - // Single-point: separate loop (unchanged path) - let ec_wit = half_bits * 12; - let skew_correction = 10; - let offset_const = 2; - let identity_const = 2; - per_point_fixed + ec_wit + skew_correction + offset_const + identity_const - } else { - // Multi-point: merged loop with shared doubling - // Per bit: 4W (shared double) + n_points × 8W (2×(1W select + 3W add)) - let ec_wit = half_bits * (4 + 8 * n_points); - // Skew correction: 10W per point - let skew_correction = n_points * 10; - // Offset and identity constants are shared (not per-point) - let offset_const = 2; - let identity_const = 2; - n_points * per_point_fixed + ec_wit + skew_correction + offset_const + identity_const - }; - - // === Point accumulation (multi-point only) === - let accum = if n_points > 1 { - // Initial accumulator: 2W (constant witnesses for offset x,y) - let acc_init = 2; - // Per point: point_add_verified_native (3W) + 2 skip selects (2W) - let per_point_accum = n_points * (3 + 2); - // all_skipped tracking: (n_points - 1) product witnesses - let all_skipped = n_points - 1; - // Offset subtraction: 3 constants + 2 selects + point_add (3W) + 2 mask selects - let offset_sub = 3 + 2 + 3 + 2; - - acc_init + per_point_accum + all_skipped + offset_sub - } else { - 0 - }; - - // === Range check cost === - // Native EC ops produce NO range checks (no multi-limb arithmetic). - // Only scalar relation produces range checks. + + detect_skip + + sanitize + + ec_hint + + sr_wit; + + // === Shared constants === + let shared_constants = 5; // gen_x, gen_y, zero, offset_x, offset_y + + // === EC verification loop (merged, shared doubling) === + // Per bit: 4W (shared double) + n_points × 8W (2×(1W select + 3W add)) + let ec_loop = half_bits * (4 + 8 * n_points); + // Skew correction: 2 branches × (3W add + 2W select) = 10W per point + let skew = n_points * 10; + + // === Point accumulation === + let accum = 2 // initial acc constants + + n_points * 5 // add(3W) + skip_select(2W) + + n_points.saturating_sub(1) // all_skipped products + + 10; // offset sub: 3 const + 2 sel + 3 add + 2 mask + + // === Range checks (only from scalar relation for native) === let mut rc_map: BTreeMap = BTreeMap::new(); - let sr_checks = count_scalar_relation_range_checks(native_field_bits, scalar_bits); - for (bits, count) in &sr_checks { + for (bits, count) in &sr_rc { *rc_map.entry(*bits).or_default() += n_points * count; } let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - inline_total + accum + range_check_cost + n_points * per_point + shared_constants + ec_loop + skew + accum + range_check_cost } /// Picks the widest limb size for scalar-relation multi-limb arithmetic that /// fits inside the native field without overflow. /// -/// Searches for the minimum number of limbs N (starting from 1) such that -/// the schoolbook column equations don't overflow the native field. Fewer -/// limbs means wider limbs, which means fewer witnesses and range checks. -/// /// For BN254 (254-bit native field, ~254-bit order): N=3 @ 85-bit limbs. /// For small curves where half_scalar × full_scalar fits natively: N=1. pub(super) fn scalar_relation_limb_bits(native_field_bits: u32, order_bits: usize) -> u32 { let half_bits = (order_bits + 1) / 2; - // N=1 is valid only if the mul product (half_scalar * full_scalar) - // fits in the native field without wrapping. + // N=1 is valid only if mul product fits in the native field. if half_bits + order_bits < native_field_bits as usize { return order_bits as u32; } - // For N>=2: find minimum N where schoolbook column equations fit. for n in 2..=super::MAX_LIMBS { let lb = ((order_bits + n - 1) / n) as u32; if column_equation_fits_native_field(native_field_bits, lb, n) { @@ -538,27 +353,9 @@ pub(super) fn scalar_relation_limb_bits(native_field_bits: u32, order_bits: usiz /// Check whether schoolbook column equation values fit in the native field. /// -/// In `mul_mod_p_multi`, the schoolbook multiplication verifies `a·b = p·q + r` -/// via column equations that include product sums, carry offsets, and outgoing -/// carries. Both sides of each column equation must evaluate to less than the -/// native field modulus as **integers** — if they overflow, the field's modular -/// reduction makes `LHS ≡ RHS (mod p)` weaker than `LHS = RHS`, breaking -/// soundness. -/// -/// The maximum integer value across either side of any column equation is -/// bounded by: -/// -/// `2^(2W + ceil(log2(N)) + 3)` -/// -/// where `W = limb_bits` and `N = num_limbs`. This accounts for: -/// - Up to N cross-products per column, each < `2^(2W)` -/// - The carry offset `2^(2W + ceil(log2(N)) + 1)` (dominant term) -/// - Outgoing carry term `2^W * offset_carry` on the RHS -/// -/// Since the native field modulus satisfies `p >= 2^(native_field_bits - 1)`, -/// the conservative soundness condition is: -/// -/// `2 * limb_bits + ceil(log2(num_limbs)) + 3 < native_field_bits` +/// The maximum integer value in any column equation is bounded by +/// `2^(2W + ceil(log2(N)) + 3)` where W = limb_bits, N = num_limbs. +/// This must be less than the native field modulus (~`2^native_field_bits`). pub fn column_equation_fits_native_field( native_field_bits: u32, limb_bits: u32, @@ -573,13 +370,8 @@ pub fn column_equation_fits_native_field( /// Search for optimal (limb_bits, window_size) minimizing witness cost. /// -/// Searches limb_bits ∈ [8..max] and window_size ∈ [2..8]. -/// Each candidate is checked for column equation soundness: the schoolbook -/// multiplication's intermediate values must fit in the native field without -/// modular wraparound (see [`column_equation_fits_native_field`]). -/// -/// `is_native` should come from `CurveParams::is_native_field()` which -/// compares actual modulus values, not just bit widths. +/// Searches limb_bits ∈ \[8..max\] and window_size ∈ \[2..8\]. +/// Each candidate is checked for column equation soundness. pub fn get_optimal_msm_params( native_field_bits: u32, curve_modulus_bits: u32, @@ -588,24 +380,11 @@ pub fn get_optimal_msm_params( is_native: bool, ) -> (u32, usize) { if is_native { - let mut best_cost = usize::MAX; - let mut best_window = 4; - for ws in 2..=8 { - let cost = calculate_msm_witness_cost( - native_field_bits, - curve_modulus_bits, - n_points, - scalar_bits, - ws, - native_field_bits, - true, - ); - if cost < best_cost { - best_cost = cost; - best_window = ws; - } - } - return (native_field_bits, best_window); + // Native path uses signed-bit wNAF (w=1), no limb decomposition. + // Window size is unused; return a default. + let cost = calculate_msm_witness_cost_native(native_field_bits, n_points, scalar_bits); + let _ = cost; // cost is the same regardless of window_size + return (native_field_bits, 4); } let max_limb_bits = (native_field_bits.saturating_sub(4)) / 2; @@ -614,7 +393,7 @@ pub fn get_optimal_msm_params( let mut best_window = 4; for lb in 8..=max_limb_bits { - let num_limbs = ((curve_modulus_bits as usize) + (lb as usize) - 1) / (lb as usize); + let num_limbs = ceil_div(curve_modulus_bits as usize, lb as usize); if !column_equation_fits_native_field(native_field_bits, lb, num_limbs) { continue; } @@ -653,7 +432,7 @@ mod tests { #[test] fn test_optimal_params_secp256r1() { let (limb_bits, window_size) = get_optimal_msm_params(254, 256, 1, 256, false); - let num_limbs = ((256 + limb_bits - 1) / limb_bits) as usize; + let num_limbs = ceil_div(256, limb_bits as usize); assert!( column_equation_fits_native_field(254, limb_bits, num_limbs), "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" @@ -664,7 +443,7 @@ mod tests { #[test] fn test_optimal_params_goldilocks() { let (limb_bits, window_size) = get_optimal_msm_params(254, 64, 1, 64, false); - let num_limbs = ((64 + limb_bits - 1) / limb_bits) as usize; + let num_limbs = ceil_div(64, limb_bits as usize); assert!( column_equation_fits_native_field(254, limb_bits, num_limbs), "optimizer selected unsound limb_bits={limb_bits} (N={num_limbs})" @@ -689,14 +468,17 @@ mod tests { } #[test] - fn test_scalar_relation_witnesses_grumpkin() { - let sr = count_scalar_relation_witnesses(254, 256); + fn test_scalar_relation_cost_grumpkin() { + let (sr, rc) = scalar_relation_cost(254, 256); assert!(sr > 50 && sr < 200, "unexpected scalar_relation={sr}"); + let total_rc: usize = rc.values().sum(); + assert!(total_rc > 30, "too few range checks: {total_rc}"); + assert!(total_rc < 200, "too many range checks: {total_rc}"); } #[test] - fn test_scalar_relation_witnesses_small_curve() { - let sr = count_scalar_relation_witnesses(254, 64); + fn test_scalar_relation_cost_small_curve() { + let (sr, _) = scalar_relation_cost(254, 64); assert!( sr < 100, "64-bit curve scalar_relation={sr} should be < 100" @@ -704,72 +486,11 @@ mod tests { } #[test] - fn test_is_zero_cost_independent_of_num_limbs() { - let (_, _, _, _, n_is_zero_w4, ..) = count_glv_field_ops(128, 4); - let (_, _, _, _, n_is_zero_w3, ..) = count_glv_field_ops(128, 3); - assert!(n_is_zero_w4 > 0); - assert!(n_is_zero_w3 > 0); - } - - #[test] - fn test_inv_single_limb_witness_count() { + fn test_field_op_witnesses_single_limb() { // inv_mod_p_single: a_inv(1) + mul_mod_p_single(5) = 6 - assert_eq!(witnesses_per_op(1, FieldOpType::Inv, false), 6); - } - - #[test] - fn test_selects_counted_separately() { - // Verify selects are returned as separate counts, not mixed into - // field ops. - let (_, _, _, _, _, pt_sel, coord_sel) = count_glv_field_ops(128, 4); - assert!(pt_sel > 0, "expected point selects > 0"); - assert_eq!(coord_sel, 2, "expected 2 coord selects (y-negation)"); - } - - #[test] - fn test_select_cost_scales_with_num_limbs() { - // For N=3, select cost should be 2*N per point select (1 witness - // per limb per coordinate, inlined select_witness). - let half_bits = 129; - let (_, _, _, _, _, n_pt_sel, n_coord_sel) = count_glv_field_ops(half_bits, 4); - let select_cost_n1 = n_pt_sel * 2 * 1 + n_coord_sel * 1; - let select_cost_n3 = n_pt_sel * 2 * 3 + n_coord_sel * 3; - // N=3 should be exactly 3× N=1 for selects (linear in num_limbs) - assert_eq!(select_cost_n3, select_cost_n1 * 3); - } - - #[test] - fn test_range_checks_per_op_native() { - assert!(range_checks_per_op(1, FieldOpType::Add, true, 254, 254).is_empty()); - assert!(range_checks_per_op(1, FieldOpType::Mul, true, 254, 254).is_empty()); - assert!(range_checks_per_op(1, FieldOpType::Inv, true, 254, 254).is_empty()); - } - - #[test] - fn test_range_checks_per_op_single_limb() { - let rc = range_checks_per_op(1, FieldOpType::Add, false, 64, 64); - assert_eq!(rc, vec![(64, 1)]); - let rc = range_checks_per_op(1, FieldOpType::Inv, false, 64, 64); - assert_eq!(rc, vec![(64, 2)]); - } - - #[test] - fn test_range_checks_per_op_multi_limb() { - // N=3, limb_bits=86: carry_bits = 86 + ceil(log2(3)) + 2 = 90 - let rc = range_checks_per_op(3, FieldOpType::Add, false, 86, 256); - assert_eq!(rc, vec![(86, 6)]); - let rc = range_checks_per_op(3, FieldOpType::Mul, false, 86, 256); - assert_eq!(rc, vec![(86, 9), (90, 4)]); - let rc = range_checks_per_op(3, FieldOpType::Inv, false, 86, 256); - assert_eq!(rc, vec![(86, 12), (90, 4)]); - } - - #[test] - fn test_scalar_relation_range_checks_256bit() { - let rc = count_scalar_relation_range_checks(254, 256); - let total: usize = rc.values().sum(); - assert!(total > 30, "too few range checks: {total}"); - assert!(total < 200, "too many range checks: {total}"); + assert_eq!(field_op_witnesses(0, 0, 0, 1, 1, false), 6); + // add_mod_p_single: 5 + assert_eq!(field_op_witnesses(1, 0, 0, 0, 1, false), 5); } #[test] diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index ba70b8574..1aa176f8b 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -1,5 +1,5 @@ use { - ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField}, + ark_ff::{Field, PrimeField}, provekit_common::FieldElement, }; @@ -571,7 +571,6 @@ mod tests { } } -#[allow(dead_code)] pub fn secp256r1_params() -> CurveParams { CurveParams { field_modulus_p: [ diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 975963ed6..c3deeff61 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -1,5 +1,5 @@ use { - super::{select_witness, FieldOps}, + super::{multi_limb_ops::MultiLimbOps, Limbs}, crate::noir_to_r1cs::NoirToR1CSCompiler, provekit_common::{witness::WitnessBuilder, FieldElement}, }; @@ -19,7 +19,7 @@ use { /// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using /// compute_is_zero and conditionally select the point-at-infinity /// result before calling this function. -pub fn point_double(ops: &mut F, x1: F::Elem, y1: F::Elem) -> (F::Elem, F::Elem) { +pub fn point_double(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { let a = ops.curve_a(); // Computing numerator = 3 * x1^2 + a @@ -63,13 +63,13 @@ pub fn point_double(ops: &mut F, x1: F::Elem, y1: F::Elem) -> (F::E /// This function does NOT handle either case — the constraint system /// will be unsatisfiable if x1 = x2. The caller must detect this /// and branch accordingly. -pub fn point_add( - ops: &mut F, - x1: F::Elem, - y1: F::Elem, - x2: F::Elem, - y2: F::Elem, -) -> (F::Elem, F::Elem) { +pub fn point_add( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { // Computing lambda = (y2 - y1) / (x2 - x1) let numerator = ops.sub(y2, y1); let denominator = ops.sub(x2, x1); @@ -91,12 +91,12 @@ pub fn point_add( /// Conditional point select without boolean constraint on `flag`. /// Caller must ensure `flag` is already constrained boolean. -pub fn point_select_unchecked( - ops: &mut F, +pub fn point_select_unchecked( + ops: &mut MultiLimbOps, flag: usize, - on_false: (F::Elem, F::Elem), - on_true: (F::Elem, F::Elem), -) -> (F::Elem, F::Elem) { + on_false: (Limbs, Limbs), + on_true: (Limbs, Limbs), +) -> (Limbs, Limbs) { let x = ops.select_unchecked(flag, on_false.0, on_true.0); let y = ops.select_unchecked(flag, on_false.1, on_true.1); (x, y) @@ -106,12 +106,12 @@ pub fn point_select_unchecked( /// /// T\[0\] = P (dummy entry, used when window digit = 0) /// T\[1\] = P, T\[2\] = 2P, T\[i\] = T\[i-1\] + P for i >= 3. -fn build_point_table( - ops: &mut F, - px: F::Elem, - py: F::Elem, +fn build_point_table( + ops: &mut MultiLimbOps, + px: Limbs, + py: Limbs, table_size: usize, -) -> Vec<(F::Elem, F::Elem)> { +) -> Vec<(Limbs, Limbs)> { assert!(table_size >= 2); let mut table = Vec::with_capacity(table_size); table.push((px, py)); // T[0] = P (dummy) @@ -135,13 +135,13 @@ fn build_point_table( /// /// Each bit is constrained boolean exactly once, then all subsequent selects /// on that bit use the unchecked variant. -fn table_lookup( - ops: &mut F, - table: &[(F::Elem, F::Elem)], +fn table_lookup( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], bits: &[usize], -) -> (F::Elem, F::Elem) { +) -> (Limbs, Limbs) { assert_eq!(table.len(), 1 << bits.len()); - let mut current: Vec<(F::Elem, F::Elem)> = table.to_vec(); + let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); // Process bits from MSB to LSB for &bit in bits.iter().rev() { ops.constrain_flag(bit); // constrain boolean once per bit @@ -175,21 +175,21 @@ fn table_lookup( /// 5. point_add(acc, T_R\[d2\]) + is_zero(d2) + point_select /// /// Returns the final accumulator (x, y). -pub fn scalar_mul_glv( - ops: &mut F, +pub fn scalar_mul_glv( + ops: &mut MultiLimbOps, // Point P (table 1) - px: F::Elem, - py: F::Elem, + px: Limbs, + py: Limbs, s1_bits: &[usize], // 128 bit witnesses for |s1| // Point R (table 2) — the claimed output - rx: F::Elem, - ry: F::Elem, + rx: Limbs, + ry: Limbs, s2_bits: &[usize], // 128 bit witnesses for |s2| // Shared parameters window_size: usize, - offset_x: F::Elem, - offset_y: F::Elem, -) -> (F::Elem, F::Elem) { + offset_x: Limbs, + offset_y: Limbs, +) -> (Limbs, Limbs) { let n1 = s1_bits.len(); let n2 = s2_bits.len(); assert_eq!(n1, n2, "s1 and s2 must have the same number of bits"); @@ -197,8 +197,6 @@ pub fn scalar_mul_glv( let w = window_size; let table_size = 1 << w; - // TODO : implement lazy overflow as used in gnark. - // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R let table_p = build_point_table(ops, px, py, table_size); let table_r = build_point_table(ops, rx, ry, table_size); diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index c52a604db..d39ac6572 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,18 +1,24 @@ -pub mod cost_model; -pub mod curve; -pub mod ec_points; -pub mod multi_limb_arith; -pub mod multi_limb_ops; +pub(crate) mod cost_model; +pub(crate) mod curve; +pub(crate) mod ec_points; +pub(crate) mod multi_limb_arith; +pub(crate) mod multi_limb_ops; mod native; mod non_native; mod scalar_relation; use { - crate::{msm::multi_limb_arith::compute_is_zero, noir_to_r1cs::NoirToR1CSCompiler}, - ark_ff::{AdditiveGroup, Field, PrimeField}, + crate::{ + constraint_helpers::{ + add_constant_witness, compute_boolean_or, constrain_boolean, select_witness, + }, + msm::multi_limb_arith::compute_is_zero, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, curve::CurveParams, provekit_common::{ - witness::{ConstantOrR1CSWitness, ConstantTerm, SumTerm, WitnessBuilder}, + witness::{ConstantOrR1CSWitness, SumTerm, WitnessBuilder}, FieldElement, }, std::collections::BTreeMap, @@ -28,9 +34,9 @@ pub const MAX_LIMBS: usize = 32; /// A fixed-capacity array of witness indices, indexed by limb position. /// -/// This type is `Copy`, so it can be used as `FieldOps::Elem` without -/// requiring const generics or dispatch macros. The runtime `len` field -/// tracks how many limbs are actually in use. +/// This type is `Copy`, so it can be passed by value without requiring +/// const generics or dispatch macros. The runtime `len` field tracks how +/// many limbs are actually in use. #[derive(Clone, Copy)] pub struct Limbs { data: [usize; MAX_LIMBS], @@ -114,125 +120,9 @@ impl std::ops::IndexMut for Limbs { } // --------------------------------------------------------------------------- -// FieldOps trait -// --------------------------------------------------------------------------- - -pub trait FieldOps { - type Elem: Copy; - - fn add(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn sub(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn mul(&mut self, a: Self::Elem, b: Self::Elem) -> Self::Elem; - fn inv(&mut self, a: Self::Elem) -> Self::Elem; - fn curve_a(&mut self) -> Self::Elem; - - /// Constrains `flag` to be boolean (`flag * flag = flag`). - fn constrain_flag(&mut self, flag: usize); - - /// Conditional select without boolean constraint on `flag`. - /// Caller must ensure `flag` is already constrained boolean. - fn select_unchecked( - &mut self, - flag: usize, - on_false: Self::Elem, - on_true: Self::Elem, - ) -> Self::Elem; - - /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if - /// `flag` is 0. Constrains `flag` to be boolean (`flag * flag = flag`). - fn select(&mut self, flag: usize, on_false: Self::Elem, on_true: Self::Elem) -> Self::Elem { - self.constrain_flag(flag); - self.select_unchecked(flag, on_false, on_true) - } - - /// Checks if a native witness value is zero. - /// Returns a boolean witness: 1 if zero, 0 if non-zero. - fn is_zero(&mut self, value: usize) -> usize; - - /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. - /// Does NOT constrain bits to be boolean — caller must ensure that. - fn pack_bits(&mut self, bits: &[usize]) -> usize; - - /// Returns a constant field element from its limb decomposition. - fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Self::Elem; -} - -// --------------------------------------------------------------------------- -// Private helpers +// Private helpers (MSM-specific) // --------------------------------------------------------------------------- -/// Constrains `flag` to be boolean: `flag * flag = flag`. -pub(crate) fn constrain_boolean(compiler: &mut NoirToR1CSCompiler, flag: usize) { - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, flag)], - &[(FieldElement::ONE, flag)], - &[(FieldElement::ONE, flag)], - ); -} - -/// Single-witness conditional select: `out = on_false + flag * (on_true - -/// on_false)`. -/// -/// Uses a single witness + single R1CS constraint: -/// flag * (on_true - on_false) = result - on_false -pub(crate) fn select_witness( - compiler: &mut NoirToR1CSCompiler, - flag: usize, - on_false: usize, - on_true: usize, -) -> usize { - // When both branches are the same witness, result is trivially that witness. - if on_false == on_true { - return on_false; - } - let result = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::SelectWitness { - output: result, - flag, - on_false, - on_true, - }); - // flag * (on_true - on_false) = result - on_false - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, flag)], - &[(FieldElement::ONE, on_true), (-FieldElement::ONE, on_false)], - &[(FieldElement::ONE, result), (-FieldElement::ONE, on_false)], - ); - result -} - -/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. -pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { - let terms: Vec = bits - .iter() - .enumerate() - .map(|(i, &bit)| SumTerm(Some(FieldElement::from(1u128 << i)), bit)) - .collect(); - compiler.add_sum(terms) -} - -/// Computes `a OR b` for two boolean witnesses: `1 - (1 - a)(1 - b)`. -/// Does NOT constrain a or b to be boolean — caller must ensure that. -/// -/// Uses a single witness + single R1CS constraint: -/// (1 - a) * (1 - b) = 1 - result -fn compute_boolean_or(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) -> usize { - let one = compiler.witness_one(); - let result = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::BooleanOr { - output: result, - a, - b, - }); - // (1 - a) * (1 - b) = 1 - result - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, one), (-FieldElement::ONE, a)], - &[(FieldElement::ONE, one), (-FieldElement::ONE, b)], - &[(FieldElement::ONE, one), (-FieldElement::ONE, result)], - ); - result -} - /// Detects whether a point-scalar pair is degenerate (scalar=0 or point at /// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if /// degenerate). @@ -282,16 +172,6 @@ fn sanitize_point_scalar( } } -/// Constrains `a * b = 0`. -fn constrain_product_zero(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, a)], &[(FieldElement::ONE, b)], &[( - FieldElement::ZERO, - compiler.witness_one(), - )]); -} - /// Negate a y-coordinate and conditionally select based on a sign flag. /// Returns `(y_eff, neg_y_eff)` where: /// - if `neg_flag=0`: `y_eff = y`, `neg_y_eff = -y` @@ -430,52 +310,22 @@ fn add_single_msm( (curve.modulus_bits() as usize + limb_bits as usize - 1) / limb_bits as usize }; - process_single_msm( - compiler, - &point_wits, - &scalar_wits, - outputs, - num_limbs, - limb_bits, - window_size, - range_checks, - curve, - ); -} - -/// Process a full single-MSM with runtime `num_limbs`. -/// -/// Dispatches to single-point or multi-point path based on the number of -/// input points. -fn process_single_msm( - compiler: &mut NoirToR1CSCompiler, - point_wits: &[usize], - scalar_wits: &[usize], - outputs: (usize, usize, usize), - num_limbs: usize, - limb_bits: u32, - window_size: usize, - range_checks: &mut BTreeMap>, - curve: &CurveParams, -) { let n_points = point_wits.len() / 3; - if n_points == 1 { - process_single_point_msm( + if curve.is_native_field() { + native::process_multi_point_native( compiler, - point_wits, - scalar_wits, + &point_wits, + &scalar_wits, outputs, - num_limbs, - limb_bits, - window_size, + n_points, range_checks, curve, ); } else { - process_multi_point_msm( + non_native::process_multi_point_non_native( compiler, - point_wits, - scalar_wits, + &point_wits, + &scalar_wits, outputs, n_points, num_limbs, @@ -487,148 +337,6 @@ fn process_single_msm( } } -/// Single-point MSM: R = \[s\]P with degenerate-case handling. -/// -/// The ACIR output (out_x, out_y) is the result directly. Sanitizes inputs -/// to handle scalar=0 and point-at-infinity, then verifies via FakeGLV. -fn process_single_point_msm<'a>( - mut compiler: &'a mut NoirToR1CSCompiler, - point_wits: &[usize], - scalar_wits: &[usize], - outputs: (usize, usize, usize), - num_limbs: usize, - limb_bits: u32, - window_size: usize, - range_checks: &'a mut BTreeMap>, - curve: &CurveParams, -) { - let (out_x, out_y, out_inf) = outputs; - - // Allocate constants - let one = compiler.witness_one(); - let gen_x_witness = - add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.0)); - let gen_y_witness = - add_constant_witness(compiler, curve::curve_native_point_fe(&curve.generator.1)); - let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); - - // Sanitize inputs: swap in generator G and scalar=1 when degenerate - let san = sanitize_point_scalar( - compiler, - point_wits[0], - point_wits[1], - scalar_wits[0], - scalar_wits[1], - point_wits[2], - gen_x_witness, - gen_y_witness, - zero_witness, - one, - ); - - // Sanitize R (output point): when is_skip=1, R must be G (since [1]*G = G) - let sanitized_rx = select_witness(compiler, san.is_skip, out_x, gen_x_witness); - let sanitized_ry = select_witness(compiler, san.is_skip, out_y, gen_y_witness); - - if curve.is_native_field() { - // Native-field optimized path: hint-verified EC + wNAF - native::verify_point_fakeglv_native( - compiler, - range_checks, - san.px, - san.py, - sanitized_rx, - sanitized_ry, - san.s_lo, - san.s_hi, - curve, - ); - } else { - // Generic multi-limb path - let (px, py) = non_native::decompose_point_to_limbs( - compiler, - san.px, - san.py, - num_limbs, - limb_bits, - range_checks, - ); - let (rx, ry) = non_native::decompose_point_to_limbs( - compiler, - sanitized_rx, - sanitized_ry, - num_limbs, - limb_bits, - range_checks, - ); - (compiler, _) = non_native::verify_point_fakeglv( - compiler, - range_checks, - px, - py, - rx, - ry, - san.s_lo, - san.s_hi, - num_limbs, - limb_bits, - window_size, - curve, - ); - } - - // Mask output: when is_skip, output must be (0, 0, 1) - constrain_equal(compiler, out_inf, san.is_skip); - constrain_product_zero(compiler, san.is_skip, out_x); - constrain_product_zero(compiler, san.is_skip, out_y); -} - -/// Multi-point MSM: computes R_i = \[s_i\]P_i via hints, verifies each with -/// FakeGLV, then accumulates R_i's with offset-based accumulation and skip -/// handling. -/// -/// When `curve.is_native_field()`, uses a merged-loop optimization: all -/// points share a single doubling per bit, saving 4*(n-1) constraints per -/// bit of the half-scalar (≈512 for 2 points on Grumpkin). -fn process_multi_point_msm( - compiler: &mut NoirToR1CSCompiler, - point_wits: &[usize], - scalar_wits: &[usize], - outputs: (usize, usize, usize), - n_points: usize, - num_limbs: usize, - limb_bits: u32, - window_size: usize, - range_checks: &mut BTreeMap>, - curve: &CurveParams, -) { - if curve.is_native_field() { - native::process_multi_point_native( - compiler, - point_wits, - scalar_wits, - outputs, - n_points, - range_checks, - curve, - ); - return; - } - - non_native::process_multi_point_non_native( - compiler, - point_wits, - scalar_wits, - outputs, - n_points, - num_limbs, - limb_bits, - window_size, - range_checks, - curve, - ); -} - /// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. fn emit_fakeglv_hint( compiler: &mut NoirToR1CSCompiler, @@ -654,46 +362,110 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes } } -/// Creates a constant witness with the given value, pinned by an R1CS -/// constraint so that a malicious prover cannot set it to an arbitrary value. -fn add_constant_witness(compiler: &mut NoirToR1CSCompiler, value: FieldElement) -> usize { - let w = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::Constant(ConstantTerm(w, value))); - // Pin: 1 * w = value * 1 (embeds the constant into the constraint matrix) - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, w)], - &[(value, compiler.witness_one())], - ); - w -} - -/// Constrains a witness to equal a known constant value. -/// Uses the constant as an R1CS coefficient — no witness needed for the -/// expected value. Use this for identity checks where the witness must equal -/// a compile-time-known value. -fn constrain_to_constant(compiler: &mut NoirToR1CSCompiler, witness: usize, value: FieldElement) { - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, witness)], - &[(value, compiler.witness_one())], - ); -} +#[cfg(test)] +mod tests { + use {super::*, crate::noir_to_r1cs::NoirToR1CSCompiler}; + + /// Verify that the non-native (SECP256R1) single-point MSM path generates + /// constraints without panicking. This does multi-limb arithmetic, + /// range checks, and FakeGLV verification — the entire non-native code path + /// that has no Noir e2e coverage for now : ) + #[test] + fn test_secp256r1_single_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // Allocate witness slots for: px, py, inf, s_lo, s_hi, out_x, out_y, out_inf + // (witness 0 is the constant-one witness) + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(8); + let px = base; + let py = base + 1; + let inf = base + 2; + let s_lo = base + 3; + let s_hi = base + 4; + let out_x = base + 5; + let out_y = base + 6; + let out_inf = base + 7; + + let points = vec![ + ConstantOrR1CSWitness::Witness(px), + ConstantOrR1CSWitness::Witness(py), + ConstantOrR1CSWitness::Witness(inf), + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(s_lo), + ConstantOrR1CSWitness::Witness(s_hi), + ]; + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); -/// Constrains two witnesses to be equal: `a - b = 0`. -fn constrain_equal(compiler: &mut NoirToR1CSCompiler, a: usize, b: usize) { - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, a), (-FieldElement::ONE, b)], - &[(FieldElement::ZERO, compiler.witness_one())], - ); -} + assert!( + n_constraints > 100, + "expected substantial constraints for non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 100, + "expected substantial witnesses for non-native MSM, got {n_witnesses}" + ); + assert!( + !range_checks.is_empty(), + "non-native MSM should produce range checks" + ); + } -/// Constrains a witness to be zero: `w = 0`. -fn constrain_zero(compiler: &mut NoirToR1CSCompiler, w: usize) { - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, w)], - &[(FieldElement::ZERO, compiler.witness_one())], - ); + /// Verify that the non-native multi-point MSM path (2 points, SECP256R1) + /// generates constraints. does the multi-point accumulation and offset + /// subtraction logic for the non-native path. + #[test] + fn test_secp256r1_multi_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // 2 points: px1, py1, inf1, px2, py2, inf2, s1_lo, s1_hi, s2_lo, s2_hi, + // out_x, out_y, out_inf + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(13); + + let points = vec![ + ConstantOrR1CSWitness::Witness(base), // px1 + ConstantOrR1CSWitness::Witness(base + 1), // py1 + ConstantOrR1CSWitness::Witness(base + 2), // inf1 + ConstantOrR1CSWitness::Witness(base + 3), // px2 + ConstantOrR1CSWitness::Witness(base + 4), // py2 + ConstantOrR1CSWitness::Witness(base + 5), // inf2 + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(base + 6), // s1_lo + ConstantOrR1CSWitness::Witness(base + 7), // s1_hi + ConstantOrR1CSWitness::Witness(base + 8), // s2_lo + ConstantOrR1CSWitness::Witness(base + 9), // s2_hi + ]; + let out_x = base + 10; + let out_y = base + 11; + let out_inf = base + 12; + + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); + + // Multi-point should produce more constraints than single-point + assert!( + n_constraints > 200, + "expected substantial constraints for 2-point non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 200, + "expected substantial witnesses for 2-point non-native MSM, got {n_witnesses}" + ); + } } diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index d16f26c13..a3d6707df 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -1,12 +1,14 @@ -//! `MultiLimbOps` — unified FieldOps implementation parameterized by runtime -//! limb count. +//! `MultiLimbOps` — field arithmetic parameterized by runtime limb count. //! -//! Uses `Limbs` (a fixed-capacity Copy type) as `FieldOps::Elem`, enabling -//! arbitrary limb counts without const generics or dispatch macros. +//! Uses `Limbs` (a fixed-capacity Copy type) as the element representation, +//! enabling arbitrary limb counts without const generics or dispatch macros. use { - super::{multi_limb_arith, FieldOps, Limbs}, - crate::noir_to_r1cs::NoirToR1CSCompiler, + super::{multi_limb_arith, Limbs}, + crate::{ + constraint_helpers::{constrain_boolean, pack_bits_helper, select_witness}, + noir_to_r1cs::NoirToR1CSCompiler, + }, ark_ff::{AdditiveGroup, Field}, provekit_common::{ witness::{ConstantTerm, SumTerm, WitnessBuilder}, @@ -30,6 +32,60 @@ pub struct MultiLimbParams { pub modulus_fe: Option, } +impl MultiLimbParams { + /// Build params for EC field operations (mod field_modulus_p). + pub fn for_field_modulus( + num_limbs: usize, + limb_bits: u32, + curve: &super::curve::CurveParams, + ) -> Self { + let is_native = curve.is_native_field(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if !is_native { + Some(curve.p_native_fe()) + } else { + None + }; + Self { + num_limbs, + limb_bits, + p_limbs: curve.p_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.field_modulus_p, + curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + is_native, + modulus_fe, + } + } + + /// Build params for scalar relation verification (mod curve_order_n). + pub fn for_curve_order( + num_limbs: usize, + limb_bits: u32, + curve: &super::curve::CurveParams, + ) -> Self { + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + let modulus_fe = if num_limbs == 1 { + Some(super::curve::curve_native_point_fe(&curve.curve_order_n)) + } else { + None + }; + Self { + num_limbs, + limb_bits, + p_limbs: curve.curve_order_n_limbs(limb_bits, num_limbs), + p_minus_1_limbs: curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs), + two_pow_w, + modulus_raw: curve.curve_order_n, + curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + is_native: false, /* always non-native for + * scalar relation */ + modulus_fe, + } + } +} + /// Unified field operations struct parameterized by runtime limb count. pub struct MultiLimbOps<'a, 'p> { pub compiler: &'a mut NoirToR1CSCompiler, @@ -56,12 +112,8 @@ impl MultiLimbOps<'_, '_> { let zero = self.constant_limbs(&zero_vals); self.sub(zero, value) } -} -impl FieldOps for MultiLimbOps<'_, '_> { - type Elem = Limbs; - - fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { + pub fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); if self.is_native_single() { @@ -101,7 +153,7 @@ impl FieldOps for MultiLimbOps<'_, '_> { } } - fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { + pub fn sub(&mut self, a: Limbs, b: Limbs) -> Limbs { debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); if self.is_native_single() { @@ -142,7 +194,7 @@ impl FieldOps for MultiLimbOps<'_, '_> { } } - fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { + pub fn mul(&mut self, a: Limbs, b: Limbs) -> Limbs { debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); debug_assert_eq!(b.len(), self.n(), "b.len() != num_limbs"); if self.is_native_single() { @@ -173,7 +225,7 @@ impl FieldOps for MultiLimbOps<'_, '_> { } } - fn inv(&mut self, a: Limbs) -> Limbs { + pub fn inv(&mut self, a: Limbs) -> Limbs { debug_assert_eq!(a.len(), self.n(), "a.len() != num_limbs"); if self.is_native_single() { let a_inv = self.compiler.num_witnesses(); @@ -205,7 +257,7 @@ impl FieldOps for MultiLimbOps<'_, '_> { } } - fn curve_a(&mut self) -> Limbs { + pub fn curve_a(&mut self) -> Limbs { let n = self.n(); let mut out = Limbs::new(n); for i in 0..n { @@ -224,28 +276,43 @@ impl FieldOps for MultiLimbOps<'_, '_> { out } - fn constrain_flag(&mut self, flag: usize) { - super::constrain_boolean(self.compiler, flag); + /// Constrains `flag` to be boolean (`flag * flag = flag`). + pub fn constrain_flag(&mut self, flag: usize) { + constrain_boolean(self.compiler, flag); } - fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + /// Conditional select without boolean constraint on `flag`. + /// Caller must ensure `flag` is already constrained boolean. + pub fn select_unchecked(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { let n = self.n(); let mut out = Limbs::new(n); for i in 0..n { - out[i] = super::select_witness(self.compiler, flag, on_false[i], on_true[i]); + out[i] = select_witness(self.compiler, flag, on_false[i], on_true[i]); } out } - fn is_zero(&mut self, value: usize) -> usize { + /// Conditional select: returns `on_true` if `flag` is 1, `on_false` if + /// `flag` is 0. Constrains `flag` to be boolean. + pub fn select(&mut self, flag: usize, on_false: Limbs, on_true: Limbs) -> Limbs { + self.constrain_flag(flag); + self.select_unchecked(flag, on_false, on_true) + } + + /// Checks if a native witness value is zero. + /// Returns a boolean witness: 1 if zero, 0 if non-zero. + pub fn is_zero(&mut self, value: usize) -> usize { multi_limb_arith::compute_is_zero(self.compiler, value) } - fn pack_bits(&mut self, bits: &[usize]) -> usize { - super::pack_bits_helper(self.compiler, bits) + /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. + /// Does NOT constrain bits to be boolean — caller must ensure that. + pub fn pack_bits(&mut self, bits: &[usize]) -> usize { + pack_bits_helper(self.compiler, bits) } - fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { + /// Returns a constant field element from its limb decomposition. + pub fn constant_limbs(&mut self, limbs: &[FieldElement]) -> Limbs { let n = self.n(); assert_eq!( limbs.len(), diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs index fb97884f4..7583977de 100644 --- a/provekit/r1cs-compiler/src/msm/native.rs +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -5,11 +5,16 @@ use { super::{ - add_constant_witness, constrain_boolean, constrain_equal, constrain_to_constant, curve, - ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, negate_y_signed_native, - sanitize_point_scalar, scalar_relation, select_witness, + curve, ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, + negate_y_signed_native, sanitize_point_scalar, scalar_relation, + }, + crate::{ + constraint_helpers::{ + add_constant_witness, constrain_boolean, constrain_equal, constrain_to_constant, + select_witness, + }, + noir_to_r1cs::NoirToR1CSCompiler, }, - crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{AdditiveGroup, Field}, curve::CurveParams, provekit_common::{witness::WitnessBuilder, FieldElement}, @@ -20,7 +25,7 @@ use { /// /// Holds the inputs needed by `scalar_mul_merged_native_wnaf` to process /// one point's P and R branches inside the shared-doubling loop. -pub(super) struct NativePointData { +struct NativePointData { px: usize, py_eff: usize, neg_py_eff: usize, @@ -33,95 +38,7 @@ pub(super) struct NativePointData { s2_skew: usize, } -/// Native-field FakeGLV verification using hint-verified EC ops. -/// -/// This path is used when `curve.is_native_field()` and replaces -/// `verify_point_fakeglv` for significant constraint savings. -/// -/// Key differences from the generic path: -/// - EC ops use hint-verified formulas (4W+4C per double vs ~12, 3W+3C per add -/// vs ~8) -/// - On-curve checks use raw constraints (2W+3C vs ~6W+6C) -/// - Still uses binary bit decomposition + windowed scalar mul (same table -/// structure) -pub(super) fn verify_point_fakeglv_native( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - px: usize, - py: usize, - rx: usize, - ry: usize, - s_lo: usize, - s_hi: usize, - curve: &CurveParams, -) { - // Step 1: On-curve checks for P and R (native) - ec_points::verify_on_curve_native(compiler, px, py, curve); - ec_points::verify_on_curve_native(compiler, rx, ry, curve); - - // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 - let (s1_witness, s2_witness, neg1_witness, neg2_witness) = - emit_fakeglv_hint(compiler, s_lo, s_hi, curve); - - // Step 3: Signed-bit decomposition - let half_bits = curve.glv_half_bits() as usize; - let (s1_bits, s1_skew) = decompose_signed_bits(compiler, s1_witness, half_bits); - let (s2_bits, s2_skew) = decompose_signed_bits(compiler, s2_witness, half_bits); - - // Step 4: Conditionally negate y-coordinates - let (py_effective, neg_py_effective) = negate_y_signed_native(compiler, neg1_witness, py); - let (ry_effective, neg_ry_effective) = negate_y_signed_native(compiler, neg2_witness, ry); - - // Step 5: Scalar mul via merged loop (single-point = one-element slice) - let point_data = NativePointData { - px, - py_eff: py_effective, - neg_py_eff: neg_py_effective, - s1_bits, - s1_skew, - rx, - ry_eff: ry_effective, - neg_ry_eff: neg_ry_effective, - s2_bits, - s2_skew, - }; - let offset_x_fe = curve::curve_native_point_fe(&curve.offset_point.0); - let offset_y_fe = curve::curve_native_point_fe(&curve.offset_point.1); - let offset_x = add_constant_witness(compiler, offset_x_fe); - let offset_y = add_constant_witness(compiler, offset_y_fe); - - let (acc_x, acc_y) = - scalar_mul_merged_native_wnaf(compiler, &[point_data], offset_x, offset_y, curve); - - // Step 6: Identity check — acc should equal accumulated offset - // (hardcoded into constraint matrix, not a prover-controlled witness) - let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); - constrain_to_constant( - compiler, - acc_x, - curve::curve_native_point_fe(&acc_off_x_raw), - ); - constrain_to_constant( - compiler, - acc_y, - curve::curve_native_point_fe(&acc_off_y_raw), - ); - - // Step 7: Scalar relation verification (unchanged) - scalar_relation::verify_scalar_relation( - compiler, - range_checks, - s_lo, - s_hi, - s1_witness, - s2_witness, - neg1_witness, - neg2_witness, - curve, - ); -} - -/// Multi-point native-field MSM with merged-loop optimization. +/// Native-field MSM with merged-loop optimization. /// /// All points share a single doubling per bit, saving 4*(n-1) constraints /// per bit of the half-scalar. @@ -370,7 +287,7 @@ fn scalar_mul_merged_native_wnaf( /// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} /// /// All bits and skew are boolean-constrained. -pub(super) fn decompose_signed_bits( +fn decompose_signed_bits( compiler: &mut NoirToR1CSCompiler, scalar: usize, num_bits: usize, diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index 4cb186a49..1137edb66 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -5,47 +5,23 @@ use { super::{ - add_constant_witness, constrain_equal, constrain_to_constant, curve, ec_points, - emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, multi_limb_ops, - sanitize_point_scalar, scalar_relation, select_witness, FieldOps, Limbs, + curve, ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + sanitize_point_scalar, scalar_relation, Limbs, }, crate::{ + constraint_helpers::{ + add_constant_witness, constrain_equal, constrain_to_constant, select_witness, + }, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, noir_to_r1cs::NoirToR1CSCompiler, }, ark_ff::{AdditiveGroup, Field}, curve::{decompose_to_limbs as decompose_to_limbs_pub, CurveParams}, - multi_limb_ops::{MultiLimbOps, MultiLimbParams}, provekit_common::{witness::SumTerm, FieldElement}, std::collections::BTreeMap, }; -/// Build `MultiLimbParams` for a given runtime `num_limbs`. -pub(super) fn build_params( - num_limbs: usize, - limb_bits: u32, - curve: &CurveParams, -) -> MultiLimbParams { - let is_native = curve.is_native_field(); - let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); - let modulus_fe = if !is_native { - Some(curve.p_native_fe()) - } else { - None - }; - MultiLimbParams { - num_limbs, - limb_bits, - p_limbs: curve.p_limbs(limb_bits, num_limbs), - p_minus_1_limbs: curve.p_minus_1_limbs(limb_bits, num_limbs), - two_pow_w, - modulus_raw: curve.field_modulus_p, - curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), - is_native, - modulus_fe, - } -} - /// FakeGLV verification for a single point: verifies R = \[s\]P. /// /// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies @@ -53,7 +29,7 @@ pub(super) fn build_params( /// half-width scalars. /// /// Returns the mutable references back to the caller for continued use. -pub(super) fn verify_point_fakeglv<'a>( +fn verify_point_fakeglv<'a>( mut compiler: &'a mut NoirToR1CSCompiler, mut range_checks: &'a mut BTreeMap>, px: Limbs, @@ -74,7 +50,7 @@ pub(super) fn verify_point_fakeglv<'a>( // --- let (s1_witness, s2_witness, neg1_witness, neg2_witness); { - let params = build_params(num_limbs, limb_bits, curve); + let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, @@ -185,7 +161,7 @@ pub(super) fn process_multi_point_non_native<'a>( let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); // Build params once for all multi-limb ops in the multi-point path - let params = build_params(num_limbs, limb_bits, curve); + let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); // Offset point as limbs for accumulation let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); @@ -349,7 +325,7 @@ fn verify_on_curve( } /// Decompose a point (px_witness, py_witness) into Limbs. -pub(super) fn decompose_point_to_limbs( +fn decompose_point_to_limbs( compiler: &mut NoirToR1CSCompiler, px_witness: usize, py_witness: usize, diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs index b437c309e..0126564ed 100644 --- a/provekit/r1cs-compiler/src/msm/scalar_relation.rs +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -5,11 +5,12 @@ use { super::{ - constrain_zero, cost_model, curve, + cost_model, curve, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, - FieldOps, Limbs, + Limbs, }, crate::{ + constraint_helpers::constrain_zero, digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder}, noir_to_r1cs::NoirToR1CSCompiler, }, @@ -38,40 +39,6 @@ fn limb_widths(total_bits: usize, max_width: u32) -> Vec { .collect() } -/// Builds `MultiLimbParams` for scalar relation verification (mod -/// curve_order_n). -fn build_scalar_relation_params( - num_limbs: usize, - limb_bits: u32, - curve: &CurveParams, -) -> MultiLimbParams { - // Scalar relation uses curve_order_n as the modulus. - // This is always non-native (curve_order_n ≠ BN254 scalar field modulus, - // except for Grumpkin where they're very close but still different). - let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); - let n_limbs = curve.curve_order_n_limbs(limb_bits, num_limbs); - let n_minus_1_limbs = curve.curve_order_n_minus_1_limbs(limb_bits, num_limbs); - - // For N=1 non-native, we need the modulus as a FieldElement - let modulus_fe = if num_limbs == 1 { - Some(curve::curve_native_point_fe(&curve.curve_order_n)) - } else { - None - }; - - MultiLimbParams { - num_limbs, - limb_bits, - p_limbs: n_limbs, - p_minus_1_limbs: n_minus_1_limbs, - two_pow_w, - modulus_raw: curve.curve_order_n, - curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused - is_native: false, // always non-native - modulus_fe, - } -} - /// Verifies the scalar relation: (-1)^neg1 * |s1| + (-1)^neg2 * |s2| * s ≡ 0 /// (mod n). /// @@ -93,7 +60,7 @@ pub(super) fn verify_scalar_relation( let num_limbs = (order_bits + limb_bits as usize - 1) / limb_bits as usize; let half_bits = curve.glv_half_bits() as usize; - let params = build_scalar_relation_params(num_limbs, limb_bits, curve); + let params = MultiLimbParams::for_curve_order(num_limbs, limb_bits, curve); let mut ops = MultiLimbOps { compiler, range_checks, diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index fee5b64b9..0f621293d 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -23,7 +23,7 @@ struct NargoTomlPackage { name: String, } -fn test_noir_compiler(test_case_path: impl AsRef) { +fn test_noir_compiler(test_case_path: impl AsRef, witness_file: &str) { let test_case_path = test_case_path.as_ref(); compile_workspace(test_case_path).expect("Compiling workspace"); @@ -36,7 +36,7 @@ fn test_noir_compiler(test_case_path: impl AsRef) { let package_name = nargo_toml.package.name; let circuit_path = test_case_path.join(format!("target/{package_name}.json")); - let witness_file_path = test_case_path.join("Prover.toml"); + let witness_file_path = test_case_path.join(witness_file); let schema = NoirCompiler::from_file(&circuit_path, provekit_common::HashConfig::default()) .expect("Reading proof scheme"); @@ -69,20 +69,57 @@ pub fn compile_workspace(workspace_path: impl AsRef) -> Result Ok(workspace) } -#[test_case("../../noir-examples/noir-r1cs-test-programs/acir_assert_zero")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-only-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u8")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-u16")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/read-write-memory")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/conditional-write")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bin-opcode")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/bounded-vec")] -#[test_case("../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained")] -#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check"; "complete_age_check")] -#[test_case("../../noir-examples/embedded_curve_msm"; "embedded_curve_msm")] -fn case_noir(path: &str) { - test_noir_compiler(path); +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/acir_assert_zero", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/simplest-read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-only-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u8", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-u16", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/range-check-mixed-bases", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/read-write-memory", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/conditional-write", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bin-opcode", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-r1cs-test-programs/small-sha", "Prover.toml")] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/bounded-vec", + "Prover.toml" +)] +#[test_case( + "../../noir-examples/noir-r1cs-test-programs/brillig-unconstrained", + "Prover.toml" +)] +#[test_case("../../noir-examples/noir-passport-monolithic/complete_age_check", "Prover.toml"; "complete_age_check")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover.toml"; "embedded_curve_msm")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_zero_scalars.toml"; "msm_zero_scalars")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_single_nonzero.toml"; "msm_single_nonzero")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_order.toml"; "msm_near_order")] +#[test_case("../../noir-examples/embedded_curve_msm", "Prover_near_identity.toml"; "msm_near_identity")] +fn case_noir(path: &str, witness_file: &str) { + test_noir_compiler(path, witness_file); } From 3e0ddab95601bf9590f58669cc0b56debcfed5d7 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 12 Mar 2026 12:15:23 +0530 Subject: [PATCH 16/19] feat : added merged doubles and signed digit window with skew correction for non native msm --- provekit/prover/src/bigint_mod.rs | 70 ++++- .../prover/src/witness/witness_builder.rs | 148 +++------ provekit/r1cs-compiler/src/msm/cost_model.rs | 128 +++++--- provekit/r1cs-compiler/src/msm/ec_points.rs | 290 +++++++++++------ provekit/r1cs-compiler/src/msm/mod.rs | 61 +++- provekit/r1cs-compiler/src/msm/native.rs | 63 +--- provekit/r1cs-compiler/src/msm/non_native.rs | 293 ++++++++---------- 7 files changed, 586 insertions(+), 467 deletions(-) diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index e4ea1fea8..95a605793 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -3,6 +3,7 @@ /// These helpers compute modular inverse via Fermat's little theorem: /// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and /// square-and-multiply exponentiation. +use {ark_ff::PrimeField, provekit_common::FieldElement}; /// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → /// 512-bit). @@ -527,6 +528,26 @@ fn mul_mod_no_reduce(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { [wide[0], wide[1], wide[2], wide[3]] } +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +/// Convert a `[u64; 4]` bigint to a `FieldElement`. +pub fn bigint_to_fe(val: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*val)).unwrap() +} + +/// Read a `FieldElement` witness as a `[u64; 4]` bigint. +pub fn fe_to_bigint(fe: FieldElement) -> [u64; 4] { + fe.into_bigint().0 +} + +/// Reconstruct a 256-bit scalar from two 128-bit halves: `scalar = lo + hi * +/// 2^128`. +pub fn reconstruct_from_halves(lo: &[u64; 4], hi: &[u64; 4]) -> [u64; 4] { + [lo[0], lo[1], hi[0], hi[1]] +} + // --------------------------------------------------------------------------- // Modular arithmetic helpers for EC operations (prover-side) // --------------------------------------------------------------------------- @@ -564,15 +585,15 @@ pub fn mod_inverse(a: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { mod_pow(a, &exp, p) } -/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. -/// Returns (x3, y3) = 2*(px, py). -pub fn ec_point_double( +/// EC point doubling with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcDoubleHint` prover which needs lambda as a witness. +pub fn ec_point_double_with_lambda( px: &[u64; 4], py: &[u64; 4], a: &[u64; 4], p: &[u64; 4], -) -> ([u64; 4], [u64; 4]) { - // lambda = (3*x^2 + a) / (2*y) +) -> ([u64; 4], [u64; 4], [u64; 4]) { let x_sq = mul_mod(px, px, p); let two_x_sq = mod_add(&x_sq, &x_sq, p); let three_x_sq = mod_add(&two_x_sq, &x_sq, p); @@ -581,44 +602,65 @@ pub fn ec_point_double( let denom_inv = mod_inverse(&two_y, p); let lambda = mul_mod(&numerator, &denom_inv, p); - // x3 = lambda^2 - 2*x let lambda_sq = mul_mod(&lambda, &lambda, p); let two_x = mod_add(px, px, p); let x3 = mod_sub(&lambda_sq, &two_x, p); - // y3 = lambda * (x - x3) - y let x_minus_x3 = mod_sub(px, &x3, p); let lambda_dx = mul_mod(&lambda, &x_minus_x3, p); let y3 = mod_sub(&lambda_dx, py, p); + (lambda, x3, y3) +} + +/// EC point doubling in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = 2*(px, py). +pub fn ec_point_double( + px: &[u64; 4], + py: &[u64; 4], + a: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_double_with_lambda(px, py, a, p); (x3, y3) } -/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. -/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. -pub fn ec_point_add( +/// EC point addition with lambda exposed: returns (lambda, x3, y3). +/// +/// Used by the `EcAddHint` prover which needs lambda as a witness. +pub fn ec_point_add_with_lambda( p1x: &[u64; 4], p1y: &[u64; 4], p2x: &[u64; 4], p2y: &[u64; 4], p: &[u64; 4], -) -> ([u64; 4], [u64; 4]) { - // lambda = (y2 - y1) / (x2 - x1) +) -> ([u64; 4], [u64; 4], [u64; 4]) { let numerator = mod_sub(p2y, p1y, p); let denominator = mod_sub(p2x, p1x, p); let denom_inv = mod_inverse(&denominator, p); let lambda = mul_mod(&numerator, &denom_inv, p); - // x3 = lambda^2 - x1 - x2 let lambda_sq = mul_mod(&lambda, &lambda, p); let x1_plus_x2 = mod_add(p1x, p2x, p); let x3 = mod_sub(&lambda_sq, &x1_plus_x2, p); - // y3 = lambda * (x1 - x3) - y1 let x1_minus_x3 = mod_sub(p1x, &x3, p); let lambda_dx = mul_mod(&lambda, &x1_minus_x3, p); let y3 = mod_sub(&lambda_dx, p1y, p); + (lambda, x3, y3) +} + +/// EC point addition in affine coordinates on y^2 = x^3 + ax + b. +/// Returns (x3, y3) = (p1x, p1y) + (p2x, p2y). Requires p1x != p2x. +pub fn ec_point_add( + p1x: &[u64; 4], + p1y: &[u64; 4], + p2x: &[u64; 4], + p2y: &[u64; 4], + p: &[u64; 4], +) -> ([u64; 4], [u64; 4]) { + let (_, x3, y3) = ec_point_add_with_lambda(p1x, p1y, p2x, p2y, p); (x3, y3) } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 9cc8d8335..53e91d1d2 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,5 +1,13 @@ use { - crate::witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + crate::{ + bigint_mod::{ + add_4limb, bigint_to_fe, cmp_4limb, compute_mul_mod_carries, decompose_to_u128_limbs, + divmod, divmod_wide, ec_point_add_with_lambda, ec_point_double_with_lambda, + ec_scalar_mul, fe_to_bigint, half_gcd, mod_pow, reconstruct_from_halves, + reconstruct_from_u128_limbs, sub_u64, widening_mul, + }, + witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, + }, acir::native_types::WitnessMap, ark_ff::{BigInteger, Field, PrimeField}, ark_std::Zero, @@ -49,7 +57,7 @@ fn read_witness_limbs( bigint[0] as u128 | ((bigint[1] as u128) << 64) }) .collect(); - crate::bigint_mod::reconstruct_from_u128_limbs(&limb_values, limb_bits) + reconstruct_from_u128_limbs(&limb_values, limb_bits) } /// Write u128 limb values as FieldElement witnesses starting at `start`. @@ -110,22 +118,16 @@ impl WitnessBuilderSolver for WitnessBuilder { }); } WitnessBuilder::ModularInverse(witness_idx, operand_idx, modulus) => { - let a = witness[*operand_idx].unwrap(); - let a_limbs = a.into_bigint().0; + let a_limbs = fe_to_bigint(witness[*operand_idx].unwrap()); let m_limbs = modulus.into_bigint().0; - // Fermat's little theorem: a^{-1} = a^{m-2} mod m - let exp = crate::bigint_mod::sub_u64(&m_limbs, 2); - let result_limbs = crate::bigint_mod::mod_pow(&a_limbs, &exp, &m_limbs); - witness[*witness_idx] = - Some(FieldElement::from_bigint(ark_ff::BigInt(result_limbs)).unwrap()); + let exp = sub_u64(&m_limbs, 2); + witness[*witness_idx] = Some(bigint_to_fe(&mod_pow(&a_limbs, &exp, &m_limbs))); } WitnessBuilder::IntegerQuotient(witness_idx, dividend_idx, divisor) => { - let dividend = witness[*dividend_idx].unwrap(); - let d_limbs = dividend.into_bigint().0; + let d_limbs = fe_to_bigint(witness[*dividend_idx].unwrap()); let m_limbs = divisor.into_bigint().0; - let (quotient, _remainder) = crate::bigint_mod::divmod(&d_limbs, &m_limbs); - witness[*witness_idx] = - Some(FieldElement::from_bigint(ark_ff::BigInt(quotient)).unwrap()); + let (quotient, _) = divmod(&d_limbs, &m_limbs); + witness[*witness_idx] = Some(bigint_to_fe("ient)); } WitnessBuilder::IndexedLogUpDenominator( witness_idx, @@ -335,9 +337,6 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::bigint_mod::{ - compute_mul_mod_carries, decompose_to_u128_limbs, divmod_wide, widening_mul, - }; let n = *num_limbs as usize; let w = *limb_bits; @@ -370,7 +369,6 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, num_limbs, } => { - use crate::bigint_mod::{decompose_to_u128_limbs, mod_pow, sub_u64}; let n = *num_limbs as usize; let w = *limb_bits; @@ -387,7 +385,6 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, .. } => { - use crate::bigint_mod::{add_4limb, cmp_4limb}; let w = *limb_bits; let a_val = read_witness_limbs(witness, a_limbs, w); @@ -414,7 +411,6 @@ impl WitnessBuilderSolver for WitnessBuilder { limb_bits, .. } => { - use crate::bigint_mod::cmp_4limb; let w = *limb_bits; let a_val = read_witness_limbs(witness, a_limbs, w); @@ -445,17 +441,15 @@ impl WitnessBuilderSolver for WitnessBuilder { s_hi, curve_order, } => { - // Reconstruct s = s_lo + s_hi * 2^128 - let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; - let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; - let s_val: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; + let s_val = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), + ); - let (val1, val2, neg1, neg2) = crate::bigint_mod::half_gcd(&s_val, curve_order); + let (val1, val2, neg1, neg2) = half_gcd(&s_val, curve_order); - witness[*output_start] = - Some(FieldElement::from_bigint(ark_ff::BigInt(val1)).unwrap()); - witness[*output_start + 1] = - Some(FieldElement::from_bigint(ark_ff::BigInt(val2)).unwrap()); + witness[*output_start] = Some(bigint_to_fe(&val1)); + witness[*output_start + 1] = Some(bigint_to_fe(&val2)); witness[*output_start + 2] = Some(FieldElement::from(neg1 as u64)); witness[*output_start + 3] = Some(FieldElement::from(neg2 as u64)); } @@ -466,33 +460,15 @@ impl WitnessBuilderSolver for WitnessBuilder { curve_a, field_modulus_p, } => { - let px_val = witness[*px].unwrap().into_bigint().0; - let py_val = witness[*py].unwrap().into_bigint().0; - - // Compute lambda, x3, y3 using bigint_mod helpers - use crate::bigint_mod::{mod_add, mod_inverse, mod_sub, mul_mod}; - let x_sq = mul_mod(&px_val, &px_val, field_modulus_p); - let two_x_sq = mod_add(&x_sq, &x_sq, field_modulus_p); - let three_x_sq = mod_add(&two_x_sq, &x_sq, field_modulus_p); - let numerator = mod_add(&three_x_sq, curve_a, field_modulus_p); - let two_y = mod_add(&py_val, &py_val, field_modulus_p); - let denom_inv = mod_inverse(&two_y, field_modulus_p); - let lambda = mul_mod(&numerator, &denom_inv, field_modulus_p); + let px_val = fe_to_bigint(witness[*px].unwrap()); + let py_val = fe_to_bigint(witness[*py].unwrap()); - let lambda_sq = mul_mod(&lambda, &lambda, field_modulus_p); - let two_x = mod_add(&px_val, &px_val, field_modulus_p); - let x3 = mod_sub(&lambda_sq, &two_x, field_modulus_p); + let (lambda, x3, y3) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); - let x_minus_x3 = mod_sub(&px_val, &x3, field_modulus_p); - let lambda_dx = mul_mod(&lambda, &x_minus_x3, field_modulus_p); - let y3 = mod_sub(&lambda_dx, &py_val, field_modulus_p); - - witness[*output_start] = - Some(FieldElement::from_bigint(ark_ff::BigInt(lambda)).unwrap()); - witness[*output_start + 1] = - Some(FieldElement::from_bigint(ark_ff::BigInt(x3)).unwrap()); - witness[*output_start + 2] = - Some(FieldElement::from_bigint(ark_ff::BigInt(y3)).unwrap()); + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); } WitnessBuilder::EcAddHint { output_start, @@ -502,31 +478,17 @@ impl WitnessBuilderSolver for WitnessBuilder { y2, field_modulus_p, } => { - let x1_val = witness[*x1].unwrap().into_bigint().0; - let y1_val = witness[*y1].unwrap().into_bigint().0; - let x2_val = witness[*x2].unwrap().into_bigint().0; - let y2_val = witness[*y2].unwrap().into_bigint().0; - - use crate::bigint_mod::{mod_add, mod_inverse, mod_sub, mul_mod}; - let numerator = mod_sub(&y2_val, &y1_val, field_modulus_p); - let denominator = mod_sub(&x2_val, &x1_val, field_modulus_p); - let denom_inv = mod_inverse(&denominator, field_modulus_p); - let lambda = mul_mod(&numerator, &denom_inv, field_modulus_p); - - let lambda_sq = mul_mod(&lambda, &lambda, field_modulus_p); - let x1_plus_x2 = mod_add(&x1_val, &x2_val, field_modulus_p); - let x3 = mod_sub(&lambda_sq, &x1_plus_x2, field_modulus_p); + let x1_val = fe_to_bigint(witness[*x1].unwrap()); + let y1_val = fe_to_bigint(witness[*y1].unwrap()); + let x2_val = fe_to_bigint(witness[*x2].unwrap()); + let y2_val = fe_to_bigint(witness[*y2].unwrap()); - let x1_minus_x3 = mod_sub(&x1_val, &x3, field_modulus_p); - let lambda_dx = mul_mod(&lambda, &x1_minus_x3, field_modulus_p); - let y3 = mod_sub(&lambda_dx, &y1_val, field_modulus_p); + let (lambda, x3, y3) = + ec_point_add_with_lambda(&x1_val, &y1_val, &x2_val, &y2_val, field_modulus_p); - witness[*output_start] = - Some(FieldElement::from_bigint(ark_ff::BigInt(lambda)).unwrap()); - witness[*output_start + 1] = - Some(FieldElement::from_bigint(ark_ff::BigInt(x3)).unwrap()); - witness[*output_start + 2] = - Some(FieldElement::from_bigint(ark_ff::BigInt(y3)).unwrap()); + witness[*output_start] = Some(bigint_to_fe(&lambda)); + witness[*output_start + 1] = Some(bigint_to_fe(&x3)); + witness[*output_start + 2] = Some(bigint_to_fe(&y3)); } WitnessBuilder::EcScalarMulHint { output_start, @@ -537,28 +499,17 @@ impl WitnessBuilderSolver for WitnessBuilder { curve_a, field_modulus_p, } => { - // Reconstruct scalar s = s_lo + s_hi * 2^128 - let s_lo_val = witness[*s_lo].unwrap().into_bigint().0; - let s_hi_val = witness[*s_hi].unwrap().into_bigint().0; - let scalar: [u64; 4] = [s_lo_val[0], s_lo_val[1], s_hi_val[0], s_hi_val[1]]; - - // Reconstruct point P - let px_val = witness[*px].unwrap().into_bigint().0; - let py_val = witness[*py].unwrap().into_bigint().0; - - // Compute R = [s]P - let (rx, ry) = crate::bigint_mod::ec_scalar_mul( - &px_val, - &py_val, - &scalar, - curve_a, - field_modulus_p, + let scalar = reconstruct_from_halves( + &fe_to_bigint(witness[*s_lo].unwrap()), + &fe_to_bigint(witness[*s_hi].unwrap()), ); + let px_val = fe_to_bigint(witness[*px].unwrap()); + let py_val = fe_to_bigint(witness[*py].unwrap()); + + let (rx, ry) = ec_scalar_mul(&px_val, &py_val, &scalar, curve_a, field_modulus_p); - witness[*output_start] = - Some(FieldElement::from_bigint(ark_ff::BigInt(rx)).unwrap()); - witness[*output_start + 1] = - Some(FieldElement::from_bigint(ark_ff::BigInt(ry)).unwrap()); + witness[*output_start] = Some(bigint_to_fe(&rx)); + witness[*output_start + 1] = Some(bigint_to_fe(&ry)); } WitnessBuilder::SelectWitness { output, @@ -583,6 +534,9 @@ impl WitnessBuilderSolver for WitnessBuilder { } => { let s_fe = witness[*scalar].unwrap(); let s_big = s_fe.into_bigint().0; + // NOTE: Only reads lower 128 bits. Safe for FakeGLV half-scalars + // (≤128 bits for 256-bit curves) but would silently truncate + // larger values. The R1CS reconstruction constraint catches this. let s_val: u128 = s_big[0] as u128 | ((s_big[1] as u128) << 64); let n = *num_bits; let skew: u128 = if s_val & 1 == 0 { 1 } else { 0 }; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index 39d2ed7f3..d055fa89d 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -141,62 +141,86 @@ pub fn calculate_msm_witness_cost( let n = ceil_div(curve_modulus_bits as usize, limb_bits as usize); let half_bits = (scalar_bits + 1) / 2; let w = window_size; - let table_size = 1usize << w; + let half_table_size = 1usize << (w - 1); let num_windows = ceil_div(half_bits, w); // === GLV scalar mul field op counts === // point_double: (5 add, 3 sub, 4 mul, 1 inv) + N constant witnesses (curve_a) // point_add: (1 add, 5 sub, 3 mul, 1 inv) - // Table building (2 tables for P and R) - let (tbl_d, tbl_a) = if table_size > 2 { - (1, table_size - 3) + // --- Shared costs (counted once, NOT per-point) --- + // Main loop doublings: w doublings per window, shared across all points + let mut shared_add = num_windows * w * 5; + let mut shared_sub = num_windows * w * 3; + let mut shared_mul = num_windows * w * 4; + let mut shared_inv = num_windows * w; + + // --- Per-point costs --- + // Signed table building (2 tables for P and R): odd multiples [P, 3P, 5P, ...] + // Build cost: 1 double (for 2P) + (half_table_size - 1) adds when size >= 2. + let (tbl_d, tbl_a) = if half_table_size >= 2 { + (1, half_table_size - 1) } else { (0, 0) }; - let mut n_add = 2 * (tbl_d * 5 + tbl_a * 1); - let mut n_sub = 2 * (tbl_d * 3 + tbl_a * 5); - let mut n_mul = 2 * (tbl_d * 4 + tbl_a * 3); - let mut n_inv = 2 * (tbl_d + tbl_a); - - // Main loop: w shared doublings + 2 point_adds per window - n_add += num_windows * (w * 5 + 2 * 1); - n_sub += num_windows * (w * 3 + 2 * 5); - n_mul += num_windows * (w * 4 + 2 * 3); - n_inv += num_windows * (w + 2); + let mut pp_add = 2 * (tbl_d * 5 + tbl_a * 1); + let mut pp_sub = 2 * (tbl_d * 3 + tbl_a * 5); + let mut pp_mul = 2 * (tbl_d * 4 + tbl_a * 3); + let mut pp_inv = 2 * (tbl_d + tbl_a); + + // Main loop per-point: 2 point_adds + 2 negates per window + pp_add += num_windows * 2 * 1; + pp_sub += num_windows * (2 * 5 + 2); // +2 for signed lookup negates + pp_mul += num_windows * 2 * 3; + pp_inv += num_windows * 2; + + // Skew corrections: 2 branches × (1 negate + 1 point_add) per point + pp_add += 2 * 1; + pp_sub += 2 * (5 + 1); // point_add subs + negate sub + pp_mul += 2 * 3; + pp_inv += 2; // On-curve checks (P and R): 2 × (4 mul + 2 add) - n_mul += 8; - n_add += 4; + pp_mul += 8; + pp_add += 4; - // Y-negation: 2 negate = 2 sub (negate calls sub(zero, value)) - n_sub += 2; + // Y-negation (FakeGLV): 2 negate = 2 sub + pp_sub += 2; - let glv_field_ops = field_op_witnesses(n_add, n_sub, n_mul, n_inv, n, false); + let shared_field_ops = + field_op_witnesses(shared_add, shared_sub, shared_mul, shared_inv, n, false); + let pp_field_ops = field_op_witnesses(pp_add, pp_sub, pp_mul, pp_inv, n, false); // Constant witness allocations not captured by field ops: // - curve_a() in each point_double: N per call // - on-curve: 2 × (curve_a + curve_b) = 4N - // - negate: 2 × constant_limbs(zero) = 2N - // - offset point in verify_point_fakeglv: 2N - let n_doubles = 2 * tbl_d + num_windows * w; - let glv_constants = n_doubles * n + 4 * n + 2 * n + 2 * n; - - // Selects + is_zero (not field ops, priced separately) - let table_selects = num_windows * 2 * ((1 << w) - 1) * 2 * n; - let skip_selects = num_windows * 2 * 2 * n; + // - negate zeros: FakeGLV(2) + signed_lookup(2*num_windows) + skew(2) = + // (4+2*num_windows)×N + // - offset point: 2N (shared) + let shared_doubles = num_windows * w; + let pp_doubles = 2 * tbl_d; + let pp_negate_zeros = (4 + 2 * num_windows) * n; + let shared_constants_glv = shared_doubles * n + 2 * n; // shared double curve_a + offset + let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros; + + // Selects (not field ops, priced separately) + // Signed table: halved from 2^w to 2^(w-1) entries + let table_selects = num_windows * 2 * (half_table_size.saturating_sub(1)) * 2 * n; + // XOR bits for signed lookup: 2 witnesses per bit, (w-1) bits, 2 branches + let xor_cost = num_windows * 2 * 2 * w.saturating_sub(1); + // Y-select after negate in signed lookup: 2 branches per window + let signed_y_selects = num_windows * 2 * n; + // FakeGLV y-negation selects let y_negate_selects = 2 * n; - let is_zero_cost = num_windows * 2 * 3; // 3 native witnesses each + // Skew correction selects: 2 branches × 2 (x+y) + let skew_selects = 2 * 2 * n; - let glv_cost = glv_field_ops - + glv_constants - + table_selects - + skip_selects - + y_negate_selects - + is_zero_cost; + let pp_selects = table_selects + xor_cost + signed_y_selects + y_negate_selects + skew_selects; // === Per-point overhead === - let scalar_bit_decomp = 2 * half_bits; + // Signed-bit decomposition: num_bits bits + 1 skew per half-scalar, 2 + // half-scalars + let scalar_bit_decomp = 2 * (half_bits + 1); let detect_skip = 8; // 2×is_zero(3W) + product(1W) + or(1W) let sanitize = 4; // 4 select_witness let ec_hint = 4; // 2W hint + 2W selects @@ -204,7 +228,9 @@ pub fn calculate_msm_witness_cost( let glv_hint = 4; // s1, s2, neg1, neg2 let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); - let per_point = glv_cost + let per_point = pp_field_ops + + pp_constants + + pp_selects + scalar_bit_decomp + detect_skip + sanitize @@ -227,12 +253,25 @@ pub fn calculate_msm_witness_cost( // === Range check resolution === let mut rc_map: BTreeMap = BTreeMap::new(); - // GLV field ops (per point) + // Shared GLV field ops (doublings — counted once) + add_field_op_range_checks( + shared_add, + shared_sub, + shared_mul, + shared_inv, + n, + limb_bits, + curve_modulus_bits, + false, + &mut rc_map, + ); + + // Per-point GLV field ops add_field_op_range_checks( - n_points * n_add, - n_points * n_sub, - n_points * n_mul, - n_points * n_inv, + n_points * pp_add, + n_points * pp_sub, + n_points * pp_mul, + n_points * pp_inv, n, limb_bits, curve_modulus_bits, @@ -265,7 +304,12 @@ pub fn calculate_msm_witness_cost( let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - n_points * per_point + shared_constants + accum + range_check_cost + shared_field_ops + + shared_constants_glv + + n_points * per_point + + shared_constants + + accum + + range_check_cost } /// Native-field MSM cost: hint-verified EC ops with signed-bit wNAF (w=1). diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index c3deeff61..764b64c55 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -1,7 +1,11 @@ use { - super::{multi_limb_ops::MultiLimbOps, Limbs}, + super::{curve::CurveParams, multi_limb_ops::MultiLimbOps, Limbs}, crate::noir_to_r1cs::NoirToR1CSCompiler, - provekit_common::{witness::WitnessBuilder, FieldElement}, + ark_ff::{Field, PrimeField}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, }; /// Generic point doubling on y^2 = x^3 + ax + b. @@ -102,25 +106,27 @@ pub fn point_select_unchecked( (x, y) } -/// Builds a point table for windowed scalar multiplication. +/// Builds a signed point table of odd multiples for signed-digit windowed +/// scalar multiplication. +/// +/// T\[0\] = P, T\[1\] = 3P, T\[2\] = 5P, ..., T\[k-1\] = (2k-1)P +/// where k = `half_table_size` = 2^(w-1). /// -/// T\[0\] = P (dummy entry, used when window digit = 0) -/// T\[1\] = P, T\[2\] = 2P, T\[i\] = T\[i-1\] + P for i >= 3. -fn build_point_table( +/// Build cost: 1 point_double (for 2P) + (k-1) point_adds when k >= 2. +fn build_signed_point_table( ops: &mut MultiLimbOps, px: Limbs, py: Limbs, - table_size: usize, + half_table_size: usize, ) -> Vec<(Limbs, Limbs)> { - assert!(table_size >= 2); - let mut table = Vec::with_capacity(table_size); - table.push((px, py)); // T[0] = P (dummy) - table.push((px, py)); // T[1] = P - if table_size > 2 { - table.push(point_double(ops, px, py)); // T[2] = 2P - for i in 3..table_size { + assert!(half_table_size >= 1); + let mut table = Vec::with_capacity(half_table_size); + table.push((px, py)); // T[0] = 1*P + if half_table_size >= 2 { + let two_p = point_double(ops, px, py); // 2P + for i in 1..half_table_size { let prev = table[i - 1]; - table.push(point_add(ops, prev.0, prev.1, px, py)); + table.push(point_add(ops, prev.0, prev.1, two_p.0, two_p.1)); } } table @@ -160,50 +166,148 @@ fn table_lookup( current[0] } -/// Interleaved two-point scalar multiplication for FakeGLV. +/// Like `table_lookup`, but skips boolean constraints on bits. /// -/// Computes `[s1]P + [s2]R` using shared doublings, where s1 and s2 are -/// half-width scalars (typically ~128-bit for 256-bit curves). The -/// accumulator starts at an offset point and the caller checks equality -/// with the accumulated offset to verify the constraint `[s1]P + [s2]R = O`. +/// Use when bits are already known boolean (e.g. XOR'd bits derived from +/// boolean-constrained inputs in `signed_table_lookup`). +fn table_lookup_unchecked( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], + bits: &[usize], +) -> (Limbs, Limbs) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); + for &bit in bits.iter().rev() { + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select_unchecked( + ops, + bit, + current[i], + current[i + half], + )); + } + current = next; + } + current[0] +} + +/// Signed-digit table lookup: selects from a half-size table using XOR'd +/// index bits, then conditionally negates y based on the sign bit. /// -/// Structure per window (from MSB to LSB): -/// 1. `w` shared doublings on accumulator -/// 2. Table lookup in T_P\[d1\] for s1's window digit -/// 3. point_add(acc, T_P\[d1\]) + is_zero(d1) + point_select -/// 4. Table lookup in T_R\[d2\] for s2's window digit -/// 5. point_add(acc, T_R\[d2\]) + is_zero(d2) + point_select +/// For a w-bit window with bits \[b_0, ..., b_{w-1}\] (LSB first): +/// - sign_bit = b_{w-1} (MSB): 1 = positive digit, 0 = negative digit +/// - index_bits = \[b_0, ..., b_{w-2}\] (lower w-1 bits) +/// - When positive: table index = lower bits as-is +/// - When negative: table index = bitwise complement of lower bits, and y is +/// negated /// -/// Returns the final accumulator (x, y). -pub fn scalar_mul_glv( +/// The XOR'd bits are computed as: `idx_i = 1 - b_i - MSB + 2*b_i*MSB`, +/// which equals `b_i` when MSB=1, and `1-b_i` when MSB=0. +/// +/// # Precondition +/// `sign_bit` must be boolean-constrained by the caller. This function uses +/// it in `select_unchecked` without re-constraining. Currently satisfied: +/// `decompose_signed_bits` boolean-constrains all bits including the MSB +/// used as `sign_bit`. +fn signed_table_lookup( ops: &mut MultiLimbOps, - // Point P (table 1) - px: Limbs, - py: Limbs, - s1_bits: &[usize], // 128 bit witnesses for |s1| - // Point R (table 2) — the claimed output - rx: Limbs, - ry: Limbs, - s2_bits: &[usize], // 128 bit witnesses for |s2| - // Shared parameters + table: &[(Limbs, Limbs)], + index_bits: &[usize], + sign_bit: usize, +) -> (Limbs, Limbs) { + let (x, y) = if index_bits.is_empty() { + // w=1: single entry, no lookup needed + assert_eq!(table.len(), 1); + table[0] + } else { + // Compute XOR'd index bits: idx_i = 1 - b_i - MSB + 2*b_i*MSB + let one_w = ops.compiler.witness_one(); + let two = FieldElement::from(2u64); + let xor_bits: Vec = index_bits + .iter() + .map(|&bit| { + let prod = ops.compiler.add_product(bit, sign_bit); + ops.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::ONE), one_w), + SumTerm(Some(-FieldElement::ONE), bit), + SumTerm(Some(-FieldElement::ONE), sign_bit), + SumTerm(Some(two), prod), + ]) + }) + .collect(); + + // XOR'd bits are boolean by construction (product of two booleans + // combined linearly), so skip redundant boolean constraints. + table_lookup_unchecked(ops, table, &xor_bits) + }; + + // Conditionally negate y: sign_bit=0 (negative) → -y, sign_bit=1 (positive) → y + let neg_y = ops.negate(y); + let eff_y = ops.select_unchecked(sign_bit, neg_y, y); + // select_unchecked(flag, on_false, on_true): + // sign_bit=0 → on_false=neg_y (negative digit, negate y) ✓ + // sign_bit=1 → on_true=y (positive digit, keep y) ✓ + + (x, eff_y) +} + +/// Per-point data for merged multi-point GLV scalar multiplication. +pub struct MergedGlvPoint { + /// Point P x-coordinate (limbs) + pub px: Limbs, + /// Point P y-coordinate (effective, post-negation) + pub py: Limbs, + /// Signed-bit decomposition of |s1| (half-scalar for P), LSB first + pub s1_bits: Vec, + /// Skew correction witness for s1 branch (boolean) + pub s1_skew: usize, + /// Point R x-coordinate (limbs) + pub rx: Limbs, + /// Point R y-coordinate (effective, post-negation) + pub ry: Limbs, + /// Signed-bit decomposition of |s2| (half-scalar for R), LSB first + pub s2_bits: Vec, + /// Skew correction witness for s2 branch (boolean) + pub s2_skew: usize, +} + +/// Merged multi-point GLV scalar multiplication with shared doublings +/// and signed-digit windows. +/// +/// Uses signed-digit encoding: each w-bit window produces a signed odd digit +/// d ∈ {±1, ±3, ..., ±(2^w - 1)}, eliminating zero-digit handling. +/// Tables store odd multiples \[P, 3P, 5P, ..., (2^w-1)P\] with only +/// 2^(w-1) entries (half the unsigned table size). +/// +/// After the main loop, applies skew corrections: if skew=1, subtracts P +/// (or R) to account for the signed decomposition bias. +/// +/// Returns the final accumulator `(x, y)`. +pub fn scalar_mul_merged_glv( + ops: &mut MultiLimbOps, + points: &[MergedGlvPoint], window_size: usize, offset_x: Limbs, offset_y: Limbs, ) -> (Limbs, Limbs) { - let n1 = s1_bits.len(); - let n2 = s2_bits.len(); - assert_eq!(n1, n2, "s1 and s2 must have the same number of bits"); - let n = n1; + assert!(!points.is_empty()); + let n = points[0].s1_bits.len(); let w = window_size; - let table_size = 1 << w; - - // Build point tables: T_P[i] = [i]P, T_R[i] = [i]R - let table_p = build_point_table(ops, px, py, table_size); - let table_r = build_point_table(ops, rx, ry, table_size); + let half_table_size = 1usize << (w - 1); + + // Build signed point tables (odd multiples) for all points upfront + let tables: Vec<(Vec<(Limbs, Limbs)>, Vec<(Limbs, Limbs)>)> = points + .iter() + .map(|pt| { + let tp = build_signed_point_table(ops, pt.px, pt.py, half_table_size); + let tr = build_signed_point_table(ops, pt.rx, pt.ry, half_table_size); + (tp, tr) + }) + .collect(); let num_windows = (n + w - 1) / w; - - // Initialize accumulator with the offset point let mut acc = (offset_x, offset_y); // Process all windows from MSB down to LSB @@ -212,45 +316,62 @@ pub fn scalar_mul_glv( let bit_end = std::cmp::min(bit_start + w, n); let actual_w = bit_end - bit_start; - // w shared doublings on the accumulator + // w shared doublings on the accumulator (shared across ALL points) let mut doubled_acc = acc; for _ in 0..w { doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); } - // --- Process P's window digit (s1) --- - let s1_window_bits = &s1_bits[bit_start..bit_end]; - let lookup_table_p = if actual_w < w { - &table_p[..1 << actual_w] - } else { - &table_p[..] - }; - let looked_up_p = table_lookup(ops, lookup_table_p, s1_window_bits); - let added_p = point_add( - ops, - doubled_acc.0, - doubled_acc.1, - looked_up_p.0, - looked_up_p.1, - ); - let digit_p = ops.pack_bits(s1_window_bits); - let digit_p_is_zero = ops.is_zero(digit_p); - // is_zero already constrains its output boolean; skip redundant check - let after_p = point_select_unchecked(ops, digit_p_is_zero, added_p, doubled_acc); - - // --- Process R's window digit (s2) --- - let s2_window_bits = &s2_bits[bit_start..bit_end]; - let lookup_table_r = if actual_w < w { - &table_r[..1 << actual_w] - } else { - &table_r[..] - }; - let looked_up_r = table_lookup(ops, lookup_table_r, s2_window_bits); - let added_r = point_add(ops, after_p.0, after_p.1, looked_up_r.0, looked_up_r.1); - let digit_r = ops.pack_bits(s2_window_bits); - let digit_r_is_zero = ops.is_zero(digit_r); - // is_zero already constrains its output boolean; skip redundant check - acc = point_select_unchecked(ops, digit_r_is_zero, added_r, after_p); + let mut cur = doubled_acc; + + // For each point: P branch + R branch (signed-digit lookup) + for (pt, (table_p, table_r)) in points.iter().zip(tables.iter()) { + // --- P branch (s1 window) --- + let s1_window_bits = &pt.s1_bits[bit_start..bit_end]; + let sign_bit_p = s1_window_bits[actual_w - 1]; // MSB + let index_bits_p = &s1_window_bits[..actual_w - 1]; // lower bits + let actual_table_p = if actual_w < w { + &table_p[..1 << (actual_w - 1)] + } else { + &table_p[..] + }; + let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); + // All signed digits are non-zero — no is_zero check needed + cur = point_add(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); + + // --- R branch (s2 window) --- + let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; + let sign_bit_r = s2_window_bits[actual_w - 1]; // MSB + let index_bits_r = &s2_window_bits[..actual_w - 1]; // lower bits + let actual_table_r = if actual_w < w { + &table_r[..1 << (actual_w - 1)] + } else { + &table_r[..] + }; + let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); + cur = point_add(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); + } + + acc = cur; + } + + // Skew corrections: subtract P (or R) if skew=1 for each point. + // The signed decomposition gives: scalar = Σ d_i * 2^i - skew, + // so the main loop computed (scalar + skew) * P. If skew=1, subtract P. + for pt in points { + // P branch skew + let neg_py = ops.negate(pt.py); + let (sub_px, sub_py) = point_add(ops, acc.0, acc.1, pt.px, neg_py); + let new_x = ops.select_unchecked(pt.s1_skew, acc.0, sub_px); + let new_y = ops.select_unchecked(pt.s1_skew, acc.1, sub_py); + acc = (new_x, new_y); + + // R branch skew + let neg_ry = ops.negate(pt.ry); + let (sub_rx, sub_ry) = point_add(ops, acc.0, acc.1, pt.rx, neg_ry); + let new_x = ops.select_unchecked(pt.s2_skew, acc.0, sub_rx); + let new_y = ops.select_unchecked(pt.s2_skew, acc.1, sub_ry); + acc = (new_x, new_y); } acc @@ -263,11 +384,6 @@ pub fn scalar_mul_glv( // Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw // R1CS constraints, eliminating expensive field inversions from the circuit. -use { - super::curve::CurveParams, - ark_ff::{Field, PrimeField}, -}; - /// Hint-verified point doubling for native field. /// /// Allocates EcDoubleHint → (lambda, x3, y3) = 3W. diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index d39ac6572..35b1f3483 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -15,7 +15,7 @@ use { msm::multi_limb_arith::compute_is_zero, noir_to_r1cs::NoirToR1CSCompiler, }, - ark_ff::{Field, PrimeField}, + ark_ff::{AdditiveGroup, Field, PrimeField}, curve::CurveParams, provekit_common::{ witness::{ConstantOrR1CSWitness, SumTerm, WitnessBuilder}, @@ -354,6 +354,65 @@ fn emit_fakeglv_hint( (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) } +/// Signed-bit decomposition for wNAF scalar multiplication. +/// +/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} +/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: +/// scalar = Σ d_i * 2^i - skew +/// +/// Reconstruction constraint (1 linear R1CS): +/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// +/// All bits and skew are boolean-constrained. +/// +/// # Limitation +/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower +/// 128 bits of the field element). This is correct for FakeGLV half-scalars +/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` +/// exceeds 128. The R1CS reconstruction constraint would then fail. +pub(super) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let two = FieldElement::from(2u64); + let constant = two.pow([num_bits as u64]) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} + /// Resolves a `ConstantOrR1CSWitness` to a witness index. fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { match input { diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs index 7583977de..ba62cf7ff 100644 --- a/provekit/r1cs-compiler/src/msm/native.rs +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -10,14 +10,13 @@ use { }, crate::{ constraint_helpers::{ - add_constant_witness, constrain_boolean, constrain_equal, constrain_to_constant, - select_witness, + add_constant_witness, constrain_equal, constrain_to_constant, select_witness, }, noir_to_r1cs::NoirToR1CSCompiler, }, - ark_ff::{AdditiveGroup, Field}, + ark_ff::AdditiveGroup, curve::CurveParams, - provekit_common::{witness::WitnessBuilder, FieldElement}, + provekit_common::FieldElement, std::collections::BTreeMap, }; @@ -101,8 +100,8 @@ pub(super) fn process_multi_point_native( // FakeGLV decomposition + signed-bit decomposition let (s1, s2, neg1, neg2) = emit_fakeglv_hint(compiler, san.s_lo, san.s_hi, curve); let half_bits = curve.glv_half_bits() as usize; - let (s1_bits, s1_skew) = decompose_signed_bits(compiler, s1, half_bits); - let (s2_bits, s2_skew) = decompose_signed_bits(compiler, s2, half_bits); + let (s1_bits, s1_skew) = super::decompose_signed_bits(compiler, s1, half_bits); + let (s2_bits, s2_skew) = super::decompose_signed_bits(compiler, s2, half_bits); // Y-negation let (py_eff, neg_py_eff) = negate_y_signed_native(compiler, neg1, san.py); @@ -276,55 +275,3 @@ fn scalar_mul_merged_native_wnaf( (acc_x, acc_y) } - -/// Signed-bit decomposition for wNAF scalar multiplication. -/// -/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} -/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: -/// scalar = Σ d_i * 2^i - skew -/// -/// Reconstruction constraint (1 linear R1CS): -/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} -/// -/// All bits and skew are boolean-constrained. -fn decompose_signed_bits( - compiler: &mut NoirToR1CSCompiler, - scalar: usize, - num_bits: usize, -) -> (Vec, usize) { - let start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::SignedBitHint { - output_start: start, - scalar, - num_bits, - }); - let bits: Vec = (start..start + num_bits).collect(); - let skew = start + num_bits; - - // Boolean-constrain each bit and skew - for &b in &bits { - constrain_boolean(compiler, b); - } - constrain_boolean(compiler, skew); - - // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} - // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 - let one = compiler.witness_one(); - let constant = FieldElement::from(1u128 << num_bits) - FieldElement::ONE; - let mut b_terms: Vec<(FieldElement, usize)> = bits - .iter() - .enumerate() - .map(|(i, &b)| (-FieldElement::from(1u128 << (i + 1)), b)) - .collect(); - b_terms.push((FieldElement::ONE, scalar)); - b_terms.push((FieldElement::ONE, skew)); - b_terms.push((constant, one)); - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( - FieldElement::ZERO, - one, - )]); - - (bits, skew) -} diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index 1137edb66..a506fabb7 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -2,6 +2,9 @@ //! //! Used when `!curve.is_native_field()` — uses `MultiLimbOps` for all EC //! arithmetic with configurable limb width. +//! +//! Multi-point MSM uses merged doublings: all points share a single set of +//! `w` doublings per window, saving `w × (n_points - 1)` doublings per window. use { super::{ @@ -22,122 +25,10 @@ use { std::collections::BTreeMap, }; -/// FakeGLV verification for a single point: verifies R = \[s\]P. -/// -/// Decomposes s via half-GCD into sub-scalars (s1, s2) and verifies -/// \[s1\]P + \[s2\]R = O using interleaved windowed scalar mul with -/// half-width scalars. +/// Multi-point non-native MSM with merged-loop optimization. /// -/// Returns the mutable references back to the caller for continued use. -fn verify_point_fakeglv<'a>( - mut compiler: &'a mut NoirToR1CSCompiler, - mut range_checks: &'a mut BTreeMap>, - px: Limbs, - py: Limbs, - rx: Limbs, - ry: Limbs, - s_lo: usize, - s_hi: usize, - num_limbs: usize, - limb_bits: u32, - window_size: usize, - curve: &CurveParams, -) -> ( - &'a mut NoirToR1CSCompiler, - &'a mut BTreeMap>, -) { - // --- Steps 1-4: On-curve checks, FakeGLV decomposition, and GLV scalar mul - // --- - let (s1_witness, s2_witness, neg1_witness, neg2_witness); - { - let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - - // Step 1: On-curve checks for P and R - let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); - verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); - verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); - - // Step 2: FakeGLVHint → |s1|, |s2|, neg1, neg2 - (s1_witness, s2_witness, neg1_witness, neg2_witness) = - emit_fakeglv_hint(ops.compiler, s_lo, s_hi, curve); - - // Step 3: Decompose |s1|, |s2| into half_bits bits each - let half_bits = curve.glv_half_bits() as usize; - let s1_bits = decompose_half_scalar_bits(ops.compiler, s1_witness, half_bits); - let s2_bits = decompose_half_scalar_bits(ops.compiler, s2_witness, half_bits); - - // Step 4: Conditionally negate P.y and R.y + GLV scalar mul + identity - // check - - // Compute negated y-coordinates: neg_y = 0 - y (mod p) - let neg_py = ops.negate(py); - let neg_ry = ops.negate(ry); - - // Select: if neg1=1, use neg_py; else use py - // neg1 and neg2 are constrained to be boolean by ops.select internally. - let py_effective = ops.select(neg1_witness, py, neg_py); - // Select: if neg2=1, use neg_ry; else use ry - let ry_effective = ops.select(neg2_witness, ry, neg_ry); - - // GLV scalar mul - let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); - let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); - let offset_x = ops.constant_limbs(&offset_x_values); - let offset_y = ops.constant_limbs(&offset_y_values); - - let glv_acc = ec_points::scalar_mul_glv( - &mut ops, - px, - py_effective, - &s1_bits, - rx, - ry_effective, - &s2_bits, - window_size, - offset_x, - offset_y, - ); - - // Identity check: acc should equal [2^(num_windows * window_size)] * - // offset_point - let glv_num_windows = (half_bits + window_size - 1) / window_size; - let glv_n_doublings = glv_num_windows * window_size; - let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); - - // Identity check: hardcode expected limb values as R1CS coefficients - let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); - let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); - for i in 0..num_limbs { - constrain_to_constant(ops.compiler, glv_acc.0[i], acc_off_x_values[i]); - constrain_to_constant(ops.compiler, glv_acc.1[i], acc_off_y_values[i]); - } - - compiler = ops.compiler; - range_checks = ops.range_checks; - } - - // --- Step 5: Scalar relation verification --- - scalar_relation::verify_scalar_relation( - compiler, - range_checks, - s_lo, - s_hi, - s1_witness, - s2_witness, - neg1_witness, - neg2_witness, - curve, - ); - - (compiler, range_checks) -} - -/// Multi-point non-native MSM with offset-based accumulation. +/// All points share a single set of doublings per window, saving +/// `w × (n_points - 1)` doublings per window compared to separate loops. pub(super) fn process_multi_point_non_native<'a>( mut compiler: &'a mut NoirToR1CSCompiler, point_wits: &[usize], @@ -160,27 +51,21 @@ pub(super) fn process_multi_point_non_native<'a>( let gen_y_witness = add_constant_witness(compiler, gen_y_fe); let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); - // Build params once for all multi-limb ops in the multi-point path + // Build params once for all multi-limb ops let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); + let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); // Offset point as limbs for accumulation let offset_x_values = curve.offset_x_limbs(limb_bits, num_limbs); let offset_y_values = curve.offset_y_limbs(limb_bits, num_limbs); - // Start accumulator at offset_point - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - let mut acc_x = ops.constant_limbs(&offset_x_values); - let mut acc_y = ops.constant_limbs(&offset_y_values); - compiler = ops.compiler; - range_checks = ops.range_checks; - // Track all_skipped = product of all is_skip flags let mut all_skipped: Option = None; + let mut merged_points: Vec = Vec::new(); + let mut scalar_rel_inputs: Vec<(usize, usize, usize, usize, usize, usize)> = Vec::new(); + let mut accum_inputs: Vec<(Limbs, Limbs, usize)> = Vec::new(); + // Phase 1: Per-point preprocessing for i in 0..n_points { let san = sanitize_point_scalar( compiler, @@ -209,7 +94,7 @@ pub(super) fn process_multi_point_non_native<'a>( curve, ); - // Generic multi-limb path + // Decompose points to limbs let (px, py) = decompose_point_to_limbs(compiler, san.px, san.py, num_limbs, limb_bits, range_checks); let (rx, ry) = decompose_point_to_limbs( @@ -221,44 +106,134 @@ pub(super) fn process_multi_point_non_native<'a>( range_checks, ); - // Verify R_i = [s_i]P_i using FakeGLV (on sanitized values) - (compiler, range_checks) = verify_point_fakeglv( - compiler, - range_checks, + // On-curve checks, FakeGLV, bit decomposition, y-negation (via MultiLimbOps) + let (s1_witness, s2_witness, neg1_witness, neg2_witness); + let (py_effective, ry_effective, s1_bits, s2_bits, s1_skew, s2_skew); + { + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + + // On-curve checks for P and R + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + + // FakeGLVHint → |s1|, |s2|, neg1, neg2 + (s1_witness, s2_witness, neg1_witness, neg2_witness) = + emit_fakeglv_hint(ops.compiler, san.s_lo, san.s_hi, curve); + + // Signed-bit decomposition of |s1|, |s2| — produces signed digits + // d_i = 2*b_i - 1 ∈ {-1, +1} with skew correction, matching the + // native path's wNAF approach. + let half_bits = curve.glv_half_bits() as usize; + (s1_bits, s1_skew) = super::decompose_signed_bits(ops.compiler, s1_witness, half_bits); + (s2_bits, s2_skew) = super::decompose_signed_bits(ops.compiler, s2_witness, half_bits); + + // Conditionally negate y-coordinates + let neg_py = ops.negate(py); + let neg_ry = ops.negate(ry); + py_effective = ops.select(neg1_witness, py, neg_py); + ry_effective = ops.select(neg2_witness, ry, neg_ry); + + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + merged_points.push(ec_points::MergedGlvPoint { px, - py, + py: py_effective, + s1_bits, + s1_skew, rx, - ry, + ry: ry_effective, + s2_bits, + s2_skew, + }); + + scalar_rel_inputs.push(( san.s_lo, san.s_hi, - num_limbs, - limb_bits, - window_size, - curve, - ); + s1_witness, + s2_witness, + neg1_witness, + neg2_witness, + )); + accum_inputs.push((rx, ry, san.is_skip)); + } - // Offset-based accumulation with conditional select + // Phase 2: Merged scalar mul verification (shared doublings across all points) + let half_bits = curve.glv_half_bits() as usize; + let glv_acc; + { let mut ops = MultiLimbOps { compiler, range_checks, params: ¶ms, }; - let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); - let (new_acc_x, new_acc_y) = ec_points::point_select_unchecked( + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + glv_acc = ec_points::scalar_mul_merged_glv( &mut ops, - san.is_skip, - (cand_x, cand_y), - (acc_x, acc_y), + &merged_points, + window_size, + offset_x, + offset_y, ); - acc_x = new_acc_x; - acc_y = new_acc_y; + + // Identity check: acc should equal accumulated offset + let glv_num_windows = (half_bits + window_size - 1) / window_size; + let glv_n_doublings = glv_num_windows * window_size; + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); + + let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); + let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); + for i in 0..num_limbs { + constrain_to_constant(ops.compiler, glv_acc.0[i], acc_off_x_values[i]); + constrain_to_constant(ops.compiler, glv_acc.1[i], acc_off_y_values[i]); + } + compiler = ops.compiler; range_checks = ops.range_checks; } + // Phase 3: Per-point scalar relations + for &(s_lo, s_hi, s1, s2, neg1, neg2) in &scalar_rel_inputs { + scalar_relation::verify_scalar_relation( + compiler, + range_checks, + s_lo, + s_hi, + s1, + s2, + neg1, + neg2, + curve, + ); + } + + // Phase 4: Accumulation (offset-based, same as before) let all_skipped = all_skipped.expect("MSM must have at least one point"); - // Generic multi-limb offset subtraction + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + let mut acc_x = ops.constant_limbs(&offset_x_values); + let mut acc_y = ops.constant_limbs(&offset_y_values); + + for &(rx, ry, is_skip) in &accum_inputs { + let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); + let (new_acc_x, new_acc_y) = + ec_points::point_select_unchecked(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); + acc_x = new_acc_x; + acc_y = new_acc_y; + } + + // Offset subtraction let neg_offset_y_raw = curve::negate_field_element(&curve.offset_point.1, &curve.field_modulus_p); let neg_offset_y_values = curve::decompose_to_limbs(&neg_offset_y_raw, limb_bits, num_limbs); @@ -267,12 +242,6 @@ pub(super) fn process_multi_point_non_native<'a>( let neg_gen_y_raw = curve::negate_field_element(&curve.generator.1, &curve.field_modulus_p); let neg_gen_y_values = curve::decompose_to_limbs(&neg_gen_y_raw, limb_bits, num_limbs); - let mut ops = MultiLimbOps { - compiler, - range_checks, - params: ¶ms, - }; - let sub_x = { let off_x = ops.constant_limbs(&offset_x_values); let g_x = ops.constant_limbs(&gen_x_limb_values); @@ -380,17 +349,5 @@ fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits compiler.add_sum(terms) } -/// Decomposes a half-scalar witness into `half_bits` bit witnesses (LSB first). -fn decompose_half_scalar_bits( - compiler: &mut NoirToR1CSCompiler, - scalar: usize, - half_bits: usize, -) -> Vec { - let log_bases = vec![1usize; half_bits]; - let dd = add_digital_decomposition(compiler, log_bases, vec![scalar]); - let mut bits = Vec::with_capacity(half_bits); - for bit_idx in 0..half_bits { - bits.push(dd.get_digit_witness_index(bit_idx, 0)); - } - bits -} +// `decompose_half_scalar_bits` replaced by `super::decompose_signed_bits` +// which produces signed digits with skew correction, halving lookup tables. From 1ac6a8e3729e1552ffd9809f26198b451cd0114c Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 13 Mar 2026 06:09:53 +0530 Subject: [PATCH 17/19] feat : added prover hint for step by step multi limb ops and fused table doubling --- provekit/common/src/witness/mod.rs | 4 +- .../src/witness/scheduling/dependency.rs | 19 +- .../common/src/witness/scheduling/remapper.rs | 22 + .../common/src/witness/witness_builder.rs | 32 + provekit/prover/src/bigint_mod.rs | 169 +++++ .../prover/src/witness/witness_builder.rs | 268 +++++++- provekit/r1cs-compiler/src/msm/cost_model.rs | 320 +++++++-- provekit/r1cs-compiler/src/msm/curve.rs | 4 + provekit/r1cs-compiler/src/msm/ec_points.rs | 650 +++++++++++++++++- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 64 +- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 34 +- provekit/r1cs-compiler/src/msm/non_native.rs | 26 +- .../r1cs-compiler/src/msm/scalar_relation.rs | 81 ++- 13 files changed, 1584 insertions(+), 109 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index f7cf80db2..e4968563c 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -19,8 +19,8 @@ pub use { ram::{SpiceMemoryOperation, SpiceWitnesses}, scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, witness_builder::{ - CombinedTableEntryInverseData, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, witness_generator::NoirWitnessGenerator, }; diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 6f06268ea..f86ea414e 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -1,7 +1,7 @@ use { crate::witness::{ - ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, - WitnessCoefficient, + ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, + WitnessBuilder, WitnessCoefficient, }, std::collections::HashMap, }; @@ -225,6 +225,9 @@ impl DependencyInfo { } WitnessBuilder::EcDoubleHint { px, py, .. } => vec![*px, *py], WitnessBuilder::EcAddHint { x1, y1, x2, y2, .. } => vec![*x1, *y1, *x2, *y2], + WitnessBuilder::NonNativeEcHint { inputs, .. } => { + inputs.iter().flatten().copied().collect() + } WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], WitnessBuilder::EcScalarMulHint { px, py, s_lo, s_hi, .. @@ -350,6 +353,18 @@ impl DependencyInfo { WitnessBuilder::EcAddHint { output_start, .. } => { (*output_start..*output_start + 3).collect() } + WitnessBuilder::NonNativeEcHint { + output_start, + num_limbs, + op, + .. + } => { + let count = match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (12 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (7 * *num_limbs - 4) as usize, + }; + (*output_start..*output_start + count).collect() + } WitnessBuilder::FakeGLVHint { output_start, .. } => { (*output_start..*output_start + 4).collect() } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 0e039a1ba..5d8b17023 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -394,6 +394,28 @@ impl WitnessIndexRemapper { y2: self.remap(*y2), field_modulus_p: *field_modulus_p, }, + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => WitnessBuilder::NonNativeEcHint { + output_start: self.remap(*output_start), + op: op.clone(), + inputs: inputs + .iter() + .map(|v| v.iter().map(|&w| self.remap(w)).collect()) + .collect(), + curve_a: *curve_a, + curve_b: *curve_b, + field_modulus_p: *field_modulus_p, + limb_bits: *limb_bits, + num_limbs: *num_limbs, + }, WitnessBuilder::FakeGLVHint { output_start, s_lo, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 22ecee690..0e03090a7 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -54,9 +54,21 @@ pub struct CombinedTableEntryInverseData { pub xor_out: FieldElement, } +/// Operation type for the unified non-native EC hint. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum NonNativeEcOp { + /// Point doubling: inputs = [[px_limbs], [py_limbs]], outputs 12N-6 + Double, + /// Point addition: inputs = [[x1_limbs], [y1_limbs], [x2_limbs], + /// [y2_limbs]], outputs 12N-6 + Add, + /// On-curve check: inputs = [[px_limbs], [py_limbs]], outputs 7N-4 + OnCurve, +} + /// Indicates how to solve for a collection of R1CS witnesses in terms of /// earlier (i.e. already solved for) R1CS witnesses and/or ACIR witness values. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum WitnessBuilder { /// Constant value, used for the constant one witness & e.g. static lookups /// (witness index, constant value) @@ -346,6 +358,22 @@ pub enum WitnessBuilder { a: usize, b: usize, }, + /// Unified prover hint for non-native EC operations (multi-limb). + /// + /// `op` selects the operation: + /// - `Double`: inputs = [[px], [py]], outputs 12N-6 witnesses + /// - `Add`: inputs = [[x1], [y1], [x2], [y2]], outputs 12N-6 witnesses + /// - `OnCurve`: inputs = [[px], [py]], outputs 7N-4 witnesses + NonNativeEcHint { + output_start: usize, + op: NonNativeEcOp, + inputs: Vec>, + curve_a: [u64; 4], + curve_b: [u64; 4], + field_modulus_p: [u64; 4], + limb_bits: u32, + num_limbs: u32, + }, /// Signed-bit decomposition hint for wNAF scalar multiplication. /// Given scalar s with num_bits bits, computes sign-bits b_0..b_{n-1} /// and skew ∈ {0,1} such that: @@ -427,6 +455,10 @@ impl WitnessBuilder { WitnessBuilder::SignedBitHint { num_bits, .. } => *num_bits + 1, WitnessBuilder::EcDoubleHint { .. } => 3, WitnessBuilder::EcAddHint { .. } => 3, + WitnessBuilder::NonNativeEcHint { op, num_limbs, .. } => match op { + NonNativeEcOp::Double | NonNativeEcOp::Add => (12 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (7 * *num_limbs - 4) as usize, + }, WitnessBuilder::FakeGLVHint { .. } => 4, WitnessBuilder::EcScalarMulHint { .. } => 2, diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index 95a605793..09f8c5308 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -291,6 +291,89 @@ pub fn decompose_to_u128_limbs(val: &[u64; 4], num_limbs: usize, limb_bits: u32) limbs } +/// Convert u128 limbs to i128 limbs (for carry computation linear terms). +pub fn to_i128_limbs(limbs: &[u128]) -> Vec { + limbs.iter().map(|&v| v as i128).collect() +} + +/// Compute signed quotient q such that: +/// Σ lhs_products[i] * coeff_i - Σ rhs_products[j] * coeff_j - rhs_sub ≡ 0 +/// (mod p) Returns q as decomposed limbs, with negative q stored as -q in the +/// native field. +pub fn signed_quotient_wide( + lhs_products: &[(&[u64; 4], &[u64; 4], u64)], + rhs_products: &[(&[u64; 4], &[u64; 4], u64)], + rhs_sub: Option<&[u64; 4]>, + p: &[u64; 4], + n: usize, + w: u32, +) -> Vec { + fn accumulate_wide(terms: &[(&[u64; 4], &[u64; 4], u64)]) -> [u64; 8] { + let mut acc = [0u64; 8]; + for &(a, b, coeff) in terms { + let prod = widening_mul(a, b); + let mut carry = 0u128; + for i in 0..8 { + let v = acc[i] as u128 + (prod[i] as u128) * (coeff as u128) + carry; + acc[i] = v as u64; + carry = v >> 64; + } + } + acc + } + + let lhs_wide = accumulate_wide(lhs_products); + let mut rhs_wide = accumulate_wide(rhs_products); + if let Some(sub) = rhs_sub { + let mut carry = 0u128; + for i in 0..4 { + let v = rhs_wide[i] as u128 + sub[i] as u128 + carry; + rhs_wide[i] = v as u64; + carry = v >> 64; + } + for i in 4..8 { + let v = rhs_wide[i] as u128 + carry; + rhs_wide[i] = v as u64; + carry = v >> 64; + } + } + + let lhs_ge = { + let mut ge = false; + for i in (0..8).rev() { + if lhs_wide[i] > rhs_wide[i] { + ge = true; + break; + } else if lhs_wide[i] < rhs_wide[i] { + break; + } + if i == 0 { + ge = true; + } + } + ge + }; + let (big, small) = if lhs_ge { + (lhs_wide, rhs_wide) + } else { + (rhs_wide, lhs_wide) + }; + let mut diff = [0u64; 8]; + let mut bw = 0u64; + for i in 0..8 { + let (d1, b1) = big[i].overflowing_sub(small[i]); + let (d2, b2) = d1.overflowing_sub(bw); + diff[i] = d2; + bw = b1 as u64 + b2 as u64; + } + let (q_abs, _) = divmod_wide(&diff, p); + if lhs_ge { + decompose_to_u128_limbs(&q_abs, n, w) + } else { + decompose_to_u128_limbs(&fe_to_bigint(-bigint_to_fe(&q_abs)), n, w) + } +} + /// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` /// boundaries. pub fn reconstruct_from_u128_limbs(limb_values: &[u128], limb_bits: u32) -> [u64; 4] { @@ -760,6 +843,92 @@ pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; (quotient, remainder) } +/// Compute unsigned-offset carries for a general merged column equation. +/// +/// Each `product_set` entry is (a_limbs, b_limbs, coefficient): +/// LHS_terms = Σ coeff * Σ_{i+j=k} a[i]*b[j] +/// +/// Each `linear_set` entry is (limb_values, coefficient) for non-product terms: +/// LHS_terms += Σ coeff * val[k] (for k < val.len()) +/// +/// The equation verified is: LHS = Σ p[i]*q[j] + carry_chain +/// (no separate result — the "result" is encoded in the linear terms). +pub fn compute_ec_verification_carries( + product_sets: &[(&[u128], &[u128], i64)], + linear_terms: &[(Vec, i64)], // (limb_values extended to 2N-1, coefficient) + p_limbs: &[u128], + q_limbs: &[u128], + n: usize, + limb_bits: u32, +) -> Vec { + let w = limb_bits; + let num_columns = 2 * n - 1; + let num_carries = num_columns - 1; + + // Use a larger offset to account for merged terms. + // Max terms per column: sum of coefficients × N products + linear terms. + let max_coeff_sum: u64 = product_sets + .iter() + .map(|(_, _, c)| c.unsigned_abs() as u64) + .sum::() + + linear_terms + .iter() + .map(|(_, c)| c.unsigned_abs() as u64) + .sum::() + + n as u64; // p*q terms + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = w + extra_bits; + let carry_offset = 1u128 << carry_offset_bits; + + let mut carries = Vec::with_capacity(num_carries); + let mut carry: i128 = 0; + + for k in 0..num_columns { + let mut col_value: i128 = 0; + + // Product terms + for &(a, b, coeff) in product_sets { + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value += coeff as i128 * (a[i] as i128) * (b[j as usize] as i128); + } + } + } + + // Linear terms + for (vals, coeff) in linear_terms { + if k < vals.len() { + col_value += *coeff as i128 * vals[k]; + } + } + + // Subtract p*q contribution + for i in 0..n { + let j = k as isize - i as isize; + if j >= 0 && (j as usize) < n { + col_value -= (p_limbs[i] as i128) * (q_limbs[j as usize] as i128); + } + } + + col_value += carry; + + if k < num_carries { + debug_assert_eq!( + col_value & ((1i128 << w) - 1), + 0, + "non-zero remainder at column {k}: col_value={col_value}" + ); + carry = col_value >> w; + carries.push((carry + carry_offset as i128) as u128); + } else { + debug_assert_eq!(col_value, 0, "non-zero final column value: {col_value}"); + } + } + + carries +} + #[cfg(test)] mod tests { use super::*; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 53e91d1d2..d663728f9 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -1,10 +1,12 @@ use { crate::{ bigint_mod::{ - add_4limb, bigint_to_fe, cmp_4limb, compute_mul_mod_carries, decompose_to_u128_limbs, - divmod, divmod_wide, ec_point_add_with_lambda, ec_point_double_with_lambda, - ec_scalar_mul, fe_to_bigint, half_gcd, mod_pow, reconstruct_from_halves, - reconstruct_from_u128_limbs, sub_u64, widening_mul, + add_4limb, bigint_to_fe, cmp_4limb, compute_ec_verification_carries, + compute_mul_mod_carries, decompose_to_u128_limbs, divmod, divmod_wide, + ec_point_add_with_lambda, ec_point_double_with_lambda, ec_scalar_mul, fe_to_bigint, + half_gcd, mod_add, mod_pow, mod_sub, mul_mod, reconstruct_from_halves, + reconstruct_from_u128_limbs, signed_quotient_wide, sub_u64, to_i128_limbs, + widening_mul, }, witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, }, @@ -14,8 +16,8 @@ use { provekit_common::{ utils::noir_to_native, witness::{ - compute_spread, ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, - WitnessBuilder, WitnessCoefficient, + compute_spread, ConstantOrR1CSWitness, ConstantTerm, NonNativeEcOp, ProductLinearTerm, + SumTerm, WitnessBuilder, WitnessCoefficient, }, FieldElement, NoirElement, TranscriptSponge, }, @@ -490,6 +492,260 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output_start + 1] = Some(bigint_to_fe(&x3)); witness[*output_start + 2] = Some(bigint_to_fe(&y3)); } + WitnessBuilder::NonNativeEcHint { + output_start, + op, + inputs, + curve_a, + curve_b, + field_modulus_p, + limb_bits, + num_limbs, + } => { + let n = *num_limbs as usize; + let w = *limb_bits; + let os = *output_start; + + let p_l = decompose_to_u128_limbs(field_modulus_p, n, w); + + match op { + NonNativeEcOp::Double => { + let px_val = read_witness_limbs(witness, &inputs[0], w); + let py_val = read_witness_limbs(witness, &inputs[1], w); + let (lam, x3v, y3v) = + ec_point_double_with_lambda(&px_val, &py_val, curve_a, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + let a_l = decompose_to_u128_limbs(curve_a, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Eq1: 2*λ*py - 3*px² - a = q1*p + let q1 = signed_quotient_wide( + &[(&lam, &py_val, 2)], + &[(&px_val, &px_val, 3)], + Some(curve_a), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 3 * n, &q1); + let c1 = compute_ec_verification_carries( + &[(&ll, &pyl, 2), (&pl, &pl, -3)], + &[(to_i128_limbs(&a_l), -1)], + &p_l, + &q1, + n, + w, + ); + write_limbs(witness, os + 4 * n, &c1); + + // Eq2: λ² - x3 - 2*px = q2*p + let two_px = mod_add(&px_val, &px_val, field_modulus_p); + let rhs2 = mod_add(&x3v, &two_px, field_modulus_p); + let q2 = signed_quotient_wide( + &[(&lam, &lam, 1)], + &[], + Some(&rhs2), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 6 * n - 2, &q2); + let c2 = compute_ec_verification_carries( + &[(&ll, &ll, 1)], + &[ + (to_i128_limbs(&xl), -1), + (pl.iter().map(|&v| 2 * v as i128).collect(), -1), + ], + &p_l, + &q2, + n, + w, + ); + write_limbs(witness, os + 7 * n - 2, &c2); + + // Eq3: λ*(px-x3) - y3 - py = q3*p + let dx = mod_sub(&px_val, &x3v, field_modulus_p); + let rhs3 = mod_add(&y3v, &py_val, field_modulus_p); + let q3 = signed_quotient_wide( + &[(&lam, &dx, 1)], + &[], + Some(&rhs3), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 9 * n - 4, &q3); + let c3 = compute_ec_verification_carries( + &[(&ll, &pl, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&pyl), -1)], + &p_l, + &q3, + n, + w, + ); + write_limbs(witness, os + 10 * n - 4, &c3); + } + NonNativeEcOp::Add => { + let x1v = read_witness_limbs(witness, &inputs[0], w); + let y1v = read_witness_limbs(witness, &inputs[1], w); + let x2v = read_witness_limbs(witness, &inputs[2], w); + let y2v = read_witness_limbs(witness, &inputs[3], w); + let (lam, x3v, y3v) = + ec_point_add_with_lambda(&x1v, &y1v, &x2v, &y2v, field_modulus_p); + let ll = decompose_to_u128_limbs(&lam, n, w); + let xl = decompose_to_u128_limbs(&x3v, n, w); + let yl = decompose_to_u128_limbs(&y3v, n, w); + let x1l = decompose_to_u128_limbs(&x1v, n, w); + let y1l = decompose_to_u128_limbs(&y1v, n, w); + let x2l = decompose_to_u128_limbs(&x2v, n, w); + let y2l = decompose_to_u128_limbs(&y2v, n, w); + write_limbs(witness, os, &ll); + write_limbs(witness, os + n, &xl); + write_limbs(witness, os + 2 * n, &yl); + + // Eq1: λ*(x2-x1) - (y2-y1) = q1*p + let dx = mod_sub(&x2v, &x1v, field_modulus_p); + let dy = mod_sub(&y2v, &y1v, field_modulus_p); + let q1 = signed_quotient_wide( + &[(&lam, &dx, 1)], + &[], + Some(&dy), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 3 * n, &q1); + let c1 = compute_ec_verification_carries( + &[(&ll, &x2l, 1), (&ll, &x1l, -1)], + &[(to_i128_limbs(&y2l), -1), (to_i128_limbs(&y1l), 1)], + &p_l, + &q1, + n, + w, + ); + write_limbs(witness, os + 4 * n, &c1); + + // Eq2: λ² - x3 - x1 - x2 = q2*p + let sum_x = + mod_add(&x3v, &mod_add(&x1v, &x2v, field_modulus_p), field_modulus_p); + let q2 = signed_quotient_wide( + &[(&lam, &lam, 1)], + &[], + Some(&sum_x), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 6 * n - 2, &q2); + let c2 = compute_ec_verification_carries( + &[(&ll, &ll, 1)], + &[ + (to_i128_limbs(&xl), -1), + (to_i128_limbs(&x1l), -1), + (to_i128_limbs(&x2l), -1), + ], + &p_l, + &q2, + n, + w, + ); + write_limbs(witness, os + 7 * n - 2, &c2); + + // Eq3: λ*(x1-x3) - y3 - y1 = q3*p + let dx3 = mod_sub(&x1v, &x3v, field_modulus_p); + let rhs3 = mod_add(&y3v, &y1v, field_modulus_p); + let q3 = signed_quotient_wide( + &[(&lam, &dx3, 1)], + &[], + Some(&rhs3), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 9 * n - 4, &q3); + let c3 = compute_ec_verification_carries( + &[(&ll, &x1l, 1), (&ll, &xl, -1)], + &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&y1l), -1)], + &p_l, + &q3, + n, + w, + ); + write_limbs(witness, os + 10 * n - 4, &c3); + } + NonNativeEcOp::OnCurve => { + let px_val = read_witness_limbs(witness, &inputs[0], w); + let py_val = read_witness_limbs(witness, &inputs[1], w); + let x_sq_val = mul_mod(&px_val, &px_val, field_modulus_p); + let xsl = decompose_to_u128_limbs(&x_sq_val, n, w); + let pl = decompose_to_u128_limbs(&px_val, n, w); + let pyl = decompose_to_u128_limbs(&py_val, n, w); + write_limbs(witness, os, &xsl); + + // Eq1: px·px - x_sq = q1·p (always non-negative quotient) + let q1 = signed_quotient_wide( + &[(&px_val, &px_val, 1)], + &[], + Some(&x_sq_val), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + n, &q1); + let c1 = compute_ec_verification_carries( + &[(&pl, &pl, 1)], + &[(to_i128_limbs(&xsl), -1)], + &p_l, + &q1, + n, + w, + ); + write_limbs(witness, os + 2 * n, &c1); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let x_sq_px = mul_mod(&x_sq_val, &px_val, field_modulus_p); + let a_px = mul_mod(curve_a, &px_val, field_modulus_p); + let rhs_val = mod_add( + &mod_add(&x_sq_px, &a_px, field_modulus_p), + curve_b, + field_modulus_p, + ); + let q2 = signed_quotient_wide( + &[(&py_val, &py_val, 1)], + &[], + Some(&rhs_val), + field_modulus_p, + n, + w, + ); + write_limbs(witness, os + 4 * n - 2, &q2); + + let a_l = decompose_to_u128_limbs(curve_a, n, w); + let b_l = decompose_to_u128_limbs(curve_b, n, w); + let a_is_zero = curve_a.iter().all(|&v| v == 0); + let mut prod_sets: Vec<(&[u128], &[u128], i64)> = + vec![(&pyl, &pyl, 1), (&xsl, &pl, -1)]; + if !a_is_zero { + prod_sets.push((&a_l, &pl, -1)); + } + let c2 = compute_ec_verification_carries( + &prod_sets, + &[(to_i128_limbs(&b_l), -1)], + &p_l, + &q2, + n, + w, + ); + write_limbs(witness, os + 5 * n - 2, &c2); + } + } + } WitnessBuilder::EcScalarMulHint { output_start, px, diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index d055fa89d..d06bc0eee 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -102,8 +102,11 @@ fn scalar_relation_cost( + 2 + n; + // Only n limbs worth of scalar DD digits get range checks; unused digits + // are zero-constrained instead (soundness fix for small curves). + let scalar_dd_rcs = n.min(2 * scalar_half_limbs); let mut rc_map = BTreeMap::new(); - *rc_map.entry(limb_bits).or_default() += 2 * scalar_half_limbs + 2 * half_limbs; + *rc_map.entry(limb_bits).or_default() += scalar_dd_rcs + 2 * half_limbs; add_field_op_range_checks( 1, 1, @@ -144,20 +147,249 @@ pub fn calculate_msm_witness_cost( let half_table_size = 1usize << (w - 1); let num_windows = ceil_div(half_bits, w); + // Use hint-verified costs for multi-limb (n >= 2), generic field ops for + // single-limb. + let use_hint_verified = n >= 2; + + if use_hint_verified { + calculate_msm_witness_cost_hint_verified( + native_field_bits, + n_points, + scalar_bits, + w, + limb_bits, + n, + half_bits, + half_table_size, + num_windows, + ) + } else { + calculate_msm_witness_cost_generic( + native_field_bits, + curve_modulus_bits, + n_points, + scalar_bits, + w, + limb_bits, + n, + half_bits, + half_table_size, + num_windows, + ) + } +} + +/// Hint-verified non-native MSM cost (num_limbs >= 2). +/// +/// EC ops use prover hints verified via schoolbook column equations: +/// - `point_double_verified_non_native`: (12N-6)W hint + 5N² products + N +/// constants +/// - `point_add_verified_non_native`: (12N-6)W hint + 4N² products +/// - `verify_on_curve_non_native`: (7N-4)W hint + 4N² products + 2N constants +/// (worst case) +/// +/// Each hint-verified op also produces: +/// - 3N less_than_p witnesses per double/add (3 calls × 3N per call = 9N) +/// - 1N less_than_p witnesses per on-curve (1 call × 3N = 3N) +/// - Range checks: hint limbs at limb_bits + carries at carry_range_bits + +/// less_than_p at limb_bits +#[allow(clippy::too_many_arguments)] +fn calculate_msm_witness_cost_hint_verified( + native_field_bits: u32, + n_points: usize, + scalar_bits: usize, + w: usize, + limb_bits: u32, + n: usize, + half_bits: usize, + half_table_size: usize, + num_windows: usize, +) -> usize { + // === Hint-verified EC op witness counts === + // point_double: (12N-6) hint + 5N² products + N pinned constants + 9N + // less_than_p + let double_hint = 12 * n - 6; + let double_products = 5 * n * n; + let double_constants = n; // a_limbs + let double_ltp = 3 * 3 * n; // 3 less_than_p calls × 3N each + let double_wit = double_hint + double_products + double_constants + double_ltp; + + // point_add: (12N-6) hint + 4N² products + 9N less_than_p + let add_hint = 12 * n - 6; + let add_products = 4 * n * n; + let add_ltp = 3 * 3 * n; + let add_wit = add_hint + add_products + add_ltp; + + // on_curve (worst case, a != 0): (7N-4) hint + 4N² products + 2N constants + 3N + // less_than_p + let oncurve_hint = 7 * n - 4; + let oncurve_products = 4 * n * n; // px_px + py_py + xsq_px + a_px + let oncurve_constants = 2 * n; // a_limbs + b_limbs + let oncurve_ltp = 3 * n; // 1 less_than_p call + let oncurve_wit = oncurve_hint + oncurve_products + oncurve_constants + oncurve_ltp; + + // === Shared costs (doublings, counted once) === + let shared_doubles = num_windows * w; + let shared_ec_wit = shared_doubles * double_wit; + // Offset point constant_limbs: 2N (shared, allocated once in Phase 2) + let shared_offset_constants = 2 * n; + + // === Per-point EC costs === + // Table building: 2 tables × (1 double + (half_table_size-1) adds) when size >= + // 2 + let (tbl_d, tbl_a) = if half_table_size >= 2 { + (1, half_table_size - 1) + } else { + (0, 0) + }; + let pp_table_ec = 2 * (tbl_d * double_wit + tbl_a * add_wit); + + // Main loop per-point: 2 adds per window + let pp_loop_ec = num_windows * 2 * add_wit; + + // Skew corrections: 2 branches × 1 add per point + let pp_skew_ec = 2 * add_wit; + + // On-curve checks (P and R): 2 calls + let pp_oncurve = 2 * oncurve_wit; + + // Y-negation via negate_mod_p_multi (borrow chain, no less_than_p): + // negate = 3N witnesses (N v-sums + N borrows + N r-sums) + // select = N witnesses + let negate_wit = 3 * n; + let pp_y_negate = 2 * (negate_wit + n); // 2 × (negate + select) + + // Signed table lookup per window: negate + select_unchecked on y + // negate(y): 3N, select_unchecked(y): N + let pp_signed_lookup_negate = num_windows * 2 * (negate_wit + n); + + // Skew correction negate: 2 × negate(py) via negate_mod_p_multi + let pp_skew_negate = 2 * negate_wit; + // Skew correction selects: 2 branches × 2N (x+y select_unchecked) + let pp_skew_selects = 2 * 2 * n; + + // Selects for signed table lookup (not field ops) + let table_selects = num_windows * 2 * (half_table_size.saturating_sub(1)) * 2 * n; + let xor_cost = num_windows * 2 * 2 * w.saturating_sub(1); + + let pp_selects = table_selects + xor_cost; + + // === Per-point overhead (non-EC) === + let scalar_bit_decomp = 2 * (half_bits + 1); + let detect_skip = 8; + let sanitize = 4; + let ec_hint = 4; // EcScalarMulHint (2W) + 2W selects + let point_decomp = 4 * n; // 4 witnesses per coord (N>1 always here) + let glv_hint = 4; + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + + let per_point = pp_table_ec + + pp_loop_ec + + pp_skew_ec + + pp_oncurve + + pp_y_negate + + pp_signed_lookup_negate + + pp_skew_negate + + pp_skew_selects + + pp_selects + + scalar_bit_decomp + + detect_skip + + sanitize + + ec_hint + + point_decomp + + glv_hint + + sr_witnesses; + + // === Shared constants === + let shared_constants = 3 + shared_offset_constants; // gen_x, gen_y, zero + offset + + // === Point accumulation === + // Accumulation adds use hint-verified point_add_dispatch + let accum_add_wit = add_wit; + let accum = n_points * (accum_add_wit + 2 * n) // per-point add + skip select + + n_points.saturating_sub(1) // all_skipped products + + accum_add_wit + 4 * n + 2 * n // offset subtraction + constants + selects + + 2 + 2; // mask + recompose (n > 1 always) + + // === Range check resolution === + let mut rc_map: BTreeMap = BTreeMap::new(); + + // Hint-verified EC range checks + let double_carry_bits = hint_verified_carry_range_bits(limb_bits, 6 + n as u64, n); + let add_carry_bits = hint_verified_carry_range_bits(limb_bits, 4 + n as u64, n); + let oncurve_carry_bits = hint_verified_carry_range_bits(limb_bits, 5 + n as u64, n); + + // Shared doublings range checks + // Per double: 6N limb checks + 3*(2N-2) carry checks + 3 × 2N less_than_p limb + // checks + *rc_map.entry(limb_bits).or_default() += shared_doubles * (6 * n + 3 * 2 * n); + *rc_map.entry(double_carry_bits).or_default() += shared_doubles * 3 * (2 * n - 2); + + // Per-point EC range checks + let pp_doubles_count = 2 * tbl_d; + let pp_adds_count = 2 * tbl_a + num_windows * 2 + 2; // table + loop + skew + let pp_oncurve_count = 2; + + // Per double: 6N limb + 6N ltp limb + 3*(2N-2) carry + *rc_map.entry(limb_bits).or_default() += n_points * pp_doubles_count * (6 * n + 3 * 2 * n); + *rc_map.entry(double_carry_bits).or_default() += n_points * pp_doubles_count * 3 * (2 * n - 2); + + // Per add: 6N limb + 6N ltp limb + 3*(2N-2) carry + *rc_map.entry(limb_bits).or_default() += n_points * pp_adds_count * (6 * n + 3 * 2 * n); + *rc_map.entry(add_carry_bits).or_default() += n_points * pp_adds_count * 3 * (2 * n - 2); + + // Per on-curve: 3N limb + 2N ltp limb + 2*(2N-2) carry + *rc_map.entry(limb_bits).or_default() += n_points * pp_oncurve_count * (3 * n + 2 * n); + *rc_map.entry(oncurve_carry_bits).or_default() += n_points * pp_oncurve_count * 2 * (2 * n - 2); + + // Accumulation adds range checks (hint-verified) + *rc_map.entry(limb_bits).or_default() += (n_points + 1) * (6 * n + 3 * 2 * n); + *rc_map.entry(add_carry_bits).or_default() += (n_points + 1) * 3 * (2 * n - 2); + + // Negate range checks via negate_mod_p_multi: N limb checks per negate + // (no less_than_p, so N instead of 2N per negate) + let negate_count_pp = 2 + num_windows * 2 + 2; // y-negate(2) + signed_lookup + skew + *rc_map.entry(limb_bits).or_default() += n_points * negate_count_pp * n; + + // Point decomposition + *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; + + // Scalar relation + for (bits, count) in &sr_range_checks { + *rc_map.entry(*bits).or_default() += n_points * count; + } + + let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); + + shared_ec_wit + shared_constants + n_points * per_point + accum + range_check_cost +} + +/// Generic (single-limb) non-native MSM cost using MultiLimbOps field op +/// chains. +#[allow(clippy::too_many_arguments)] +fn calculate_msm_witness_cost_generic( + native_field_bits: u32, + curve_modulus_bits: u32, + n_points: usize, + scalar_bits: usize, + w: usize, + limb_bits: u32, + n: usize, + half_bits: usize, + half_table_size: usize, + num_windows: usize, +) -> usize { // === GLV scalar mul field op counts === // point_double: (5 add, 3 sub, 4 mul, 1 inv) + N constant witnesses (curve_a) // point_add: (1 add, 5 sub, 3 mul, 1 inv) // --- Shared costs (counted once, NOT per-point) --- - // Main loop doublings: w doublings per window, shared across all points - let mut shared_add = num_windows * w * 5; - let mut shared_sub = num_windows * w * 3; - let mut shared_mul = num_windows * w * 4; - let mut shared_inv = num_windows * w; + let shared_add = num_windows * w * 5; + let shared_sub = num_windows * w * 3; + let shared_mul = num_windows * w * 4; + let shared_inv = num_windows * w; // --- Per-point costs --- - // Signed table building (2 tables for P and R): odd multiples [P, 3P, 5P, ...] - // Build cost: 1 double (for 2P) + (half_table_size - 1) adds when size >= 2. let (tbl_d, tbl_a) = if half_table_size >= 2 { (1, half_table_size - 1) } else { @@ -168,64 +400,43 @@ pub fn calculate_msm_witness_cost( let mut pp_mul = 2 * (tbl_d * 4 + tbl_a * 3); let mut pp_inv = 2 * (tbl_d + tbl_a); - // Main loop per-point: 2 point_adds + 2 negates per window pp_add += num_windows * 2 * 1; - pp_sub += num_windows * (2 * 5 + 2); // +2 for signed lookup negates + pp_sub += num_windows * (2 * 5 + 2); pp_mul += num_windows * 2 * 3; pp_inv += num_windows * 2; - // Skew corrections: 2 branches × (1 negate + 1 point_add) per point pp_add += 2 * 1; - pp_sub += 2 * (5 + 1); // point_add subs + negate sub + pp_sub += 2 * (5 + 1); pp_mul += 2 * 3; pp_inv += 2; - // On-curve checks (P and R): 2 × (4 mul + 2 add) - pp_mul += 8; + pp_mul += 8; // on-curve pp_add += 4; - - // Y-negation (FakeGLV): 2 negate = 2 sub - pp_sub += 2; + pp_sub += 2; // y-negation let shared_field_ops = field_op_witnesses(shared_add, shared_sub, shared_mul, shared_inv, n, false); let pp_field_ops = field_op_witnesses(pp_add, pp_sub, pp_mul, pp_inv, n, false); - // Constant witness allocations not captured by field ops: - // - curve_a() in each point_double: N per call - // - on-curve: 2 × (curve_a + curve_b) = 4N - // - negate zeros: FakeGLV(2) + signed_lookup(2*num_windows) + skew(2) = - // (4+2*num_windows)×N - // - offset point: 2N (shared) let shared_doubles = num_windows * w; let pp_doubles = 2 * tbl_d; let pp_negate_zeros = (4 + 2 * num_windows) * n; - let shared_constants_glv = shared_doubles * n + 2 * n; // shared double curve_a + offset + let shared_constants_glv = shared_doubles * n + 2 * n; let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros; - // Selects (not field ops, priced separately) - // Signed table: halved from 2^w to 2^(w-1) entries let table_selects = num_windows * 2 * (half_table_size.saturating_sub(1)) * 2 * n; - // XOR bits for signed lookup: 2 witnesses per bit, (w-1) bits, 2 branches let xor_cost = num_windows * 2 * 2 * w.saturating_sub(1); - // Y-select after negate in signed lookup: 2 branches per window let signed_y_selects = num_windows * 2 * n; - // FakeGLV y-negation selects let y_negate_selects = 2 * n; - // Skew correction selects: 2 branches × 2 (x+y) let skew_selects = 2 * 2 * n; - let pp_selects = table_selects + xor_cost + signed_y_selects + y_negate_selects + skew_selects; - // === Per-point overhead === - // Signed-bit decomposition: num_bits bits + 1 skew per half-scalar, 2 - // half-scalars let scalar_bit_decomp = 2 * (half_bits + 1); - let detect_skip = 8; // 2×is_zero(3W) + product(1W) + or(1W) - let sanitize = 4; // 4 select_witness - let ec_hint = 4; // 2W hint + 2W selects + let detect_skip = 8; + let sanitize = 4; + let ec_hint = 4; let point_decomp = if n > 1 { 4 * n } else { 0 }; - let glv_hint = 4; // s1, s2, neg1, neg2 + let glv_hint = 4; let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); let per_point = pp_field_ops @@ -239,21 +450,18 @@ pub fn calculate_msm_witness_cost( + glv_hint + sr_witnesses; - // === Shared constants (allocated once) === - // gen_x, gen_y, zero (3W) + offset_{x,y} (2×num_limbs W via constant_limbs) let shared_constants = 3 + 2 * n; - // === Point accumulation === - let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); // point_add - let accum = n_points * (pa_cost + 2 * n) // per-point add + skip select - + n_points.saturating_sub(1) // all_skipped products - + pa_cost + 4 * n + 2 * n // offset subtraction + constants + selects - + 2 + if n > 1 { 2 } else { 0 }; // mask + recompose + let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); + let accum = n_points * (pa_cost + 2 * n) + + n_points.saturating_sub(1) + + pa_cost + + 4 * n + + 2 * n + + 2 + + if n > 1 { 2 } else { 0 }; - // === Range check resolution === let mut rc_map: BTreeMap = BTreeMap::new(); - - // Shared GLV field ops (doublings — counted once) add_field_op_range_checks( shared_add, shared_sub, @@ -265,8 +473,6 @@ pub fn calculate_msm_witness_cost( false, &mut rc_map, ); - - // Per-point GLV field ops add_field_op_range_checks( n_points * pp_add, n_points * pp_sub, @@ -278,18 +484,12 @@ pub fn calculate_msm_witness_cost( false, &mut rc_map, ); - - // Point decomposition (per point, N>1 only) if n > 1 { *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; } - - // Scalar relation (per point) for (bits, count) in &sr_range_checks { *rc_map.entry(*bits).or_default() += n_points * count; } - - // Accumulation: (n_points + 1) point_adds (1 add, 5 sub, 3 mul, 1 inv each) add_field_op_range_checks( (n_points + 1) * 1, (n_points + 1) * 5, @@ -312,6 +512,12 @@ pub fn calculate_msm_witness_cost( + range_check_cost } +/// Carry range check bits for hint-verified EC column equations. +fn hint_verified_carry_range_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + limb_bits + extra_bits +} + /// Native-field MSM cost: hint-verified EC ops with signed-bit wNAF (w=1). /// /// The native path uses prover hints verified via raw R1CS constraints: diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve.rs index 1aa176f8b..278f7de47 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve.rs @@ -31,6 +31,10 @@ impl CurveParams { decompose_to_limbs(&self.curve_a, limb_bits, num_limbs) } + pub fn curve_b_limbs(&self, limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(&self.curve_b, limb_bits, num_limbs) + } + /// Number of bits in the field modulus. pub fn modulus_bits(&self) -> u32 { if self.is_native_field() { diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs index 764b64c55..6cb2e28b8 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points.rs @@ -1,13 +1,45 @@ use { - super::{curve::CurveParams, multi_limb_ops::MultiLimbOps, Limbs}, + super::{ + curve::CurveParams, + multi_limb_arith::less_than_p_check_multi, + multi_limb_ops::{MultiLimbOps, MultiLimbParams}, + Limbs, + }, crate::noir_to_r1cs::NoirToR1CSCompiler, ark_ff::{Field, PrimeField}, provekit_common::{ - witness::{SumTerm, WitnessBuilder}, + witness::{NonNativeEcOp, SumTerm, WitnessBuilder}, FieldElement, }, + std::collections::BTreeMap, }; +/// Dispatching point doubling: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_double_dispatch(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_double_verified_non_native(ops.compiler, ops.range_checks, x1, y1, ops.params) + } else { + point_double(ops, x1, y1) + } +} + +/// Dispatching point addition: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_add_dispatch( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_add_verified_non_native(ops.compiler, ops.range_checks, x1, y1, x2, y2, ops.params) + } else { + point_add(ops, x1, y1, x2, y2) + } +} + /// Generic point doubling on y^2 = x^3 + ax + b. /// /// Given P = (x1, y1), computes 2P = (x3, y3): @@ -123,10 +155,10 @@ fn build_signed_point_table( let mut table = Vec::with_capacity(half_table_size); table.push((px, py)); // T[0] = 1*P if half_table_size >= 2 { - let two_p = point_double(ops, px, py); // 2P + let two_p = point_double_dispatch(ops, px, py); // 2P for i in 1..half_table_size { let prev = table[i - 1]; - table.push(point_add(ops, prev.0, prev.1, two_p.0, two_p.1)); + table.push(point_add_dispatch(ops, prev.0, prev.1, two_p.0, two_p.1)); } } table @@ -319,7 +351,7 @@ pub fn scalar_mul_merged_glv( // w shared doublings on the accumulator (shared across ALL points) let mut doubled_acc = acc; for _ in 0..w { - doubled_acc = point_double(ops, doubled_acc.0, doubled_acc.1); + doubled_acc = point_double_dispatch(ops, doubled_acc.0, doubled_acc.1); } let mut cur = doubled_acc; @@ -337,7 +369,7 @@ pub fn scalar_mul_merged_glv( }; let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); // All signed digits are non-zero — no is_zero check needed - cur = point_add(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); // --- R branch (s2 window) --- let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; @@ -349,7 +381,7 @@ pub fn scalar_mul_merged_glv( &table_r[..] }; let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); - cur = point_add(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); } acc = cur; @@ -361,14 +393,14 @@ pub fn scalar_mul_merged_glv( for pt in points { // P branch skew let neg_py = ops.negate(pt.py); - let (sub_px, sub_py) = point_add(ops, acc.0, acc.1, pt.px, neg_py); + let (sub_px, sub_py) = point_add_dispatch(ops, acc.0, acc.1, pt.px, neg_py); let new_x = ops.select_unchecked(pt.s1_skew, acc.0, sub_px); let new_y = ops.select_unchecked(pt.s1_skew, acc.1, sub_py); acc = (new_x, new_y); // R branch skew let neg_ry = ops.negate(pt.ry); - let (sub_rx, sub_ry) = point_add(ops, acc.0, acc.1, pt.rx, neg_ry); + let (sub_rx, sub_ry) = point_add_dispatch(ops, acc.0, acc.1, pt.rx, neg_ry); let new_x = ops.select_unchecked(pt.s2_skew, acc.0, sub_rx); let new_y = ops.select_unchecked(pt.s2_skew, acc.1, sub_ry); acc = (new_x, new_y); @@ -533,3 +565,603 @@ pub fn verify_on_curve_native( (b_fe, compiler.witness_one()), ]); } + +// =========================================================================== +// Non-native hint-verified EC operations (multi-limb schoolbook) +// =========================================================================== +// These replace the step-by-step MultiLimbOps chain with prover hints verified +// via schoolbook column equations. Each bilinear mod-p equation is checked by: +// 1. Pre-computing product witnesses a[i]*b[j] +// 2. Column equations: Σ coeff·prod[k] + linear[k] + carry_in + offset = Σ +// p[i]*q[j] + carry_out * W +// Since p is constant, p[i]*q[j] terms are linear in q (no product witness). + +/// Collect witness indices from `start..start+len`. +fn witness_range(start: usize, len: usize) -> Vec { + (start..start + len).collect() +} + +/// Allocate N×N product witnesses for `a[i]*b[j]`. +fn make_products(compiler: &mut NoirToR1CSCompiler, a: &[usize], b: &[usize]) -> Vec> { + let n = a.len(); + debug_assert_eq!(n, b.len()); + let mut prods = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + prods[i][j] = compiler.add_product(a[i], b[j]); + } + } + prods +} + +/// Allocate pinned constant witnesses from pre-decomposed `FieldElement` limbs. +fn allocate_pinned_constant_limbs( + compiler: &mut NoirToR1CSCompiler, + limb_values: &[FieldElement], +) -> Vec { + limb_values + .iter() + .map(|&val| { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(w, val), + )); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(val, compiler.witness_one())], + ); + w + }) + .collect() +} + +/// Range-check limb witnesses at `limb_bits` and carry witnesses at +/// `carry_range_bits`. +fn range_check_limbs_and_carries( + range_checks: &mut BTreeMap>, + limb_vecs: &[&[usize]], + carry_vecs: &[&[usize]], + limb_bits: u32, + carry_range_bits: u32, +) { + for limbs in limb_vecs { + for &w in *limbs { + range_checks.entry(limb_bits).or_default().push(w); + } + } + for carries in carry_vecs { + for &c in *carries { + range_checks.entry(carry_range_bits).or_default().push(c); + } + } +} + +/// Convert `Vec` to `Limbs` and do a less-than-p check. +fn less_than_p_check_vec( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + v: &[usize], + params: &MultiLimbParams, +) { + let n = v.len(); + let mut limbs = Limbs::new(n); + for i in 0..n { + limbs[i] = v[i]; + } + less_than_p_check_multi( + compiler, + range_checks, + limbs, + ¶ms.p_minus_1_limbs, + params.two_pow_w, + params.limb_bits, + ); +} + +/// Emit schoolbook column equations for a merged verification equation. +/// +/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k = q·p (mod p, as integers) +/// +/// `product_sets`: each (products_2d, coefficient) where products_2d[i][j] +/// is the witness index for a[i]*b[j]. +/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms +/// (limb_witnesses has N entries, zero-padded). +/// `q_witnesses`: quotient limbs (N entries). +/// `carry_witnesses`: unsigned-offset carry witnesses (2N-2 entries). +fn emit_schoolbook_column_equations( + compiler: &mut NoirToR1CSCompiler, + product_sets: &[(&[Vec], FieldElement)], // (products[i][j], coeff) + linear_limbs: &[(&[usize], FieldElement)], // (limb_witnesses, coeff) + q_witnesses: &[usize], + carry_witnesses: &[usize], + p_limbs: &[FieldElement], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) { + let w1 = compiler.witness_one(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + + // Carry offset scaled for the merged equation's larger coefficients + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + let offset_w_minus_carry = offset_w - carry_offset_fe; + + let num_columns = 2 * n - 1; + + for k in 0..num_columns { + // LHS: Σ coeff * products[i][j] for i+j=k + carry_in + offset + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + + for &(products, coeff) in product_sets { + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((coeff, products[i][j_val as usize])); + } + } + } + + // Add linear terms (for k < N only, since linear_limbs are N-length) + for &(limbs, coeff) in linear_limbs { + if k < limbs.len() { + lhs_terms.push((coeff, limbs[k])); + } + } + + // Add carry_in and offset + if k > 0 { + lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ p[i]*q[j] for i+j=k + carry_out * W (or offset at last column) + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q_witnesses[j_val as usize])); + } + } + + if k < num_columns - 1 { + rhs_terms.push((two_pow_w, carry_witnesses[k])); + } else { + // Last column: balance with offset_w (no outgoing carry) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } +} + +/// Hint-verified on-curve check for non-native field (multi-limb). +/// +/// Verifies y² = x³ + ax + b (mod p) via: +/// Eq1: x·x - x_sq = q1·p (x_sq correctness) +/// Eq2: y·y - x_sq·x - a·x - b = q2·p (on-curve) +/// +/// Total: (7N-4)W hint + (N² + 2N² [+ N²])products + 2×(2N-1) constraints +/// + 1 less-than-p check. +pub fn verify_on_curve_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified on-curve check requires n >= 2"); + + let a_is_zero = params.curve_a_raw.iter().all(|&v| v == 0); + + // Soundness check + { + // max terms in a column: px·px(1) + x_sq(1) + py·py(1) + x_sq·px(1) + [a·px(1)] + // + b(1) + pq(N) + let max_coeff_sum: u64 = if a_is_zero { + 4 + n as u64 + } else { + 5 + n as u64 + }; + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let max_bits = 2 * params.limb_bits + extra_bits + 1; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "On-curve column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", + params.limb_bits + ); + } + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::OnCurve, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: params.curve_b_raw, + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [x_sq(N), q1(N), c1(2N-2), q2(N), c2(2N-2)] + let x_sq = witness_range(os, n); + let q1 = witness_range(os + n, n); + let c1 = witness_range(os + 2 * n, 2 * n - 2); + let q2 = witness_range(os + 4 * n - 2, n); + let c2 = witness_range(os + 5 * n - 2, 2 * n - 2); + + // Eq1: px·px - x_sq = q1·p + let prod_px_px = make_products(compiler, &px.as_slice()[..n], &px.as_slice()[..n]); + + let max_coeff_eq1: u64 = 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_px_px, FieldElement::ONE)], + &[(&x_sq, -FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let prod_py_py = make_products(compiler, &py.as_slice()[..n], &py.as_slice()[..n]); + let prod_xsq_px = make_products(compiler, &x_sq, &px.as_slice()[..n]); + let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_b_limbs[..n]); + + if a_is_zero { + let max_coeff_eq2: u64 = 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } else { + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + let prod_a_px = make_products(compiler, &a_limbs, &px.as_slice()[..n]); + + let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + (&prod_a_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } + + // Range checks on hint outputs + let max_coeff = if a_is_zero { + 4 + n as u64 + } else { + 5 + n as u64 + }; + let carry_extra_bits = ((max_coeff as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_range_bits = params.limb_bits + carry_extra_bits; + range_check_limbs_and_carries( + range_checks, + &[&x_sq, &q1, &q2], + &[&c1, &c2], + params.limb_bits, + carry_range_bits, + ); + + // Less-than-p check for x_sq + less_than_p_check_vec(compiler, range_checks, &x_sq, params); +} + +/// Hint-verified point doubling for non-native field (multi-limb). +/// +/// Allocates NonNativeEcDoubleHint → (lambda, x3, y3, q1, c1, q2, c2, q3, c3). +/// Verifies via schoolbook column equations on 3 EC verification equations. +/// Total: (12N-6)W hint + ~(4N²+N) products + 3×(2N-1) column constraints +/// + 3 less-than-p checks. +pub fn point_double_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + // Soundness check: merged column equations fit native field + { + let max_coeff_sum: u64 = 2 + 3 + 1 + n as u64; // λy(2) + xx(3) + a(1) + pq(N) + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let max_bits = 2 * params.limb_bits + extra_bits + 1; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "Merged EC column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", + params.limb_bits + ); + } + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Double, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: [0; 4], // unused for double + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [lambda(N), x3(N), y3(N), q1(N), c1(2N-2), q2(N), + // c2(2N-2), q3(N), c3(2N-2)] + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1 = witness_range(os + 3 * n, n); + let c1 = witness_range(os + 4 * n, 2 * n - 2); + let q2 = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + let q3 = witness_range(os + 9 * n - 4, n); + let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + + let px_s = &px.as_slice()[..n]; + let py_s = &py.as_slice()[..n]; + + // Eq1: 2*lambda*py - 3*px*px - a = q1*p + let prod_lam_py = make_products(compiler, &lambda, py_s); + let prod_px_px = make_products(compiler, px_s, px_s); + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + + let max_coeff_eq1: u64 = 2 + 3 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_py, FieldElement::from(2u64)), + (&prod_px_px, -FieldElement::from(3u64)), + ], + &[(&a_limbs, -FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: lambda² - x3 - 2*px = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + let max_coeff_eq2: u64 = 1 + 1 + 2 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[(&x3, -FieldElement::ONE), (px_s, -FieldElement::from(2u64))], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + + // Eq3: lambda*px - lambda*x3 - y3 - py = q3*p + let prod_lam_px = make_products(compiler, &lambda, px_s); + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_px, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], + &q3, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq3, + ); + + // Range checks on hint outputs + // max_coeff across eqs: Eq1 = 6+N, Eq2 = 4+N, Eq3 = 4+N → worst = 6+N + let max_coeff_carry = 6u64 + n as u64; + let carry_extra_bits = ((max_coeff_carry as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_range_bits = params.limb_bits + carry_extra_bits; + range_check_limbs_and_carries( + range_checks, + &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[&c1, &c2, &c3], + params.limb_bits, + carry_range_bits, + ); + + // Less-than-p checks for lambda, x3, y3 + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + let mut x3_limbs = Limbs::new(n); + let mut y3_limbs = Limbs::new(n); + for i in 0..n { + x3_limbs[i] = x3[i]; + y3_limbs[i] = y3[i]; + } + (x3_limbs, y3_limbs) +} + +/// Hint-verified point addition for non-native field (multi-limb). +/// +/// Same approach as `point_double_verified_non_native` but verifies: +/// Eq1: lambda*(x2-x1) = y2-y1 (mod p) +/// Eq2: lambda² = x3+x1+x2 (mod p) +/// Eq3: lambda*(x1-x3) = y3+y1 (mod p) +pub fn point_add_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + // Soundness check: column equations fit native field + { + let max_coeff_sum: u64 = 4 + n as u64; // all 3 eqs: 1+1+1+1+N + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let max_bits = 2 * params.limb_bits + extra_bits + 1; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "EC add column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", + params.limb_bits + ); + } + + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Add, + inputs: vec![ + x1.as_slice()[..n].to_vec(), + y1.as_slice()[..n].to_vec(), + x2.as_slice()[..n].to_vec(), + y2.as_slice()[..n].to_vec(), + ], + curve_a: [0; 4], // unused for add + curve_b: [0; 4], // unused for add + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1 = witness_range(os + 3 * n, n); + let c1 = witness_range(os + 4 * n, 2 * n - 2); + let q2 = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + let q3 = witness_range(os + 9 * n - 4, n); + let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + + let x1_s = &x1.as_slice()[..n]; + let y1_s = &y1.as_slice()[..n]; + let x2_s = &x2.as_slice()[..n]; + let y2_s = &y2.as_slice()[..n]; + + // Eq1: lambda*x2 - lambda*x1 - y2 + y1 = q1*p + let prod_lam_x2 = make_products(compiler, &lambda, x2_s); + let prod_lam_x1 = make_products(compiler, &lambda, x1_s); + + let max_coeff: u64 = 1 + 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x2, FieldElement::ONE), + (&prod_lam_x1, -FieldElement::ONE), + ], + &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq2: lambda² - x3 - x1 - x2 = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[ + (&x3, -FieldElement::ONE), + (x1_s, -FieldElement::ONE), + (x2_s, -FieldElement::ONE), + ], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq3: lambda*x1 - lambda*x3 - y3 - y1 = q3*p + // Reuse prod_lam_x1 from Eq1 + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x1, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], + &q3, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Range checks + // max_coeff across all 3 eqs = 4+N + let max_coeff_carry = 4u64 + n as u64; + let carry_extra_bits = ((max_coeff_carry as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_range_bits = params.limb_bits + carry_extra_bits; + range_check_limbs_and_carries( + range_checks, + &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[&c1, &c2, &c3], + params.limb_bits, + carry_range_bits, + ); + + // Less-than-p checks + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + let mut x3_limbs = Limbs::new(n); + let mut y3_limbs = Limbs::new(n); + for i in 0..n { + x3_limbs[i] = x3[i]; + y3_limbs[i] = y3[i]; + } + (x3_limbs, y3_limbs) +} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs index 840f8081a..ef19dd31c 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -278,6 +278,68 @@ pub fn add_mod_p_multi( r } +/// Negate a multi-limb value: computes `p - y` directly via borrow chain. +/// +/// Since inputs are already verified `y ∈ [0, p)` (from less_than_p on +/// prior operations), the result `p - y` is in `(0, p]`. For `y > 0`, +/// the result is in `(0, p)` — canonical. For `y = 0`, the result is `p ≡ 0`, +/// which has valid limbs (each < 2^limb_bits) and is correct modulo p. +/// +/// This avoids the generic `sub_mod_p_multi` pathway which allocates +/// N zero-constant witnesses, a borrow quotient, and a less_than_p check. +/// +/// Witnesses: 3N (N v-sums + N borrows + N result sums). +/// Range checks: N at limb_bits. +pub fn negate_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + y: Limbs, + p_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, +) -> Limbs { + let n = y.len(); + assert!(n >= 2, "negate_mod_p_multi requires n >= 2, got n={n}"); + let w1 = compiler.witness_one(); + + let mut r = Limbs::new(n); + let mut borrow_prev: Option = None; + + for i in 0..n { + // v[i] = p[i] + 2^W - y[i] + borrow_{i-1} + // The 2^W offset ensures v[i] >= 0 (since p[i] >= 0, y[i] < 2^W). + // When borrow_prev exists, combine w1 terms to avoid duplicate column + // indices in the R1CS sparse matrix. + let w1_coeff = if borrow_prev.is_some() { + p_limbs[i] + two_pow_w - FieldElement::ONE + } else { + p_limbs[i] + two_pow_w + }; + let mut terms = vec![ + SumTerm(Some(w1_coeff), w1), + SumTerm(Some(-FieldElement::ONE), y[i]), + ]; + if let Some(borrow) = borrow_prev { + terms.push(SumTerm(None, borrow)); + } + let v = compiler.add_sum(terms); + + // borrow[i] = floor(v[i] / 2^W) + let borrow = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(borrow, v, two_pow_w)); + + // r[i] = v[i] - borrow[i] * 2^W + r[i] = compiler.add_sum(vec![SumTerm(None, v), SumTerm(Some(-two_pow_w), borrow)]); + + // Range check r[i] — ensures borrow is uniquely determined + range_checks.entry(limb_bits).or_default().push(r[i]); + + borrow_prev = Some(borrow); + } + + r +} + /// (a - b) mod p for multi-limb values. pub fn sub_mod_p_multi( compiler: &mut NoirToR1CSCompiler, @@ -563,7 +625,7 @@ pub fn inv_mod_p_multi( /// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. /// Uses borrow propagation: d\[i\] = (p-1)\[i\] - r\[i\] + borrow_in - /// borrow_out * 2^W -fn less_than_p_check_multi( +pub fn less_than_p_check_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, r: Limbs, diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index a3d6707df..52bd3d934 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -26,6 +26,11 @@ pub struct MultiLimbParams { pub two_pow_w: FieldElement, pub modulus_raw: [u64; 4], pub curve_a_limbs: Vec, + /// Raw curve_a value as [u64; 4] for hint-verified EC ops + pub curve_a_raw: [u64; 4], + pub curve_b_limbs: Vec, + /// Raw curve_b value as [u64; 4] for hint-verified on-curve checks + pub curve_b_raw: [u64; 4], /// p = native field → skip mod reduction pub is_native: bool, /// For N=1 non-native: the modulus as a single FieldElement @@ -54,6 +59,9 @@ impl MultiLimbParams { two_pow_w, modulus_raw: curve.field_modulus_p, curve_a_limbs: curve.curve_a_limbs(limb_bits, num_limbs), + curve_a_raw: curve.curve_a, + curve_b_limbs: curve.curve_b_limbs(limb_bits, num_limbs), + curve_b_raw: curve.curve_b, is_native, modulus_fe, } @@ -79,6 +87,9 @@ impl MultiLimbParams { two_pow_w, modulus_raw: curve.curve_order_n, curve_a_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + curve_a_raw: [0u64; 4], // unused for scalar relation + curve_b_limbs: vec![FieldElement::from(0u64); num_limbs], // unused + curve_b_raw: [0u64; 4], // unused for scalar relation is_native: false, /* always non-native for * scalar relation */ modulus_fe, @@ -106,11 +117,26 @@ impl MultiLimbOps<'_, '_> { self.params.num_limbs } - /// Negate a multi-limb value: computes `0 - value (mod p)`. + /// Negate a multi-limb value: computes `p - value (mod p)`. + /// + /// For multi-limb non-native (N≥2), uses a dedicated borrow chain that + /// skips zero-constant allocation, the borrow-quotient hint, and the + /// less_than_p check — saving (4N+1) witnesses per call. pub fn negate(&mut self, value: Limbs) -> Limbs { - let zero_vals = vec![FieldElement::from(0u64); self.params.num_limbs]; - let zero = self.constant_limbs(&zero_vals); - self.sub(zero, value) + if self.params.num_limbs >= 2 && !self.params.is_native { + multi_limb_arith::negate_mod_p_multi( + self.compiler, + self.range_checks, + value, + &self.params.p_limbs, + self.params.two_pow_w, + self.params.limb_bits, + ) + } else { + let zero_vals = vec![FieldElement::from(0u64); self.params.num_limbs]; + let zero = self.constant_limbs(&zero_vals); + self.sub(zero, value) + } } pub fn add(&mut self, a: Limbs, b: Limbs) -> Limbs { diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index a506fabb7..05254b3d8 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -106,7 +106,23 @@ pub(super) fn process_multi_point_non_native<'a>( range_checks, ); - // On-curve checks, FakeGLV, bit decomposition, y-negation (via MultiLimbOps) + // On-curve checks: use hint-verified for multi-limb, generic for single-limb + if num_limbs >= 2 { + ec_points::verify_on_curve_non_native(compiler, range_checks, px, py, ¶ms); + ec_points::verify_on_curve_non_native(compiler, range_checks, rx, ry, ¶ms); + } else { + let mut ops = MultiLimbOps { + compiler, + range_checks, + params: ¶ms, + }; + verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); + verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); + compiler = ops.compiler; + range_checks = ops.range_checks; + } + + // FakeGLV, bit decomposition, y-negation (via MultiLimbOps) let (s1_witness, s2_witness, neg1_witness, neg2_witness); let (py_effective, ry_effective, s1_bits, s2_bits, s1_skew, s2_skew); { @@ -116,10 +132,6 @@ pub(super) fn process_multi_point_non_native<'a>( params: ¶ms, }; - // On-curve checks for P and R - verify_on_curve(&mut ops, px, py, &b_limb_values, num_limbs); - verify_on_curve(&mut ops, rx, ry, &b_limb_values, num_limbs); - // FakeGLVHint → |s1|, |s2|, neg1, neg2 (s1_witness, s2_witness, neg1_witness, neg2_witness) = emit_fakeglv_hint(ops.compiler, san.s_lo, san.s_hi, curve); @@ -226,7 +238,7 @@ pub(super) fn process_multi_point_non_native<'a>( let mut acc_y = ops.constant_limbs(&offset_y_values); for &(rx, ry, is_skip) in &accum_inputs { - let (cand_x, cand_y) = ec_points::point_add(&mut ops, acc_x, acc_y, rx, ry); + let (cand_x, cand_y) = ec_points::point_add_dispatch(&mut ops, acc_x, acc_y, rx, ry); let (new_acc_x, new_acc_y) = ec_points::point_select_unchecked(&mut ops, is_skip, (cand_x, cand_y), (acc_x, acc_y)); acc_x = new_acc_x; @@ -253,7 +265,7 @@ pub(super) fn process_multi_point_non_native<'a>( ops.select(all_skipped, neg_off_y, neg_g_y) }; - let (result_x, result_y) = ec_points::point_add(&mut ops, acc_x, acc_y, sub_x, sub_y); + let (result_x, result_y) = ec_points::point_add_dispatch(&mut ops, acc_x, acc_y, sub_x, sub_y); compiler = ops.compiler; if num_limbs == 1 { diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs index 0126564ed..7cca0221c 100644 --- a/provekit/r1cs-compiler/src/msm/scalar_relation.rs +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -96,6 +96,10 @@ pub(super) fn verify_scalar_relation( /// When `limb_bits` divides 128 (e.g. 64), limb boundaries align with the /// s_lo/s_hi split. Otherwise (e.g. 85-bit limbs), one limb straddles bit 128 /// and is assembled from a partial s_lo digit and a partial s_hi digit. +/// +/// For small curves where `num_limbs * limb_bits < 256`, the digits beyond the +/// used limbs are constrained to zero. This ensures the scalar fits in the +/// representation and prevents truncation attacks. fn decompose_scalar_from_halves( ops: &mut MultiLimbOps, s_lo: usize, @@ -115,13 +119,26 @@ fn decompose_scalar_from_halves( limbs[i] = dd_lo.get_digit_witness_index(i, 0); ops.range_checks.entry(w as u32).or_default().push(limbs[i]); } - for (i, &w) in widths.iter().enumerate().take(num_limbs - from_lo) { + let from_hi = (num_limbs - from_lo).min(widths.len()); + for (i, &w) in widths.iter().enumerate().take(from_hi) { limbs[from_lo + i] = dd_hi.get_digit_witness_index(i, 0); ops.range_checks .entry(w as u32) .or_default() .push(limbs[from_lo + i]); } + + // Constrain unused dd_lo digits to zero (small curves where num_limbs + // covers fewer than 128 bits of s_lo). + for i in from_lo..widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + // Constrain unused dd_hi digits to zero (small curves where the upper + // half is partially or entirely unused). + for i in from_hi..widths.len() { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(i, 0)); + } + limbs } else { // Example: 85-bit limbs, 254-bit order → @@ -134,11 +151,12 @@ fn decompose_scalar_from_halves( let lo_widths = limb_widths(SCALAR_HALF_BITS, limb_bits); let hi_widths = vec![hi_head, hi_rest]; - let dd_lo = add_digital_decomposition(ops.compiler, lo_widths, vec![s_lo]); + let dd_lo = add_digital_decomposition(ops.compiler, lo_widths.clone(), vec![s_lo]); let dd_hi = add_digital_decomposition(ops.compiler, hi_widths, vec![s_hi]); let mut limbs = Limbs::new(num_limbs); - for i in 0..lo_full { + let lo_used = lo_full.min(num_limbs); + for i in 0..lo_used { limbs[i] = dd_lo.get_digit_witness_index(i, 0); ops.range_checks .entry(limb_bits) @@ -146,24 +164,29 @@ fn decompose_scalar_from_halves( .push(limbs[i]); } - // Cross-boundary limb: lo_tail bits from s_lo + hi_head bits from s_hi - let shift = FieldElement::from(2u64).pow([lo_tail as u64]); - let lo_digit = dd_lo.get_digit_witness_index(lo_full, 0); - let hi_digit = dd_hi.get_digit_witness_index(0, 0); - limbs[lo_full] = ops.compiler.add_sum(vec![ - SumTerm(None, lo_digit), - SumTerm(Some(shift), hi_digit), - ]); - ops.range_checks - .entry(lo_tail as u32) - .or_default() - .push(lo_digit); - ops.range_checks - .entry(hi_head as u32) - .or_default() - .push(hi_digit); - - if hi_rest > 0 { + // Cross-boundary limb and hi_rest, only if num_limbs needs them. + let needs_cross = num_limbs > lo_full; + let needs_hi_rest = num_limbs > lo_full + 1 && hi_rest > 0; + + if needs_cross { + let shift = FieldElement::from(2u64).pow([lo_tail as u64]); + let lo_digit = dd_lo.get_digit_witness_index(lo_full, 0); + let hi_digit = dd_hi.get_digit_witness_index(0, 0); + limbs[lo_full] = ops.compiler.add_sum(vec![ + SumTerm(None, lo_digit), + SumTerm(Some(shift), hi_digit), + ]); + ops.range_checks + .entry(lo_tail as u32) + .or_default() + .push(lo_digit); + ops.range_checks + .entry(hi_head as u32) + .or_default() + .push(hi_digit); + } + + if needs_hi_rest { limbs[lo_full + 1] = dd_hi.get_digit_witness_index(1, 0); ops.range_checks .entry(hi_rest as u32) @@ -171,6 +194,22 @@ fn decompose_scalar_from_halves( .push(limbs[lo_full + 1]); } + // Constrain unused digits to zero for small curves. + // dd_lo: digits beyond lo_used that aren't part of the cross-boundary. + if !needs_cross { + // The tail digit of dd_lo and all dd_hi digits are unused. + for i in lo_used..lo_widths.len() { + constrain_zero(ops.compiler, dd_lo.get_digit_witness_index(i, 0)); + } + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(0, 0)); + if hi_rest > 0 { + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + } else if !needs_hi_rest && hi_rest > 0 { + // Cross-boundary is used but hi_rest digit is unused. + constrain_zero(ops.compiler, dd_hi.get_digit_witness_index(1, 0)); + } + limbs } } From 8f57e57cd99b2c01d99f1afef80c08be2da041a3 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 13 Mar 2026 07:13:20 +0530 Subject: [PATCH 18/19] refactor : modules and comments --- .../common/src/witness/witness_builder.rs | 15 +- provekit/prover/src/bigint_mod.rs | 10 +- .../r1cs-compiler/src/constraint_helpers.rs | 2 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 425 +++--- provekit/r1cs-compiler/src/msm/ec_points.rs | 1167 ----------------- .../src/msm/ec_points/generic.rs | 104 ++ .../src/msm/ec_points/hints_native.rs | 161 +++ .../src/msm/ec_points/hints_non_native.rs | 596 +++++++++ .../r1cs-compiler/src/msm/ec_points/mod.rs | 51 + .../r1cs-compiler/src/msm/ec_points/tables.rs | 259 ++++ provekit/r1cs-compiler/src/msm/mod.rs | 301 +---- .../r1cs-compiler/src/msm/multi_limb_arith.rs | 170 +-- .../r1cs-compiler/src/msm/multi_limb_ops.rs | 5 +- provekit/r1cs-compiler/src/msm/native.rs | 33 +- provekit/r1cs-compiler/src/msm/non_native.rs | 36 +- provekit/r1cs-compiler/src/msm/sanitize.rs | 181 +++ provekit/r1cs-compiler/src/msm/tests.rs | 107 ++ 17 files changed, 1837 insertions(+), 1786 deletions(-) delete mode 100644 provekit/r1cs-compiler/src/msm/ec_points.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_points/generic.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_points/mod.rs create mode 100644 provekit/r1cs-compiler/src/msm/ec_points/tables.rs create mode 100644 provekit/r1cs-compiler/src/msm/sanitize.rs create mode 100644 provekit/r1cs-compiler/src/msm/tests.rs diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 0e03090a7..7a5fe1e7a 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -57,12 +57,12 @@ pub struct CombinedTableEntryInverseData { /// Operation type for the unified non-native EC hint. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum NonNativeEcOp { - /// Point doubling: inputs = [[px_limbs], [py_limbs]], outputs 12N-6 + /// Point doubling: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 12N-6 Double, - /// Point addition: inputs = [[x1_limbs], [y1_limbs], [x2_limbs], - /// [y2_limbs]], outputs 12N-6 + /// Point addition: inputs = \[\[x1_limbs\], \[y1_limbs\], \[x2_limbs\], + /// \[y2_limbs\]\], outputs 12N-6 Add, - /// On-curve check: inputs = [[px_limbs], [py_limbs]], outputs 7N-4 + /// On-curve check: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 7N-4 OnCurve, } @@ -361,9 +361,10 @@ pub enum WitnessBuilder { /// Unified prover hint for non-native EC operations (multi-limb). /// /// `op` selects the operation: - /// - `Double`: inputs = [[px], [py]], outputs 12N-6 witnesses - /// - `Add`: inputs = [[x1], [y1], [x2], [y2]], outputs 12N-6 witnesses - /// - `OnCurve`: inputs = [[px], [py]], outputs 7N-4 witnesses + /// - `Double`: inputs = \[\[px\], \[py\]\], outputs 12N-6 witnesses + /// - `Add`: inputs = \[\[x1\], \[y1\], \[x2\], \[y2\]\], outputs 12N-6 + /// witnesses + /// - `OnCurve`: inputs = \[\[px\], \[py\]\], outputs 7N-4 witnesses NonNativeEcHint { output_start: usize, op: NonNativeEcOp, diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index 09f8c5308..f61b97f8c 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -297,8 +297,8 @@ pub fn to_i128_limbs(limbs: &[u128]) -> Vec { } /// Compute signed quotient q such that: -/// Σ lhs_products[i] * coeff_i - Σ rhs_products[j] * coeff_j - rhs_sub ≡ 0 -/// (mod p) Returns q as decomposed limbs, with negative q stored as -q in the +/// Σ lhs_products\[i\] * coeff_i - Σ rhs_products\[j\] * coeff_j - rhs_sub ≡ +/// 0 (mod p) Returns q as decomposed limbs, with negative q stored as -q in the /// native field. pub fn signed_quotient_wide( lhs_products: &[(&[u64; 4], &[u64; 4], u64)], @@ -846,12 +846,12 @@ pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; /// Compute unsigned-offset carries for a general merged column equation. /// /// Each `product_set` entry is (a_limbs, b_limbs, coefficient): -/// LHS_terms = Σ coeff * Σ_{i+j=k} a[i]*b[j] +/// LHS_terms = Σ coeff * Σ_{i+j=k} a\[i\]*b\[j\] /// /// Each `linear_set` entry is (limb_values, coefficient) for non-product terms: -/// LHS_terms += Σ coeff * val[k] (for k < val.len()) +/// LHS_terms += Σ coeff * val\[k\] (for k < val.len()) /// -/// The equation verified is: LHS = Σ p[i]*q[j] + carry_chain +/// The equation verified is: LHS = Σ p\[i\]*q\[j\] + carry_chain /// (no separate result — the "result" is encoded in the linear terms). pub fn compute_ec_verification_carries( product_sets: &[(&[u128], &[u128], i64)], diff --git a/provekit/r1cs-compiler/src/constraint_helpers.rs b/provekit/r1cs-compiler/src/constraint_helpers.rs index 2bd033bc2..9561fe7f3 100644 --- a/provekit/r1cs-compiler/src/constraint_helpers.rs +++ b/provekit/r1cs-compiler/src/constraint_helpers.rs @@ -49,7 +49,7 @@ pub(crate) fn select_witness( result } -/// Packs bit witnesses into a digit: `d = Σ bits[i] * 2^i`. +/// Packs bit witnesses into a digit: `d = Σ bits\[i\] * 2^i`. pub(crate) fn pack_bits_helper(compiler: &mut NoirToR1CSCompiler, bits: &[usize]) -> usize { let terms: Vec = bits .iter() diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index d06bc0eee..be924ecc5 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -10,10 +10,47 @@ use std::collections::BTreeMap; /// fit in the native field. const SCALAR_HALF_BITS: usize = 128; +/// Per-point overhead witnesses shared across all non-native paths. +/// - detect_skip: 2×is_zero(3W) + product(1W) + or(1W) = 8W +/// - sanitize: 4 select_witness = 4W +/// - ec_hint: EcScalarMulHint(2W) + 2W selects = 4W +/// - glv_hint: s1, s2, neg1, neg2 = 4W +const DETECT_SKIP_WIT: usize = 8; +const SANITIZE_WIT: usize = 4; +const EC_HINT_WIT: usize = 4; +const GLV_HINT_WIT: usize = 4; + fn ceil_div(a: usize, b: usize) -> usize { (a + b - 1) / b } +/// Table building ops: (doubles, adds) for constructing a signed-digit table. +/// Each table has `half_table_size` entries of odd multiples. +fn table_build_ops(half_table_size: usize) -> (usize, usize) { + if half_table_size >= 2 { + (1, half_table_size - 1) + } else { + (0, 0) + } +} + +/// Per-point overhead witnesses common to all non-native paths. +fn per_point_overhead(half_bits: usize, num_limbs: usize, sr_witnesses: usize) -> usize { + let scalar_bit_decomp = 2 * (half_bits + 1); + let point_decomp = if num_limbs > 1 { 4 * num_limbs } else { 0 }; + scalar_bit_decomp + + DETECT_SKIP_WIT + + SANITIZE_WIT + + EC_HINT_WIT + + GLV_HINT_WIT + + point_decomp + + sr_witnesses +} + +// --------------------------------------------------------------------------- +// Field op cost helpers (used by generic single-limb path + scalar relation) +// --------------------------------------------------------------------------- + /// Total witnesses produced by N-limb field operations. /// /// Per-op witness counts by configuration: @@ -76,6 +113,70 @@ fn add_field_op_range_checks( } } +// --------------------------------------------------------------------------- +// Hint-verified EC op cost model (non-native, num_limbs >= 2) +// --------------------------------------------------------------------------- + +/// Witness and range check costs for a single hint-verified EC operation. +struct HintVerifiedEcCost { + witnesses: usize, + rc_limb: usize, + rc_carry: usize, + carry_bits: u32, +} + +impl HintVerifiedEcCost { + /// point_double: (12N-6)W hint + 5N² products + N constants + 3×3N ltp + fn point_double(n: usize, limb_bits: u32) -> Self { + let wit = (12 * n - 6) + 5 * n * n + n + 9 * n; + Self { + witnesses: wit, + rc_limb: 6 * n + 6 * n, // 6N hint limbs + 3×2N ltp limbs + rc_carry: 3 * (2 * n - 2), // 3 equations × (2N-2) carries + carry_bits: hint_carry_bits(limb_bits, 6 + n as u64, n), + } + } + + /// point_add: (12N-6)W hint + 4N² products + 3×3N ltp + fn point_add(n: usize, limb_bits: u32) -> Self { + let wit = (12 * n - 6) + 4 * n * n + 9 * n; + Self { + witnesses: wit, + rc_limb: 6 * n + 6 * n, + rc_carry: 3 * (2 * n - 2), + carry_bits: hint_carry_bits(limb_bits, 4 + n as u64, n), + } + } + + /// on_curve (worst case, a != 0): (7N-4)W hint + 4N² products + 2N + /// constants + 3N ltp + fn on_curve(n: usize, limb_bits: u32) -> Self { + let wit = (7 * n - 4) + 4 * n * n + 2 * n + 3 * n; + Self { + witnesses: wit, + rc_limb: 3 * n + 2 * n, // 3N hint limbs + 2N ltp limbs + rc_carry: 2 * (2 * n - 2), // 2 equations × (2N-2) carries + carry_bits: hint_carry_bits(limb_bits, 5 + n as u64, n), + } + } + + /// Accumulate `count` of this op's range checks into `rc_map`. + fn add_range_checks(&self, count: usize, limb_bits: u32, rc_map: &mut BTreeMap) { + *rc_map.entry(limb_bits).or_default() += count * self.rc_limb; + *rc_map.entry(self.carry_bits).or_default() += count * self.rc_carry; + } +} + +/// Carry range check bits for hint-verified EC column equations. +fn hint_carry_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + limb_bits + extra_bits +} + +// --------------------------------------------------------------------------- +// Scalar relation cost +// --------------------------------------------------------------------------- + /// Witnesses and range checks for scalar relation verification. /// /// Verifies `(-1)^neg1*|s1| + (-1)^neg2*|s2|*s ≡ 0 (mod n)` using multi-limb @@ -122,6 +223,10 @@ fn scalar_relation_cost( (witnesses, rc_map) } +// --------------------------------------------------------------------------- +// MSM cost entry point +// --------------------------------------------------------------------------- + /// Total estimated witness cost for an MSM. /// /// Accounts for three categories of witnesses: @@ -147,11 +252,7 @@ pub fn calculate_msm_witness_cost( let half_table_size = 1usize << (w - 1); let num_windows = ceil_div(half_bits, w); - // Use hint-verified costs for multi-limb (n >= 2), generic field ops for - // single-limb. - let use_hint_verified = n >= 2; - - if use_hint_verified { + if n >= 2 { calculate_msm_witness_cost_hint_verified( native_field_bits, n_points, @@ -179,20 +280,11 @@ pub fn calculate_msm_witness_cost( } } +// --------------------------------------------------------------------------- +// Hint-verified (multi-limb) non-native cost +// --------------------------------------------------------------------------- + /// Hint-verified non-native MSM cost (num_limbs >= 2). -/// -/// EC ops use prover hints verified via schoolbook column equations: -/// - `point_double_verified_non_native`: (12N-6)W hint + 5N² products + N -/// constants -/// - `point_add_verified_non_native`: (12N-6)W hint + 4N² products -/// - `verify_on_curve_non_native`: (7N-4)W hint + 4N² products + 2N constants -/// (worst case) -/// -/// Each hint-verified op also produces: -/// - 3N less_than_p witnesses per double/add (3 calls × 3N per call = 9N) -/// - 1N less_than_p witnesses per on-curve (1 call × 3N = 3N) -/// - Range checks: hint limbs at limb_bits + carries at carry_range_bits + -/// less_than_p at limb_bits #[allow(clippy::too_many_arguments)] fn calculate_msm_witness_cost_hint_verified( native_field_bits: u32, @@ -205,83 +297,31 @@ fn calculate_msm_witness_cost_hint_verified( half_table_size: usize, num_windows: usize, ) -> usize { - // === Hint-verified EC op witness counts === - // point_double: (12N-6) hint + 5N² products + N pinned constants + 9N - // less_than_p - let double_hint = 12 * n - 6; - let double_products = 5 * n * n; - let double_constants = n; // a_limbs - let double_ltp = 3 * 3 * n; // 3 less_than_p calls × 3N each - let double_wit = double_hint + double_products + double_constants + double_ltp; - - // point_add: (12N-6) hint + 4N² products + 9N less_than_p - let add_hint = 12 * n - 6; - let add_products = 4 * n * n; - let add_ltp = 3 * 3 * n; - let add_wit = add_hint + add_products + add_ltp; - - // on_curve (worst case, a != 0): (7N-4) hint + 4N² products + 2N constants + 3N - // less_than_p - let oncurve_hint = 7 * n - 4; - let oncurve_products = 4 * n * n; // px_px + py_py + xsq_px + a_px - let oncurve_constants = 2 * n; // a_limbs + b_limbs - let oncurve_ltp = 3 * n; // 1 less_than_p call - let oncurve_wit = oncurve_hint + oncurve_products + oncurve_constants + oncurve_ltp; - - // === Shared costs (doublings, counted once) === - let shared_doubles = num_windows * w; - let shared_ec_wit = shared_doubles * double_wit; - // Offset point constant_limbs: 2N (shared, allocated once in Phase 2) - let shared_offset_constants = 2 * n; - - // === Per-point EC costs === - // Table building: 2 tables × (1 double + (half_table_size-1) adds) when size >= - // 2 - let (tbl_d, tbl_a) = if half_table_size >= 2 { - (1, half_table_size - 1) - } else { - (0, 0) - }; - let pp_table_ec = 2 * (tbl_d * double_wit + tbl_a * add_wit); - - // Main loop per-point: 2 adds per window - let pp_loop_ec = num_windows * 2 * add_wit; + let ec_double = HintVerifiedEcCost::point_double(n, limb_bits); + let ec_add = HintVerifiedEcCost::point_add(n, limb_bits); + let ec_oncurve = HintVerifiedEcCost::on_curve(n, limb_bits); + let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + let (tbl_d, tbl_a) = table_build_ops(half_table_size); - // Skew corrections: 2 branches × 1 add per point - let pp_skew_ec = 2 * add_wit; + // negate_mod_p_multi: 3N witnesses, N range checks (no less_than_p) + let negate_wit = 3 * n; - // On-curve checks (P and R): 2 calls - let pp_oncurve = 2 * oncurve_wit; + // --- Shared costs (one doubling chain for all points) --- + let shared_doubles = num_windows * w; + let shared_ec_wit = shared_doubles * ec_double.witnesses; + let shared_offset_constants = 2 * n; - // Y-negation via negate_mod_p_multi (borrow chain, no less_than_p): - // negate = 3N witnesses (N v-sums + N borrows + N r-sums) - // select = N witnesses - let negate_wit = 3 * n; + // --- Per-point EC witnesses --- + let pp_table_ec = 2 * (tbl_d * ec_double.witnesses + tbl_a * ec_add.witnesses); + let pp_loop_ec = num_windows * 2 * ec_add.witnesses; + let pp_skew_ec = 2 * ec_add.witnesses; + let pp_oncurve = 2 * ec_oncurve.witnesses; let pp_y_negate = 2 * (negate_wit + n); // 2 × (negate + select) - - // Signed table lookup per window: negate + select_unchecked on y - // negate(y): 3N, select_unchecked(y): N let pp_signed_lookup_negate = num_windows * 2 * (negate_wit + n); - - // Skew correction negate: 2 × negate(py) via negate_mod_p_multi let pp_skew_negate = 2 * negate_wit; - // Skew correction selects: 2 branches × 2N (x+y select_unchecked) - let pp_skew_selects = 2 * 2 * n; - - // Selects for signed table lookup (not field ops) - let table_selects = num_windows * 2 * (half_table_size.saturating_sub(1)) * 2 * n; - let xor_cost = num_windows * 2 * 2 * w.saturating_sub(1); - - let pp_selects = table_selects + xor_cost; - - // === Per-point overhead (non-EC) === - let scalar_bit_decomp = 2 * (half_bits + 1); - let detect_skip = 8; - let sanitize = 4; - let ec_hint = 4; // EcScalarMulHint (2W) + 2W selects - let point_decomp = 4 * n; // 4 witnesses per coord (N>1 always here) - let glv_hint = 4; - let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); + let pp_skew_selects = 2 * 2 * n; // 2 branches × 2N + let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; + let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); let per_point = pp_table_ec + pp_loop_ec @@ -291,79 +331,55 @@ fn calculate_msm_witness_cost_hint_verified( + pp_signed_lookup_negate + pp_skew_negate + pp_skew_selects - + pp_selects - + scalar_bit_decomp - + detect_skip - + sanitize - + ec_hint - + point_decomp - + glv_hint - + sr_witnesses; + + pp_table_selects + + pp_xor + + per_point_overhead(half_bits, n, sr_witnesses); - // === Shared constants === + // --- Shared constants --- let shared_constants = 3 + shared_offset_constants; // gen_x, gen_y, zero + offset - // === Point accumulation === - // Accumulation adds use hint-verified point_add_dispatch - let accum_add_wit = add_wit; - let accum = n_points * (accum_add_wit + 2 * n) // per-point add + skip select - + n_points.saturating_sub(1) // all_skipped products - + accum_add_wit + 4 * n + 2 * n // offset subtraction + constants + selects - + 2 + 2; // mask + recompose (n > 1 always) + // --- Point accumulation --- + let accum = n_points * (ec_add.witnesses + 2 * n) // per-point add + skip select + + n_points.saturating_sub(1) // all_skipped products + + ec_add.witnesses + 4 * n + 2 * n // offset sub + constants + selects + + 2 + 2; // mask + recompose - // === Range check resolution === + // --- Range checks --- let mut rc_map: BTreeMap = BTreeMap::new(); - // Hint-verified EC range checks - let double_carry_bits = hint_verified_carry_range_bits(limb_bits, 6 + n as u64, n); - let add_carry_bits = hint_verified_carry_range_bits(limb_bits, 4 + n as u64, n); - let oncurve_carry_bits = hint_verified_carry_range_bits(limb_bits, 5 + n as u64, n); + // Shared doublings + ec_double.add_range_checks(shared_doubles, limb_bits, &mut rc_map); - // Shared doublings range checks - // Per double: 6N limb checks + 3*(2N-2) carry checks + 3 × 2N less_than_p limb - // checks - *rc_map.entry(limb_bits).or_default() += shared_doubles * (6 * n + 3 * 2 * n); - *rc_map.entry(double_carry_bits).or_default() += shared_doubles * 3 * (2 * n - 2); - - // Per-point EC range checks + // Per-point: table doubles + table/loop/skew adds + on-curve let pp_doubles_count = 2 * tbl_d; - let pp_adds_count = 2 * tbl_a + num_windows * 2 + 2; // table + loop + skew - let pp_oncurve_count = 2; - - // Per double: 6N limb + 6N ltp limb + 3*(2N-2) carry - *rc_map.entry(limb_bits).or_default() += n_points * pp_doubles_count * (6 * n + 3 * 2 * n); - *rc_map.entry(double_carry_bits).or_default() += n_points * pp_doubles_count * 3 * (2 * n - 2); - - // Per add: 6N limb + 6N ltp limb + 3*(2N-2) carry - *rc_map.entry(limb_bits).or_default() += n_points * pp_adds_count * (6 * n + 3 * 2 * n); - *rc_map.entry(add_carry_bits).or_default() += n_points * pp_adds_count * 3 * (2 * n - 2); + let pp_adds_count = 2 * tbl_a + num_windows * 2 + 2; + ec_double.add_range_checks(n_points * pp_doubles_count, limb_bits, &mut rc_map); + ec_add.add_range_checks(n_points * pp_adds_count, limb_bits, &mut rc_map); + ec_oncurve.add_range_checks(n_points * 2, limb_bits, &mut rc_map); - // Per on-curve: 3N limb + 2N ltp limb + 2*(2N-2) carry - *rc_map.entry(limb_bits).or_default() += n_points * pp_oncurve_count * (3 * n + 2 * n); - *rc_map.entry(oncurve_carry_bits).or_default() += n_points * pp_oncurve_count * 2 * (2 * n - 2); + // Accumulation adds + ec_add.add_range_checks(n_points + 1, limb_bits, &mut rc_map); - // Accumulation adds range checks (hint-verified) - *rc_map.entry(limb_bits).or_default() += (n_points + 1) * (6 * n + 3 * 2 * n); - *rc_map.entry(add_carry_bits).or_default() += (n_points + 1) * 3 * (2 * n - 2); - - // Negate range checks via negate_mod_p_multi: N limb checks per negate - // (no less_than_p, so N instead of 2N per negate) - let negate_count_pp = 2 + num_windows * 2 + 2; // y-negate(2) + signed_lookup + skew + // Negate range checks: N limb checks per negate + let negate_count_pp = 2 + num_windows * 2 + 2; // y-negate + signed_lookup + skew *rc_map.entry(limb_bits).or_default() += n_points * negate_count_pp * n; // Point decomposition *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; // Scalar relation - for (bits, count) in &sr_range_checks { - *rc_map.entry(*bits).or_default() += n_points * count; + for (&bits, &count) in &sr_range_checks { + *rc_map.entry(bits).or_default() += n_points * count; } let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - shared_ec_wit + shared_constants + n_points * per_point + accum + range_check_cost } +// --------------------------------------------------------------------------- +// Generic (single-limb) non-native cost +// --------------------------------------------------------------------------- + /// Generic (single-limb) non-native MSM cost using MultiLimbOps field op /// chains. #[allow(clippy::too_many_arguments)] @@ -379,79 +395,48 @@ fn calculate_msm_witness_cost_generic( half_table_size: usize, num_windows: usize, ) -> usize { - // === GLV scalar mul field op counts === - // point_double: (5 add, 3 sub, 4 mul, 1 inv) + N constant witnesses (curve_a) + // point_double: (5 add, 3 sub, 4 mul, 1 inv) // point_add: (1 add, 5 sub, 3 mul, 1 inv) + let (tbl_d, tbl_a) = table_build_ops(half_table_size); + let shared_doubles = num_windows * w; - // --- Shared costs (counted once, NOT per-point) --- - let shared_add = num_windows * w * 5; - let shared_sub = num_windows * w * 3; - let shared_mul = num_windows * w * 4; - let shared_inv = num_windows * w; + // --- Shared doubling field ops --- + let shared_add = shared_doubles * 5; + let shared_sub = shared_doubles * 3; + let shared_mul = shared_doubles * 4; + let shared_inv = shared_doubles; - // --- Per-point costs --- - let (tbl_d, tbl_a) = if half_table_size >= 2 { - (1, half_table_size - 1) - } else { - (0, 0) - }; - let mut pp_add = 2 * (tbl_d * 5 + tbl_a * 1); - let mut pp_sub = 2 * (tbl_d * 3 + tbl_a * 5); - let mut pp_mul = 2 * (tbl_d * 4 + tbl_a * 3); - let mut pp_inv = 2 * (tbl_d + tbl_a); - - pp_add += num_windows * 2 * 1; - pp_sub += num_windows * (2 * 5 + 2); - pp_mul += num_windows * 2 * 3; - pp_inv += num_windows * 2; - - pp_add += 2 * 1; - pp_sub += 2 * (5 + 1); - pp_mul += 2 * 3; - pp_inv += 2; - - pp_mul += 8; // on-curve - pp_add += 4; - pp_sub += 2; // y-negation + // --- Per-point field ops: tables + loop adds + skew + on-curve + y-negate --- + let mut pp_add = 2 * (tbl_d * 5 + tbl_a) + num_windows * 2 + 2 + 4; + let mut pp_sub = 2 * (tbl_d * 3 + tbl_a * 5) + num_windows * (2 * 5 + 2) + 2 * 6 + 2; + let mut pp_mul = 2 * (tbl_d * 4 + tbl_a * 3) + num_windows * 2 * 3 + 2 * 3 + 8; + let mut pp_inv = 2 * (tbl_d + tbl_a) + num_windows * 2 + 2; let shared_field_ops = field_op_witnesses(shared_add, shared_sub, shared_mul, shared_inv, n, false); let pp_field_ops = field_op_witnesses(pp_add, pp_sub, pp_mul, pp_inv, n, false); - let shared_doubles = num_windows * w; let pp_doubles = 2 * tbl_d; let pp_negate_zeros = (4 + 2 * num_windows) * n; let shared_constants_glv = shared_doubles * n + 2 * n; let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros; - let table_selects = num_windows * 2 * (half_table_size.saturating_sub(1)) * 2 * n; - let xor_cost = num_windows * 2 * 2 * w.saturating_sub(1); - let signed_y_selects = num_windows * 2 * n; - let y_negate_selects = 2 * n; - let skew_selects = 2 * 2 * n; - let pp_selects = table_selects + xor_cost + signed_y_selects + y_negate_selects + skew_selects; + let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; + let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); + let pp_signed_y_selects = num_windows * 2 * n; + let pp_y_negate_selects = 2 * n; + let pp_skew_selects = 2 * 2 * n; + let pp_selects = + pp_table_selects + pp_xor + pp_signed_y_selects + pp_y_negate_selects + pp_skew_selects; - let scalar_bit_decomp = 2 * (half_bits + 1); - let detect_skip = 8; - let sanitize = 4; - let ec_hint = 4; - let point_decomp = if n > 1 { 4 * n } else { 0 }; - let glv_hint = 4; let (sr_witnesses, sr_range_checks) = scalar_relation_cost(native_field_bits, scalar_bits); - let per_point = pp_field_ops - + pp_constants - + pp_selects - + scalar_bit_decomp - + detect_skip - + sanitize - + ec_hint - + point_decomp - + glv_hint - + sr_witnesses; + let per_point = + pp_field_ops + pp_constants + pp_selects + per_point_overhead(half_bits, n, sr_witnesses); let shared_constants = 3 + 2 * n; + // --- Point accumulation --- let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); let accum = n_points * (pa_cost + 2 * n) + n_points.saturating_sub(1) @@ -461,6 +446,7 @@ fn calculate_msm_witness_cost_generic( + 2 + if n > 1 { 2 } else { 0 }; + // --- Range checks --- let mut rc_map: BTreeMap = BTreeMap::new(); add_field_op_range_checks( shared_add, @@ -487,14 +473,14 @@ fn calculate_msm_witness_cost_generic( if n > 1 { *rc_map.entry(limb_bits).or_default() += n_points * 4 * n; } - for (bits, count) in &sr_range_checks { - *rc_map.entry(*bits).or_default() += n_points * count; + for (&bits, &count) in &sr_range_checks { + *rc_map.entry(bits).or_default() += n_points * count; } add_field_op_range_checks( - (n_points + 1) * 1, + n_points + 1, (n_points + 1) * 5, (n_points + 1) * 3, - (n_points + 1) * 1, + n_points + 1, n, limb_bits, curve_modulus_bits, @@ -512,11 +498,9 @@ fn calculate_msm_witness_cost_generic( + range_check_cost } -/// Carry range check bits for hint-verified EC column equations. -fn hint_verified_carry_range_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { - let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - limb_bits + extra_bits -} +// --------------------------------------------------------------------------- +// Native-field cost +// --------------------------------------------------------------------------- /// Native-field MSM cost: hint-verified EC ops with signed-bit wNAF (w=1). /// @@ -534,50 +518,41 @@ fn calculate_msm_witness_cost_native( ) -> usize { let half_bits = (scalar_bits + 1) / 2; - // === Per-point fixed costs === let on_curve = 4; // 2 × verify_on_curve_native (2W each) - let glv_hint = 4; // s1, s2, neg1, neg2 - let scalar_bits_cost = 2 * (half_bits + 1); // 2 × (half_bits + skew) let y_negate = 6; // 2 × 3W (neg_y, y_eff, neg_y_eff) - let detect_skip = 8; // 2×is_zero(3W) + product(1W) + or(1W) - let sanitize = 4; // 4 select_witness - let ec_hint = 4; // 2W hint + 2W selects let (sr_wit, sr_rc) = scalar_relation_cost(native_field_bits, scalar_bits); - let per_point = on_curve - + glv_hint - + scalar_bits_cost - + y_negate - + detect_skip - + sanitize - + ec_hint + let per_point = on_curve + y_negate + + 2 * (half_bits + 1) // scalar bit decomposition + + DETECT_SKIP_WIT + SANITIZE_WIT + EC_HINT_WIT + GLV_HINT_WIT + sr_wit; - // === Shared constants === let shared_constants = 5; // gen_x, gen_y, zero, offset_x, offset_y - // === EC verification loop (merged, shared doubling) === // Per bit: 4W (shared double) + n_points × 8W (2×(1W select + 3W add)) let ec_loop = half_bits * (4 + 8 * n_points); // Skew correction: 2 branches × (3W add + 2W select) = 10W per point let skew = n_points * 10; - // === Point accumulation === - let accum = 2 // initial acc constants - + n_points * 5 // add(3W) + skip_select(2W) - + n_points.saturating_sub(1) // all_skipped products + let accum = 2 // initial acc constants + + n_points * 5 // add(3W) + skip_select(2W) + + n_points.saturating_sub(1) // all_skipped products + 10; // offset sub: 3 const + 2 sel + 3 add + 2 mask - // === Range checks (only from scalar relation for native) === + // Range checks (only from scalar relation for native) let mut rc_map: BTreeMap = BTreeMap::new(); - for (bits, count) in &sr_rc { - *rc_map.entry(*bits).or_default() += n_points * count; + for (&bits, &count) in &sr_rc { + *rc_map.entry(bits).or_default() += n_points * count; } let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); n_points * per_point + shared_constants + ec_loop + skew + accum + range_check_cost } +// --------------------------------------------------------------------------- +// Parameter search +// --------------------------------------------------------------------------- + /// Picks the widest limb size for scalar-relation multi-limb arithmetic that /// fits inside the native field without overflow. /// diff --git a/provekit/r1cs-compiler/src/msm/ec_points.rs b/provekit/r1cs-compiler/src/msm/ec_points.rs deleted file mode 100644 index 6cb2e28b8..000000000 --- a/provekit/r1cs-compiler/src/msm/ec_points.rs +++ /dev/null @@ -1,1167 +0,0 @@ -use { - super::{ - curve::CurveParams, - multi_limb_arith::less_than_p_check_multi, - multi_limb_ops::{MultiLimbOps, MultiLimbParams}, - Limbs, - }, - crate::noir_to_r1cs::NoirToR1CSCompiler, - ark_ff::{Field, PrimeField}, - provekit_common::{ - witness::{NonNativeEcOp, SumTerm, WitnessBuilder}, - FieldElement, - }, - std::collections::BTreeMap, -}; - -/// Dispatching point doubling: uses hint-verified for multi-limb non-native, -/// generic field-ops otherwise. -pub fn point_double_dispatch(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { - if ops.params.num_limbs >= 2 && !ops.params.is_native { - point_double_verified_non_native(ops.compiler, ops.range_checks, x1, y1, ops.params) - } else { - point_double(ops, x1, y1) - } -} - -/// Dispatching point addition: uses hint-verified for multi-limb non-native, -/// generic field-ops otherwise. -pub fn point_add_dispatch( - ops: &mut MultiLimbOps, - x1: Limbs, - y1: Limbs, - x2: Limbs, - y2: Limbs, -) -> (Limbs, Limbs) { - if ops.params.num_limbs >= 2 && !ops.params.is_native { - point_add_verified_non_native(ops.compiler, ops.range_checks, x1, y1, x2, y2, ops.params) - } else { - point_add(ops, x1, y1, x2, y2) - } -} - -/// Generic point doubling on y^2 = x^3 + ax + b. -/// -/// Given P = (x1, y1), computes 2P = (x3, y3): -/// lambda = (3 * x1^2 + a) / (2 * y1) -/// x3 = lambda^2 - 2 * x1 -/// y3 = lambda * (x1 - x3) - y1 -/// -/// Edge case — y1 = 0 (point of order 2): -/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. -/// The result should be the point at infinity (identity element). -/// This function does NOT handle that case — the constraint system will -/// be unsatisfiable if y1 = 0 (the inverse verification will fail to -/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using -/// compute_is_zero and conditionally select the point-at-infinity -/// result before calling this function. -pub fn point_double(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { - let a = ops.curve_a(); - - // Computing numerator = 3 * x1^2 + a - let x1_sq = ops.mul(x1, x1); - let two_x1_sq = ops.add(x1_sq, x1_sq); - let three_x1_sq = ops.add(two_x1_sq, x1_sq); - let numerator = ops.add(three_x1_sq, a); - - // Computing denominator = 2 * y1 - let denominator = ops.add(y1, y1); - - // Computing lambda = numerator * denominator^(-1) - let denom_inv = ops.inv(denominator); - let lambda = ops.mul(numerator, denom_inv); - - // Computing x3 = lambda^2 - 2 * x1 - let lambda_sq = ops.mul(lambda, lambda); - let two_x1 = ops.add(x1, x1); - let x3 = ops.sub(lambda_sq, two_x1); - - // Computing y3 = lambda * (x1 - x3) - y1 - let x1_minus_x3 = ops.sub(x1, x3); - let lambda_dx = ops.mul(lambda, x1_minus_x3); - let y3 = ops.sub(lambda_dx, y1); - - (x3, y3) -} - -/// Generic point addition on y^2 = x^3 + ax + b. -/// -/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): -/// lambda = (y2 - y1) / (x2 - x1) -/// x3 = lambda^2 - x1 - x2 -/// y3 = lambda * (x1 - x3) - y1 -/// -/// Edge cases — x1 = x2: -/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does -/// not exist. This covers two cases: -/// - P1 = P2 (same point): use `point_double` instead. -/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. -/// This function does NOT handle either case — the constraint system -/// will be unsatisfiable if x1 = x2. The caller must detect this -/// and branch accordingly. -pub fn point_add( - ops: &mut MultiLimbOps, - x1: Limbs, - y1: Limbs, - x2: Limbs, - y2: Limbs, -) -> (Limbs, Limbs) { - // Computing lambda = (y2 - y1) / (x2 - x1) - let numerator = ops.sub(y2, y1); - let denominator = ops.sub(x2, x1); - let denom_inv = ops.inv(denominator); - let lambda = ops.mul(numerator, denom_inv); - - // Computing x3 = lambda^2 - x1 - x2 - let lambda_sq = ops.mul(lambda, lambda); - let x1_plus_x2 = ops.add(x1, x2); - let x3 = ops.sub(lambda_sq, x1_plus_x2); - - // Computing y3 = lambda * (x1 - x3) - y1 - let x1_minus_x3 = ops.sub(x1, x3); - let lambda_dx = ops.mul(lambda, x1_minus_x3); - let y3 = ops.sub(lambda_dx, y1); - - (x3, y3) -} - -/// Conditional point select without boolean constraint on `flag`. -/// Caller must ensure `flag` is already constrained boolean. -pub fn point_select_unchecked( - ops: &mut MultiLimbOps, - flag: usize, - on_false: (Limbs, Limbs), - on_true: (Limbs, Limbs), -) -> (Limbs, Limbs) { - let x = ops.select_unchecked(flag, on_false.0, on_true.0); - let y = ops.select_unchecked(flag, on_false.1, on_true.1); - (x, y) -} - -/// Builds a signed point table of odd multiples for signed-digit windowed -/// scalar multiplication. -/// -/// T\[0\] = P, T\[1\] = 3P, T\[2\] = 5P, ..., T\[k-1\] = (2k-1)P -/// where k = `half_table_size` = 2^(w-1). -/// -/// Build cost: 1 point_double (for 2P) + (k-1) point_adds when k >= 2. -fn build_signed_point_table( - ops: &mut MultiLimbOps, - px: Limbs, - py: Limbs, - half_table_size: usize, -) -> Vec<(Limbs, Limbs)> { - assert!(half_table_size >= 1); - let mut table = Vec::with_capacity(half_table_size); - table.push((px, py)); // T[0] = 1*P - if half_table_size >= 2 { - let two_p = point_double_dispatch(ops, px, py); // 2P - for i in 1..half_table_size { - let prev = table[i - 1]; - table.push(point_add_dispatch(ops, prev.0, prev.1, two_p.0, two_p.1)); - } - } - table -} - -/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ -/// bits\[i\] * 2^i`. -/// -/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, -/// halving the candidate set at each level. Total: `(2^w - 1)` point selects -/// for a table of `2^w` entries. -/// -/// Each bit is constrained boolean exactly once, then all subsequent selects -/// on that bit use the unchecked variant. -fn table_lookup( - ops: &mut MultiLimbOps, - table: &[(Limbs, Limbs)], - bits: &[usize], -) -> (Limbs, Limbs) { - assert_eq!(table.len(), 1 << bits.len()); - let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); - // Process bits from MSB to LSB - for &bit in bits.iter().rev() { - ops.constrain_flag(bit); // constrain boolean once per bit - let half = current.len() / 2; - let mut next = Vec::with_capacity(half); - for i in 0..half { - next.push(point_select_unchecked( - ops, - bit, - current[i], - current[i + half], - )); - } - current = next; - } - current[0] -} - -/// Like `table_lookup`, but skips boolean constraints on bits. -/// -/// Use when bits are already known boolean (e.g. XOR'd bits derived from -/// boolean-constrained inputs in `signed_table_lookup`). -fn table_lookup_unchecked( - ops: &mut MultiLimbOps, - table: &[(Limbs, Limbs)], - bits: &[usize], -) -> (Limbs, Limbs) { - assert_eq!(table.len(), 1 << bits.len()); - let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); - for &bit in bits.iter().rev() { - let half = current.len() / 2; - let mut next = Vec::with_capacity(half); - for i in 0..half { - next.push(point_select_unchecked( - ops, - bit, - current[i], - current[i + half], - )); - } - current = next; - } - current[0] -} - -/// Signed-digit table lookup: selects from a half-size table using XOR'd -/// index bits, then conditionally negates y based on the sign bit. -/// -/// For a w-bit window with bits \[b_0, ..., b_{w-1}\] (LSB first): -/// - sign_bit = b_{w-1} (MSB): 1 = positive digit, 0 = negative digit -/// - index_bits = \[b_0, ..., b_{w-2}\] (lower w-1 bits) -/// - When positive: table index = lower bits as-is -/// - When negative: table index = bitwise complement of lower bits, and y is -/// negated -/// -/// The XOR'd bits are computed as: `idx_i = 1 - b_i - MSB + 2*b_i*MSB`, -/// which equals `b_i` when MSB=1, and `1-b_i` when MSB=0. -/// -/// # Precondition -/// `sign_bit` must be boolean-constrained by the caller. This function uses -/// it in `select_unchecked` without re-constraining. Currently satisfied: -/// `decompose_signed_bits` boolean-constrains all bits including the MSB -/// used as `sign_bit`. -fn signed_table_lookup( - ops: &mut MultiLimbOps, - table: &[(Limbs, Limbs)], - index_bits: &[usize], - sign_bit: usize, -) -> (Limbs, Limbs) { - let (x, y) = if index_bits.is_empty() { - // w=1: single entry, no lookup needed - assert_eq!(table.len(), 1); - table[0] - } else { - // Compute XOR'd index bits: idx_i = 1 - b_i - MSB + 2*b_i*MSB - let one_w = ops.compiler.witness_one(); - let two = FieldElement::from(2u64); - let xor_bits: Vec = index_bits - .iter() - .map(|&bit| { - let prod = ops.compiler.add_product(bit, sign_bit); - ops.compiler.add_sum(vec![ - SumTerm(Some(FieldElement::ONE), one_w), - SumTerm(Some(-FieldElement::ONE), bit), - SumTerm(Some(-FieldElement::ONE), sign_bit), - SumTerm(Some(two), prod), - ]) - }) - .collect(); - - // XOR'd bits are boolean by construction (product of two booleans - // combined linearly), so skip redundant boolean constraints. - table_lookup_unchecked(ops, table, &xor_bits) - }; - - // Conditionally negate y: sign_bit=0 (negative) → -y, sign_bit=1 (positive) → y - let neg_y = ops.negate(y); - let eff_y = ops.select_unchecked(sign_bit, neg_y, y); - // select_unchecked(flag, on_false, on_true): - // sign_bit=0 → on_false=neg_y (negative digit, negate y) ✓ - // sign_bit=1 → on_true=y (positive digit, keep y) ✓ - - (x, eff_y) -} - -/// Per-point data for merged multi-point GLV scalar multiplication. -pub struct MergedGlvPoint { - /// Point P x-coordinate (limbs) - pub px: Limbs, - /// Point P y-coordinate (effective, post-negation) - pub py: Limbs, - /// Signed-bit decomposition of |s1| (half-scalar for P), LSB first - pub s1_bits: Vec, - /// Skew correction witness for s1 branch (boolean) - pub s1_skew: usize, - /// Point R x-coordinate (limbs) - pub rx: Limbs, - /// Point R y-coordinate (effective, post-negation) - pub ry: Limbs, - /// Signed-bit decomposition of |s2| (half-scalar for R), LSB first - pub s2_bits: Vec, - /// Skew correction witness for s2 branch (boolean) - pub s2_skew: usize, -} - -/// Merged multi-point GLV scalar multiplication with shared doublings -/// and signed-digit windows. -/// -/// Uses signed-digit encoding: each w-bit window produces a signed odd digit -/// d ∈ {±1, ±3, ..., ±(2^w - 1)}, eliminating zero-digit handling. -/// Tables store odd multiples \[P, 3P, 5P, ..., (2^w-1)P\] with only -/// 2^(w-1) entries (half the unsigned table size). -/// -/// After the main loop, applies skew corrections: if skew=1, subtracts P -/// (or R) to account for the signed decomposition bias. -/// -/// Returns the final accumulator `(x, y)`. -pub fn scalar_mul_merged_glv( - ops: &mut MultiLimbOps, - points: &[MergedGlvPoint], - window_size: usize, - offset_x: Limbs, - offset_y: Limbs, -) -> (Limbs, Limbs) { - assert!(!points.is_empty()); - let n = points[0].s1_bits.len(); - let w = window_size; - let half_table_size = 1usize << (w - 1); - - // Build signed point tables (odd multiples) for all points upfront - let tables: Vec<(Vec<(Limbs, Limbs)>, Vec<(Limbs, Limbs)>)> = points - .iter() - .map(|pt| { - let tp = build_signed_point_table(ops, pt.px, pt.py, half_table_size); - let tr = build_signed_point_table(ops, pt.rx, pt.ry, half_table_size); - (tp, tr) - }) - .collect(); - - let num_windows = (n + w - 1) / w; - let mut acc = (offset_x, offset_y); - - // Process all windows from MSB down to LSB - for i in (0..num_windows).rev() { - let bit_start = i * w; - let bit_end = std::cmp::min(bit_start + w, n); - let actual_w = bit_end - bit_start; - - // w shared doublings on the accumulator (shared across ALL points) - let mut doubled_acc = acc; - for _ in 0..w { - doubled_acc = point_double_dispatch(ops, doubled_acc.0, doubled_acc.1); - } - - let mut cur = doubled_acc; - - // For each point: P branch + R branch (signed-digit lookup) - for (pt, (table_p, table_r)) in points.iter().zip(tables.iter()) { - // --- P branch (s1 window) --- - let s1_window_bits = &pt.s1_bits[bit_start..bit_end]; - let sign_bit_p = s1_window_bits[actual_w - 1]; // MSB - let index_bits_p = &s1_window_bits[..actual_w - 1]; // lower bits - let actual_table_p = if actual_w < w { - &table_p[..1 << (actual_w - 1)] - } else { - &table_p[..] - }; - let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); - // All signed digits are non-zero — no is_zero check needed - cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); - - // --- R branch (s2 window) --- - let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; - let sign_bit_r = s2_window_bits[actual_w - 1]; // MSB - let index_bits_r = &s2_window_bits[..actual_w - 1]; // lower bits - let actual_table_r = if actual_w < w { - &table_r[..1 << (actual_w - 1)] - } else { - &table_r[..] - }; - let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); - cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); - } - - acc = cur; - } - - // Skew corrections: subtract P (or R) if skew=1 for each point. - // The signed decomposition gives: scalar = Σ d_i * 2^i - skew, - // so the main loop computed (scalar + skew) * P. If skew=1, subtract P. - for pt in points { - // P branch skew - let neg_py = ops.negate(pt.py); - let (sub_px, sub_py) = point_add_dispatch(ops, acc.0, acc.1, pt.px, neg_py); - let new_x = ops.select_unchecked(pt.s1_skew, acc.0, sub_px); - let new_y = ops.select_unchecked(pt.s1_skew, acc.1, sub_py); - acc = (new_x, new_y); - - // R branch skew - let neg_ry = ops.negate(pt.ry); - let (sub_rx, sub_ry) = point_add_dispatch(ops, acc.0, acc.1, pt.rx, neg_ry); - let new_x = ops.select_unchecked(pt.s2_skew, acc.0, sub_rx); - let new_y = ops.select_unchecked(pt.s2_skew, acc.1, sub_ry); - acc = (new_x, new_y); - } - - acc -} - -// =========================================================================== -// Native-field hint-verified EC operations -// =========================================================================== -// These operate on single native field element witnesses (no multi-limb). -// Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw -// R1CS constraints, eliminating expensive field inversions from the circuit. - -/// Hint-verified point doubling for native field. -/// -/// Allocates EcDoubleHint → (lambda, x3, y3) = 3W. -/// Verification constraints (4C): -/// 1. x_sq = px * px (1C via add_product) -/// 2. lambda * 2*py = 3*x_sq + a (1C raw) -/// 3. lambda * lambda = x3 + 2*px (1C raw) -/// 4. lambda * (px - x3) = y3 + py (1C raw) -/// -/// Total: 4W + 4C (1W for x_sq via add_product, 3W from hint). -pub fn point_double_verified_native( - compiler: &mut NoirToR1CSCompiler, - px: usize, - py: usize, - curve: &CurveParams, -) -> (usize, usize) { - // Allocate hint witnesses - let hint_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { - output_start: hint_start, - px, - py, - curve_a: curve.curve_a, - field_modulus_p: curve.field_modulus_p, - }); - let lambda = hint_start; - let x3 = hint_start + 1; - let y3 = hint_start + 2; - - // x_sq = px * px (1W + 1C) - let x_sq = compiler.add_product(px, px); - - // Constraint: lambda * (2 * py) = 3 * x_sq + a - // A = [lambda], B = [2*py], C = [3*x_sq + a_const] - let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); - let three = FieldElement::from(3u64); - let two = FieldElement::from(2u64); - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, lambda)], &[(two, py)], &[ - (three, x_sq), - (a_fe, compiler.witness_one()), - ]); - - // Constraint: lambda^2 = x3 + 2*px - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, x3), (two, px)], - ); - - // Constraint: lambda * (px - x3) = y3 + py - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, px), (-FieldElement::ONE, x3)], - &[(FieldElement::ONE, y3), (FieldElement::ONE, py)], - ); - - (x3, y3) -} - -/// Hint-verified point addition for native field. -/// -/// Allocates EcAddHint → (lambda, x3, y3) = 3W. -/// Verification constraints (3C): -/// 1. lambda * (x2 - x1) = y2 - y1 (1C raw) -/// 2. lambda^2 = x3 + x1 + x2 (1C raw) -/// 3. lambda * (x1 - x3) = y3 + y1 (1C raw) -/// -/// Total: 3W + 3C. -pub fn point_add_verified_native( - compiler: &mut NoirToR1CSCompiler, - x1: usize, - y1: usize, - x2: usize, - y2: usize, - curve: &CurveParams, -) -> (usize, usize) { - // Allocate hint witnesses - let hint_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::EcAddHint { - output_start: hint_start, - x1, - y1, - x2, - y2, - field_modulus_p: curve.field_modulus_p, - }); - let lambda = hint_start; - let x3 = hint_start + 1; - let y3 = hint_start + 2; - - // Constraint: lambda * (x2 - x1) = y2 - y1 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, x2), (-FieldElement::ONE, x1)], - &[(FieldElement::ONE, y2), (-FieldElement::ONE, y1)], - ); - - // Constraint: lambda^2 = x3 + x1 + x2 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, lambda)], - &[ - (FieldElement::ONE, x3), - (FieldElement::ONE, x1), - (FieldElement::ONE, x2), - ], - ); - - // Constraint: lambda * (x1 - x3) = y3 + y1 - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, lambda)], - &[(FieldElement::ONE, x1), (-FieldElement::ONE, x3)], - &[(FieldElement::ONE, y3), (FieldElement::ONE, y1)], - ); - - (x3, y3) -} - -/// On-curve check for native field: y^2 = x^3 + a*x + b. -/// -/// Constraints (3C, 2W): -/// 1. x_sq = x * x (1C via add_product) -/// 2. x_cu = x_sq * x (1C via add_product) -/// 3. y * y = x_cu + a*x + b (1C raw) -/// -/// Total: 2W + 3C. -pub fn verify_on_curve_native( - compiler: &mut NoirToR1CSCompiler, - x: usize, - y: usize, - curve: &CurveParams, -) { - let x_sq = compiler.add_product(x, x); - let x_cu = compiler.add_product(x_sq, x); - - let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); - let b_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_b)).unwrap(); - - // y * y = x_cu + a*x + b - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, y)], &[(FieldElement::ONE, y)], &[ - (FieldElement::ONE, x_cu), - (a_fe, x), - (b_fe, compiler.witness_one()), - ]); -} - -// =========================================================================== -// Non-native hint-verified EC operations (multi-limb schoolbook) -// =========================================================================== -// These replace the step-by-step MultiLimbOps chain with prover hints verified -// via schoolbook column equations. Each bilinear mod-p equation is checked by: -// 1. Pre-computing product witnesses a[i]*b[j] -// 2. Column equations: Σ coeff·prod[k] + linear[k] + carry_in + offset = Σ -// p[i]*q[j] + carry_out * W -// Since p is constant, p[i]*q[j] terms are linear in q (no product witness). - -/// Collect witness indices from `start..start+len`. -fn witness_range(start: usize, len: usize) -> Vec { - (start..start + len).collect() -} - -/// Allocate N×N product witnesses for `a[i]*b[j]`. -fn make_products(compiler: &mut NoirToR1CSCompiler, a: &[usize], b: &[usize]) -> Vec> { - let n = a.len(); - debug_assert_eq!(n, b.len()); - let mut prods = vec![vec![0usize; n]; n]; - for i in 0..n { - for j in 0..n { - prods[i][j] = compiler.add_product(a[i], b[j]); - } - } - prods -} - -/// Allocate pinned constant witnesses from pre-decomposed `FieldElement` limbs. -fn allocate_pinned_constant_limbs( - compiler: &mut NoirToR1CSCompiler, - limb_values: &[FieldElement], -) -> Vec { - limb_values - .iter() - .map(|&val| { - let w = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::Constant( - provekit_common::witness::ConstantTerm(w, val), - )); - compiler.r1cs.add_constraint( - &[(FieldElement::ONE, compiler.witness_one())], - &[(FieldElement::ONE, w)], - &[(val, compiler.witness_one())], - ); - w - }) - .collect() -} - -/// Range-check limb witnesses at `limb_bits` and carry witnesses at -/// `carry_range_bits`. -fn range_check_limbs_and_carries( - range_checks: &mut BTreeMap>, - limb_vecs: &[&[usize]], - carry_vecs: &[&[usize]], - limb_bits: u32, - carry_range_bits: u32, -) { - for limbs in limb_vecs { - for &w in *limbs { - range_checks.entry(limb_bits).or_default().push(w); - } - } - for carries in carry_vecs { - for &c in *carries { - range_checks.entry(carry_range_bits).or_default().push(c); - } - } -} - -/// Convert `Vec` to `Limbs` and do a less-than-p check. -fn less_than_p_check_vec( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - v: &[usize], - params: &MultiLimbParams, -) { - let n = v.len(); - let mut limbs = Limbs::new(n); - for i in 0..n { - limbs[i] = v[i]; - } - less_than_p_check_multi( - compiler, - range_checks, - limbs, - ¶ms.p_minus_1_limbs, - params.two_pow_w, - params.limb_bits, - ); -} - -/// Emit schoolbook column equations for a merged verification equation. -/// -/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k = q·p (mod p, as integers) -/// -/// `product_sets`: each (products_2d, coefficient) where products_2d[i][j] -/// is the witness index for a[i]*b[j]. -/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms -/// (limb_witnesses has N entries, zero-padded). -/// `q_witnesses`: quotient limbs (N entries). -/// `carry_witnesses`: unsigned-offset carry witnesses (2N-2 entries). -fn emit_schoolbook_column_equations( - compiler: &mut NoirToR1CSCompiler, - product_sets: &[(&[Vec], FieldElement)], // (products[i][j], coeff) - linear_limbs: &[(&[usize], FieldElement)], // (limb_witnesses, coeff) - q_witnesses: &[usize], - carry_witnesses: &[usize], - p_limbs: &[FieldElement], - n: usize, - limb_bits: u32, - max_coeff_sum: u64, -) { - let w1 = compiler.witness_one(); - let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); - - // Carry offset scaled for the merged equation's larger coefficients - let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - let carry_offset_bits = limb_bits + extra_bits; - let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); - let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); - let offset_w_minus_carry = offset_w - carry_offset_fe; - - let num_columns = 2 * n - 1; - - for k in 0..num_columns { - // LHS: Σ coeff * products[i][j] for i+j=k + carry_in + offset - let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); - - for &(products, coeff) in product_sets { - for i in 0..n { - let j_val = k as isize - i as isize; - if j_val >= 0 && (j_val as usize) < n { - lhs_terms.push((coeff, products[i][j_val as usize])); - } - } - } - - // Add linear terms (for k < N only, since linear_limbs are N-length) - for &(limbs, coeff) in linear_limbs { - if k < limbs.len() { - lhs_terms.push((coeff, limbs[k])); - } - } - - // Add carry_in and offset - if k > 0 { - lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); - lhs_terms.push((offset_w_minus_carry, w1)); - } else { - lhs_terms.push((offset_w, w1)); - } - - // RHS: Σ p[i]*q[j] for i+j=k + carry_out * W (or offset at last column) - let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); - for i in 0..n { - let j_val = k as isize - i as isize; - if j_val >= 0 && (j_val as usize) < n { - rhs_terms.push((p_limbs[i], q_witnesses[j_val as usize])); - } - } - - if k < num_columns - 1 { - rhs_terms.push((two_pow_w, carry_witnesses[k])); - } else { - // Last column: balance with offset_w (no outgoing carry) - rhs_terms.push((offset_w, w1)); - } - - compiler - .r1cs - .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); - } -} - -/// Hint-verified on-curve check for non-native field (multi-limb). -/// -/// Verifies y² = x³ + ax + b (mod p) via: -/// Eq1: x·x - x_sq = q1·p (x_sq correctness) -/// Eq2: y·y - x_sq·x - a·x - b = q2·p (on-curve) -/// -/// Total: (7N-4)W hint + (N² + 2N² [+ N²])products + 2×(2N-1) constraints -/// + 1 less-than-p check. -pub fn verify_on_curve_non_native( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - px: Limbs, - py: Limbs, - params: &MultiLimbParams, -) { - let n = params.num_limbs; - assert!(n >= 2, "hint-verified on-curve check requires n >= 2"); - - let a_is_zero = params.curve_a_raw.iter().all(|&v| v == 0); - - // Soundness check - { - // max terms in a column: px·px(1) + x_sq(1) + py·py(1) + x_sq·px(1) + [a·px(1)] - // + b(1) + pq(N) - let max_coeff_sum: u64 = if a_is_zero { - 4 + n as u64 - } else { - 5 + n as u64 - }; - let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - let max_bits = 2 * params.limb_bits + extra_bits + 1; - assert!( - max_bits < FieldElement::MODULUS_BIT_SIZE, - "On-curve column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", - params.limb_bits - ); - } - - // Allocate hint - let os = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { - output_start: os, - op: NonNativeEcOp::OnCurve, - inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], - curve_a: params.curve_a_raw, - curve_b: params.curve_b_raw, - field_modulus_p: params.modulus_raw, - limb_bits: params.limb_bits, - num_limbs: n as u32, - }); - - // Parse hint layout: [x_sq(N), q1(N), c1(2N-2), q2(N), c2(2N-2)] - let x_sq = witness_range(os, n); - let q1 = witness_range(os + n, n); - let c1 = witness_range(os + 2 * n, 2 * n - 2); - let q2 = witness_range(os + 4 * n - 2, n); - let c2 = witness_range(os + 5 * n - 2, 2 * n - 2); - - // Eq1: px·px - x_sq = q1·p - let prod_px_px = make_products(compiler, &px.as_slice()[..n], &px.as_slice()[..n]); - - let max_coeff_eq1: u64 = 1 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[(&prod_px_px, FieldElement::ONE)], - &[(&x_sq, -FieldElement::ONE)], - &q1, - &c1, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq1, - ); - - // Eq2: py·py - x_sq·px - a·px - b = q2·p - let prod_py_py = make_products(compiler, &py.as_slice()[..n], &py.as_slice()[..n]); - let prod_xsq_px = make_products(compiler, &x_sq, &px.as_slice()[..n]); - let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_b_limbs[..n]); - - if a_is_zero { - let max_coeff_eq2: u64 = 1 + 1 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_py_py, FieldElement::ONE), - (&prod_xsq_px, -FieldElement::ONE), - ], - &[(&b_limbs, -FieldElement::ONE)], - &q2, - &c2, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq2, - ); - } else { - let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); - let prod_a_px = make_products(compiler, &a_limbs, &px.as_slice()[..n]); - - let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_py_py, FieldElement::ONE), - (&prod_xsq_px, -FieldElement::ONE), - (&prod_a_px, -FieldElement::ONE), - ], - &[(&b_limbs, -FieldElement::ONE)], - &q2, - &c2, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq2, - ); - } - - // Range checks on hint outputs - let max_coeff = if a_is_zero { - 4 + n as u64 - } else { - 5 + n as u64 - }; - let carry_extra_bits = ((max_coeff as f64 * n as f64).log2().ceil() as u32) + 1; - let carry_range_bits = params.limb_bits + carry_extra_bits; - range_check_limbs_and_carries( - range_checks, - &[&x_sq, &q1, &q2], - &[&c1, &c2], - params.limb_bits, - carry_range_bits, - ); - - // Less-than-p check for x_sq - less_than_p_check_vec(compiler, range_checks, &x_sq, params); -} - -/// Hint-verified point doubling for non-native field (multi-limb). -/// -/// Allocates NonNativeEcDoubleHint → (lambda, x3, y3, q1, c1, q2, c2, q3, c3). -/// Verifies via schoolbook column equations on 3 EC verification equations. -/// Total: (12N-6)W hint + ~(4N²+N) products + 3×(2N-1) column constraints -/// + 3 less-than-p checks. -pub fn point_double_verified_non_native( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - px: Limbs, - py: Limbs, - params: &MultiLimbParams, -) -> (Limbs, Limbs) { - let n = params.num_limbs; - assert!(n >= 2, "hint-verified non-native requires n >= 2"); - - // Soundness check: merged column equations fit native field - { - let max_coeff_sum: u64 = 2 + 3 + 1 + n as u64; // λy(2) + xx(3) + a(1) + pq(N) - let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - let max_bits = 2 * params.limb_bits + extra_bits + 1; - assert!( - max_bits < FieldElement::MODULUS_BIT_SIZE, - "Merged EC column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", - params.limb_bits - ); - } - - // Allocate hint - let os = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { - output_start: os, - op: NonNativeEcOp::Double, - inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], - curve_a: params.curve_a_raw, - curve_b: [0; 4], // unused for double - field_modulus_p: params.modulus_raw, - limb_bits: params.limb_bits, - num_limbs: n as u32, - }); - - // Parse hint layout: [lambda(N), x3(N), y3(N), q1(N), c1(2N-2), q2(N), - // c2(2N-2), q3(N), c3(2N-2)] - let lambda = witness_range(os, n); - let x3 = witness_range(os + n, n); - let y3 = witness_range(os + 2 * n, n); - let q1 = witness_range(os + 3 * n, n); - let c1 = witness_range(os + 4 * n, 2 * n - 2); - let q2 = witness_range(os + 6 * n - 2, n); - let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); - let q3 = witness_range(os + 9 * n - 4, n); - let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); - - let px_s = &px.as_slice()[..n]; - let py_s = &py.as_slice()[..n]; - - // Eq1: 2*lambda*py - 3*px*px - a = q1*p - let prod_lam_py = make_products(compiler, &lambda, py_s); - let prod_px_px = make_products(compiler, px_s, px_s); - let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); - - let max_coeff_eq1: u64 = 2 + 3 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_lam_py, FieldElement::from(2u64)), - (&prod_px_px, -FieldElement::from(3u64)), - ], - &[(&a_limbs, -FieldElement::ONE)], - &q1, - &c1, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq1, - ); - - // Eq2: lambda² - x3 - 2*px = q2*p - let prod_lam_lam = make_products(compiler, &lambda, &lambda); - - let max_coeff_eq2: u64 = 1 + 1 + 2 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[(&prod_lam_lam, FieldElement::ONE)], - &[(&x3, -FieldElement::ONE), (px_s, -FieldElement::from(2u64))], - &q2, - &c2, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq2, - ); - - // Eq3: lambda*px - lambda*x3 - y3 - py = q3*p - let prod_lam_px = make_products(compiler, &lambda, px_s); - let prod_lam_x3 = make_products(compiler, &lambda, &x3); - - let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_lam_px, FieldElement::ONE), - (&prod_lam_x3, -FieldElement::ONE), - ], - &[(&y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], - &q3, - &c3, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff_eq3, - ); - - // Range checks on hint outputs - // max_coeff across eqs: Eq1 = 6+N, Eq2 = 4+N, Eq3 = 4+N → worst = 6+N - let max_coeff_carry = 6u64 + n as u64; - let carry_extra_bits = ((max_coeff_carry as f64 * n as f64).log2().ceil() as u32) + 1; - let carry_range_bits = params.limb_bits + carry_extra_bits; - range_check_limbs_and_carries( - range_checks, - &[&lambda, &x3, &y3, &q1, &q2, &q3], - &[&c1, &c2, &c3], - params.limb_bits, - carry_range_bits, - ); - - // Less-than-p checks for lambda, x3, y3 - less_than_p_check_vec(compiler, range_checks, &lambda, params); - less_than_p_check_vec(compiler, range_checks, &x3, params); - less_than_p_check_vec(compiler, range_checks, &y3, params); - - let mut x3_limbs = Limbs::new(n); - let mut y3_limbs = Limbs::new(n); - for i in 0..n { - x3_limbs[i] = x3[i]; - y3_limbs[i] = y3[i]; - } - (x3_limbs, y3_limbs) -} - -/// Hint-verified point addition for non-native field (multi-limb). -/// -/// Same approach as `point_double_verified_non_native` but verifies: -/// Eq1: lambda*(x2-x1) = y2-y1 (mod p) -/// Eq2: lambda² = x3+x1+x2 (mod p) -/// Eq3: lambda*(x1-x3) = y3+y1 (mod p) -pub fn point_add_verified_non_native( - compiler: &mut NoirToR1CSCompiler, - range_checks: &mut BTreeMap>, - x1: Limbs, - y1: Limbs, - x2: Limbs, - y2: Limbs, - params: &MultiLimbParams, -) -> (Limbs, Limbs) { - let n = params.num_limbs; - assert!(n >= 2, "hint-verified non-native requires n >= 2"); - - // Soundness check: column equations fit native field - { - let max_coeff_sum: u64 = 4 + n as u64; // all 3 eqs: 1+1+1+1+N - let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - let max_bits = 2 * params.limb_bits + extra_bits + 1; - assert!( - max_bits < FieldElement::MODULUS_BIT_SIZE, - "EC add column equation overflow: limb_bits={}, n={n}, needs {max_bits} bits", - params.limb_bits - ); - } - - let os = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { - output_start: os, - op: NonNativeEcOp::Add, - inputs: vec![ - x1.as_slice()[..n].to_vec(), - y1.as_slice()[..n].to_vec(), - x2.as_slice()[..n].to_vec(), - y2.as_slice()[..n].to_vec(), - ], - curve_a: [0; 4], // unused for add - curve_b: [0; 4], // unused for add - field_modulus_p: params.modulus_raw, - limb_bits: params.limb_bits, - num_limbs: n as u32, - }); - - let lambda = witness_range(os, n); - let x3 = witness_range(os + n, n); - let y3 = witness_range(os + 2 * n, n); - let q1 = witness_range(os + 3 * n, n); - let c1 = witness_range(os + 4 * n, 2 * n - 2); - let q2 = witness_range(os + 6 * n - 2, n); - let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); - let q3 = witness_range(os + 9 * n - 4, n); - let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); - - let x1_s = &x1.as_slice()[..n]; - let y1_s = &y1.as_slice()[..n]; - let x2_s = &x2.as_slice()[..n]; - let y2_s = &y2.as_slice()[..n]; - - // Eq1: lambda*x2 - lambda*x1 - y2 + y1 = q1*p - let prod_lam_x2 = make_products(compiler, &lambda, x2_s); - let prod_lam_x1 = make_products(compiler, &lambda, x1_s); - - let max_coeff: u64 = 1 + 1 + 1 + 1 + n as u64; - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_lam_x2, FieldElement::ONE), - (&prod_lam_x1, -FieldElement::ONE), - ], - &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], - &q1, - &c1, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff, - ); - - // Eq2: lambda² - x3 - x1 - x2 = q2*p - let prod_lam_lam = make_products(compiler, &lambda, &lambda); - - emit_schoolbook_column_equations( - compiler, - &[(&prod_lam_lam, FieldElement::ONE)], - &[ - (&x3, -FieldElement::ONE), - (x1_s, -FieldElement::ONE), - (x2_s, -FieldElement::ONE), - ], - &q2, - &c2, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff, - ); - - // Eq3: lambda*x1 - lambda*x3 - y3 - y1 = q3*p - // Reuse prod_lam_x1 from Eq1 - let prod_lam_x3 = make_products(compiler, &lambda, &x3); - - emit_schoolbook_column_equations( - compiler, - &[ - (&prod_lam_x1, FieldElement::ONE), - (&prod_lam_x3, -FieldElement::ONE), - ], - &[(&y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], - &q3, - &c3, - ¶ms.p_limbs, - n, - params.limb_bits, - max_coeff, - ); - - // Range checks - // max_coeff across all 3 eqs = 4+N - let max_coeff_carry = 4u64 + n as u64; - let carry_extra_bits = ((max_coeff_carry as f64 * n as f64).log2().ceil() as u32) + 1; - let carry_range_bits = params.limb_bits + carry_extra_bits; - range_check_limbs_and_carries( - range_checks, - &[&lambda, &x3, &y3, &q1, &q2, &q3], - &[&c1, &c2, &c3], - params.limb_bits, - carry_range_bits, - ); - - // Less-than-p checks - less_than_p_check_vec(compiler, range_checks, &lambda, params); - less_than_p_check_vec(compiler, range_checks, &x3, params); - less_than_p_check_vec(compiler, range_checks, &y3, params); - - let mut x3_limbs = Limbs::new(n); - let mut y3_limbs = Limbs::new(n); - for i in 0..n { - x3_limbs[i] = x3[i]; - y3_limbs[i] = y3[i]; - } - (x3_limbs, y3_limbs) -} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/generic.rs b/provekit/r1cs-compiler/src/msm/ec_points/generic.rs new file mode 100644 index 000000000..c0236fe8f --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/generic.rs @@ -0,0 +1,104 @@ +//! Generic point operations using `MultiLimbOps` field arithmetic. +//! +//! These work for any field (native or non-native) by going through the +//! `MultiLimbOps` abstraction layer. + +use crate::msm::{multi_limb_ops::MultiLimbOps, Limbs}; + +/// Generic point doubling on y^2 = x^3 + ax + b. +/// +/// Given P = (x1, y1), computes 2P = (x3, y3): +/// lambda = (3 * x1^2 + a) / (2 * y1) +/// x3 = lambda^2 - 2 * x1 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge case — y1 = 0 (point of order 2): +/// When y1 = 0, the denominator 2*y1 = 0 and the inverse does not exist. +/// The result should be the point at infinity (identity element). +/// This function does NOT handle that case — the constraint system will +/// be unsatisfiable if y1 = 0 (the inverse verification will fail to +/// verify 0 * inv = 1 mod p). The caller must check y1 = 0 using +/// compute_is_zero and conditionally select the point-at-infinity +/// result before calling this function. +pub fn point_double(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { + let a = ops.curve_a(); + + // Computing numerator = 3 * x1^2 + a + let x1_sq = ops.mul(x1, x1); + let two_x1_sq = ops.add(x1_sq, x1_sq); + let three_x1_sq = ops.add(two_x1_sq, x1_sq); + let numerator = ops.add(three_x1_sq, a); + + // Computing denominator = 2 * y1 + let denominator = ops.add(y1, y1); + + // Computing lambda = numerator * denominator^(-1) + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - 2 * x1 + let lambda_sq = ops.mul(lambda, lambda); + let two_x1 = ops.add(x1, x1); + let x3 = ops.sub(lambda_sq, two_x1); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Generic point addition on y^2 = x^3 + ax + b. +/// +/// Given P1 = (x1, y1) and P2 = (x2, y2), computes P1 + P2 = (x3, y3): +/// lambda = (y2 - y1) / (x2 - x1) +/// x3 = lambda^2 - x1 - x2 +/// y3 = lambda * (x1 - x3) - y1 +/// +/// Edge cases — x1 = x2: +/// When x1 = x2, the denominator (x2 - x1) = 0 and the inverse does +/// not exist. This covers two cases: +/// - P1 = P2 (same point): use `point_double` instead. +/// - P1 = -P2 (y1 = -y2): the result is the point at infinity. +/// This function does NOT handle either case — the constraint system +/// will be unsatisfiable if x1 = x2. The caller must detect this +/// and branch accordingly. +pub fn point_add( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { + // Computing lambda = (y2 - y1) / (x2 - x1) + let numerator = ops.sub(y2, y1); + let denominator = ops.sub(x2, x1); + let denom_inv = ops.inv(denominator); + let lambda = ops.mul(numerator, denom_inv); + + // Computing x3 = lambda^2 - x1 - x2 + let lambda_sq = ops.mul(lambda, lambda); + let x1_plus_x2 = ops.add(x1, x2); + let x3 = ops.sub(lambda_sq, x1_plus_x2); + + // Computing y3 = lambda * (x1 - x3) - y1 + let x1_minus_x3 = ops.sub(x1, x3); + let lambda_dx = ops.mul(lambda, x1_minus_x3); + let y3 = ops.sub(lambda_dx, y1); + + (x3, y3) +} + +/// Conditional point select without boolean constraint on `flag`. +/// Caller must ensure `flag` is already constrained boolean. +pub fn point_select_unchecked( + ops: &mut MultiLimbOps, + flag: usize, + on_false: (Limbs, Limbs), + on_true: (Limbs, Limbs), +) -> (Limbs, Limbs) { + let x = ops.select_unchecked(flag, on_false.0, on_true.0); + let y = ops.select_unchecked(flag, on_false.1, on_true.1); + (x, y) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs new file mode 100644 index 000000000..b7ddf20a6 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_native.rs @@ -0,0 +1,161 @@ +//! Native-field hint-verified EC operations. +//! +//! These operate on single native field element witnesses (no multi-limb). +//! Each EC op allocates a hint for (lambda, x3, y3) and verifies via raw +//! R1CS constraints, eliminating expensive field inversions from the circuit. + +use { + crate::{msm::curve::CurveParams, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::{Field, PrimeField}, + provekit_common::{witness::WitnessBuilder, FieldElement}, +}; + +/// Hint-verified point doubling for native field. +/// +/// Allocates EcDoubleHint → (lambda, x3, y3) = 3W hint. +/// Verification via 4 R1CS constraints: +/// 1. x_sq = px * px (1W+1C via add_product) +/// 2. lambda · 2·py = 3·x_sq + a (1C) +/// 3. lambda² = x3 + 2·px (1C) +/// 4. lambda · (px - x3) = y3 + py (1C) +/// +/// Total: 4W + 4C. +pub fn point_double_verified_native( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcDoubleHint { + output_start: hint_start, + px, + py, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // x_sq = px * px (1W + 1C) + let x_sq = compiler.add_product(px, px); + + // Constraint: lambda * (2 * py) = 3 * x_sq + a + // A = [lambda], B = [2*py], C = [3*x_sq + a_const] + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let three = FieldElement::from(3u64); + let two = FieldElement::from(2u64); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, lambda)], &[(two, py)], &[ + (three, x_sq), + (a_fe, compiler.witness_one()), + ]); + + // Constraint: lambda^2 = x3 + 2*px + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x3), (two, px)], + ); + + // Constraint: lambda * (px - x3) = y3 + py + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, px), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, py)], + ); + + (x3, y3) +} + +/// Hint-verified point addition for native field. +/// +/// Allocates EcAddHint → (lambda, x3, y3) = 3W. +/// Verification constraints (3C): +/// 1. lambda * (x2 - x1) = y2 - y1 (1C raw) +/// 2. lambda^2 = x3 + x1 + x2 (1C raw) +/// 3. lambda * (x1 - x3) = y3 + y1 (1C raw) +/// +/// Total: 3W + 3C. +pub fn point_add_verified_native( + compiler: &mut NoirToR1CSCompiler, + x1: usize, + y1: usize, + x2: usize, + y2: usize, + curve: &CurveParams, +) -> (usize, usize) { + // Allocate hint witnesses + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcAddHint { + output_start: hint_start, + x1, + y1, + x2, + y2, + field_modulus_p: curve.field_modulus_p, + }); + let lambda = hint_start; + let x3 = hint_start + 1; + let y3 = hint_start + 2; + + // Constraint: lambda * (x2 - x1) = y2 - y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x2), (-FieldElement::ONE, x1)], + &[(FieldElement::ONE, y2), (-FieldElement::ONE, y1)], + ); + + // Constraint: lambda^2 = x3 + x1 + x2 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, lambda)], + &[ + (FieldElement::ONE, x3), + (FieldElement::ONE, x1), + (FieldElement::ONE, x2), + ], + ); + + // Constraint: lambda * (x1 - x3) = y3 + y1 + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, lambda)], + &[(FieldElement::ONE, x1), (-FieldElement::ONE, x3)], + &[(FieldElement::ONE, y3), (FieldElement::ONE, y1)], + ); + + (x3, y3) +} + +/// On-curve check for native field: y² = x³ + a·x + b. +/// +/// Verification via 3 R1CS constraints: +/// 1. x_sq = x · x (1W+1C via add_product) +/// 2. x_cu = x_sq · x (1W+1C via add_product) +/// 3. y · y = x_cu + a·x + b (1C) +/// +/// Total: 2W + 3C. +pub fn verify_on_curve_native( + compiler: &mut NoirToR1CSCompiler, + x: usize, + y: usize, + curve: &CurveParams, +) { + let x_sq = compiler.add_product(x, x); + let x_cu = compiler.add_product(x_sq, x); + + let a_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_a)).unwrap(); + let b_fe = FieldElement::from_bigint(ark_ff::BigInt(curve.curve_b)).unwrap(); + + // y * y = x_cu + a*x + b + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, y)], &[(FieldElement::ONE, y)], &[ + (FieldElement::ONE, x_cu), + (a_fe, x), + (b_fe, compiler.witness_one()), + ]); +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs new file mode 100644 index 000000000..fdd6c1756 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs @@ -0,0 +1,596 @@ +//! Non-native hint-verified EC operations (multi-limb schoolbook). +//! +//! These replace the step-by-step MultiLimbOps chain with prover hints verified +//! via schoolbook column equations. Each bilinear mod-p equation is checked by: +//! 1. Pre-computing product witnesses a\[i\]*b\[j\] +//! 2. Column equations: Σ coeff·prod\[k\] + linear\[k\] + carry_in + offset = Σ +//! p\[i\]*q\[j\] + carry_out * W +//! Since p is constant, p\[i\]*q\[j\] terms are linear in q (no product +//! witness). + +use { + crate::{ + msm::{multi_limb_arith::less_than_p_check_multi, multi_limb_ops::MultiLimbParams, Limbs}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{Field, PrimeField}, + provekit_common::{ + witness::{NonNativeEcOp, WitnessBuilder}, + FieldElement, + }, + std::collections::BTreeMap, +}; + +/// Collect witness indices from `start..start+len`. +fn witness_range(start: usize, len: usize) -> Vec { + (start..start + len).collect() +} + +/// Allocate N×N product witnesses for `a\[i\]*b\[j\]`. +fn make_products(compiler: &mut NoirToR1CSCompiler, a: &[usize], b: &[usize]) -> Vec> { + let n = a.len(); + debug_assert_eq!(n, b.len()); + let mut prods = vec![vec![0usize; n]; n]; + for i in 0..n { + for j in 0..n { + prods[i][j] = compiler.add_product(a[i], b[j]); + } + } + prods +} + +/// Allocate pinned constant witnesses from pre-decomposed `FieldElement` limbs. +fn allocate_pinned_constant_limbs( + compiler: &mut NoirToR1CSCompiler, + limb_values: &[FieldElement], +) -> Vec { + limb_values + .iter() + .map(|&val| { + let w = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::Constant( + provekit_common::witness::ConstantTerm(w, val), + )); + compiler.r1cs.add_constraint( + &[(FieldElement::ONE, compiler.witness_one())], + &[(FieldElement::ONE, w)], + &[(val, compiler.witness_one())], + ); + w + }) + .collect() +} + +/// Range-check limb witnesses at `limb_bits` and carry witnesses at +/// `carry_range_bits`. +fn range_check_limbs_and_carries( + range_checks: &mut BTreeMap>, + limb_vecs: &[&[usize]], + carry_vecs: &[&[usize]], + limb_bits: u32, + carry_range_bits: u32, +) { + for limbs in limb_vecs { + for &w in *limbs { + range_checks.entry(limb_bits).or_default().push(w); + } + } + for carries in carry_vecs { + for &c in *carries { + range_checks.entry(carry_range_bits).or_default().push(c); + } + } +} + +/// Convert `Vec` to `Limbs` and do a less-than-p check. +fn less_than_p_check_vec( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + v: &[usize], + params: &MultiLimbParams, +) { + let n = v.len(); + let mut limbs = Limbs::new(n); + for i in 0..n { + limbs[i] = v[i]; + } + less_than_p_check_multi( + compiler, + range_checks, + limbs, + ¶ms.p_minus_1_limbs, + params.two_pow_w, + params.limb_bits, + ); +} + +/// Compute carry range bits for hint-verified column equations. +fn carry_range_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + limb_bits + extra_bits +} + +/// Soundness check: verify that merged column equations fit the native field. +fn check_column_equation_fits(limb_bits: u32, max_coeff_sum: u64, n: usize, op_name: &str) { + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let max_bits = 2 * limb_bits + extra_bits + 1; + assert!( + max_bits < FieldElement::MODULUS_BIT_SIZE, + "{op_name} column equation overflow: limb_bits={limb_bits}, n={n}, needs {max_bits} bits", + ); +} + +/// Emit schoolbook column equations for a merged verification equation. +/// +/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k = q·p (mod p, as integers) +/// +/// `product_sets`: each (products_2d, coefficient) where products_2d\[i\]\[j\] +/// is the witness index for a\[i\]*b\[j\]. +/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms +/// (limb_witnesses has N entries, zero-padded). +/// `q_witnesses`: quotient limbs (N entries). +/// `carry_witnesses`: unsigned-offset carry witnesses (2N-2 entries). +fn emit_schoolbook_column_equations( + compiler: &mut NoirToR1CSCompiler, + product_sets: &[(&[Vec], FieldElement)], // (products[i][j], coeff) + linear_limbs: &[(&[usize], FieldElement)], // (limb_witnesses, coeff) + q_witnesses: &[usize], + carry_witnesses: &[usize], + p_limbs: &[FieldElement], + n: usize, + limb_bits: u32, + max_coeff_sum: u64, +) { + let w1 = compiler.witness_one(); + let two_pow_w = FieldElement::from(2u64).pow([limb_bits as u64]); + + // Carry offset scaled for the merged equation's larger coefficients + let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; + let carry_offset_bits = limb_bits + extra_bits; + let carry_offset_fe = FieldElement::from(2u64).pow([carry_offset_bits as u64]); + let offset_w = FieldElement::from(2u64).pow([(carry_offset_bits + limb_bits) as u64]); + let offset_w_minus_carry = offset_w - carry_offset_fe; + + let num_columns = 2 * n - 1; + + for k in 0..num_columns { + // LHS: Σ coeff * products[i][j] for i+j=k + carry_in + offset + let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + + for &(products, coeff) in product_sets { + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((coeff, products[i][j_val as usize])); + } + } + } + + // Add linear terms (for k < N only, since linear_limbs are N-length) + for &(limbs, coeff) in linear_limbs { + if k < limbs.len() { + lhs_terms.push((coeff, limbs[k])); + } + } + + // Add carry_in and offset + if k > 0 { + lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); + lhs_terms.push((offset_w_minus_carry, w1)); + } else { + lhs_terms.push((offset_w, w1)); + } + + // RHS: Σ p[i]*q[j] for i+j=k + carry_out * W (or offset at last column) + let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + rhs_terms.push((p_limbs[i], q_witnesses[j_val as usize])); + } + } + + if k < num_columns - 1 { + rhs_terms.push((two_pow_w, carry_witnesses[k])); + } else { + // Last column: balance with offset_w (no outgoing carry) + rhs_terms.push((offset_w, w1)); + } + + compiler + .r1cs + .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + } +} + +/// Helper to convert witness Vec to Limbs. +fn vec_to_limbs(v: &[usize]) -> Limbs { + let n = v.len(); + let mut limbs = Limbs::new(n); + for i in 0..n { + limbs[i] = v[i]; + } + limbs +} + +/// Hint-verified on-curve check for non-native field (multi-limb). +/// +/// Verifies y² = x³ + ax + b (mod p) via two schoolbook column equations: +/// Eq1: x·x - x_sq = q1·p (x_sq correctness) +/// Eq2: y·y - x_sq·x - a·x - b = q2·p (on-curve) +/// +/// Total: (7N-4)W hint + 3N² products (or 2N² when a=0) + 2×(2N-1) column +/// constraints + 1 less-than-p check. +pub fn verify_on_curve_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified on-curve check requires n >= 2"); + + let a_is_zero = params.curve_a_raw.iter().all(|&v| v == 0); + + let max_coeff_sum: u64 = if a_is_zero { + 4 + n as u64 + } else { + 5 + n as u64 + }; + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "On-curve"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::OnCurve, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: params.curve_b_raw, + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [x_sq(N), q1(N), c1(2N-2), q2(N), c2(2N-2)] + let x_sq = witness_range(os, n); + let q1 = witness_range(os + n, n); + let c1 = witness_range(os + 2 * n, 2 * n - 2); + let q2 = witness_range(os + 4 * n - 2, n); + let c2 = witness_range(os + 5 * n - 2, 2 * n - 2); + + // Eq1: px·px - x_sq = q1·p + let prod_px_px = make_products(compiler, &px.as_slice()[..n], &px.as_slice()[..n]); + + let max_coeff_eq1: u64 = 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_px_px, FieldElement::ONE)], + &[(&x_sq, -FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: py·py - x_sq·px - a·px - b = q2·p + let prod_py_py = make_products(compiler, &py.as_slice()[..n], &py.as_slice()[..n]); + let prod_xsq_px = make_products(compiler, &x_sq, &px.as_slice()[..n]); + let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_b_limbs[..n]); + + if a_is_zero { + let max_coeff_eq2: u64 = 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } else { + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + let prod_a_px = make_products(compiler, &a_limbs, &px.as_slice()[..n]); + + let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_py_py, FieldElement::ONE), + (&prod_xsq_px, -FieldElement::ONE), + (&prod_a_px, -FieldElement::ONE), + ], + &[(&b_limbs, -FieldElement::ONE)], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + } + + // Range checks on hint outputs + let crb = carry_range_bits(params.limb_bits, max_coeff_sum, n); + range_check_limbs_and_carries( + range_checks, + &[&x_sq, &q1, &q2], + &[&c1, &c2], + params.limb_bits, + crb, + ); + + // Less-than-p check for x_sq + less_than_p_check_vec(compiler, range_checks, &x_sq, params); +} + +/// Hint-verified point doubling for non-native field (multi-limb). +/// +/// Allocates NonNativeEcDoubleHint → (lambda, x3, y3, q1, c1, q2, c2, q3, c3). +/// Verifies via schoolbook column equations on 3 EC equations: +/// Eq1: 2·lambda·py - 3·px² - a = q1·p (2N² products: lam·py, px·px) +/// Eq2: lambda² - x3 - 2·px = q2·p (1N² products: lam·lam) +/// Eq3: lambda·(px - x3) - y3 - py = q3·p (2N² products: lam·px, lam·x3) +/// +/// Total: (12N-6)W hint + 5N²+N products + 3×(2N-1) column constraints +/// + 3 less-than-p checks. +pub fn point_double_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + px: Limbs, + py: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff_sum: u64 = 2 + 3 + 1 + n as u64; // λy(2) + xx(3) + a(1) + pq(N) + check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "Merged EC double"); + + // Allocate hint + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Double, + inputs: vec![px.as_slice()[..n].to_vec(), py.as_slice()[..n].to_vec()], + curve_a: params.curve_a_raw, + curve_b: [0; 4], // unused for double + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + // Parse hint layout: [lambda(N), x3(N), y3(N), q1(N), c1(2N-2), q2(N), + // c2(2N-2), q3(N), c3(2N-2)] + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1 = witness_range(os + 3 * n, n); + let c1 = witness_range(os + 4 * n, 2 * n - 2); + let q2 = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + let q3 = witness_range(os + 9 * n - 4, n); + let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + + let px_s = &px.as_slice()[..n]; + let py_s = &py.as_slice()[..n]; + + // Eq1: 2*lambda*py - 3*px*px - a = q1*p + let prod_lam_py = make_products(compiler, &lambda, py_s); + let prod_px_px = make_products(compiler, px_s, px_s); + let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); + + let max_coeff_eq1: u64 = 2 + 3 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_py, FieldElement::from(2u64)), + (&prod_px_px, -FieldElement::from(3u64)), + ], + &[(&a_limbs, -FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq1, + ); + + // Eq2: lambda² - x3 - 2*px = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + let max_coeff_eq2: u64 = 1 + 1 + 2 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[(&x3, -FieldElement::ONE), (px_s, -FieldElement::from(2u64))], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq2, + ); + + // Eq3: lambda*px - lambda*x3 - y3 - py = q3*p + let prod_lam_px = make_products(compiler, &lambda, px_s); + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + n as u64; + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_px, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], + &q3, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff_eq3, + ); + + // Range checks on hint outputs + // max_coeff across eqs: Eq1 = 6+N, Eq2 = 4+N, Eq3 = 4+N → worst = 6+N + let max_coeff_carry = 6u64 + n as u64; + let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); + range_check_limbs_and_carries( + range_checks, + &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[&c1, &c2, &c3], + params.limb_bits, + crb, + ); + + // Less-than-p checks for lambda, x3, y3 + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + (vec_to_limbs(&x3), vec_to_limbs(&y3)) +} + +/// Hint-verified point addition for non-native field (multi-limb). +/// +/// Same approach as `point_double_verified_non_native` but verifies: +/// Eq1: lambda·(x2-x1) - (y2-y1) = q1·p (2N² products: lam·x2, lam·x1) +/// Eq2: lambda² - x3 - x1 - x2 = q2·p (1N² products: lam·lam) +/// Eq3: lambda·(x1-x3) - y3 - y1 = q3·p (1N² products: lam·x3; lam·x1 +/// reused) +/// +/// Total: (12N-6)W hint + 4N² products + 3×(2N-1) column constraints +/// + 3 less-than-p checks. +pub fn point_add_verified_non_native( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, + params: &MultiLimbParams, +) -> (Limbs, Limbs) { + let n = params.num_limbs; + assert!(n >= 2, "hint-verified non-native requires n >= 2"); + + let max_coeff: u64 = 1 + 1 + 1 + 1 + n as u64; // all 3 eqs: 1+1+1+1+N + check_column_equation_fits(params.limb_bits, max_coeff, n, "EC add"); + + let os = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::NonNativeEcHint { + output_start: os, + op: NonNativeEcOp::Add, + inputs: vec![ + x1.as_slice()[..n].to_vec(), + y1.as_slice()[..n].to_vec(), + x2.as_slice()[..n].to_vec(), + y2.as_slice()[..n].to_vec(), + ], + curve_a: [0; 4], // unused for add + curve_b: [0; 4], // unused for add + field_modulus_p: params.modulus_raw, + limb_bits: params.limb_bits, + num_limbs: n as u32, + }); + + let lambda = witness_range(os, n); + let x3 = witness_range(os + n, n); + let y3 = witness_range(os + 2 * n, n); + let q1 = witness_range(os + 3 * n, n); + let c1 = witness_range(os + 4 * n, 2 * n - 2); + let q2 = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); + let q3 = witness_range(os + 9 * n - 4, n); + let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + + let x1_s = &x1.as_slice()[..n]; + let y1_s = &y1.as_slice()[..n]; + let x2_s = &x2.as_slice()[..n]; + let y2_s = &y2.as_slice()[..n]; + + // Eq1: lambda*x2 - lambda*x1 - y2 + y1 = q1*p + let prod_lam_x2 = make_products(compiler, &lambda, x2_s); + let prod_lam_x1 = make_products(compiler, &lambda, x1_s); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x2, FieldElement::ONE), + (&prod_lam_x1, -FieldElement::ONE), + ], + &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], + &q1, + &c1, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq2: lambda² - x3 - x1 - x2 = q2*p + let prod_lam_lam = make_products(compiler, &lambda, &lambda); + + emit_schoolbook_column_equations( + compiler, + &[(&prod_lam_lam, FieldElement::ONE)], + &[ + (&x3, -FieldElement::ONE), + (x1_s, -FieldElement::ONE), + (x2_s, -FieldElement::ONE), + ], + &q2, + &c2, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Eq3: lambda*x1 - lambda*x3 - y3 - y1 = q3*p + // Reuse prod_lam_x1 from Eq1 + let prod_lam_x3 = make_products(compiler, &lambda, &x3); + + emit_schoolbook_column_equations( + compiler, + &[ + (&prod_lam_x1, FieldElement::ONE), + (&prod_lam_x3, -FieldElement::ONE), + ], + &[(&y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], + &q3, + &c3, + ¶ms.p_limbs, + n, + params.limb_bits, + max_coeff, + ); + + // Range checks + // max_coeff across all 3 eqs = 4+N + let max_coeff_carry = 4u64 + n as u64; + let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); + range_check_limbs_and_carries( + range_checks, + &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[&c1, &c2, &c3], + params.limb_bits, + crb, + ); + + // Less-than-p checks + less_than_p_check_vec(compiler, range_checks, &lambda, params); + less_than_p_check_vec(compiler, range_checks, &x3, params); + less_than_p_check_vec(compiler, range_checks, &y3, params); + + (vec_to_limbs(&x3), vec_to_limbs(&y3)) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/mod.rs b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs new file mode 100644 index 000000000..debfb3c14 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/mod.rs @@ -0,0 +1,51 @@ +//! Elliptic curve point operations for MSM. +//! +//! Submodules: +//! - `generic` — generic EC ops via `MultiLimbOps` abstraction +//! - `tables` — point table construction, lookup, and merged GLV loop +//! - `hints_native` — native-field hint-verified EC ops +//! - `hints_non_native` — non-native hint-verified EC ops (schoolbook) + +mod generic; +mod hints_native; +mod hints_non_native; +mod tables; + +// Re-exports +use super::{multi_limb_ops::MultiLimbOps, Limbs}; +pub use { + generic::{point_add, point_double, point_select_unchecked}, + hints_native::{ + point_add_verified_native, point_double_verified_native, verify_on_curve_native, + }, + hints_non_native::{ + point_add_verified_non_native, point_double_verified_non_native, verify_on_curve_non_native, + }, + tables::{scalar_mul_merged_glv, MergedGlvPoint}, +}; + +/// Dispatching point doubling: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_double_dispatch(ops: &mut MultiLimbOps, x1: Limbs, y1: Limbs) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_double_verified_non_native(ops.compiler, ops.range_checks, x1, y1, ops.params) + } else { + point_double(ops, x1, y1) + } +} + +/// Dispatching point addition: uses hint-verified for multi-limb non-native, +/// generic field-ops otherwise. +pub fn point_add_dispatch( + ops: &mut MultiLimbOps, + x1: Limbs, + y1: Limbs, + x2: Limbs, + y2: Limbs, +) -> (Limbs, Limbs) { + if ops.params.num_limbs >= 2 && !ops.params.is_native { + point_add_verified_non_native(ops.compiler, ops.range_checks, x1, y1, x2, y2, ops.params) + } else { + point_add(ops, x1, y1, x2, y2) + } +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/tables.rs b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs new file mode 100644 index 000000000..27f2f7e5a --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/ec_points/tables.rs @@ -0,0 +1,259 @@ +//! Point table construction and lookup for windowed scalar multiplication. +//! +//! Builds tables of point multiples and performs lookups using bit witnesses +//! for both unsigned and signed-digit windowed approaches. + +use { + super::{generic::point_select_unchecked, point_add_dispatch, point_double_dispatch}, + crate::msm::{multi_limb_ops::MultiLimbOps, Limbs}, + ark_ff::Field, + provekit_common::{witness::SumTerm, FieldElement}, +}; + +/// Builds a signed point table of odd multiples for signed-digit windowed +/// scalar multiplication. +/// +/// T\[0\] = P, T\[1\] = 3P, T\[2\] = 5P, ..., T\[k-1\] = (2k-1)P +/// where k = `half_table_size` = 2^(w-1). +/// +/// Build cost: 1 point_double (for 2P) + (k-1) point_adds when k >= 2. +fn build_signed_point_table( + ops: &mut MultiLimbOps, + px: Limbs, + py: Limbs, + half_table_size: usize, +) -> Vec<(Limbs, Limbs)> { + assert!(half_table_size >= 1); + let mut table = Vec::with_capacity(half_table_size); + table.push((px, py)); // T[0] = 1*P + if half_table_size >= 2 { + let two_p = point_double_dispatch(ops, px, py); // 2P + for i in 1..half_table_size { + let prev = table[i - 1]; + table.push(point_add_dispatch(ops, prev.0, prev.1, two_p.0, two_p.1)); + } + } + table +} + +/// Selects T\[d\] from a point table using bit witnesses, where `d = Σ +/// bits\[i\] * 2^i`. +/// +/// Uses a binary tree of `point_select`s: processes bits from MSB to LSB, +/// halving the candidate set at each level. Total: `(2^w - 1)` point selects +/// for a table of `2^w` entries. +/// +/// When `constrain_bits` is true, each bit is constrained boolean exactly +/// once. When false, bits are assumed already constrained (e.g. XOR'd bits +/// derived from boolean-constrained inputs). +fn table_lookup( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], + bits: &[usize], + constrain_bits: bool, +) -> (Limbs, Limbs) { + assert_eq!(table.len(), 1 << bits.len()); + let mut current: Vec<(Limbs, Limbs)> = table.to_vec(); + // Process bits from MSB to LSB + for &bit in bits.iter().rev() { + if constrain_bits { + ops.constrain_flag(bit); + } + let half = current.len() / 2; + let mut next = Vec::with_capacity(half); + for i in 0..half { + next.push(point_select_unchecked( + ops, + bit, + current[i], + current[i + half], + )); + } + current = next; + } + current[0] +} + +/// Signed-digit table lookup: selects from a half-size table using XOR'd +/// index bits, then conditionally negates y based on the sign bit. +/// +/// For a w-bit window with bits \[b_0, ..., b_{w-1}\] (LSB first): +/// - sign_bit = b_{w-1} (MSB): 1 = positive digit, 0 = negative digit +/// - index_bits = \[b_0, ..., b_{w-2}\] (lower w-1 bits) +/// - When positive: table index = lower bits as-is +/// - When negative: table index = bitwise complement of lower bits, and y is +/// negated +/// +/// The XOR'd bits are computed as: `idx_i = 1 - b_i - MSB + 2*b_i*MSB`, +/// which equals `b_i` when MSB=1, and `1-b_i` when MSB=0. +/// +/// # Precondition +/// `sign_bit` must be boolean-constrained by the caller. This function uses +/// it in `select_unchecked` without re-constraining. Currently satisfied: +/// `decompose_signed_bits` boolean-constrains all bits including the MSB +/// used as `sign_bit`. +fn signed_table_lookup( + ops: &mut MultiLimbOps, + table: &[(Limbs, Limbs)], + index_bits: &[usize], + sign_bit: usize, +) -> (Limbs, Limbs) { + let (x, y) = if index_bits.is_empty() { + // w=1: single entry, no lookup needed + assert_eq!(table.len(), 1); + table[0] + } else { + // Compute XOR'd index bits: idx_i = 1 - b_i - MSB + 2*b_i*MSB + let one_w = ops.compiler.witness_one(); + let two = FieldElement::from(2u64); + let xor_bits: Vec = index_bits + .iter() + .map(|&bit| { + let prod = ops.compiler.add_product(bit, sign_bit); + ops.compiler.add_sum(vec![ + SumTerm(Some(FieldElement::ONE), one_w), + SumTerm(Some(-FieldElement::ONE), bit), + SumTerm(Some(-FieldElement::ONE), sign_bit), + SumTerm(Some(two), prod), + ]) + }) + .collect(); + + // XOR'd bits are boolean by construction (product of two booleans + // combined linearly), so skip redundant boolean constraints. + table_lookup(ops, table, &xor_bits, false) + }; + + // Conditionally negate y: sign_bit=0 (negative) → -y, sign_bit=1 (positive) → y + let neg_y = ops.negate(y); + let eff_y = ops.select_unchecked(sign_bit, neg_y, y); + // select_unchecked(flag, on_false, on_true): + // sign_bit=0 → on_false=neg_y (negative digit, negate y) ✓ + // sign_bit=1 → on_true=y (positive digit, keep y) ✓ + + (x, eff_y) +} + +/// Per-point data for merged multi-point GLV scalar multiplication. +pub struct MergedGlvPoint { + /// Point P x-coordinate (limbs) + pub px: Limbs, + /// Point P y-coordinate (effective, post-negation) + pub py: Limbs, + /// Signed-bit decomposition of |s1| (half-scalar for P), LSB first + pub s1_bits: Vec, + /// Skew correction witness for s1 branch (boolean) + pub s1_skew: usize, + /// Point R x-coordinate (limbs) + pub rx: Limbs, + /// Point R y-coordinate (effective, post-negation) + pub ry: Limbs, + /// Signed-bit decomposition of |s2| (half-scalar for R), LSB first + pub s2_bits: Vec, + /// Skew correction witness for s2 branch (boolean) + pub s2_skew: usize, +} + +/// Merged multi-point GLV scalar multiplication with shared doublings +/// and signed-digit windows. +/// +/// Uses signed-digit encoding: each w-bit window produces a signed odd digit +/// d ∈ {±1, ±3, ..., ±(2^w - 1)}, eliminating zero-digit handling. +/// Tables store odd multiples \[P, 3P, 5P, ..., (2^w-1)P\] with only +/// 2^(w-1) entries (half the unsigned table size). +/// +/// After the main loop, applies skew corrections: if skew=1, subtracts P +/// (or R) to account for the signed decomposition bias. +/// +/// Returns the final accumulator `(x, y)`. +pub fn scalar_mul_merged_glv( + ops: &mut MultiLimbOps, + points: &[MergedGlvPoint], + window_size: usize, + offset_x: Limbs, + offset_y: Limbs, +) -> (Limbs, Limbs) { + assert!(!points.is_empty()); + let n = points[0].s1_bits.len(); + let w = window_size; + let half_table_size = 1usize << (w - 1); + + // Build signed point tables (odd multiples) for all points upfront + let tables: Vec<(Vec<(Limbs, Limbs)>, Vec<(Limbs, Limbs)>)> = points + .iter() + .map(|pt| { + let tp = build_signed_point_table(ops, pt.px, pt.py, half_table_size); + let tr = build_signed_point_table(ops, pt.rx, pt.ry, half_table_size); + (tp, tr) + }) + .collect(); + + let num_windows = (n + w - 1) / w; + let mut acc = (offset_x, offset_y); + + // Process all windows from MSB down to LSB + for i in (0..num_windows).rev() { + let bit_start = i * w; + let bit_end = std::cmp::min(bit_start + w, n); + let actual_w = bit_end - bit_start; + + // w shared doublings on the accumulator (shared across ALL points) + let mut doubled_acc = acc; + for _ in 0..w { + doubled_acc = point_double_dispatch(ops, doubled_acc.0, doubled_acc.1); + } + + let mut cur = doubled_acc; + + // For each point: P branch + R branch (signed-digit lookup) + for (pt, (table_p, table_r)) in points.iter().zip(tables.iter()) { + // --- P branch (s1 window) --- + let s1_window_bits = &pt.s1_bits[bit_start..bit_end]; + let sign_bit_p = s1_window_bits[actual_w - 1]; // MSB + let index_bits_p = &s1_window_bits[..actual_w - 1]; // lower bits + let actual_table_p = if actual_w < w { + &table_p[..1 << (actual_w - 1)] + } else { + &table_p[..] + }; + let looked_up_p = signed_table_lookup(ops, actual_table_p, index_bits_p, sign_bit_p); + // All signed digits are non-zero — no is_zero check needed + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_p.0, looked_up_p.1); + + // --- R branch (s2 window) --- + let s2_window_bits = &pt.s2_bits[bit_start..bit_end]; + let sign_bit_r = s2_window_bits[actual_w - 1]; // MSB + let index_bits_r = &s2_window_bits[..actual_w - 1]; // lower bits + let actual_table_r = if actual_w < w { + &table_r[..1 << (actual_w - 1)] + } else { + &table_r[..] + }; + let looked_up_r = signed_table_lookup(ops, actual_table_r, index_bits_r, sign_bit_r); + cur = point_add_dispatch(ops, cur.0, cur.1, looked_up_r.0, looked_up_r.1); + } + + acc = cur; + } + + // Skew corrections: subtract P (or R) if skew=1 for each point. + // The signed decomposition gives: scalar = Σ d_i * 2^i - skew, + // so the main loop computed (scalar + skew) * P. If skew=1, subtract P. + for pt in points { + // P branch skew + let neg_py = ops.negate(pt.py); + let (sub_px, sub_py) = point_add_dispatch(ops, acc.0, acc.1, pt.px, neg_py); + let new_x = ops.select_unchecked(pt.s1_skew, acc.0, sub_px); + let new_y = ops.select_unchecked(pt.s1_skew, acc.1, sub_py); + acc = (new_x, new_y); + + // R branch skew + let neg_ry = ops.negate(pt.ry); + let (sub_rx, sub_ry) = point_add_dispatch(ops, acc.0, acc.1, pt.rx, neg_ry); + let new_x = ops.select_unchecked(pt.s2_skew, acc.0, sub_rx); + let new_y = ops.select_unchecked(pt.s2_skew, acc.1, sub_ry); + acc = (new_x, new_y); + } + + acc +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 35b1f3483..816bbb929 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -5,21 +5,21 @@ pub(crate) mod multi_limb_arith; pub(crate) mod multi_limb_ops; mod native; mod non_native; +mod sanitize; mod scalar_relation; +#[cfg(test)] +mod tests; +// Re-export sanitize helpers so submodules (native, non_native) can use +// `super::sanitize_point_scalar` etc. use { - crate::{ - constraint_helpers::{ - add_constant_witness, compute_boolean_or, constrain_boolean, select_witness, - }, - msm::multi_limb_arith::compute_is_zero, - noir_to_r1cs::NoirToR1CSCompiler, - }, - ark_ff::{AdditiveGroup, Field, PrimeField}, + crate::{constraint_helpers::add_constant_witness, noir_to_r1cs::NoirToR1CSCompiler}, + ark_ff::PrimeField, curve::CurveParams, - provekit_common::{ - witness::{ConstantOrR1CSWitness, SumTerm, WitnessBuilder}, - FieldElement, + provekit_common::witness::ConstantOrR1CSWitness, + sanitize::{ + decompose_signed_bits, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, + negate_y_signed_native, sanitize_point_scalar, }, std::collections::BTreeMap, }; @@ -119,99 +119,6 @@ impl std::ops::IndexMut for Limbs { } } -// --------------------------------------------------------------------------- -// Private helpers (MSM-specific) -// --------------------------------------------------------------------------- - -/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at -/// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if -/// degenerate). -fn detect_skip( - compiler: &mut NoirToR1CSCompiler, - s_lo: usize, - s_hi: usize, - inf_flag: usize, -) -> usize { - constrain_boolean(compiler, inf_flag); - let is_zero_s_lo = compute_is_zero(compiler, s_lo); - let is_zero_s_hi = compute_is_zero(compiler, s_hi); - let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); - compute_boolean_or(compiler, s_is_zero, inf_flag) -} - -/// Sanitized point-scalar inputs after degenerate-case detection. -struct SanitizedInputs { - px: usize, - py: usize, - s_lo: usize, - s_hi: usize, - is_skip: usize, -} - -/// Detects degenerate cases (scalar=0 or point at infinity) and replaces -/// the point with the generator G and scalar with 1 when degenerate. -fn sanitize_point_scalar( - compiler: &mut NoirToR1CSCompiler, - px: usize, - py: usize, - s_lo: usize, - s_hi: usize, - inf_flag: usize, - gen_x: usize, - gen_y: usize, - zero: usize, - one: usize, -) -> SanitizedInputs { - let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); - SanitizedInputs { - px: select_witness(compiler, is_skip, px, gen_x), - py: select_witness(compiler, is_skip, py, gen_y), - s_lo: select_witness(compiler, is_skip, s_lo, one), - s_hi: select_witness(compiler, is_skip, s_hi, zero), - is_skip, - } -} - -/// Negate a y-coordinate and conditionally select based on a sign flag. -/// Returns `(y_eff, neg_y_eff)` where: -/// - if `neg_flag=0`: `y_eff = y`, `neg_y_eff = -y` -/// - if `neg_flag=1`: `y_eff = -y`, `neg_y_eff = y` -fn negate_y_signed_native( - compiler: &mut NoirToR1CSCompiler, - neg_flag: usize, - y: usize, -) -> (usize, usize) { - constrain_boolean(compiler, neg_flag); - let neg_y = compiler.add_sum(vec![SumTerm(Some(-FieldElement::ONE), y)]); - let y_eff = select_witness(compiler, neg_flag, y, neg_y); - let neg_y_eff = select_witness(compiler, neg_flag, neg_y, y); - (y_eff, neg_y_eff) -} - -/// Emit an `EcScalarMulHint` and sanitize the result point. -/// When `is_skip=1`, the result is swapped to the generator point. -fn emit_ec_scalar_mul_hint_and_sanitize( - compiler: &mut NoirToR1CSCompiler, - san: &SanitizedInputs, - gen_x_witness: usize, - gen_y_witness: usize, - curve: &CurveParams, -) -> (usize, usize) { - let hint_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { - output_start: hint_start, - px: san.px, - py: san.py, - s_lo: san.s_lo, - s_hi: san.s_hi, - curve_a: curve.curve_a, - field_modulus_p: curve.field_modulus_p, - }); - let rx = select_witness(compiler, san.is_skip, hint_start, gen_x_witness); - let ry = select_witness(compiler, san.is_skip, hint_start + 1, gen_y_witness); - (rx, ry) -} - // --------------------------------------------------------------------------- // MSM entry point // --------------------------------------------------------------------------- @@ -256,7 +163,7 @@ pub fn add_msm_with_curve( return; } - let native_bits = FieldElement::MODULUS_BIT_SIZE; + let native_bits = provekit_common::FieldElement::MODULUS_BIT_SIZE; let curve_bits = curve.modulus_bits(); let is_native = curve.is_native_field(); let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / 3).sum(); @@ -337,82 +244,6 @@ fn add_single_msm( } } -/// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. -fn emit_fakeglv_hint( - compiler: &mut NoirToR1CSCompiler, - s_lo: usize, - s_hi: usize, - curve: &CurveParams, -) -> (usize, usize, usize, usize) { - let glv_start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { - output_start: glv_start, - s_lo, - s_hi, - curve_order: curve.curve_order_n, - }); - (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) -} - -/// Signed-bit decomposition for wNAF scalar multiplication. -/// -/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} -/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: -/// scalar = Σ d_i * 2^i - skew -/// -/// Reconstruction constraint (1 linear R1CS): -/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} -/// -/// All bits and skew are boolean-constrained. -/// -/// # Limitation -/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower -/// 128 bits of the field element). This is correct for FakeGLV half-scalars -/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` -/// exceeds 128. The R1CS reconstruction constraint would then fail. -pub(super) fn decompose_signed_bits( - compiler: &mut NoirToR1CSCompiler, - scalar: usize, - num_bits: usize, -) -> (Vec, usize) { - let start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::SignedBitHint { - output_start: start, - scalar, - num_bits, - }); - let bits: Vec = (start..start + num_bits).collect(); - let skew = start + num_bits; - - // Boolean-constrain each bit and skew - for &b in &bits { - constrain_boolean(compiler, b); - } - constrain_boolean(compiler, skew); - - // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} - // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 - let one = compiler.witness_one(); - let two = FieldElement::from(2u64); - let constant = two.pow([num_bits as u64]) - FieldElement::ONE; - let mut b_terms: Vec<(FieldElement, usize)> = bits - .iter() - .enumerate() - .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) - .collect(); - b_terms.push((FieldElement::ONE, scalar)); - b_terms.push((FieldElement::ONE, skew)); - b_terms.push((constant, one)); - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( - FieldElement::ZERO, - one, - )]); - - (bits, skew) -} - /// Resolves a `ConstantOrR1CSWitness` to a witness index. fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitness) -> usize { match input { @@ -420,111 +251,3 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes ConstantOrR1CSWitness::Constant(value) => add_constant_witness(compiler, *value), } } - -#[cfg(test)] -mod tests { - use {super::*, crate::noir_to_r1cs::NoirToR1CSCompiler}; - - /// Verify that the non-native (SECP256R1) single-point MSM path generates - /// constraints without panicking. This does multi-limb arithmetic, - /// range checks, and FakeGLV verification — the entire non-native code path - /// that has no Noir e2e coverage for now : ) - #[test] - fn test_secp256r1_single_point_msm_generates_constraints() { - let mut compiler = NoirToR1CSCompiler::new(); - let curve = curve::secp256r1_params(); - let mut range_checks: BTreeMap> = BTreeMap::new(); - - // Allocate witness slots for: px, py, inf, s_lo, s_hi, out_x, out_y, out_inf - // (witness 0 is the constant-one witness) - let base = compiler.num_witnesses(); - compiler.r1cs.add_witnesses(8); - let px = base; - let py = base + 1; - let inf = base + 2; - let s_lo = base + 3; - let s_hi = base + 4; - let out_x = base + 5; - let out_y = base + 6; - let out_inf = base + 7; - - let points = vec![ - ConstantOrR1CSWitness::Witness(px), - ConstantOrR1CSWitness::Witness(py), - ConstantOrR1CSWitness::Witness(inf), - ]; - let scalars = vec![ - ConstantOrR1CSWitness::Witness(s_lo), - ConstantOrR1CSWitness::Witness(s_hi), - ]; - let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; - - add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); - - let n_constraints = compiler.r1cs.num_constraints(); - let n_witnesses = compiler.num_witnesses(); - - assert!( - n_constraints > 100, - "expected substantial constraints for non-native MSM, got {n_constraints}" - ); - assert!( - n_witnesses > 100, - "expected substantial witnesses for non-native MSM, got {n_witnesses}" - ); - assert!( - !range_checks.is_empty(), - "non-native MSM should produce range checks" - ); - } - - /// Verify that the non-native multi-point MSM path (2 points, SECP256R1) - /// generates constraints. does the multi-point accumulation and offset - /// subtraction logic for the non-native path. - #[test] - fn test_secp256r1_multi_point_msm_generates_constraints() { - let mut compiler = NoirToR1CSCompiler::new(); - let curve = curve::secp256r1_params(); - let mut range_checks: BTreeMap> = BTreeMap::new(); - - // 2 points: px1, py1, inf1, px2, py2, inf2, s1_lo, s1_hi, s2_lo, s2_hi, - // out_x, out_y, out_inf - let base = compiler.num_witnesses(); - compiler.r1cs.add_witnesses(13); - - let points = vec![ - ConstantOrR1CSWitness::Witness(base), // px1 - ConstantOrR1CSWitness::Witness(base + 1), // py1 - ConstantOrR1CSWitness::Witness(base + 2), // inf1 - ConstantOrR1CSWitness::Witness(base + 3), // px2 - ConstantOrR1CSWitness::Witness(base + 4), // py2 - ConstantOrR1CSWitness::Witness(base + 5), // inf2 - ]; - let scalars = vec![ - ConstantOrR1CSWitness::Witness(base + 6), // s1_lo - ConstantOrR1CSWitness::Witness(base + 7), // s1_hi - ConstantOrR1CSWitness::Witness(base + 8), // s2_lo - ConstantOrR1CSWitness::Witness(base + 9), // s2_hi - ]; - let out_x = base + 10; - let out_y = base + 11; - let out_inf = base + 12; - - let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; - - add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); - - let n_constraints = compiler.r1cs.num_constraints(); - let n_witnesses = compiler.num_witnesses(); - - // Multi-point should produce more constraints than single-point - assert!( - n_constraints > 200, - "expected substantial constraints for 2-point non-native MSM, got {n_constraints}" - ); - assert!( - n_witnesses > 200, - "expected substantial witnesses for 2-point non-native MSM, got {n_witnesses}" - ); - } -} diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs index ef19dd31c..f0163b88e 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_arith.rs @@ -194,14 +194,21 @@ pub fn compute_is_zero(compiler: &mut NoirToR1CSCompiler, value: usize) -> usize // N≥2 multi-limb path (generalization of wide_ops.rs) // --------------------------------------------------------------------------- -/// (a + b) mod p for multi-limb values. +/// Shared core for `add_mod_p_multi` and `sub_mod_p_multi`. /// -/// Per limb i: v_i = a\[i\] + b\[i\] + 2^W - q*p\[i\] + carry_{i-1} -/// carry_i = floor(v_i / 2^W) -/// r\[i\] = v_i - carry_i * 2^W -pub fn add_mod_p_multi( +/// Both operations follow the same carry-chain structure; only the per-limb +/// formula and quotient witness builder differ: +/// add: v\[i\] = a\[i\] + b\[i\] + 2^W - q·p\[i\] + carry (q via +/// MultiLimbAddQuotient) sub: v\[i\] = a\[i\] - b\[i\] + q·p\[i\] + 2^W + +/// carry (q via MultiLimbSubBorrow) +/// +/// When `carry_prev` is present, the w1 coefficient absorbs a `-1` term +/// (i.e. `w1_coeff = ... + 2^W - 1`) so that `carry_prev` contributes +/// `+1 · carry_prev` without creating duplicate (row, col) entries. +fn add_sub_mod_p_core( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, + is_add: bool, a: Limbs, b: Limbs, p_limbs: &[FieldElement], @@ -211,19 +218,30 @@ pub fn add_mod_p_multi( modulus_raw: &[u64; 4], ) -> Limbs { let n = a.len(); - assert!(n >= 2, "add_mod_p_multi requires n >= 2, got n={n}"); + assert!(n >= 2, "add/sub_mod_p_multi requires n >= 2, got n={n}"); let w1 = compiler.witness_one(); - // Witness: q = floor((a + b) / p) ∈ {0, 1} + // Witness: q ∈ {0, 1} let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, - limb_bits, - num_limbs: n as u32, - }); + if is_add { + compiler.add_witness_builder(WitnessBuilder::MultiLimbAddQuotient { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + } else { + compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { + output: q, + a_limbs: a.as_slice().to_vec(), + b_limbs: b.as_slice().to_vec(), + modulus: *modulus_raw, + limb_bits, + num_limbs: n as u32, + }); + } // q is boolean compiler .r1cs @@ -236,7 +254,6 @@ pub fn add_mod_p_multi( let mut carry_prev: Option = None; for i in 0..n { - // v_offset = a[i] + b[i] + 2^W - q*p[i] + carry_{i-1} // When carry_prev exists, combine w1 terms to avoid duplicate column // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). let w1_coeff = if carry_prev.is_some() { @@ -244,12 +261,18 @@ pub fn add_mod_p_multi( } else { two_pow_w }; - let mut terms = vec![ - SumTerm(None, a[i]), - SumTerm(None, b[i]), - SumTerm(Some(w1_coeff), w1), - SumTerm(Some(-p_limbs[i]), q), - ]; + // add: a[i] + b[i] + w1_coeff*1 - p[i]*q + carry + // sub: a[i] - b[i] + p[i]*q + w1_coeff*1 + carry + let mut terms = vec![SumTerm(None, a[i])]; + if is_add { + terms.push(SumTerm(None, b[i])); + terms.push(SumTerm(Some(w1_coeff), w1)); + terms.push(SumTerm(Some(-p_limbs[i]), q)); + } else { + terms.push(SumTerm(Some(-FieldElement::ONE), b[i])); + terms.push(SumTerm(Some(p_limbs[i]), q)); + terms.push(SumTerm(Some(w1_coeff), w1)); + } if let Some(carry) = carry_prev { terms.push(SumTerm(None, carry)); } @@ -278,6 +301,36 @@ pub fn add_mod_p_multi( r } +/// (a + b) mod p for multi-limb values. +/// +/// Per limb i: v_i = a\[i\] + b\[i\] + 2^W - q*p\[i\] + carry_{i-1} +/// carry_i = floor(v_i / 2^W) +/// r\[i\] = v_i - carry_i * 2^W +pub fn add_mod_p_multi( + compiler: &mut NoirToR1CSCompiler, + range_checks: &mut BTreeMap>, + a: Limbs, + b: Limbs, + p_limbs: &[FieldElement], + p_minus_1_limbs: &[FieldElement], + two_pow_w: FieldElement, + limb_bits: u32, + modulus_raw: &[u64; 4], +) -> Limbs { + add_sub_mod_p_core( + compiler, + range_checks, + true, + a, + b, + p_limbs, + p_minus_1_limbs, + two_pow_w, + limb_bits, + modulus_raw, + ) +} + /// Negate a multi-limb value: computes `p - y` directly via borrow chain. /// /// Since inputs are already verified `y ∈ [0, p)` (from less_than_p on @@ -288,7 +341,8 @@ pub fn add_mod_p_multi( /// This avoids the generic `sub_mod_p_multi` pathway which allocates /// N zero-constant witnesses, a borrow quotient, and a less_than_p check. /// -/// Witnesses: 3N (N v-sums + N borrows + N result sums). +/// Witnesses: 3N (N v-sums + N borrows + N result limbs). +/// Constraints: N (one `add_sum` per limb, each producing 1W+1C). /// Range checks: N at limb_bits. pub fn negate_mod_p_multi( compiler: &mut NoirToR1CSCompiler, @@ -352,70 +406,18 @@ pub fn sub_mod_p_multi( limb_bits: u32, modulus_raw: &[u64; 4], ) -> Limbs { - let n = a.len(); - assert!(n >= 2, "sub_mod_p_multi requires n >= 2, got n={n}"); - let w1 = compiler.witness_one(); - - // Witness: q = (a < b) ? 1 : 0 - let q = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::MultiLimbSubBorrow { - output: q, - a_limbs: a.as_slice().to_vec(), - b_limbs: b.as_slice().to_vec(), - modulus: *modulus_raw, - limb_bits, - num_limbs: n as u32, - }); - // q is boolean - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, q)], &[(FieldElement::ONE, q)], &[( - FieldElement::ONE, - q, - )]); - - let mut r = Limbs::new(n); - let mut carry_prev: Option = None; - - for i in 0..n { - // v_offset = a[i] - b[i] + q*p[i] + 2^W + carry_{i-1} - // When carry_prev exists, combine w1 terms to avoid duplicate column - // indices in the R1CS sparse matrix (set overwrites on duplicate (row,col)). - let w1_coeff = if carry_prev.is_some() { - two_pow_w - FieldElement::ONE - } else { - two_pow_w - }; - let mut terms = vec![ - SumTerm(None, a[i]), - SumTerm(Some(-FieldElement::ONE), b[i]), - SumTerm(Some(p_limbs[i]), q), - SumTerm(Some(w1_coeff), w1), - ]; - if let Some(carry) = carry_prev { - terms.push(SumTerm(None, carry)); - } - let v_offset = compiler.add_sum(terms); - - let carry = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::IntegerQuotient(carry, v_offset, two_pow_w)); - r[i] = compiler.add_sum(vec![ - SumTerm(None, v_offset), - SumTerm(Some(-two_pow_w), carry), - ]); - carry_prev = Some(carry); - } - - less_than_p_check_multi( + add_sub_mod_p_core( compiler, range_checks, - r, + false, + a, + b, + p_limbs, p_minus_1_limbs, two_pow_w, limb_bits, - ); - - r + modulus_raw, + ) } /// (a * b) mod p for multi-limb values using schoolbook multiplication. @@ -623,8 +625,8 @@ pub fn inv_mod_p_multi( } /// Proves r < p by decomposing (p-1) - r into non-negative multi-limb values. -/// Uses borrow propagation: d\[i\] = (p-1)\[i\] - r\[i\] + borrow_in - -/// borrow_out * 2^W +/// Uses borrow propagation: d\[i\] = (p-1)\[i\] + 2^W - r\[i\] + borrow_in - +/// borrow_out · 2^W. The 2^W offset keeps intermediate values non-negative. pub fn less_than_p_check_multi( compiler: &mut NoirToR1CSCompiler, range_checks: &mut BTreeMap>, diff --git a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs index 52bd3d934..3fccdbd5c 100644 --- a/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs +++ b/provekit/r1cs-compiler/src/msm/multi_limb_ops.rs @@ -331,8 +331,9 @@ impl MultiLimbOps<'_, '_> { multi_limb_arith::compute_is_zero(self.compiler, value) } - /// Packs bit witnesses into a single digit witness: `d = Σ bits[i] * 2^i`. - /// Does NOT constrain bits to be boolean — caller must ensure that. + /// Packs bit witnesses into a single digit witness: `d = Σ bits\[i\] * + /// 2^i`. Does NOT constrain bits to be boolean — caller must ensure + /// that. pub fn pack_bits(&mut self, bits: &[usize]) -> usize { pack_bits_helper(self.compiler, bits) } diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs index ba62cf7ff..cecc58fc2 100644 --- a/provekit/r1cs-compiler/src/msm/native.rs +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -1,7 +1,36 @@ //! Native-field MSM path: hint-verified EC ops with signed-bit wNAF. //! -//! Used when `curve.is_native_field()` — replaces expensive field inversions -//! with prover hints verified via raw R1CS constraints. +//! Used when `curve.is_native_field()` — the curve's base field matches the +//! R1CS native field (e.g. Grumpkin over BN254). Field operations are free +//! (single native witness each), so EC operations use prover hints verified +//! via raw R1CS constraints instead of expensive field inversions. +//! +//! ## Key techniques +//! +//! - **Hint-verified EC ops**: point_double (4W+4C), point_add (3W+3C), and +//! on-curve checks (2W+3C) use prover-supplied (lambda, x3, y3) hints +//! verified by the EC equations directly as R1CS constraints. +//! - **FakeGLV**: each 256-bit scalar is decomposed into two ~128-bit +//! half-scalars (s1, s2) via half-GCD, verified by a scalar relation mod the +//! curve order. +//! - **Signed-bit wNAF (w=1)**: each half-scalar is decomposed into signed bits +//! d_i ∈ {-1, +1}, eliminating zero-digit handling. A skew bit corrects the +//! bias post-loop. +//! - **Merged doubling**: all points share a single doubling per bit, saving 4W +//! per extra point per bit. +//! +//! ## Phases +//! +//! 1. **Preprocessing**: per-point sanitization (degenerate-case replacement), +//! on-curve checks, FakeGLV decomposition, signed-bit decomposition, +//! y-negation based on sign flags. +//! 2. **Merged scalar mul**: single loop from MSB to LSB with one shared +//! doubling + per-point P/R branch adds. Skew corrections after loop. +//! Identity check: final accumulator must equal the known offset. +//! 3. **Scalar relations**: per-point verification that (-1)^neg1·|s1| + +//! (-1)^neg2·|s2|·s ≡ 0 (mod curve_order). +//! 4. **Accumulation**: adds each point's scalar-mul result to an accumulator +//! (skipping degenerate points), subtracts the offset, constrains outputs. use { super::{ diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index 05254b3d8..270dd14d8 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -1,10 +1,38 @@ //! Non-native (generic multi-limb) MSM path. //! -//! Used when `!curve.is_native_field()` — uses `MultiLimbOps` for all EC -//! arithmetic with configurable limb width. +//! Used when `!curve.is_native_field()` — the curve's base field differs from +//! the R1CS native field (e.g. SECP256R1 over BN254). Field elements are +//! represented as multi-limb values (N limbs of `limb_bits` each), and +//! arithmetic is verified via schoolbook column equations. //! -//! Multi-point MSM uses merged doublings: all points share a single set of -//! `w` doublings per window, saving `w × (n_points - 1)` doublings per window. +//! ## Key techniques +//! +//! - **Multi-limb arithmetic**: field elements split into N limbs; add/sub use +//! carry chains with boolean quotients, mul uses schoolbook column equations, +//! all verified mod the curve's base field. +//! - **Hint-verified EC ops** (N ≥ 2): point_double, point_add, and on-curve +//! checks use prover hints verified via schoolbook column equations, avoiding +//! the step-by-step MultiLimbOps chain (which requires field inversions). +//! - **FakeGLV**: same half-GCD scalar decomposition as the native path. +//! - **Signed-digit windows**: w-bit windows produce signed odd digits d ∈ {±1, +//! ±3, ..., ±(2^w-1)}, eliminating zero-digit handling and halving the lookup +//! table to 2^(w-1) entries. Skew correction applied post-loop. +//! - **Merged doublings**: all points share w doublings per window, saving w × +//! (n_points - 1) doublings per window. +//! +//! ## Phases +//! +//! 1. **Preprocessing**: per-point sanitization, limb decomposition of point +//! coordinates, on-curve checks, FakeGLV decomposition, signed-bit +//! decomposition, y-negation via `negate_mod_p_multi`. +//! 2. **Merged scalar mul**: `scalar_mul_merged_glv` runs a single windowed +//! loop with shared doublings + per-point signed table lookups. Identity +//! check: final accumulator must equal the known offset. +//! 3. **Scalar relations**: per-point verification that (-1)^neg1·|s1| + +//! (-1)^neg2·|s2|·s ≡ 0 (mod curve_order). +//! 4. **Accumulation**: adds each point's scalar-mul result (via dispatch to +//! hint-verified or generic add), subtracts offset, recomposes limbs to +//! native field elements, constrains outputs. use { super::{ diff --git a/provekit/r1cs-compiler/src/msm/sanitize.rs b/provekit/r1cs-compiler/src/msm/sanitize.rs new file mode 100644 index 000000000..2ec8d4155 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/sanitize.rs @@ -0,0 +1,181 @@ +//! Degenerate-case detection, sanitization, and bit decomposition helpers +//! for MSM point-scalar pairs. + +use { + super::curve::CurveParams, + crate::{ + constraint_helpers::{compute_boolean_or, constrain_boolean, select_witness}, + msm::multi_limb_arith::compute_is_zero, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + witness::{SumTerm, WitnessBuilder}, + FieldElement, + }, +}; + +/// Detects whether a point-scalar pair is degenerate (scalar=0 or point at +/// infinity). Constrains `inf_flag` to boolean. Returns `is_skip` (1 if +/// degenerate). +fn detect_skip( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + inf_flag: usize, +) -> usize { + constrain_boolean(compiler, inf_flag); + let is_zero_s_lo = compute_is_zero(compiler, s_lo); + let is_zero_s_hi = compute_is_zero(compiler, s_hi); + let s_is_zero = compiler.add_product(is_zero_s_lo, is_zero_s_hi); + compute_boolean_or(compiler, s_is_zero, inf_flag) +} + +/// Sanitized point-scalar inputs after degenerate-case detection. +pub(super) struct SanitizedInputs { + pub px: usize, + pub py: usize, + pub s_lo: usize, + pub s_hi: usize, + pub is_skip: usize, +} + +/// Detects degenerate cases (scalar=0 or point at infinity) and replaces +/// the point with the generator G and scalar with 1 when degenerate. +pub(super) fn sanitize_point_scalar( + compiler: &mut NoirToR1CSCompiler, + px: usize, + py: usize, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x: usize, + gen_y: usize, + zero: usize, + one: usize, +) -> SanitizedInputs { + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + SanitizedInputs { + px: select_witness(compiler, is_skip, px, gen_x), + py: select_witness(compiler, is_skip, py, gen_y), + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + +/// Negate a y-coordinate and conditionally select based on a sign flag. +/// Returns `(y_eff, neg_y_eff)` where: +/// - if `neg_flag=0`: `y_eff = y`, `neg_y_eff = -y` +/// - if `neg_flag=1`: `y_eff = -y`, `neg_y_eff = y` +pub(super) fn negate_y_signed_native( + compiler: &mut NoirToR1CSCompiler, + neg_flag: usize, + y: usize, +) -> (usize, usize) { + constrain_boolean(compiler, neg_flag); + let neg_y = compiler.add_sum(vec![SumTerm(Some(-FieldElement::ONE), y)]); + let y_eff = select_witness(compiler, neg_flag, y, neg_y); + let neg_y_eff = select_witness(compiler, neg_flag, neg_y, y); + (y_eff, neg_y_eff) +} + +/// Emit an `EcScalarMulHint` and sanitize the result point. +/// When `is_skip=1`, the result is swapped to the generator point. +pub(super) fn emit_ec_scalar_mul_hint_and_sanitize( + compiler: &mut NoirToR1CSCompiler, + san: &SanitizedInputs, + gen_x_witness: usize, + gen_y_witness: usize, + curve: &CurveParams, +) -> (usize, usize) { + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px: san.px, + py: san.py, + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + }); + let rx = select_witness(compiler, san.is_skip, hint_start, gen_x_witness); + let ry = select_witness(compiler, san.is_skip, hint_start + 1, gen_y_witness); + (rx, ry) +} + +/// Allocates a FakeGLV hint and returns `(s1, s2, neg1, neg2)` witness indices. +pub(super) fn emit_fakeglv_hint( + compiler: &mut NoirToR1CSCompiler, + s_lo: usize, + s_hi: usize, + curve: &CurveParams, +) -> (usize, usize, usize, usize) { + let glv_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::FakeGLVHint { + output_start: glv_start, + s_lo, + s_hi, + curve_order: curve.curve_order_n, + }); + (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) +} + +/// Signed-bit decomposition for wNAF scalar multiplication. +/// +/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} +/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: +/// scalar = Σ d_i * 2^i - skew +/// +/// Reconstruction constraint (1 linear R1CS): +/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// +/// All bits and skew are boolean-constrained. +/// +/// # Limitation +/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower +/// 128 bits of the field element). This is correct for FakeGLV half-scalars +/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` +/// exceeds 128. The R1CS reconstruction constraint would then fail. +pub(super) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let two = FieldElement::from(2u64); + let constant = two.pow([num_bits as u64]) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} diff --git a/provekit/r1cs-compiler/src/msm/tests.rs b/provekit/r1cs-compiler/src/msm/tests.rs new file mode 100644 index 000000000..ea326ac15 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/tests.rs @@ -0,0 +1,107 @@ +use { + super::*, crate::noir_to_r1cs::NoirToR1CSCompiler, + provekit_common::witness::ConstantOrR1CSWitness, std::collections::BTreeMap, +}; + +/// Verify that the non-native (SECP256R1) single-point MSM path generates +/// constraints without panicking. This does multi-limb arithmetic, +/// range checks, and FakeGLV verification — the entire non-native code path +/// that has no Noir e2e coverage for now : ) +#[test] +fn test_secp256r1_single_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // Allocate witness slots for: px, py, inf, s_lo, s_hi, out_x, out_y, out_inf + // (witness 0 is the constant-one witness) + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(8); + let px = base; + let py = base + 1; + let inf = base + 2; + let s_lo = base + 3; + let s_hi = base + 4; + let out_x = base + 5; + let out_y = base + 6; + let out_inf = base + 7; + + let points = vec![ + ConstantOrR1CSWitness::Witness(px), + ConstantOrR1CSWitness::Witness(py), + ConstantOrR1CSWitness::Witness(inf), + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(s_lo), + ConstantOrR1CSWitness::Witness(s_hi), + ]; + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); + + assert!( + n_constraints > 100, + "expected substantial constraints for non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 100, + "expected substantial witnesses for non-native MSM, got {n_witnesses}" + ); + assert!( + !range_checks.is_empty(), + "non-native MSM should produce range checks" + ); +} + +/// Verify that the non-native multi-point MSM path (2 points, SECP256R1) +/// generates constraints. does the multi-point accumulation and offset +/// subtraction logic for the non-native path. +#[test] +fn test_secp256r1_multi_point_msm_generates_constraints() { + let mut compiler = NoirToR1CSCompiler::new(); + let curve = curve::secp256r1_params(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + // 2 points: px1, py1, inf1, px2, py2, inf2, s1_lo, s1_hi, s2_lo, s2_hi, + // out_x, out_y, out_inf + let base = compiler.num_witnesses(); + compiler.r1cs.add_witnesses(13); + + let points = vec![ + ConstantOrR1CSWitness::Witness(base), // px1 + ConstantOrR1CSWitness::Witness(base + 1), // py1 + ConstantOrR1CSWitness::Witness(base + 2), // inf1 + ConstantOrR1CSWitness::Witness(base + 3), // px2 + ConstantOrR1CSWitness::Witness(base + 4), // py2 + ConstantOrR1CSWitness::Witness(base + 5), // inf2 + ]; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(base + 6), // s1_lo + ConstantOrR1CSWitness::Witness(base + 7), // s1_hi + ConstantOrR1CSWitness::Witness(base + 8), // s2_lo + ConstantOrR1CSWitness::Witness(base + 9), // s2_hi + ]; + let out_x = base + 10; + let out_y = base + 11; + let out_inf = base + 12; + + let msm_ops = vec![(points, scalars, (out_x, out_y, out_inf))]; + + add_msm_with_curve(&mut compiler, msm_ops, &mut range_checks, &curve); + + let n_constraints = compiler.r1cs.num_constraints(); + let n_witnesses = compiler.num_witnesses(); + + // Multi-point should produce more constraints than single-point + assert!( + n_constraints > 200, + "expected substantial constraints for 2-point non-native MSM, got {n_constraints}" + ); + assert!( + n_witnesses > 200, + "expected substantial witnesses for 2-point non-native MSM, got {n_witnesses}" + ); +} From cf115a0cea4a15090aaaf55ed25901f3b5f7fa85 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 13 Mar 2026 13:40:50 +0530 Subject: [PATCH 19/19] feat : bug fix and witness solving test for non native msm --- Cargo.toml | 1 + provekit/common/src/witness/mod.rs | 4 +- .../src/witness/scheduling/dependency.rs | 21 +- .../common/src/witness/scheduling/remapper.rs | 12 +- .../common/src/witness/witness_builder.rs | 29 +- provekit/prover/Cargo.toml | 1 + provekit/prover/src/bigint_mod.rs | 297 +++++----- provekit/prover/src/lib.rs | 4 +- .../prover/src/witness/witness_builder.rs | 263 ++++++--- provekit/r1cs-compiler/src/lib.rs | 6 +- provekit/r1cs-compiler/src/msm/cost_model.rs | 132 ++--- .../src/msm/{curve.rs => curve/mod.rs} | 290 ++-------- .../r1cs-compiler/src/msm/curve/u256_arith.rs | 176 ++++++ .../src/msm/ec_points/hints_non_native.rs | 165 ++++-- provekit/r1cs-compiler/src/msm/limbs.rs | 96 ++++ provekit/r1cs-compiler/src/msm/mod.rs | 260 +++++---- provekit/r1cs-compiler/src/msm/native.rs | 61 ++- provekit/r1cs-compiler/src/msm/non_native.rs | 291 +++++++--- provekit/r1cs-compiler/src/msm/sanitize.rs | 144 +++-- .../r1cs-compiler/src/msm/scalar_relation.rs | 7 + provekit/r1cs-compiler/src/noir_to_r1cs.rs | 6 +- provekit/r1cs-compiler/src/range_check.rs | 4 +- tooling/provekit-bench/Cargo.toml | 1 + .../tests/msm_witness_solving.rs | 512 ++++++++++++++++++ 24 files changed, 1920 insertions(+), 863 deletions(-) rename provekit/r1cs-compiler/src/msm/{curve.rs => curve/mod.rs} (74%) create mode 100644 provekit/r1cs-compiler/src/msm/curve/u256_arith.rs create mode 100644 provekit/r1cs-compiler/src/msm/limbs.rs create mode 100644 tooling/provekit-bench/tests/msm_witness_solving.rs diff --git a/Cargo.toml b/Cargo.toml index 2e1cea09e..5c2491c47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ chrono = "0.4.41" divan = "0.1.21" hex = "0.4.3" itertools = "0.14.0" +num-bigint = "0.4" paste = "1.0.15" postcard = { version = "1.1.1", features = ["use-std"] } primitive-types = "0.13.1" diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index e4968563c..4c9950064 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -17,7 +17,9 @@ pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, digits::{decompose_into_digits, DigitalDecompositionWitnesses}, ram::{SpiceMemoryOperation, SpiceWitnesses}, - scheduling::{Layer, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders}, + scheduling::{ + Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + }, witness_builder::{ CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, WitnessBuilder, WitnessCoefficient, diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index f86ea414e..0794744ae 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -230,8 +230,17 @@ impl DependencyInfo { } WitnessBuilder::FakeGLVHint { s_lo, s_hi, .. } => vec![*s_lo, *s_hi], WitnessBuilder::EcScalarMulHint { - px, py, s_lo, s_hi, .. - } => vec![*px, *py, *s_lo, *s_hi], + px_limbs, + py_limbs, + s_lo, + s_hi, + .. + } => px_limbs + .iter() + .chain(py_limbs.iter()) + .copied() + .chain([*s_lo, *s_hi]) + .collect(), WitnessBuilder::SelectWitness { flag, on_false, @@ -368,9 +377,11 @@ impl DependencyInfo { WitnessBuilder::FakeGLVHint { output_start, .. } => { (*output_start..*output_start + 4).collect() } - WitnessBuilder::EcScalarMulHint { output_start, .. } => { - (*output_start..*output_start + 2).collect() - } + WitnessBuilder::EcScalarMulHint { + output_start, + num_limbs, + .. + } => (*output_start..*output_start + 2 * *num_limbs as usize).collect(), WitnessBuilder::MultiLimbAddQuotient { output, .. } => vec![*output], WitnessBuilder::MultiLimbSubBorrow { output, .. } => vec![*output], WitnessBuilder::U32Addition(result_idx, carry_idx, ..) => { diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 5d8b17023..dd63190c2 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -429,20 +429,24 @@ impl WitnessIndexRemapper { }, WitnessBuilder::EcScalarMulHint { output_start, - px, - py, + px_limbs, + py_limbs, s_lo, s_hi, curve_a, field_modulus_p, + num_limbs, + limb_bits, } => WitnessBuilder::EcScalarMulHint { output_start: self.remap(*output_start), - px: self.remap(*px), - py: self.remap(*py), + px_limbs: px_limbs.iter().map(|&w| self.remap(w)).collect(), + py_limbs: py_limbs.iter().map(|&w| self.remap(w)).collect(), s_lo: self.remap(*s_lo), s_hi: self.remap(*s_hi), curve_a: *curve_a, field_modulus_p: *field_modulus_p, + num_limbs: *num_limbs, + limb_bits: *limb_bits, }, WitnessBuilder::SelectWitness { output, diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 7a5fe1e7a..8bcbc3d09 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -57,12 +57,12 @@ pub struct CombinedTableEntryInverseData { /// Operation type for the unified non-native EC hint. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum NonNativeEcOp { - /// Point doubling: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 12N-6 + /// Point doubling: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 15N-6 Double, /// Point addition: inputs = \[\[x1_limbs\], \[y1_limbs\], \[x2_limbs\], - /// \[y2_limbs\]\], outputs 12N-6 + /// \[y2_limbs\]\], outputs 15N-6 Add, - /// On-curve check: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 7N-4 + /// On-curve check: inputs = \[\[px_limbs\], \[py_limbs\]\], outputs 9N-4 OnCurve, } @@ -303,15 +303,20 @@ pub enum WitnessBuilder { /// computes R = \[s\]P on the curve with parameter `curve_a` and /// field modulus `field_modulus_p`. /// - /// Outputs 2 witnesses at output_start: R_x, R_y. + /// When `num_limbs == 1`: inputs are single witnesses, outputs 2 + /// witnesses (R_x, R_y) as native field elements. + /// When `num_limbs >= 2`: inputs are limb witnesses, outputs + /// `2 * num_limbs` witnesses (R_x limbs then R_y limbs). EcScalarMulHint { output_start: usize, - px: usize, - py: usize, + px_limbs: Vec, + py_limbs: Vec, s_lo: usize, s_hi: usize, curve_a: [u64; 4], field_modulus_p: [u64; 4], + num_limbs: u32, + limb_bits: u32, }, /// Prover hint for EC point doubling on native field. /// Given P = (px, py) and curve parameter `a`, computes: @@ -361,10 +366,10 @@ pub enum WitnessBuilder { /// Unified prover hint for non-native EC operations (multi-limb). /// /// `op` selects the operation: - /// - `Double`: inputs = \[\[px\], \[py\]\], outputs 12N-6 witnesses - /// - `Add`: inputs = \[\[x1\], \[y1\], \[x2\], \[y2\]\], outputs 12N-6 + /// - `Double`: inputs = \[\[px\], \[py\]\], outputs 15N-6 witnesses + /// - `Add`: inputs = \[\[x1\], \[y1\], \[x2\], \[y2\]\], outputs 15N-6 /// witnesses - /// - `OnCurve`: inputs = \[\[px\], \[py\]\], outputs 7N-4 witnesses + /// - `OnCurve`: inputs = \[\[px\], \[py\]\], outputs 9N-4 witnesses NonNativeEcHint { output_start: usize, op: NonNativeEcOp, @@ -457,11 +462,11 @@ impl WitnessBuilder { WitnessBuilder::EcDoubleHint { .. } => 3, WitnessBuilder::EcAddHint { .. } => 3, WitnessBuilder::NonNativeEcHint { op, num_limbs, .. } => match op { - NonNativeEcOp::Double | NonNativeEcOp::Add => (12 * *num_limbs - 6) as usize, - NonNativeEcOp::OnCurve => (7 * *num_limbs - 4) as usize, + NonNativeEcOp::Double | NonNativeEcOp::Add => (15 * *num_limbs - 6) as usize, + NonNativeEcOp::OnCurve => (9 * *num_limbs - 4) as usize, }, WitnessBuilder::FakeGLVHint { .. } => 4, - WitnessBuilder::EcScalarMulHint { .. } => 2, + WitnessBuilder::EcScalarMulHint { num_limbs, .. } => 2 * *num_limbs as usize, _ => 1, } diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index a4390623d..0b80faef9 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -33,6 +33,7 @@ whir.workspace = true # 3rd party anyhow.workspace = true +num-bigint.workspace = true postcard.workspace = true rand.workspace = true rayon.workspace = true diff --git a/provekit/prover/src/bigint_mod.rs b/provekit/prover/src/bigint_mod.rs index f61b97f8c..85c0d471f 100644 --- a/provekit/prover/src/bigint_mod.rs +++ b/provekit/prover/src/bigint_mod.rs @@ -3,7 +3,11 @@ /// These helpers compute modular inverse via Fermat's little theorem: /// a^{-1} = a^{m-2} mod m, using schoolbook multiplication and /// square-and-multiply exponentiation. -use {ark_ff::PrimeField, provekit_common::FieldElement}; +use { + ark_ff::PrimeField, + num_bigint::{BigInt, Sign}, + provekit_common::FieldElement, +}; /// Schoolbook multiplication: 4×4 limbs → 8 limbs (256-bit × 256-bit → /// 512-bit). @@ -61,8 +65,7 @@ fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { let mut remainder = [0u64; 4]; for bit_pos in (0..highest_bit).rev() { // Left-shift remainder by 1 - let carry = shift_left_one(&mut remainder); - debug_assert_eq!(carry, 0, "remainder overflow during shift"); + let shift_carry = shift_left_one(&mut remainder); // Bring in the next bit from wide let limb_idx = bit_pos / 64; @@ -70,9 +73,16 @@ fn reduce_wide(wide: &[u64; 8], modulus: &[u64; 4]) -> [u64; 4] { let bit = (wide[limb_idx] >> bit_idx) & 1; remainder[0] |= bit; - // If remainder >= modulus, subtract - if cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { - sub_4limb_inplace(&mut remainder, modulus); + // If shift_carry is set, the effective remainder is 2^256 + remainder, + // which is always > any 256-bit modulus, so we must subtract. + if shift_carry != 0 || cmp_4limb(&remainder, modulus) != std::cmp::Ordering::Less { + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = remainder[i].overflowing_sub(modulus[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + remainder[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } } } @@ -296,82 +306,119 @@ pub fn to_i128_limbs(limbs: &[u128]) -> Vec { limbs.iter().map(|&v| v as i128).collect() } +/// Convert a `[u64; 8]` wide value to a `BigInt`. +fn wide_to_bigint(v: &[u64; 8]) -> BigInt { + let mut bytes = [0u8; 64]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a `[u64; 4]` to a `BigInt`. +fn u256_to_bigint(v: &[u64; 4]) -> BigInt { + let mut bytes = [0u8; 32]; + for (i, &limb) in v.iter().enumerate() { + bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + BigInt::from_bytes_le(Sign::Plus, &bytes) +} + +/// Convert a non-negative `BigInt` to `u128`. Panics if negative or too large. +fn bigint_to_u128(v: &BigInt) -> u128 { + assert!(v.sign() != Sign::Minus, "bigint_to_u128: negative value"); + let (_, bytes) = v.to_bytes_le(); + assert!(bytes.len() <= 16, "bigint_to_u128: value exceeds 128 bits"); + let mut buf = [0u8; 16]; + buf[..bytes.len()].copy_from_slice(&bytes); + u128::from_le_bytes(buf) +} + +/// Convert a non-negative `BigInt` to `[u64; 4]`. Panics if it doesn't fit. +fn bigint_to_u256(v: &BigInt) -> [u64; 4] { + assert!(v.sign() != Sign::Minus, "bigint_to_u256: negative value"); + let (_, bytes) = v.to_bytes_le(); + let mut result = [0u64; 4]; + for (i, chunk) in bytes.chunks(8).enumerate() { + if i >= 4 { + assert!( + chunk.iter().all(|&b| b == 0), + "bigint_to_u256: value exceeds 256 bits" + ); + break; + } + let mut buf = [0u8; 8]; + buf[..chunk.len()].copy_from_slice(chunk); + result[i] = u64::from_le_bytes(buf); + } + result +} + /// Compute signed quotient q such that: -/// Σ lhs_products\[i\] * coeff_i - Σ rhs_products\[j\] * coeff_j - rhs_sub ≡ -/// 0 (mod p) Returns q as decomposed limbs, with negative q stored as -q in the -/// native field. +/// Σ lhs_products\[i\] * coeff_i + Σ lhs_linear\[j\] * coeff_j +/// - Σ rhs_products\[i\] * coeff_i - Σ rhs_linear\[j\] * coeff_j ≡ 0 (mod p) +/// +/// Returns (|q| limbs, is_negative) where q = (LHS - RHS) / p. pub fn signed_quotient_wide( lhs_products: &[(&[u64; 4], &[u64; 4], u64)], rhs_products: &[(&[u64; 4], &[u64; 4], u64)], - rhs_sub: Option<&[u64; 4]>, + lhs_linear: &[(&[u64; 4], u64)], + rhs_linear: &[(&[u64; 4], u64)], p: &[u64; 4], n: usize, w: u32, -) -> Vec { - fn accumulate_wide(terms: &[(&[u64; 4], &[u64; 4], u64)]) -> [u64; 8] { - let mut acc = [0u64; 8]; +) -> (Vec, bool) { + fn accumulate_wide_products(terms: &[(&[u64; 4], &[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); for &(a, b, coeff) in terms { let prod = widening_mul(a, b); - let mut carry = 0u128; - for i in 0..8 { - let v = acc[i] as u128 + (prod[i] as u128) * (coeff as u128) + carry; - acc[i] = v as u64; - carry = v >> 64; - } + acc += wide_to_bigint(&prod) * BigInt::from(coeff); } acc } - let lhs_wide = accumulate_wide(lhs_products); - let mut rhs_wide = accumulate_wide(rhs_products); - if let Some(sub) = rhs_sub { - let mut carry = 0u128; - for i in 0..4 { - let v = rhs_wide[i] as u128 + sub[i] as u128 + carry; - rhs_wide[i] = v as u64; - carry = v >> 64; - } - for i in 4..8 { - let v = rhs_wide[i] as u128 + carry; - rhs_wide[i] = v as u64; - carry = v >> 64; + fn accumulate_wide_linear(terms: &[(&[u64; 4], u64)]) -> BigInt { + let mut acc = BigInt::from(0); + for &(val, coeff) in terms { + acc += u256_to_bigint(val) * BigInt::from(coeff); } + acc } - let lhs_ge = { - let mut ge = false; - for i in (0..8).rev() { - if lhs_wide[i] > rhs_wide[i] { - ge = true; - break; - } else if lhs_wide[i] < rhs_wide[i] { - break; - } - if i == 0 { - ge = true; - } - } - ge - }; - let (big, small) = if lhs_ge { - (lhs_wide, rhs_wide) - } else { - (rhs_wide, lhs_wide) - }; - let mut diff = [0u64; 8]; - let mut bw = 0u64; - for i in 0..8 { - let (d1, b1) = big[i].overflowing_sub(small[i]); - let (d2, b2) = d1.overflowing_sub(bw); - diff[i] = d2; - bw = b1 as u64 + b2 as u64; - } - let (q_abs, _) = divmod_wide(&diff, p); - if lhs_ge { - decompose_to_u128_limbs(&q_abs, n, w) - } else { - decompose_to_u128_limbs(&fe_to_bigint(-bigint_to_fe(&q_abs)), n, w) - } + let lhs = accumulate_wide_products(lhs_products) + accumulate_wide_linear(lhs_linear); + let rhs = accumulate_wide_products(rhs_products) + accumulate_wide_linear(rhs_linear); + + let diff = lhs - rhs; + let p_big = u256_to_bigint(p); + + let q_big = &diff / &p_big; + let rem = &diff - &q_big * &p_big; + debug_assert_eq!( + rem, + BigInt::from(0), + "signed_quotient_wide: non-zero remainder" + ); + + let is_neg = q_big.sign() == Sign::Minus; + let q_abs_big = if is_neg { -&q_big } else { q_big }; + + // Decompose directly from BigInt into u128 limbs at `w` bits each, + // since the quotient may exceed 256 bits. + let limb_mask = (BigInt::from(1u64) << w) - 1; + let mut limbs = Vec::with_capacity(n); + let mut remaining = q_abs_big; + for _ in 0..n { + let limb_val = &remaining & &limb_mask; + limbs.push(bigint_to_u128(&limb_val)); + remaining >>= w; + } + debug_assert_eq!( + remaining, + BigInt::from(0), + "quotient doesn't fit in {n} limbs at {w} bits" + ); + + (limbs, is_neg) } /// Reconstruct a 256-bit value from u128 limb values packed at `limb_bits` @@ -413,62 +460,44 @@ pub fn compute_mul_mod_carries( let n = a_limbs.len(); let w = limb_bits; let num_carries = 2 * n - 2; - let carry_offset = 1u128 << (w + ((n as f64).log2().ceil() as u32) + 1); + let carry_offset = BigInt::from(1u64) << (w + ((n as f64).log2().ceil() as u32) + 1); let mut carries = Vec::with_capacity(num_carries); - let mut carry: i128 = 0; + let mut carry = BigInt::from(0); for k in 0..(2 * n - 1) { - let mut ab_lo: u128 = 0; - let mut ab_hi: u64 = 0; + let mut col_value = BigInt::from(0); + + // a*b products for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - let prod = a_limbs[i] * b_limbs[j as usize]; - let (new_lo, ov) = ab_lo.overflowing_add(prod); - ab_lo = new_lo; - if ov { - ab_hi += 1; - } + col_value += BigInt::from(a_limbs[i]) * BigInt::from(b_limbs[j as usize]); } } - let mut pq_lo: u128 = 0; - let mut pq_hi: u64 = 0; + + // Subtract p*q + r for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - let prod = p_limbs[i] * q_limbs[j as usize]; - let (new_lo, ov) = pq_lo.overflowing_add(prod); - pq_lo = new_lo; - if ov { - pq_hi += 1; - } + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_limbs[j as usize]); } } if k < n { - let (new_lo, ov) = pq_lo.overflowing_add(r_limbs[k]); - pq_lo = new_lo; - if ov { - pq_hi += 1; - } + col_value -= BigInt::from(r_limbs[k]); } - let diff_lo = ab_lo.wrapping_sub(pq_lo); - let borrow = if ab_lo < pq_lo { 1i64 } else { 0 }; - let diff_hi = ab_hi as i64 - pq_hi as i64 - borrow; - - let carry_lo = carry as u128; - let carry_hi: i64 = if carry < 0 { -1 } else { 0 }; - let (total_lo, ov) = diff_lo.overflowing_add(carry_lo); - let total_hi = diff_hi + carry_hi + if ov { 1i64 } else { 0 }; + col_value += &carry; if k < 2 * n - 2 { + let mask = (BigInt::from(1u64) << w) - 1; debug_assert_eq!( - total_lo & ((1u128 << w) - 1), - 0, + &col_value & &mask, + BigInt::from(0), "non-zero remainder at column {k}" ); - carry = total_hi as i128 * (1i128 << (128 - w)) + (total_lo >> w) as i128; - carries.push((carry + carry_offset as i128) as u128); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); } } @@ -640,9 +669,15 @@ pub fn mod_add(a: &[u64; 4], b: &[u64; 4], p: &[u64; 4]) -> [u64; 4] { let sum = add_4limb(a, b); let sum4 = [sum[0], sum[1], sum[2], sum[3]]; if sum[4] > 0 || cmp_4limb(&sum4, p) != std::cmp::Ordering::Less { - // sum >= p, subtract p + // sum >= p, subtract p (borrow absorbs carry bit if sum[4] > 0) let mut result = sum4; - sub_4limb_inplace(&mut result, p); + let mut borrow = 0u64; + for i in 0..4 { + let (d1, b1) = result[i].overflowing_sub(p[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + result[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } result } else { sum4 @@ -851,47 +886,42 @@ pub fn divmod_wide(dividend: &[u64; 8], divisor: &[u64; 4]) -> ([u64; 4], [u64; /// Each `linear_set` entry is (limb_values, coefficient) for non-product terms: /// LHS_terms += Σ coeff * val\[k\] (for k < val.len()) /// -/// The equation verified is: LHS = Σ p\[i\]*q\[j\] + carry_chain -/// (no separate result — the "result" is encoded in the linear terms). +/// The equation verified is: +/// LHS + Σ p\[i\]*q_neg\[j\] = RHS + Σ p\[i\]*q_pos\[j\] + carry_chain +/// +/// `q_pos_limbs` and `q_neg_limbs` are both non-negative; at most one is +/// non-zero. pub fn compute_ec_verification_carries( product_sets: &[(&[u128], &[u128], i64)], linear_terms: &[(Vec, i64)], // (limb_values extended to 2N-1, coefficient) p_limbs: &[u128], - q_limbs: &[u128], + q_pos_limbs: &[u128], + q_neg_limbs: &[u128], n: usize, limb_bits: u32, + max_coeff_sum: u64, ) -> Vec { let w = limb_bits; let num_columns = 2 * n - 1; let num_carries = num_columns - 1; - // Use a larger offset to account for merged terms. - // Max terms per column: sum of coefficients × N products + linear terms. - let max_coeff_sum: u64 = product_sets - .iter() - .map(|(_, _, c)| c.unsigned_abs() as u64) - .sum::() - + linear_terms - .iter() - .map(|(_, c)| c.unsigned_abs() as u64) - .sum::() - + n as u64; // p*q terms let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; let carry_offset_bits = w + extra_bits; - let carry_offset = 1u128 << carry_offset_bits; + let carry_offset = BigInt::from(1u64) << carry_offset_bits; let mut carries = Vec::with_capacity(num_carries); - let mut carry: i128 = 0; + let mut carry = BigInt::from(0); for k in 0..num_columns { - let mut col_value: i128 = 0; + let mut col_value = BigInt::from(0); // Product terms for &(a, b, coeff) in product_sets { for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - col_value += coeff as i128 * (a[i] as i128) * (b[j as usize] as i128); + col_value += + BigInt::from(coeff) * BigInt::from(a[i]) * BigInt::from(b[j as usize]); } } } @@ -899,30 +929,37 @@ pub fn compute_ec_verification_carries( // Linear terms for (vals, coeff) in linear_terms { if k < vals.len() { - col_value += *coeff as i128 * vals[k]; + col_value += BigInt::from(*coeff) * BigInt::from(vals[k]); } } - // Subtract p*q contribution + // p*q_neg on positive side, p*q_pos on negative side for i in 0..n { let j = k as isize - i as isize; if j >= 0 && (j as usize) < n { - col_value -= (p_limbs[i] as i128) * (q_limbs[j as usize] as i128); + col_value += BigInt::from(p_limbs[i]) * BigInt::from(q_neg_limbs[j as usize]); + col_value -= BigInt::from(p_limbs[i]) * BigInt::from(q_pos_limbs[j as usize]); } } - col_value += carry; + col_value += &carry; if k < num_carries { + let mask = (BigInt::from(1u64) << w) - 1; debug_assert_eq!( - col_value & ((1i128 << w) - 1), - 0, + &col_value & &mask, + BigInt::from(0), "non-zero remainder at column {k}: col_value={col_value}" ); - carry = col_value >> w; - carries.push((carry + carry_offset as i128) as u128); + carry = &col_value >> w; + let stored = &carry + &carry_offset; + carries.push(bigint_to_u128(&stored)); } else { - debug_assert_eq!(col_value, 0, "non-zero final column value: {col_value}"); + debug_assert_eq!( + col_value, + BigInt::from(0), + "non-zero final column value: {col_value}" + ); } } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 355642b5d..de84a360c 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -21,9 +21,9 @@ use { whir::transcript::{codecs::Empty, ProverState, VerifierMessage}, }; -pub(crate) mod bigint_mod; +pub mod bigint_mod; pub mod input_utils; -mod r1cs; +pub mod r1cs; mod whir_r1cs; mod witness; diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index d663728f9..a895b1c80 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -4,9 +4,8 @@ use { add_4limb, bigint_to_fe, cmp_4limb, compute_ec_verification_carries, compute_mul_mod_carries, decompose_to_u128_limbs, divmod, divmod_wide, ec_point_add_with_lambda, ec_point_double_with_lambda, ec_scalar_mul, fe_to_bigint, - half_gcd, mod_add, mod_pow, mod_sub, mul_mod, reconstruct_from_halves, - reconstruct_from_u128_limbs, signed_quotient_wide, sub_u64, to_i128_limbs, - widening_mul, + half_gcd, mod_pow, mul_mod, reconstruct_from_halves, reconstruct_from_u128_limbs, + signed_quotient_wide, sub_u64, to_i128_limbs, widening_mul, }, witness::{digits::DigitalDecompositionWitnessesSolver, ram::SpiceWitnessesSolver}, }, @@ -508,6 +507,19 @@ impl WitnessBuilderSolver for WitnessBuilder { let p_l = decompose_to_u128_limbs(field_modulus_p, n, w); + // Helper to split signed quotient into (q_pos, q_neg). + fn split_quotient( + q_abs: Vec, + is_neg: bool, + n: usize, + ) -> (Vec, Vec) { + if is_neg { + (vec![0u128; n], q_abs) + } else { + (q_abs, vec![0u128; n]) + } + } + match op { NonNativeEcOp::Double => { let px_val = read_witness_limbs(witness, &inputs[0], w); @@ -524,72 +536,91 @@ impl WitnessBuilderSolver for WitnessBuilder { write_limbs(witness, os + n, &xl); write_limbs(witness, os + 2 * n, &yl); + // Per-equation max_coeff_sum must match compiler + let mcs_eq1 = 6 + 2 * n as u64; // 2+3+1+2n + let mcs_eq2 = 4 + 2 * n as u64; // 1+1+2+2n + let mcs_eq3 = 4 + 2 * n as u64; // 1+1+1+1+2n + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + // Eq1: 2*λ*py - 3*px² - a = q1*p - let q1 = signed_quotient_wide( + let (q1_abs, q1_neg) = signed_quotient_wide( &[(&lam, &py_val, 2)], &[(&px_val, &px_val, 3)], - Some(curve_a), + &[], + &[(curve_a, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 3 * n, &q1); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + 3 * n, &q1_pos); + write_limbs(witness, os + 4 * n, &q1_neg); let c1 = compute_ec_verification_carries( &[(&ll, &pyl, 2), (&pl, &pl, -3)], &[(to_i128_limbs(&a_l), -1)], &p_l, - &q1, + &q1_pos, + &q1_neg, n, w, + mcs_eq1, ); - write_limbs(witness, os + 4 * n, &c1); + write_limbs(witness, os + 5 * n, &c1); // Eq2: λ² - x3 - 2*px = q2*p - let two_px = mod_add(&px_val, &px_val, field_modulus_p); - let rhs2 = mod_add(&x3v, &two_px, field_modulus_p); - let q2 = signed_quotient_wide( + let (q2_abs, q2_neg) = signed_quotient_wide( &[(&lam, &lam, 1)], &[], - Some(&rhs2), + &[], + &[(&x3v, 1), (&px_val, 2)], field_modulus_p, n, w, ); - write_limbs(witness, os + 6 * n - 2, &q2); + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 7 * n - 2, &q2_pos); + write_limbs(witness, os + 8 * n - 2, &q2_neg); let c2 = compute_ec_verification_carries( &[(&ll, &ll, 1)], - &[ - (to_i128_limbs(&xl), -1), - (pl.iter().map(|&v| 2 * v as i128).collect(), -1), - ], + &[(to_i128_limbs(&xl), -1), (to_i128_limbs(&pl), -2)], &p_l, - &q2, + &q2_pos, + &q2_neg, n, w, + mcs_eq2, ); - write_limbs(witness, os + 7 * n - 2, &c2); + write_limbs(witness, os + 9 * n - 2, &c2); - // Eq3: λ*(px-x3) - y3 - py = q3*p - let dx = mod_sub(&px_val, &x3v, field_modulus_p); - let rhs3 = mod_add(&y3v, &py_val, field_modulus_p); - let q3 = signed_quotient_wide( - &[(&lam, &dx, 1)], + // Eq3: λ*px - λ*x3 - y3 - py = q3*p + let (q3_abs, q3_neg) = signed_quotient_wide( + &[(&lam, &px_val, 1)], + &[(&lam, &x3v, 1)], &[], - Some(&rhs3), + &[(&y3v, 1), (&py_val, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 9 * n - 4, &q3); + let (q3_pos, q3_neg) = split_quotient(q3_abs, q3_neg, n); + write_limbs(witness, os + 11 * n - 4, &q3_pos); + write_limbs(witness, os + 12 * n - 4, &q3_neg); let c3 = compute_ec_verification_carries( &[(&ll, &pl, 1), (&ll, &xl, -1)], &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&pyl), -1)], &p_l, - &q3, + &q3_pos, + &q3_neg, n, w, + mcs_eq3, ); - write_limbs(witness, os + 10 * n - 4, &c3); + write_limbs(witness, os + 13 * n - 4, &c3); } NonNativeEcOp::Add => { let x1v = read_witness_limbs(witness, &inputs[0], w); @@ -609,40 +640,53 @@ impl WitnessBuilderSolver for WitnessBuilder { write_limbs(witness, os + n, &xl); write_limbs(witness, os + 2 * n, &yl); - // Eq1: λ*(x2-x1) - (y2-y1) = q1*p - let dx = mod_sub(&x2v, &x1v, field_modulus_p); - let dy = mod_sub(&y2v, &y1v, field_modulus_p); - let q1 = signed_quotient_wide( - &[(&lam, &dx, 1)], - &[], - Some(&dy), + // Must match compiler's max_coeff_sum: 1+1+1+1 + 2*n + let mcs = 4 + 2 * n as u64; + + // Layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 + + // Eq1: λ*x2 - λ*x1 + y1 - y2 = q1*p + let (q1_abs, q1_neg) = signed_quotient_wide( + &[(&lam, &x2v, 1)], + &[(&lam, &x1v, 1)], + &[(&y1v, 1)], + &[(&y2v, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 3 * n, &q1); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + 3 * n, &q1_pos); + write_limbs(witness, os + 4 * n, &q1_neg); let c1 = compute_ec_verification_carries( &[(&ll, &x2l, 1), (&ll, &x1l, -1)], &[(to_i128_limbs(&y2l), -1), (to_i128_limbs(&y1l), 1)], &p_l, - &q1, + &q1_pos, + &q1_neg, n, w, + mcs, ); - write_limbs(witness, os + 4 * n, &c1); + write_limbs(witness, os + 5 * n, &c1); // Eq2: λ² - x3 - x1 - x2 = q2*p - let sum_x = - mod_add(&x3v, &mod_add(&x1v, &x2v, field_modulus_p), field_modulus_p); - let q2 = signed_quotient_wide( + let (q2_abs, q2_neg) = signed_quotient_wide( &[(&lam, &lam, 1)], &[], - Some(&sum_x), + &[], + &[(&x3v, 1), (&x1v, 1), (&x2v, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 6 * n - 2, &q2); + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 7 * n - 2, &q2_pos); + write_limbs(witness, os + 8 * n - 2, &q2_neg); let c2 = compute_ec_verification_carries( &[(&ll, &ll, 1)], &[ @@ -651,33 +695,38 @@ impl WitnessBuilderSolver for WitnessBuilder { (to_i128_limbs(&x2l), -1), ], &p_l, - &q2, + &q2_pos, + &q2_neg, n, w, + mcs, ); - write_limbs(witness, os + 7 * n - 2, &c2); + write_limbs(witness, os + 9 * n - 2, &c2); - // Eq3: λ*(x1-x3) - y3 - y1 = q3*p - let dx3 = mod_sub(&x1v, &x3v, field_modulus_p); - let rhs3 = mod_add(&y3v, &y1v, field_modulus_p); - let q3 = signed_quotient_wide( - &[(&lam, &dx3, 1)], + // Eq3: λ*x1 - λ*x3 - y3 - y1 = q3*p + let (q3_abs, q3_neg) = signed_quotient_wide( + &[(&lam, &x1v, 1)], + &[(&lam, &x3v, 1)], &[], - Some(&rhs3), + &[(&y3v, 1), (&y1v, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 9 * n - 4, &q3); + let (q3_pos, q3_neg) = split_quotient(q3_abs, q3_neg, n); + write_limbs(witness, os + 11 * n - 4, &q3_pos); + write_limbs(witness, os + 12 * n - 4, &q3_neg); let c3 = compute_ec_verification_carries( &[(&ll, &x1l, 1), (&ll, &xl, -1)], &[(to_i128_limbs(&yl), -1), (to_i128_limbs(&y1l), -1)], &p_l, - &q3, + &q3_pos, + &q3_neg, n, w, + mcs, ); - write_limbs(witness, os + 10 * n - 4, &c3); + write_limbs(witness, os + 13 * n - 4, &c3); } NonNativeEcOp::OnCurve => { let px_val = read_witness_limbs(witness, &inputs[0], w); @@ -688,47 +737,68 @@ impl WitnessBuilderSolver for WitnessBuilder { let pyl = decompose_to_u128_limbs(&py_val, n, w); write_limbs(witness, os, &xsl); - // Eq1: px·px - x_sq = q1·p (always non-negative quotient) - let q1 = signed_quotient_wide( + let a_is_zero = curve_a.iter().all(|&v| v == 0); + // Per-equation max_coeff_sum must match compiler + let mcs_eq1: u64 = 2 + 2 * n as u64; // 1+1+2n + let mcs_eq2: u64 = if a_is_zero { + 3 + 2 * n as u64 // 1+1+1+2n + } else { + 4 + 2 * n as u64 // 1+1+1+1+2n + }; + + // Layout: [x_sq(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 + + // Eq1: px·px - x_sq = q1·p + let (q1_abs, q1_neg) = signed_quotient_wide( &[(&px_val, &px_val, 1)], &[], - Some(&x_sq_val), + &[], + &[(&x_sq_val, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + n, &q1); + let (q1_pos, q1_neg) = split_quotient(q1_abs, q1_neg, n); + write_limbs(witness, os + n, &q1_pos); + write_limbs(witness, os + 2 * n, &q1_neg); let c1 = compute_ec_verification_carries( &[(&pl, &pl, 1)], &[(to_i128_limbs(&xsl), -1)], &p_l, - &q1, + &q1_pos, + &q1_neg, n, w, + mcs_eq1, ); - write_limbs(witness, os + 2 * n, &c1); + write_limbs(witness, os + 3 * n, &c1); // Eq2: py·py - x_sq·px - a·px - b = q2·p - let x_sq_px = mul_mod(&x_sq_val, &px_val, field_modulus_p); - let a_px = mul_mod(curve_a, &px_val, field_modulus_p); - let rhs_val = mod_add( - &mod_add(&x_sq_px, &a_px, field_modulus_p), - curve_b, - field_modulus_p, - ); - let q2 = signed_quotient_wide( + let a_l = decompose_to_u128_limbs(curve_a, n, w); + let b_l = decompose_to_u128_limbs(curve_b, n, w); + + let mut rhs_prods: Vec<(&[u64; 4], &[u64; 4], u64)> = + vec![(&x_sq_val, &px_val, 1)]; + if !a_is_zero { + rhs_prods.push((curve_a, &px_val, 1)); + } + let (q2_abs, q2_neg) = signed_quotient_wide( &[(&py_val, &py_val, 1)], + &rhs_prods, &[], - Some(&rhs_val), + &[(curve_b, 1)], field_modulus_p, n, w, ); - write_limbs(witness, os + 4 * n - 2, &q2); - let a_l = decompose_to_u128_limbs(curve_a, n, w); - let b_l = decompose_to_u128_limbs(curve_b, n, w); - let a_is_zero = curve_a.iter().all(|&v| v == 0); + let (q2_pos, q2_neg) = split_quotient(q2_abs, q2_neg, n); + write_limbs(witness, os + 5 * n - 2, &q2_pos); + write_limbs(witness, os + 6 * n - 2, &q2_neg); + let mut prod_sets: Vec<(&[u128], &[u128], i64)> = vec![(&pyl, &pyl, 1), (&xsl, &pl, -1)]; if !a_is_zero { @@ -738,34 +808,55 @@ impl WitnessBuilderSolver for WitnessBuilder { &prod_sets, &[(to_i128_limbs(&b_l), -1)], &p_l, - &q2, + &q2_pos, + &q2_neg, n, w, + mcs_eq2, ); - write_limbs(witness, os + 5 * n - 2, &c2); + write_limbs(witness, os + 7 * n - 2, &c2); } } } WitnessBuilder::EcScalarMulHint { output_start, - px, - py, + px_limbs, + py_limbs, s_lo, s_hi, curve_a, field_modulus_p, + num_limbs, + limb_bits, } => { + let n = *num_limbs as usize; let scalar = reconstruct_from_halves( &fe_to_bigint(witness[*s_lo].unwrap()), &fe_to_bigint(witness[*s_hi].unwrap()), ); - let px_val = fe_to_bigint(witness[*px].unwrap()); - let py_val = fe_to_bigint(witness[*py].unwrap()); + + let px_val = if n == 1 { + fe_to_bigint(witness[px_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, px_limbs, *limb_bits) + }; + let py_val = if n == 1 { + fe_to_bigint(witness[py_limbs[0]].unwrap()) + } else { + read_witness_limbs(witness, py_limbs, *limb_bits) + }; let (rx, ry) = ec_scalar_mul(&px_val, &py_val, &scalar, curve_a, field_modulus_p); - witness[*output_start] = Some(bigint_to_fe(&rx)); - witness[*output_start + 1] = Some(bigint_to_fe(&ry)); + if n == 1 { + witness[*output_start] = Some(bigint_to_fe(&rx)); + witness[*output_start + 1] = Some(bigint_to_fe(&ry)); + } else { + let rx_limbs = decompose_to_u128_limbs(&rx, n, *limb_bits); + let ry_limbs = decompose_to_u128_limbs(&ry, n, *limb_bits); + write_limbs(witness, *output_start, &rx_limbs); + write_limbs(witness, *output_start + n, &ry_limbs); + } } WitnessBuilder::SelectWitness { output, @@ -797,7 +888,15 @@ impl WitnessBuilderSolver for WitnessBuilder { let n = *num_bits; let skew: u128 = if s_val & 1 == 0 { 1 } else { 0 }; let s_adj = s_val + skew; - let t = (s_adj + ((1u128 << n) - 1)) / 2; + // t = (s_adj + 2^n - 1) / 2 + // Both s_adj and 2^n-1 are odd, so sum is even. + // To avoid u128 overflow when n >= 128, rewrite as: + // t = (s_adj - 1) / 2 + (2^n - 1 + 1) / 2 = (s_adj - 1) / 2 + 2^(n-1) + let t = if n == 0 { + s_adj / 2 + } else { + (s_adj - 1) / 2 + (1u128 << (n - 1)) + }; for i in 0..n { witness[*output_start + i] = Some(FieldElement::from(((t >> i) & 1) as u64)); } diff --git a/provekit/r1cs-compiler/src/lib.rs b/provekit/r1cs-compiler/src/lib.rs index 64e6eb46d..0b9890a8a 100644 --- a/provekit/r1cs-compiler/src/lib.rs +++ b/provekit/r1cs-compiler/src/lib.rs @@ -2,11 +2,11 @@ mod binops; mod constraint_helpers; mod digits; mod memory; -mod msm; +pub mod msm; mod noir_proof_scheme; -mod noir_to_r1cs; +pub mod noir_to_r1cs; mod poseidon2; -mod range_check; +pub mod range_check; mod sha256_compression; mod spread; mod uints; diff --git a/provekit/r1cs-compiler/src/msm/cost_model.rs b/provekit/r1cs-compiler/src/msm/cost_model.rs index be924ecc5..de5edda41 100644 --- a/provekit/r1cs-compiler/src/msm/cost_model.rs +++ b/provekit/r1cs-compiler/src/msm/cost_model.rs @@ -126,37 +126,37 @@ struct HintVerifiedEcCost { } impl HintVerifiedEcCost { - /// point_double: (12N-6)W hint + 5N² products + N constants + 3×3N ltp + /// point_double: (15N-6)W hint + 5N² products + N constants + 3×3N ltp fn point_double(n: usize, limb_bits: u32) -> Self { - let wit = (12 * n - 6) + 5 * n * n + n + 9 * n; + let wit = (15 * n - 6) + 5 * n * n + n + 9 * n; Self { witnesses: wit, - rc_limb: 6 * n + 6 * n, // 6N hint limbs + 3×2N ltp limbs + rc_limb: 9 * n + 6 * n, // 9N hint limbs (3+6 q_pos/q_neg) + 3×2N ltp limbs rc_carry: 3 * (2 * n - 2), // 3 equations × (2N-2) carries - carry_bits: hint_carry_bits(limb_bits, 6 + n as u64, n), + carry_bits: hint_carry_bits(limb_bits, 6 + 2 * n as u64, n), } } - /// point_add: (12N-6)W hint + 4N² products + 3×3N ltp + /// point_add: (15N-6)W hint + 4N² products + 3×3N ltp fn point_add(n: usize, limb_bits: u32) -> Self { - let wit = (12 * n - 6) + 4 * n * n + 9 * n; + let wit = (15 * n - 6) + 4 * n * n + 9 * n; Self { witnesses: wit, - rc_limb: 6 * n + 6 * n, + rc_limb: 9 * n + 6 * n, rc_carry: 3 * (2 * n - 2), - carry_bits: hint_carry_bits(limb_bits, 4 + n as u64, n), + carry_bits: hint_carry_bits(limb_bits, 4 + 2 * n as u64, n), } } - /// on_curve (worst case, a != 0): (7N-4)W hint + 4N² products + 2N + /// on_curve (worst case, a != 0): (9N-4)W hint + 4N² products + 2N /// constants + 3N ltp fn on_curve(n: usize, limb_bits: u32) -> Self { - let wit = (7 * n - 4) + 4 * n * n + 2 * n + 3 * n; + let wit = (9 * n - 4) + 4 * n * n + 2 * n + 3 * n; Self { witnesses: wit, - rc_limb: 3 * n + 2 * n, // 3N hint limbs + 2N ltp limbs + rc_limb: 5 * n + 2 * n, // 5N hint limbs (1+4 q_pos/q_neg) + 2N ltp limbs rc_carry: 2 * (2 * n - 2), // 2 equations × (2N-2) carries - carry_bits: hint_carry_bits(limb_bits, 5 + n as u64, n), + carry_bits: hint_carry_bits(limb_bits, 5 + 2 * n as u64, n), } } @@ -170,7 +170,7 @@ impl HintVerifiedEcCost { /// Carry range check bits for hint-verified EC column equations. fn hint_carry_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - limb_bits + extra_bits + limb_bits + extra_bits + 1 } // --------------------------------------------------------------------------- @@ -185,6 +185,7 @@ fn hint_carry_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { /// - Half-scalar decomposition (DD digits for s1, s2) /// - One mul + one add + one sub for sign handling /// - XOR witnesses (2) + select (num_limbs) +/// - s2 non-zero check: compute_is_zero(3W) + constrain_zero fn scalar_relation_cost( native_field_bits: u32, scalar_bits: usize, @@ -201,7 +202,8 @@ fn scalar_relation_cost( + 2 * n + field_op_witnesses(1, 1, 1, 0, n, false) + 2 - + n; + + n + + 3; // compute_is_zero(s2): inv + product + is_zero // Only n limbs worth of scalar DD digits get range checks; unused digits // are zero-constrained instead (soundness fix for small curves). @@ -306,12 +308,10 @@ fn calculate_msm_witness_cost_hint_verified( // negate_mod_p_multi: 3N witnesses, N range checks (no less_than_p) let negate_wit = 3 * n; - // --- Shared costs (one doubling chain for all points) --- - let shared_doubles = num_windows * w; - let shared_ec_wit = shared_doubles * ec_double.witnesses; - let shared_offset_constants = 2 * n; - - // --- Per-point EC witnesses --- + // --- Per-point EC witnesses (each point has its own doubling chain) --- + let pp_doubles = num_windows * w; // per-point doublings (no longer shared) + let pp_doubles_ec = pp_doubles * ec_double.witnesses; + let pp_offset_constants = 2 * n; // offset limbs per point let pp_table_ec = 2 * (tbl_d * ec_double.witnesses + tbl_a * ec_add.witnesses); let pp_loop_ec = num_windows * 2 * ec_add.witnesses; let pp_skew_ec = 2 * ec_add.witnesses; @@ -323,7 +323,9 @@ fn calculate_msm_witness_cost_hint_verified( let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); - let per_point = pp_table_ec + let per_point = pp_doubles_ec + + pp_offset_constants + + pp_table_ec + pp_loop_ec + pp_skew_ec + pp_oncurve @@ -336,7 +338,7 @@ fn calculate_msm_witness_cost_hint_verified( + per_point_overhead(half_bits, n, sr_witnesses); // --- Shared constants --- - let shared_constants = 3 + shared_offset_constants; // gen_x, gen_y, zero + offset + let shared_constants = 3; // gen_x, gen_y, zero // --- Point accumulation --- let accum = n_points * (ec_add.witnesses + 2 * n) // per-point add + skip select @@ -347,11 +349,8 @@ fn calculate_msm_witness_cost_hint_verified( // --- Range checks --- let mut rc_map: BTreeMap = BTreeMap::new(); - // Shared doublings - ec_double.add_range_checks(shared_doubles, limb_bits, &mut rc_map); - - // Per-point: table doubles + table/loop/skew adds + on-curve - let pp_doubles_count = 2 * tbl_d; + // Per-point: loop doublings + table doubles + table/loop/skew adds + on-curve + let pp_doubles_count = pp_doubles + 2 * tbl_d; let pp_adds_count = 2 * tbl_a + num_windows * 2 + 2; ec_double.add_range_checks(n_points * pp_doubles_count, limb_bits, &mut rc_map); ec_add.add_range_checks(n_points * pp_adds_count, limb_bits, &mut rc_map); @@ -373,7 +372,7 @@ fn calculate_msm_witness_cost_hint_verified( } let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - shared_ec_wit + shared_constants + n_points * per_point + accum + range_check_cost + shared_constants + n_points * per_point + accum + range_check_cost } // --------------------------------------------------------------------------- @@ -381,7 +380,7 @@ fn calculate_msm_witness_cost_hint_verified( // --------------------------------------------------------------------------- /// Generic (single-limb) non-native MSM cost using MultiLimbOps field op -/// chains. +/// chains. Uses per-point accumulators (no shared doublings). #[allow(clippy::too_many_arguments)] fn calculate_msm_witness_cost_generic( native_field_bits: u32, @@ -398,29 +397,22 @@ fn calculate_msm_witness_cost_generic( // point_double: (5 add, 3 sub, 4 mul, 1 inv) // point_add: (1 add, 5 sub, 3 mul, 1 inv) let (tbl_d, tbl_a) = table_build_ops(half_table_size); - let shared_doubles = num_windows * w; - - // --- Shared doubling field ops --- - let shared_add = shared_doubles * 5; - let shared_sub = shared_doubles * 3; - let shared_mul = shared_doubles * 4; - let shared_inv = shared_doubles; - - // --- Per-point field ops: tables + loop adds + skew + on-curve + y-negate --- - let mut pp_add = 2 * (tbl_d * 5 + tbl_a) + num_windows * 2 + 2 + 4; - let mut pp_sub = 2 * (tbl_d * 3 + tbl_a * 5) + num_windows * (2 * 5 + 2) + 2 * 6 + 2; - let mut pp_mul = 2 * (tbl_d * 4 + tbl_a * 3) + num_windows * 2 * 3 + 2 * 3 + 8; - let mut pp_inv = 2 * (tbl_d + tbl_a) + num_windows * 2 + 2; - - let shared_field_ops = - field_op_witnesses(shared_add, shared_sub, shared_mul, shared_inv, n, false); + let pp_loop_doubles = num_windows * w; + + // --- Per-point field ops: loop doubles + tables + loop adds + skew + on-curve + // + y-negate --- + let pp_add = pp_loop_doubles * 5 + 2 * (tbl_d * 5 + tbl_a) + num_windows * 2 + 2 + 4; + let pp_sub = + pp_loop_doubles * 3 + 2 * (tbl_d * 3 + tbl_a * 5) + num_windows * (2 * 5 + 2) + 2 * 6 + 2; + let pp_mul = + pp_loop_doubles * 4 + 2 * (tbl_d * 4 + tbl_a * 3) + num_windows * 2 * 3 + 2 * 3 + 8; + let pp_inv = pp_loop_doubles + 2 * (tbl_d + tbl_a) + num_windows * 2 + 2; + let pp_field_ops = field_op_witnesses(pp_add, pp_sub, pp_mul, pp_inv, n, false); - let pp_doubles = 2 * tbl_d; + let pp_doubles = pp_loop_doubles + 2 * tbl_d; let pp_negate_zeros = (4 + 2 * num_windows) * n; - let shared_constants_glv = shared_doubles * n + 2 * n; - let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros; - + let pp_constants = pp_doubles * n + 4 * n + pp_negate_zeros + 2 * n; // +2N for offset limbs let pp_table_selects = num_windows * 2 * half_table_size.saturating_sub(1) * 2 * n; let pp_xor = num_windows * 2 * 2 * w.saturating_sub(1); let pp_signed_y_selects = num_windows * 2 * n; @@ -434,7 +426,7 @@ fn calculate_msm_witness_cost_generic( let per_point = pp_field_ops + pp_constants + pp_selects + per_point_overhead(half_bits, n, sr_witnesses); - let shared_constants = 3 + 2 * n; + let shared_constants = 3; // gen_x, gen_y, zero // --- Point accumulation --- let pa_cost = field_op_witnesses(1, 5, 3, 1, n, false); @@ -448,17 +440,6 @@ fn calculate_msm_witness_cost_generic( // --- Range checks --- let mut rc_map: BTreeMap = BTreeMap::new(); - add_field_op_range_checks( - shared_add, - shared_sub, - shared_mul, - shared_inv, - n, - limb_bits, - curve_modulus_bits, - false, - &mut rc_map, - ); add_field_op_range_checks( n_points * pp_add, n_points * pp_sub, @@ -490,12 +471,7 @@ fn calculate_msm_witness_cost_generic( let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - shared_field_ops - + shared_constants_glv - + n_points * per_point - + shared_constants - + accum - + range_check_cost + n_points * per_point + shared_constants + accum + range_check_cost } // --------------------------------------------------------------------------- @@ -510,7 +486,8 @@ fn calculate_msm_witness_cost_generic( /// - `verify_on_curve_native`: 2W (2 products) /// - No multi-limb arithmetic → zero EC-related range checks /// -/// Uses merged-loop optimization: all points share a single doubling per bit. +/// Uses per-point accumulators: each point has its own doubling chain and +/// identity check for soundness. fn calculate_msm_witness_cost_native( native_field_bits: u32, n_points: usize, @@ -522,17 +499,20 @@ fn calculate_msm_witness_cost_native( let y_negate = 6; // 2 × 3W (neg_y, y_eff, neg_y_eff) let (sr_wit, sr_rc) = scalar_relation_cost(native_field_bits, scalar_bits); + // Per point per bit: 4W (double) + 2×(1W select + 3W add) = 12W + let ec_loop_pp = half_bits * 12; + // Skew correction: 2 branches × (3W add + 2W select) = 10W per point + let skew_pp = 10; + // Offset constants per point + let offset_pp = 2; + let per_point = on_curve + y_negate + 2 * (half_bits + 1) // scalar bit decomposition + DETECT_SKIP_WIT + SANITIZE_WIT + EC_HINT_WIT + GLV_HINT_WIT - + sr_wit; + + sr_wit + + ec_loop_pp + skew_pp + offset_pp; - let shared_constants = 5; // gen_x, gen_y, zero, offset_x, offset_y - - // Per bit: 4W (shared double) + n_points × 8W (2×(1W select + 3W add)) - let ec_loop = half_bits * (4 + 8 * n_points); - // Skew correction: 2 branches × (3W add + 2W select) = 10W per point - let skew = n_points * 10; + let shared_constants = 3; // gen_x, gen_y, zero let accum = 2 // initial acc constants + n_points * 5 // add(3W) + skip_select(2W) @@ -546,7 +526,7 @@ fn calculate_msm_witness_cost_native( } let range_check_cost = crate::range_check::estimate_range_check_cost(&rc_map); - n_points * per_point + shared_constants + ec_loop + skew + accum + range_check_cost + n_points * per_point + shared_constants + accum + range_check_cost } // --------------------------------------------------------------------------- diff --git a/provekit/r1cs-compiler/src/msm/curve.rs b/provekit/r1cs-compiler/src/msm/curve/mod.rs similarity index 74% rename from provekit/r1cs-compiler/src/msm/curve.rs rename to provekit/r1cs-compiler/src/msm/curve/mod.rs index 278f7de47..3f2c75d20 100644 --- a/provekit/r1cs-compiler/src/msm/curve.rs +++ b/provekit/r1cs-compiler/src/msm/curve/mod.rs @@ -3,6 +3,8 @@ use { provekit_common::FieldElement, }; +mod u256_arith; + pub struct CurveParams { pub field_modulus_p: [u64; 4], pub curve_order_n: [u64; 4], @@ -298,181 +300,61 @@ pub fn grumpkin_params() -> CurveParams { } } -/// 256-bit modular arithmetic for compile-time EC point computations. -/// Only used to precompute accumulated offset points; not performance-critical. -mod u256_arith { - type U256 = [u64; 4]; - - /// Returns true if a >= b. - fn gte(a: &U256, b: &U256) -> bool { - for i in (0..4).rev() { - if a[i] > b[i] { - return true; - } - if a[i] < b[i] { - return false; - } - } - true // equal - } - - /// a + b, returns (result, carry). - fn add(a: &U256, b: &U256) -> (U256, bool) { - let mut result = [0u64; 4]; - let mut carry = 0u128; - for i in 0..4 { - carry += a[i] as u128 + b[i] as u128; - result[i] = carry as u64; - carry >>= 64; - } - (result, carry != 0) - } - - /// a - b, returns (result, borrow). - fn sub(a: &U256, b: &U256) -> (U256, bool) { - let mut result = [0u64; 4]; - let mut borrow = false; - for i in 0..4 { - let (d1, b1) = a[i].overflowing_sub(b[i]); - let (d2, b2) = d1.overflowing_sub(borrow as u64); - result[i] = d2; - borrow = b1 || b2; - } - (result, borrow) - } - - /// (a + b) mod p. - pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { - let (s, overflow) = add(a, b); - if overflow || gte(&s, p) { - sub(&s, p).0 - } else { - s - } - } - - /// (a - b) mod p. - fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { - let (d, borrow) = sub(a, b); - if borrow { - add(&d, p).0 - } else { - d - } - } - - /// Schoolbook multiplication producing 512-bit result. - fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { - let mut result = [0u64; 8]; - for i in 0..4 { - let mut carry = 0u128; - for j in 0..4 { - let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; - result[i + j] = prod as u64; - carry = prod >> 64; - } - result[i + 4] = result[i + 4].wrapping_add(carry as u64); - } - result - } - - /// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long - /// division. - fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { - let mut total_bits = 0; - for i in (0..8).rev() { - if a[i] != 0 { - total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); - break; - } - } - if total_bits == 0 { - return [0; 4]; - } - - let mut r = [0u64; 4]; - for bit_idx in (0..total_bits).rev() { - // Left shift r by 1 - let overflow = r[3] >> 63; - for j in (1..4).rev() { - r[j] = (r[j] << 1) | (r[j - 1] >> 63); - } - r[0] <<= 1; - - // Insert current bit of a - let word = bit_idx / 64; - let bit = bit_idx % 64; - r[0] |= (a[word] >> bit) & 1; - - // If r >= p (or overflow from shift), subtract p - if overflow != 0 || gte(&r, p) { - r = sub(&r, p).0; - } - } - r - } - - /// (a * b) mod p. - pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { - let wide = mul_wide(a, b); - mod_reduce_wide(&wide, p) - } - - /// a^exp mod p using square-and-multiply. - fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { - let mut highest_bit = 0; - for i in (0..4).rev() { - if exp[i] != 0 { - highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); - break; - } - } - if highest_bit == 0 { - return [1, 0, 0, 0]; - } - - let mut result: U256 = [1, 0, 0, 0]; - let mut base = *base; - for bit_idx in 0..highest_bit { - let word = bit_idx / 64; - let bit = bit_idx % 64; - if (exp[word] >> bit) & 1 == 1 { - result = mod_mul(&result, &base, p); - } - base = mod_mul(&base, &base, p); - } - result - } - - /// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. - fn mod_inv(a: &U256, p: &U256) -> U256 { - let two: U256 = [2, 0, 0, 0]; - let exp = sub(p, &two).0; - mod_pow(a, &exp, p) - } - - /// EC point doubling on y^2 = x^3 + ax + b. - pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { - // lambda = (3*x^2 + a) / (2*y) - let x_sq = mod_mul(x, x, p); - let two_x_sq = mod_add(&x_sq, &x_sq, p); - let three_x_sq = mod_add(&two_x_sq, &x_sq, p); - let num = mod_add(&three_x_sq, a, p); - let two_y = mod_add(y, y, p); - let denom_inv = mod_inv(&two_y, p); - let lambda = mod_mul(&num, &denom_inv, p); - - // x3 = lambda^2 - 2*x - let lambda_sq = mod_mul(&lambda, &lambda, p); - let two_x = mod_add(x, x, p); - let x3 = mod_sub(&lambda_sq, &two_x, p); - - // y3 = lambda * (x - x3) - y - let x_minus_x3 = mod_sub(x, &x3, p); - let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); - let y3 = mod_sub(&lambda_dx, y, p); - - (x3, y3) +pub fn secp256r1_params() -> CurveParams { + CurveParams { + field_modulus_p: [ + 0xffffffffffffffff_u64, + 0xffffffff_u64, + 0x0_u64, + 0xffffffff00000001_u64, + ], + curve_order_n: [ + 0xf3b9cac2fc632551_u64, + 0xbce6faada7179e84_u64, + 0xffffffffffffffff_u64, + 0xffffffff00000000_u64, + ], + curve_a: [ + 0xfffffffffffffffc_u64, + 0x00000000ffffffff_u64, + 0x0000000000000000_u64, + 0xffffffff00000001_u64, + ], + curve_b: [ + 0x3bce3c3e27d2604b_u64, + 0x651d06b0cc53b0f6_u64, + 0xb3ebbd55769886bc_u64, + 0x5ac635d8aa3a93e7_u64, + ], + generator: ( + [ + 0xf4a13945d898c296_u64, + 0x77037d812deb33a0_u64, + 0xf8bce6e563a440f2_u64, + 0x6b17d1f2e12c4247_u64, + ], + [ + 0xcbb6406837bf51f5_u64, + 0x2bce33576b315ece_u64, + 0x8ee7eb4a7c0f9e16_u64, + 0x4fe342e2fe1a7f9b_u64, + ], + ), + // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) + offset_point: ( + [ + 0x57c84fc9d789bd85_u64, + 0xfc35ff7dc297eac3_u64, + 0xfb982fd588c6766e_u64, + 0x447d739beedb5e67_u64, + ], + [ + 0x0c7e33c972e25b32_u64, + 0x3d349b95a7fae500_u64, + 0xe12e9d953a4aaff7_u64, + 0x2d4825ab834131ee_u64, + ], + ), } } @@ -574,61 +456,3 @@ mod tests { assert_eq!(val2, back2, "roundtrip failed for offset x"); } } - -pub fn secp256r1_params() -> CurveParams { - CurveParams { - field_modulus_p: [ - 0xffffffffffffffff_u64, - 0xffffffff_u64, - 0x0_u64, - 0xffffffff00000001_u64, - ], - curve_order_n: [ - 0xf3b9cac2fc632551_u64, - 0xbce6faada7179e84_u64, - 0xffffffffffffffff_u64, - 0xffffffff00000000_u64, - ], - curve_a: [ - 0xfffffffffffffffc_u64, - 0x00000000ffffffff_u64, - 0x0000000000000000_u64, - 0xffffffff00000001_u64, - ], - curve_b: [ - 0x3bce3c3e27d2604b_u64, - 0x651d06b0cc53b0f6_u64, - 0xb3ebbd55769886bc_u64, - 0x5ac635d8aa3a93e7_u64, - ], - generator: ( - [ - 0xf4a13945d898c296_u64, - 0x77037d812deb33a0_u64, - 0xf8bce6e563a440f2_u64, - 0x6b17d1f2e12c4247_u64, - ], - [ - 0xcbb6406837bf51f5_u64, - 0x2bce33576b315ece_u64, - 0x8ee7eb4a7c0f9e16_u64, - 0x4fe342e2fe1a7f9b_u64, - ], - ), - // Offset point = [2^128]G (large offset avoids collisions with small multiples of G) - offset_point: ( - [ - 0x57c84fc9d789bd85_u64, - 0xfc35ff7dc297eac3_u64, - 0xfb982fd588c6766e_u64, - 0x447d739beedb5e67_u64, - ], - [ - 0x0c7e33c972e25b32_u64, - 0x3d349b95a7fae500_u64, - 0xe12e9d953a4aaff7_u64, - 0x2d4825ab834131ee_u64, - ], - ), - } -} diff --git a/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs b/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs new file mode 100644 index 000000000..e2bcea5e3 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/curve/u256_arith.rs @@ -0,0 +1,176 @@ +//! 256-bit modular arithmetic for compile-time EC point computations. +//! Only used to precompute accumulated offset points; not performance-critical. + +pub(super) type U256 = [u64; 4]; + +/// Returns true if a >= b. +fn gte(a: &U256, b: &U256) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal +} + +/// a + b, returns (result, carry). +fn add(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut carry = 0u128; + for i in 0..4 { + carry += a[i] as u128 + b[i] as u128; + result[i] = carry as u64; + carry >>= 64; + } + (result, carry != 0) +} + +/// a - b, returns (result, borrow). +fn sub(a: &U256, b: &U256) -> (U256, bool) { + let mut result = [0u64; 4]; + let mut borrow = false; + for i in 0..4 { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + result[i] = d2; + borrow = b1 || b2; + } + (result, borrow) +} + +/// (a + b) mod p. +pub fn mod_add(a: &U256, b: &U256, p: &U256) -> U256 { + let (s, overflow) = add(a, b); + if overflow || gte(&s, p) { + sub(&s, p).0 + } else { + s + } +} + +/// (a - b) mod p. +fn mod_sub(a: &U256, b: &U256, p: &U256) -> U256 { + let (d, borrow) = sub(a, b); + if borrow { + add(&d, p).0 + } else { + d + } +} + +/// Schoolbook multiplication producing 512-bit result. +fn mul_wide(a: &U256, b: &U256) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + result[i + j] as u128 + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = result[i + 4].wrapping_add(carry as u64); + } + result +} + +/// Reduce a 512-bit value mod a 256-bit prime using bit-by-bit long +/// division. +fn mod_reduce_wide(a: &[u64; 8], p: &U256) -> U256 { + let mut total_bits = 0; + for i in (0..8).rev() { + if a[i] != 0 { + total_bits = i * 64 + (64 - a[i].leading_zeros() as usize); + break; + } + } + if total_bits == 0 { + return [0; 4]; + } + + let mut r = [0u64; 4]; + for bit_idx in (0..total_bits).rev() { + // Left shift r by 1 + let overflow = r[3] >> 63; + for j in (1..4).rev() { + r[j] = (r[j] << 1) | (r[j - 1] >> 63); + } + r[0] <<= 1; + + // Insert current bit of a + let word = bit_idx / 64; + let bit = bit_idx % 64; + r[0] |= (a[word] >> bit) & 1; + + // If r >= p (or overflow from shift), subtract p + if overflow != 0 || gte(&r, p) { + r = sub(&r, p).0; + } + } + r +} + +/// (a * b) mod p. +pub fn mod_mul(a: &U256, b: &U256, p: &U256) -> U256 { + let wide = mul_wide(a, b); + mod_reduce_wide(&wide, p) +} + +/// a^exp mod p using square-and-multiply. +fn mod_pow(base: &U256, exp: &U256, p: &U256) -> U256 { + let mut highest_bit = 0; + for i in (0..4).rev() { + if exp[i] != 0 { + highest_bit = i * 64 + (64 - exp[i].leading_zeros() as usize); + break; + } + } + if highest_bit == 0 { + return [1, 0, 0, 0]; + } + + let mut result: U256 = [1, 0, 0, 0]; + let mut base = *base; + for bit_idx in 0..highest_bit { + let word = bit_idx / 64; + let bit = bit_idx % 64; + if (exp[word] >> bit) & 1 == 1 { + result = mod_mul(&result, &base, p); + } + base = mod_mul(&base, &base, p); + } + result +} + +/// a^(-1) mod p via Fermat's little theorem: a^(p-2) mod p. +fn mod_inv(a: &U256, p: &U256) -> U256 { + let two: U256 = [2, 0, 0, 0]; + let exp = sub(p, &two).0; + mod_pow(a, &exp, p) +} + +/// EC point doubling on y^2 = x^3 + ax + b. +pub fn ec_point_double(x: &U256, y: &U256, a: &U256, p: &U256) -> (U256, U256) { + // lambda = (3*x^2 + a) / (2*y) + let x_sq = mod_mul(x, x, p); + let two_x_sq = mod_add(&x_sq, &x_sq, p); + let three_x_sq = mod_add(&two_x_sq, &x_sq, p); + let num = mod_add(&three_x_sq, a, p); + let two_y = mod_add(y, y, p); + let denom_inv = mod_inv(&two_y, p); + let lambda = mod_mul(&num, &denom_inv, p); + + // x3 = lambda^2 - 2*x + let lambda_sq = mod_mul(&lambda, &lambda, p); + let two_x = mod_add(x, x, p); + let x3 = mod_sub(&lambda_sq, &two_x, p); + + // y3 = lambda * (x - x3) - y + let x_minus_x3 = mod_sub(x, &x3, p); + let lambda_dx = mod_mul(&lambda, &x_minus_x3, p); + let y3 = mod_sub(&lambda_dx, y, p); + + (x3, y3) +} diff --git a/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs index fdd6c1756..836e98339 100644 --- a/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs +++ b/provekit/r1cs-compiler/src/msm/ec_points/hints_non_native.rs @@ -107,7 +107,7 @@ fn less_than_p_check_vec( /// Compute carry range bits for hint-verified column equations. fn carry_range_bits(limb_bits: u32, max_coeff_sum: u64, n: usize) -> u32 { let extra_bits = ((max_coeff_sum as f64 * n as f64).log2().ceil() as u32) + 1; - limb_bits + extra_bits + limb_bits + extra_bits + 1 } /// Soundness check: verify that merged column equations fit the native field. @@ -120,21 +120,34 @@ fn check_column_equation_fits(limb_bits: u32, max_coeff_sum: u64, n: usize, op_n ); } +/// Merge terms with the same witness index by summing their coefficients. +fn merge_terms(terms: &[(FieldElement, usize)]) -> Vec<(FieldElement, usize)> { + use {ark_ff::Zero, std::collections::HashMap}; + let mut map: HashMap = HashMap::new(); + for &(coeff, idx) in terms { + *map.entry(idx).or_insert_with(FieldElement::zero) += coeff; + } + let mut result: Vec<(FieldElement, usize)> = map.into_iter().map(|(idx, c)| (c, idx)).collect(); + result.sort_by_key(|&(_, idx)| idx); + result +} + /// Emit schoolbook column equations for a merged verification equation. /// -/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k = q·p (mod p, as integers) +/// Verifies: Σ (coeff_i × A_i ⊗ B_i) + Σ linear_k + Σ p\[i\]*q_neg\[j\] +/// = Σ p\[i\]*q_pos\[j\] + carry_chain (as integers) /// /// `product_sets`: each (products_2d, coefficient) where products_2d\[i\]\[j\] /// is the witness index for a\[i\]*b\[j\]. -/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms -/// (limb_witnesses has N entries, zero-padded). -/// `q_witnesses`: quotient limbs (N entries). +/// `linear_limbs`: each (limb_witnesses, coefficient) for non-product terms. +/// `q_pos_witnesses`, `q_neg_witnesses`: split quotient limbs (N entries each). /// `carry_witnesses`: unsigned-offset carry witnesses (2N-2 entries). fn emit_schoolbook_column_equations( compiler: &mut NoirToR1CSCompiler, product_sets: &[(&[Vec], FieldElement)], // (products[i][j], coeff) linear_limbs: &[(&[usize], FieldElement)], // (limb_witnesses, coeff) - q_witnesses: &[usize], + q_pos_witnesses: &[usize], + q_neg_witnesses: &[usize], carry_witnesses: &[usize], p_limbs: &[FieldElement], n: usize, @@ -154,7 +167,7 @@ fn emit_schoolbook_column_equations( let num_columns = 2 * n - 1; for k in 0..num_columns { - // LHS: Σ coeff * products[i][j] for i+j=k + carry_in + offset + // LHS: Σ coeff * products[i][j] for i+j=k + Σ p[i]*q_neg[j] + carry_in + offset let mut lhs_terms: Vec<(FieldElement, usize)> = Vec::new(); for &(products, coeff) in product_sets { @@ -173,6 +186,14 @@ fn emit_schoolbook_column_equations( } } + // Add p*q_neg on the LHS (positive side) + for i in 0..n { + let j_val = k as isize - i as isize; + if j_val >= 0 && (j_val as usize) < n { + lhs_terms.push((p_limbs[i], q_neg_witnesses[j_val as usize])); + } + } + // Add carry_in and offset if k > 0 { lhs_terms.push((FieldElement::ONE, carry_witnesses[k - 1])); @@ -181,12 +202,12 @@ fn emit_schoolbook_column_equations( lhs_terms.push((offset_w, w1)); } - // RHS: Σ p[i]*q[j] for i+j=k + carry_out * W (or offset at last column) + // RHS: Σ p[i]*q_pos[j] for i+j=k + carry_out * W (or offset at last column) let mut rhs_terms: Vec<(FieldElement, usize)> = Vec::new(); for i in 0..n { let j_val = k as isize - i as isize; if j_val >= 0 && (j_val as usize) < n { - rhs_terms.push((p_limbs[i], q_witnesses[j_val as usize])); + rhs_terms.push((p_limbs[i], q_pos_witnesses[j_val as usize])); } } @@ -197,9 +218,12 @@ fn emit_schoolbook_column_equations( rhs_terms.push((offset_w, w1)); } + // Merge terms with the same witness index (products may share cached witnesses) + let lhs_merged = merge_terms(&lhs_terms); + let rhs_merged = merge_terms(&rhs_terms); compiler .r1cs - .add_constraint(&lhs_terms, &[(FieldElement::ONE, w1)], &rhs_terms); + .add_constraint(&lhs_merged, &[(FieldElement::ONE, w1)], &rhs_merged); } } @@ -234,9 +258,9 @@ pub fn verify_on_curve_non_native( let a_is_zero = params.curve_a_raw.iter().all(|&v| v == 0); let max_coeff_sum: u64 = if a_is_zero { - 4 + n as u64 + 4 + 2 * n as u64 } else { - 5 + n as u64 + 5 + 2 * n as u64 }; check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "On-curve"); @@ -253,22 +277,27 @@ pub fn verify_on_curve_non_native( num_limbs: n as u32, }); - // Parse hint layout: [x_sq(N), q1(N), c1(2N-2), q2(N), c2(2N-2)] + // Parse hint layout: [x_sq(N), q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2)] + // Total: 9N-4 let x_sq = witness_range(os, n); - let q1 = witness_range(os + n, n); - let c1 = witness_range(os + 2 * n, 2 * n - 2); - let q2 = witness_range(os + 4 * n - 2, n); - let c2 = witness_range(os + 5 * n - 2, 2 * n - 2); + let q1_pos = witness_range(os + n, n); + let q1_neg = witness_range(os + 2 * n, n); + let c1 = witness_range(os + 3 * n, 2 * n - 2); + let q2_pos = witness_range(os + 5 * n - 2, n); + let q2_neg = witness_range(os + 6 * n - 2, n); + let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); // Eq1: px·px - x_sq = q1·p let prod_px_px = make_products(compiler, &px.as_slice()[..n], &px.as_slice()[..n]); - let max_coeff_eq1: u64 = 1 + 1 + n as u64; + let max_coeff_eq1: u64 = 1 + 1 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[(&prod_px_px, FieldElement::ONE)], &[(&x_sq, -FieldElement::ONE)], - &q1, + &q1_pos, + &q1_neg, &c1, ¶ms.p_limbs, n, @@ -282,7 +311,7 @@ pub fn verify_on_curve_non_native( let b_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_b_limbs[..n]); if a_is_zero { - let max_coeff_eq2: u64 = 1 + 1 + 1 + n as u64; + let max_coeff_eq2: u64 = 1 + 1 + 1 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[ @@ -290,7 +319,8 @@ pub fn verify_on_curve_non_native( (&prod_xsq_px, -FieldElement::ONE), ], &[(&b_limbs, -FieldElement::ONE)], - &q2, + &q2_pos, + &q2_neg, &c2, ¶ms.p_limbs, n, @@ -301,7 +331,7 @@ pub fn verify_on_curve_non_native( let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); let prod_a_px = make_products(compiler, &a_limbs, &px.as_slice()[..n]); - let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + n as u64; + let max_coeff_eq2: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[ @@ -310,7 +340,8 @@ pub fn verify_on_curve_non_native( (&prod_a_px, -FieldElement::ONE), ], &[(&b_limbs, -FieldElement::ONE)], - &q2, + &q2_pos, + &q2_neg, &c2, ¶ms.p_limbs, n, @@ -323,7 +354,7 @@ pub fn verify_on_curve_non_native( let crb = carry_range_bits(params.limb_bits, max_coeff_sum, n); range_check_limbs_and_carries( range_checks, - &[&x_sq, &q1, &q2], + &[&x_sq, &q1_pos, &q1_neg, &q2_pos, &q2_neg], &[&c1, &c2], params.limb_bits, crb, @@ -353,7 +384,7 @@ pub fn point_double_verified_non_native( let n = params.num_limbs; assert!(n >= 2, "hint-verified non-native requires n >= 2"); - let max_coeff_sum: u64 = 2 + 3 + 1 + n as u64; // λy(2) + xx(3) + a(1) + pq(N) + let max_coeff_sum: u64 = 2 + 3 + 1 + 2 * n as u64; // λy(2) + xx(3) + a(1) + pq_pos(N) + pq_neg(N) check_column_equation_fits(params.limb_bits, max_coeff_sum, n, "Merged EC double"); // Allocate hint @@ -369,17 +400,23 @@ pub fn point_double_verified_non_native( num_limbs: n as u32, }); - // Parse hint layout: [lambda(N), x3(N), y3(N), q1(N), c1(2N-2), q2(N), - // c2(2N-2), q3(N), c3(2N-2)] + // Parse hint layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 let lambda = witness_range(os, n); let x3 = witness_range(os + n, n); let y3 = witness_range(os + 2 * n, n); - let q1 = witness_range(os + 3 * n, n); - let c1 = witness_range(os + 4 * n, 2 * n - 2); - let q2 = witness_range(os + 6 * n - 2, n); - let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); - let q3 = witness_range(os + 9 * n - 4, n); - let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + let q1_pos = witness_range(os + 3 * n, n); + let q1_neg = witness_range(os + 4 * n, n); + let c1 = witness_range(os + 5 * n, 2 * n - 2); + let q2_pos = witness_range(os + 7 * n - 2, n); + let q2_neg = witness_range(os + 8 * n - 2, n); + let c2 = witness_range(os + 9 * n - 2, 2 * n - 2); + let q3_pos = witness_range(os + 11 * n - 4, n); + let q3_neg = witness_range(os + 12 * n - 4, n); + let c3 = witness_range(os + 13 * n - 4, 2 * n - 2); let px_s = &px.as_slice()[..n]; let py_s = &py.as_slice()[..n]; @@ -389,7 +426,7 @@ pub fn point_double_verified_non_native( let prod_px_px = make_products(compiler, px_s, px_s); let a_limbs = allocate_pinned_constant_limbs(compiler, ¶ms.curve_a_limbs[..n]); - let max_coeff_eq1: u64 = 2 + 3 + 1 + n as u64; + let max_coeff_eq1: u64 = 2 + 3 + 1 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[ @@ -397,7 +434,8 @@ pub fn point_double_verified_non_native( (&prod_px_px, -FieldElement::from(3u64)), ], &[(&a_limbs, -FieldElement::ONE)], - &q1, + &q1_pos, + &q1_neg, &c1, ¶ms.p_limbs, n, @@ -408,12 +446,13 @@ pub fn point_double_verified_non_native( // Eq2: lambda² - x3 - 2*px = q2*p let prod_lam_lam = make_products(compiler, &lambda, &lambda); - let max_coeff_eq2: u64 = 1 + 1 + 2 + n as u64; + let max_coeff_eq2: u64 = 1 + 1 + 2 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[(&prod_lam_lam, FieldElement::ONE)], &[(&x3, -FieldElement::ONE), (px_s, -FieldElement::from(2u64))], - &q2, + &q2_pos, + &q2_neg, &c2, ¶ms.p_limbs, n, @@ -425,7 +464,7 @@ pub fn point_double_verified_non_native( let prod_lam_px = make_products(compiler, &lambda, px_s); let prod_lam_x3 = make_products(compiler, &lambda, &x3); - let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + n as u64; + let max_coeff_eq3: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; emit_schoolbook_column_equations( compiler, &[ @@ -433,7 +472,8 @@ pub fn point_double_verified_non_native( (&prod_lam_x3, -FieldElement::ONE), ], &[(&y3, -FieldElement::ONE), (py_s, -FieldElement::ONE)], - &q3, + &q3_pos, + &q3_neg, &c3, ¶ms.p_limbs, n, @@ -442,12 +482,14 @@ pub fn point_double_verified_non_native( ); // Range checks on hint outputs - // max_coeff across eqs: Eq1 = 6+N, Eq2 = 4+N, Eq3 = 4+N → worst = 6+N - let max_coeff_carry = 6u64 + n as u64; + // max_coeff across eqs: Eq1 = 6+2N, Eq2 = 4+2N, Eq3 = 4+2N → worst = 6+2N + let max_coeff_carry = 6u64 + 2 * n as u64; let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); range_check_limbs_and_carries( range_checks, - &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[ + &lambda, &x3, &y3, &q1_pos, &q1_neg, &q2_pos, &q2_neg, &q3_pos, &q3_neg, + ], &[&c1, &c2, &c3], params.limb_bits, crb, @@ -483,7 +525,7 @@ pub fn point_add_verified_non_native( let n = params.num_limbs; assert!(n >= 2, "hint-verified non-native requires n >= 2"); - let max_coeff: u64 = 1 + 1 + 1 + 1 + n as u64; // all 3 eqs: 1+1+1+1+N + let max_coeff: u64 = 1 + 1 + 1 + 1 + 2 * n as u64; // all 3 eqs: 1+1+1+1+2N check_column_equation_fits(params.limb_bits, max_coeff, n, "EC add"); let os = compiler.num_witnesses(); @@ -503,15 +545,23 @@ pub fn point_add_verified_non_native( num_limbs: n as u32, }); + // Parse hint layout: [lambda(N), x3(N), y3(N), + // q1_pos(N), q1_neg(N), c1(2N-2), + // q2_pos(N), q2_neg(N), c2(2N-2), + // q3_pos(N), q3_neg(N), c3(2N-2)] + // Total: 15N-6 let lambda = witness_range(os, n); let x3 = witness_range(os + n, n); let y3 = witness_range(os + 2 * n, n); - let q1 = witness_range(os + 3 * n, n); - let c1 = witness_range(os + 4 * n, 2 * n - 2); - let q2 = witness_range(os + 6 * n - 2, n); - let c2 = witness_range(os + 7 * n - 2, 2 * n - 2); - let q3 = witness_range(os + 9 * n - 4, n); - let c3 = witness_range(os + 10 * n - 4, 2 * n - 2); + let q1_pos = witness_range(os + 3 * n, n); + let q1_neg = witness_range(os + 4 * n, n); + let c1 = witness_range(os + 5 * n, 2 * n - 2); + let q2_pos = witness_range(os + 7 * n - 2, n); + let q2_neg = witness_range(os + 8 * n - 2, n); + let c2 = witness_range(os + 9 * n - 2, 2 * n - 2); + let q3_pos = witness_range(os + 11 * n - 4, n); + let q3_neg = witness_range(os + 12 * n - 4, n); + let c3 = witness_range(os + 13 * n - 4, 2 * n - 2); let x1_s = &x1.as_slice()[..n]; let y1_s = &y1.as_slice()[..n]; @@ -529,7 +579,8 @@ pub fn point_add_verified_non_native( (&prod_lam_x1, -FieldElement::ONE), ], &[(y2_s, -FieldElement::ONE), (y1_s, FieldElement::ONE)], - &q1, + &q1_pos, + &q1_neg, &c1, ¶ms.p_limbs, n, @@ -548,7 +599,8 @@ pub fn point_add_verified_non_native( (x1_s, -FieldElement::ONE), (x2_s, -FieldElement::ONE), ], - &q2, + &q2_pos, + &q2_neg, &c2, ¶ms.p_limbs, n, @@ -567,7 +619,8 @@ pub fn point_add_verified_non_native( (&prod_lam_x3, -FieldElement::ONE), ], &[(&y3, -FieldElement::ONE), (y1_s, -FieldElement::ONE)], - &q3, + &q3_pos, + &q3_neg, &c3, ¶ms.p_limbs, n, @@ -576,12 +629,14 @@ pub fn point_add_verified_non_native( ); // Range checks - // max_coeff across all 3 eqs = 4+N - let max_coeff_carry = 4u64 + n as u64; + // max_coeff across all 3 eqs = 4+2N + let max_coeff_carry = 4u64 + 2 * n as u64; let crb = carry_range_bits(params.limb_bits, max_coeff_carry, n); range_check_limbs_and_carries( range_checks, - &[&lambda, &x3, &y3, &q1, &q2, &q3], + &[ + &lambda, &x3, &y3, &q1_pos, &q1_neg, &q2_pos, &q2_neg, &q3_pos, &q3_neg, + ], &[&c1, &c2, &c3], params.limb_bits, crb, diff --git a/provekit/r1cs-compiler/src/msm/limbs.rs b/provekit/r1cs-compiler/src/msm/limbs.rs new file mode 100644 index 000000000..46d3bfdb5 --- /dev/null +++ b/provekit/r1cs-compiler/src/msm/limbs.rs @@ -0,0 +1,96 @@ +//! `Limbs`: fixed-capacity, `Copy` array of witness indices. + +// --------------------------------------------------------------------------- +// Limbs: fixed-capacity, Copy array of witness indices +// --------------------------------------------------------------------------- + +/// Maximum number of limbs supported. Covers all practical field sizes +/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). +pub const MAX_LIMBS: usize = 32; + +/// A fixed-capacity array of witness indices, indexed by limb position. +/// +/// This type is `Copy`, so it can be passed by value without requiring +/// const generics or dispatch macros. The runtime `len` field tracks how +/// many limbs are actually in use. +#[derive(Clone, Copy)] +pub struct Limbs { + data: [usize; MAX_LIMBS], + len: usize, +} + +impl Limbs { + /// Sentinel value for uninitialized limb slots. Using `usize::MAX` + /// ensures accidental use of an unfilled slot indexes an absurdly + /// large witness, causing an immediate out-of-bounds panic. + const UNINIT: usize = usize::MAX; + + /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. + pub fn new(len: usize) -> Self { + assert!( + len > 0 && len <= MAX_LIMBS, + "limb count must be 1..={MAX_LIMBS}, got {len}" + ); + Self { + data: [Self::UNINIT; MAX_LIMBS], + len, + } + } + + /// Create a single-limb `Limbs` wrapping one witness index. + pub fn single(value: usize) -> Self { + let mut l = Self { + data: [Self::UNINIT; MAX_LIMBS], + len: 1, + }; + l.data[0] = value; + l + } + + /// View the active limbs as a slice. + pub fn as_slice(&self) -> &[usize] { + &self.data[..self.len] + } + + /// Number of active limbs. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.len + } +} + +impl std::fmt::Debug for Limbs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.as_slice().iter()).finish() + } +} + +impl PartialEq for Limbs { + fn eq(&self, other: &Self) -> bool { + self.len == other.len && self.data[..self.len] == other.data[..other.len] + } +} +impl Eq for Limbs {} + +impl std::ops::Index for Limbs { + type Output = usize; + fn index(&self, i: usize) -> &usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &self.data[i] + } +} + +impl std::ops::IndexMut for Limbs { + fn index_mut(&mut self, i: usize) -> &mut usize { + debug_assert!( + i < self.len, + "Limbs index {i} out of bounds (len={})", + self.len + ); + &mut self.data[i] + } +} diff --git a/provekit/r1cs-compiler/src/msm/mod.rs b/provekit/r1cs-compiler/src/msm/mod.rs index 816bbb929..19c161a83 100644 --- a/provekit/r1cs-compiler/src/msm/mod.rs +++ b/provekit/r1cs-compiler/src/msm/mod.rs @@ -1,6 +1,7 @@ -pub(crate) mod cost_model; -pub(crate) mod curve; +pub mod cost_model; +pub mod curve; pub(crate) mod ec_points; +mod limbs; pub(crate) mod multi_limb_arith; pub(crate) mod multi_limb_ops; mod native; @@ -10,115 +11,21 @@ mod scalar_relation; #[cfg(test)] mod tests; -// Re-export sanitize helpers so submodules (native, non_native) can use -// `super::sanitize_point_scalar` etc. +pub use limbs::{Limbs, MAX_LIMBS}; use { - crate::{constraint_helpers::add_constant_witness, noir_to_r1cs::NoirToR1CSCompiler}, - ark_ff::PrimeField, + crate::{ + constraint_helpers::{add_constant_witness, constrain_boolean}, + noir_to_r1cs::NoirToR1CSCompiler, + }, + ark_ff::{AdditiveGroup, Field, PrimeField}, curve::CurveParams, - provekit_common::witness::ConstantOrR1CSWitness, - sanitize::{ - decompose_signed_bits, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, - negate_y_signed_native, sanitize_point_scalar, + provekit_common::{ + witness::{ConstantOrR1CSWitness, WitnessBuilder}, + FieldElement, }, std::collections::BTreeMap, }; -// --------------------------------------------------------------------------- -// Limbs: fixed-capacity, Copy array of witness indices -// --------------------------------------------------------------------------- - -/// Maximum number of limbs supported. Covers all practical field sizes -/// (e.g. a 512-bit modulus with 16-bit limbs = 32 limbs). -pub const MAX_LIMBS: usize = 32; - -/// A fixed-capacity array of witness indices, indexed by limb position. -/// -/// This type is `Copy`, so it can be passed by value without requiring -/// const generics or dispatch macros. The runtime `len` field tracks how -/// many limbs are actually in use. -#[derive(Clone, Copy)] -pub struct Limbs { - data: [usize; MAX_LIMBS], - len: usize, -} - -impl Limbs { - /// Sentinel value for uninitialized limb slots. Using `usize::MAX` - /// ensures accidental use of an unfilled slot indexes an absurdly - /// large witness, causing an immediate out-of-bounds panic. - const UNINIT: usize = usize::MAX; - - /// Create a new `Limbs` with `len` limbs, all initialized to `UNINIT`. - pub fn new(len: usize) -> Self { - assert!( - len > 0 && len <= MAX_LIMBS, - "limb count must be 1..={MAX_LIMBS}, got {len}" - ); - Self { - data: [Self::UNINIT; MAX_LIMBS], - len, - } - } - - /// Create a single-limb `Limbs` wrapping one witness index. - pub fn single(value: usize) -> Self { - let mut l = Self { - data: [Self::UNINIT; MAX_LIMBS], - len: 1, - }; - l.data[0] = value; - l - } - - /// View the active limbs as a slice. - pub fn as_slice(&self) -> &[usize] { - &self.data[..self.len] - } - - /// Number of active limbs. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - self.len - } -} - -impl std::fmt::Debug for Limbs { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_list().entries(self.as_slice().iter()).finish() - } -} - -impl PartialEq for Limbs { - fn eq(&self, other: &Self) -> bool { - self.len == other.len && self.data[..self.len] == other.data[..other.len] - } -} -impl Eq for Limbs {} - -impl std::ops::Index for Limbs { - type Output = usize; - fn index(&self, i: usize) -> &usize { - debug_assert!( - i < self.len, - "Limbs index {i} out of bounds (len={})", - self.len - ); - &self.data[i] - } -} - -impl std::ops::IndexMut for Limbs { - fn index_mut(&mut self, i: usize) -> &mut usize { - debug_assert!( - i < self.len, - "Limbs index {i} out of bounds (len={})", - self.len - ); - &mut self.data[i] - } -} - // --------------------------------------------------------------------------- // MSM entry point // --------------------------------------------------------------------------- @@ -251,3 +158,146 @@ fn resolve_input(compiler: &mut NoirToR1CSCompiler, input: &ConstantOrR1CSWitnes ConstantOrR1CSWitness::Constant(value) => add_constant_witness(compiler, *value), } } + +// --------------------------------------------------------------------------- +// Multi-limb MSM interface (for non-native curves with coords > BN254_Fr) +// --------------------------------------------------------------------------- + +/// MSM outputs when coordinates are in multi-limb form. +pub struct MsmLimbedOutputs { + pub out_x_limbs: Vec, + pub out_y_limbs: Vec, + pub out_inf: usize, +} + +/// Multi-limb MSM entry point for non-native curves. +/// +/// Point coordinates are provided as limbs, avoiding truncation when +/// values exceed BN254_Fr. Each point uses stride `2*num_limbs + 1`: +/// `[px_l0..px_lN-1, py_l0..py_lN-1, inf]`. +/// +/// Scalars remain as `[s_lo, s_hi]` pairs (128-bit halves fit in BN254_Fr). +/// Outputs are per-limb: `MsmLimbedOutputs` with N limbs for each coordinate. +pub fn add_msm_with_curve_limbed( + compiler: &mut NoirToR1CSCompiler, + msm_ops: Vec<( + Vec, + Vec, + MsmLimbedOutputs, + )>, + range_checks: &mut BTreeMap>, + curve: &CurveParams, + num_limbs: usize, +) { + assert!( + !curve.is_native_field(), + "limbed MSM is only for non-native curves" + ); + if msm_ops.is_empty() { + return; + } + + let native_bits = provekit_common::FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let stride = 2 * num_limbs + 1; + let n_points: usize = msm_ops.iter().map(|(pts, ..)| pts.len() / stride).sum(); + let (limb_bits, window_size) = + cost_model::get_optimal_msm_params(native_bits, curve_bits, n_points, 256, false); + + // Verify num_limbs matches what cost model produces + let expected_num_limbs = (curve_bits as usize + limb_bits as usize - 1) / limb_bits as usize; + assert_eq!( + num_limbs, expected_num_limbs, + "num_limbs mismatch: caller passed {num_limbs}, cost model expects {expected_num_limbs}" + ); + + for (points, scalars, outputs) in msm_ops { + assert!( + points.len() % stride == 0, + "points length must be a multiple of {stride} (2*{num_limbs}+1)" + ); + let n = points.len() / stride; + assert_eq!(scalars.len(), 2 * n, "scalars length must be 2x n_points"); + assert_eq!(outputs.out_x_limbs.len(), num_limbs); + assert_eq!(outputs.out_y_limbs.len(), num_limbs); + + let point_wits: Vec = points.iter().map(|p| resolve_input(compiler, p)).collect(); + let scalar_wits: Vec = scalars.iter().map(|s| resolve_input(compiler, s)).collect(); + + non_native::process_multi_point_non_native_limbed( + compiler, + &point_wits, + &scalar_wits, + &outputs, + n, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); + } +} + +// --------------------------------------------------------------------------- +// Signed-bit decomposition (shared by native and non-native paths) +// --------------------------------------------------------------------------- + +/// Signed-bit decomposition for wNAF scalar multiplication. +/// +/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} +/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: +/// scalar = Σ d_i * 2^i - skew +/// +/// Reconstruction constraint (1 linear R1CS): +/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// +/// All bits and skew are boolean-constrained. +/// +/// # Limitation +/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower +/// 128 bits of the field element). This is correct for FakeGLV half-scalars +/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` +/// exceeds 128. The R1CS reconstruction constraint would then fail. +pub(crate) fn decompose_signed_bits( + compiler: &mut NoirToR1CSCompiler, + scalar: usize, + num_bits: usize, +) -> (Vec, usize) { + let start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::SignedBitHint { + output_start: start, + scalar, + num_bits, + }); + let bits: Vec = (start..start + num_bits).collect(); + let skew = start + num_bits; + + // Boolean-constrain each bit and skew + for &b in &bits { + constrain_boolean(compiler, b); + } + constrain_boolean(compiler, skew); + + // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} + // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 + let one = compiler.witness_one(); + let two = FieldElement::from(2u64); + let constant = two.pow([num_bits as u64]) - FieldElement::ONE; + let mut b_terms: Vec<(FieldElement, usize)> = bits + .iter() + .enumerate() + .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) + .collect(); + b_terms.push((FieldElement::ONE, scalar)); + b_terms.push((FieldElement::ONE, skew)); + b_terms.push((constant, one)); + compiler + .r1cs + .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( + FieldElement::ZERO, + one, + )]); + + (bits, skew) +} diff --git a/provekit/r1cs-compiler/src/msm/native.rs b/provekit/r1cs-compiler/src/msm/native.rs index cecc58fc2..62f9374e2 100644 --- a/provekit/r1cs-compiler/src/msm/native.rs +++ b/provekit/r1cs-compiler/src/msm/native.rs @@ -34,8 +34,12 @@ use { super::{ - curve, ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, - negate_y_signed_native, sanitize_point_scalar, scalar_relation, + curve, ec_points, + sanitize::{ + emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, negate_y_signed_native, + sanitize_point_scalar, + }, + scalar_relation, }, crate::{ constraint_helpers::{ @@ -153,29 +157,33 @@ pub(super) fn process_multi_point_native( accum_inputs.push((sanitized_rx, sanitized_ry, san.is_skip)); } - // Phase 2: Merged scalar mul verification (shared doubling) + // Phase 2: Per-point scalar mul verification + // + // Each point gets its own accumulator and identity check. This ensures + // per-point soundness: b_i * (R_i - scalar_i * P_i) = O with b_i ≠ 0 + // implies R_i = scalar_i * P_i for each point independently. let half_bits = curve.glv_half_bits() as usize; let offset_x_fe = curve::curve_native_point_fe(&curve.offset_point.0); let offset_y_fe = curve::curve_native_point_fe(&curve.offset_point.1); - let offset_x = add_constant_witness(compiler, offset_x_fe); - let offset_y = add_constant_witness(compiler, offset_y_fe); + let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); + let acc_off_x_fe = curve::curve_native_point_fe(&acc_off_x_raw); + let acc_off_y_fe = curve::curve_native_point_fe(&acc_off_y_raw); - let (ver_acc_x, ver_acc_y) = - scalar_mul_merged_native_wnaf(compiler, &native_points, offset_x, offset_y, curve); + for pt in &native_points { + let offset_x = add_constant_witness(compiler, offset_x_fe); + let offset_y = add_constant_witness(compiler, offset_y_fe); - // Identity check: acc should equal accumulated offset (hardcoded into - // constraint matrix — not a witness the prover can manipulate) - let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(half_bits); - constrain_to_constant( - compiler, - ver_acc_x, - curve::curve_native_point_fe(&acc_off_x_raw), - ); - constrain_to_constant( - compiler, - ver_acc_y, - curve::curve_native_point_fe(&acc_off_y_raw), - ); + let (ver_acc_x, ver_acc_y) = scalar_mul_merged_native_wnaf( + compiler, + std::slice::from_ref(pt), + offset_x, + offset_y, + curve, + ); + + constrain_to_constant(compiler, ver_acc_x, acc_off_x_fe); + constrain_to_constant(compiler, ver_acc_y, acc_off_y_fe); + } // Phase 3: Per-point scalar relations for &(s_lo, s_hi, s1, s2, neg1, neg2) in &scalar_rel_inputs { @@ -235,15 +243,12 @@ pub(super) fn process_multi_point_native( constrain_equal(compiler, out_inf, all_skipped); } -/// Merged multi-point scalar multiplication for native field using -/// signed-bit wNAF (w=1) with shared doubling across all points. -/// -/// Instead of running separate 128-iteration loops per point (each with -/// its own doubling), this merges all points into a single loop with one -/// shared doubling per bit. Each bit costs: -/// 4C (shared double) + n_points × 8C (2×(1C select + 3C add)) +/// Multi-point scalar multiplication for native field using signed-bit wNAF +/// (w=1). /// -/// Savings: 4C × (n_points - 1) per bit ≈ 512C for 2 points on Grumpkin. +/// Called once per point with a single-element slice for per-point +/// soundness (each point gets its own accumulator and identity check). +/// Each bit costs: 4C (double) + 2 × (1C select + 3C add) = 12C per point. fn scalar_mul_merged_native_wnaf( compiler: &mut NoirToR1CSCompiler, points: &[NativePointData], diff --git a/provekit/r1cs-compiler/src/msm/non_native.rs b/provekit/r1cs-compiler/src/msm/non_native.rs index 270dd14d8..e72e232cc 100644 --- a/provekit/r1cs-compiler/src/msm/non_native.rs +++ b/provekit/r1cs-compiler/src/msm/non_native.rs @@ -31,14 +31,25 @@ //! 3. **Scalar relations**: per-point verification that (-1)^neg1·|s1| + //! (-1)^neg2·|s2|·s ≡ 0 (mod curve_order). //! 4. **Accumulation**: adds each point's scalar-mul result (via dispatch to -//! hint-verified or generic add), subtracts offset, recomposes limbs to -//! native field elements, constrains outputs. +//! hint-verified or generic add), subtracts offset, constrains outputs. +//! +//! ## I/O modes +//! +//! Two entry points share a single core implementation via `NonNativeIo`: +//! - `process_multi_point_non_native`: point coordinates as single field +//! elements (decomposed to limbs internally); outputs as single witnesses. +//! - `process_multi_point_non_native_limbed`: point coordinates pre-decomposed +//! as limbs; outputs constrained per-limb. use { super::{ - curve, ec_points, emit_ec_scalar_mul_hint_and_sanitize, emit_fakeglv_hint, + curve, ec_points, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, - sanitize_point_scalar, scalar_relation, Limbs, + sanitize::{ + emit_ec_scalar_mul_hint_and_sanitize_multi_limb, emit_fakeglv_hint, + sanitize_point_scalar_multi_limb, + }, + scalar_relation, Limbs, MsmLimbedOutputs, }, crate::{ constraint_helpers::{ @@ -53,12 +64,31 @@ use { std::collections::BTreeMap, }; -/// Multi-point non-native MSM with merged-loop optimization. +// --------------------------------------------------------------------------- +// IO mode: single field-element outputs vs. per-limb outputs +// --------------------------------------------------------------------------- + +/// Distinguishes the two non-native MSM I/O modes. +/// +/// - `SingleFe`: point coordinates come as single field elements (decomposed to +/// limbs internally); outputs are single witnesses `(out_x, out_y, out_inf)`. +/// - `Limbed`: point coordinates arrive pre-decomposed as limbs (stride +/// `2*num_limbs + 1` per point); outputs are per-limb witnesses. +enum NonNativeIo<'a> { + SingleFe { outputs: (usize, usize, usize) }, + Limbed { outputs: &'a MsmLimbedOutputs }, +} + +// --------------------------------------------------------------------------- +// Public entry points (thin wrappers around the shared core) +// --------------------------------------------------------------------------- + +/// Multi-point non-native MSM with single field-element I/O. /// /// All points share a single set of doublings per window, saving /// `w × (n_points - 1)` doublings per window compared to separate loops. pub(super) fn process_multi_point_non_native<'a>( - mut compiler: &'a mut NoirToR1CSCompiler, + compiler: &'a mut NoirToR1CSCompiler, point_wits: &[usize], scalar_wits: &[usize], outputs: (usize, usize, usize), @@ -66,19 +96,94 @@ pub(super) fn process_multi_point_non_native<'a>( num_limbs: usize, limb_bits: u32, window_size: usize, + range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + process_non_native_core( + compiler, + point_wits, + scalar_wits, + NonNativeIo::SingleFe { outputs }, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +/// Multi-point non-native MSM with per-limb I/O. +/// +/// Point coordinates are provided as limbs (stride `2*num_limbs + 1` per +/// point), avoiding the single-field-element bottleneck. Output coordinates +/// are constrained per-limb rather than recomposed. +pub(super) fn process_multi_point_non_native_limbed<'a>( + compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + outputs: &MsmLimbedOutputs, + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, + range_checks: &'a mut BTreeMap>, + curve: &CurveParams, +) { + process_non_native_core( + compiler, + point_wits, + scalar_wits, + NonNativeIo::Limbed { outputs }, + n_points, + num_limbs, + limb_bits, + window_size, + range_checks, + curve, + ); +} + +// --------------------------------------------------------------------------- +// Core implementation +// --------------------------------------------------------------------------- + +/// Unified non-native MSM implementation. +/// +/// The only differences between single-FE and limbed I/O modes are: +/// 1. **Phase 1**: how point coordinates are extracted from `point_wits` +/// (decompose from single FE vs. read pre-decomposed limbs). +/// 2. **Phase 4**: how output coordinates are constrained (recompose to single +/// FE vs. constrain per-limb). +/// +/// Phases 2 (merged scalar mul) and 3 (scalar relations) are identical. +fn process_non_native_core<'a>( + mut compiler: &'a mut NoirToR1CSCompiler, + point_wits: &[usize], + scalar_wits: &[usize], + io: NonNativeIo<'_>, + n_points: usize, + num_limbs: usize, + limb_bits: u32, + window_size: usize, mut range_checks: &'a mut BTreeMap>, curve: &CurveParams, ) { - let (out_x, out_y, out_inf) = outputs; let one = compiler.witness_one(); - - // Generator constants for sanitization - let gen_x_fe = curve::curve_native_point_fe(&curve.generator.0); - let gen_y_fe = curve::curve_native_point_fe(&curve.generator.1); - let gen_x_witness = add_constant_witness(compiler, gen_x_fe); - let gen_y_witness = add_constant_witness(compiler, gen_y_fe); let zero_witness = add_constant_witness(compiler, FieldElement::ZERO); + // Generator as limbs — avoids truncation when generator coords > BN254_Fr + let gen_x_fe_limbs = decompose_to_limbs_pub(&curve.generator.0, limb_bits, num_limbs); + let gen_y_fe_limbs = decompose_to_limbs_pub(&curve.generator.1, limb_bits, num_limbs); + let gen_x_limb_wits: Vec = gen_x_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + let gen_y_limb_wits: Vec = gen_y_fe_limbs + .iter() + .map(|&v| add_constant_witness(compiler, v)) + .collect(); + // Build params once for all multi-limb ops let params = MultiLimbParams::for_field_modulus(num_limbs, limb_bits, curve); let b_limb_values = curve::decompose_to_limbs(&curve.curve_b, limb_bits, num_limbs); @@ -95,15 +200,44 @@ pub(super) fn process_multi_point_non_native<'a>( // Phase 1: Per-point preprocessing for i in 0..n_points { - let san = sanitize_point_scalar( + // Extract point limbs and inf flag — differs by IO mode + let (px_limbs, py_limbs, inf_flag) = match &io { + NonNativeIo::SingleFe { .. } => { + let (px, py) = decompose_point_to_limbs( + compiler, + point_wits[3 * i], + point_wits[3 * i + 1], + num_limbs, + limb_bits, + range_checks, + ); + (px, py, point_wits[3 * i + 2]) + } + NonNativeIo::Limbed { .. } => { + let stride = 2 * num_limbs + 1; + let base = i * stride; + let mut px = Limbs::new(num_limbs); + let mut py = Limbs::new(num_limbs); + for j in 0..num_limbs { + px[j] = point_wits[base + j]; + py[j] = point_wits[base + num_limbs + j]; + range_checks.entry(limb_bits).or_default().push(px[j]); + range_checks.entry(limb_bits).or_default().push(py[j]); + } + (px, py, point_wits[base + 2 * num_limbs]) + } + }; + + // Sanitize at the limb level — per-limb select between input and generator + let san = sanitize_point_scalar_multi_limb( compiler, - point_wits[3 * i], - point_wits[3 * i + 1], + px_limbs, + py_limbs, scalar_wits[2 * i], scalar_wits[2 * i + 1], - point_wits[3 * i + 2], - gen_x_witness, - gen_y_witness, + inf_flag, + &gen_x_limb_wits, + &gen_y_limb_wits, zero_witness, one, ); @@ -114,26 +248,22 @@ pub(super) fn process_multi_point_non_native<'a>( Some(prev) => compiler.add_product(prev, san.is_skip), }); - let (sanitized_rx, sanitized_ry) = emit_ec_scalar_mul_hint_and_sanitize( + // EcScalarMulHint with multi-limb inputs/outputs + let (rx, ry) = emit_ec_scalar_mul_hint_and_sanitize_multi_limb( compiler, &san, - gen_x_witness, - gen_y_witness, - curve, - ); - - // Decompose points to limbs - let (px, py) = - decompose_point_to_limbs(compiler, san.px, san.py, num_limbs, limb_bits, range_checks); - let (rx, ry) = decompose_point_to_limbs( - compiler, - sanitized_rx, - sanitized_ry, + &gen_x_limb_wits, + &gen_y_limb_wits, num_limbs, limb_bits, range_checks, + curve, ); + // Sanitized px/py are already in Limbs form + let px = san.px_limbs; + let py = san.py_limbs; + // On-curve checks: use hint-verified for multi-limb, generic for single-limb if num_limbs >= 2 { ec_points::verify_on_curve_non_native(compiler, range_checks, px, py, ¶ms); @@ -203,36 +333,43 @@ pub(super) fn process_multi_point_non_native<'a>( accum_inputs.push((rx, ry, san.is_skip)); } - // Phase 2: Merged scalar mul verification (shared doublings across all points) + // Phase 2: Per-point scalar mul verification + // + // Each point gets its own accumulator and identity check. This ensures + // per-point soundness: b_i * (R_i - scalar_i * P_i) = O with b_i ≠ 0 + // implies R_i = scalar_i * P_i for each point independently. let half_bits = curve.glv_half_bits() as usize; - let glv_acc; { let mut ops = MultiLimbOps { compiler, range_checks, params: ¶ms, }; - let offset_x = ops.constant_limbs(&offset_x_values); - let offset_y = ops.constant_limbs(&offset_y_values); - - glv_acc = ec_points::scalar_mul_merged_glv( - &mut ops, - &merged_points, - window_size, - offset_x, - offset_y, - ); - // Identity check: acc should equal accumulated offset + // Precompute the expected accumulated offset (same for all points) let glv_num_windows = (half_bits + window_size - 1) / window_size; let glv_n_doublings = glv_num_windows * window_size; let (acc_off_x_raw, acc_off_y_raw) = curve.accumulated_offset(glv_n_doublings); - let acc_off_x_values = decompose_to_limbs_pub(&acc_off_x_raw, limb_bits, num_limbs); let acc_off_y_values = decompose_to_limbs_pub(&acc_off_y_raw, limb_bits, num_limbs); - for i in 0..num_limbs { - constrain_to_constant(ops.compiler, glv_acc.0[i], acc_off_x_values[i]); - constrain_to_constant(ops.compiler, glv_acc.1[i], acc_off_y_values[i]); + + for pt in &merged_points { + let offset_x = ops.constant_limbs(&offset_x_values); + let offset_y = ops.constant_limbs(&offset_y_values); + + let glv_acc = ec_points::scalar_mul_merged_glv( + &mut ops, + std::slice::from_ref(pt), + window_size, + offset_x, + offset_y, + ); + + // Per-point identity check + for j in 0..num_limbs { + constrain_to_constant(ops.compiler, glv_acc.0[j], acc_off_x_values[j]); + constrain_to_constant(ops.compiler, glv_acc.1[j], acc_off_y_values[j]); + } } compiler = ops.compiler; @@ -254,7 +391,7 @@ pub(super) fn process_multi_point_non_native<'a>( ); } - // Phase 4: Accumulation (offset-based, same as before) + // Phase 4: Accumulation + output constraining let all_skipped = all_skipped.expect("MSM must have at least one point"); let mut ops = MultiLimbOps { @@ -296,22 +433,47 @@ pub(super) fn process_multi_point_non_native<'a>( let (result_x, result_y) = ec_points::point_add_dispatch(&mut ops, acc_x, acc_y, sub_x, sub_y); compiler = ops.compiler; - if num_limbs == 1 { - let masked_result_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); - let masked_result_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); - constrain_equal(compiler, out_x, masked_result_x); - constrain_equal(compiler, out_y, masked_result_y); - } else { - let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); - let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); - let masked_result_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); - let masked_result_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); - constrain_equal(compiler, out_x, masked_result_x); - constrain_equal(compiler, out_y, masked_result_y); + // Output constraining — differs by IO mode + match &io { + NonNativeIo::SingleFe { + outputs: (out_x, out_y, out_inf), + } => { + if num_limbs == 1 { + let masked_x = select_witness(compiler, all_skipped, result_x[0], zero_witness); + let masked_y = select_witness(compiler, all_skipped, result_y[0], zero_witness); + constrain_equal(compiler, *out_x, masked_x); + constrain_equal(compiler, *out_y, masked_y); + } else { + let recomposed_x = recompose_limbs(compiler, result_x.as_slice(), limb_bits); + let recomposed_y = recompose_limbs(compiler, result_y.as_slice(), limb_bits); + let masked_x = select_witness(compiler, all_skipped, recomposed_x, zero_witness); + let masked_y = select_witness(compiler, all_skipped, recomposed_y, zero_witness); + constrain_equal(compiler, *out_x, masked_x); + constrain_equal(compiler, *out_y, masked_y); + } + constrain_equal(compiler, *out_inf, all_skipped); + } + NonNativeIo::Limbed { outputs } => { + let zero_limb_wits: Vec = (0..num_limbs) + .map(|_| add_constant_witness(compiler, FieldElement::ZERO)) + .collect(); + for j in 0..num_limbs { + let masked_x = + select_witness(compiler, all_skipped, result_x[j], zero_limb_wits[j]); + let masked_y = + select_witness(compiler, all_skipped, result_y[j], zero_limb_wits[j]); + constrain_equal(compiler, outputs.out_x_limbs[j], masked_x); + constrain_equal(compiler, outputs.out_y_limbs[j], masked_y); + } + constrain_equal(compiler, outputs.out_inf, all_skipped); + } } - constrain_equal(compiler, out_inf, all_skipped); } +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + /// On-curve check: verifies y^2 = x^3 + a*x + b for a single point. fn verify_on_curve( ops: &mut MultiLimbOps, @@ -388,6 +550,3 @@ fn recompose_limbs(compiler: &mut NoirToR1CSCompiler, limbs: &[usize], limb_bits .collect(); compiler.add_sum(terms) } - -// `decompose_half_scalar_bits` replaced by `super::decompose_signed_bits` -// which produces signed digits with skew correction, halving lookup tables. diff --git a/provekit/r1cs-compiler/src/msm/sanitize.rs b/provekit/r1cs-compiler/src/msm/sanitize.rs index 2ec8d4155..10ece47fd 100644 --- a/provekit/r1cs-compiler/src/msm/sanitize.rs +++ b/provekit/r1cs-compiler/src/msm/sanitize.rs @@ -1,14 +1,14 @@ -//! Degenerate-case detection, sanitization, and bit decomposition helpers -//! for MSM point-scalar pairs. +//! Degenerate-case detection and sanitization helpers for MSM point-scalar +//! pairs. use { - super::curve::CurveParams, + super::{curve::CurveParams, Limbs}, crate::{ constraint_helpers::{compute_boolean_or, constrain_boolean, select_witness}, msm::multi_limb_arith::compute_is_zero, noir_to_r1cs::NoirToR1CSCompiler, }, - ark_ff::{AdditiveGroup, Field}, + ark_ff::Field, provekit_common::{ witness::{SumTerm, WitnessBuilder}, FieldElement, @@ -82,6 +82,7 @@ pub(super) fn negate_y_signed_native( /// Emit an `EcScalarMulHint` and sanitize the result point. /// When `is_skip=1`, the result is swapped to the generator point. +/// Used by the native path where coordinates fit in a single field element. pub(super) fn emit_ec_scalar_mul_hint_and_sanitize( compiler: &mut NoirToR1CSCompiler, san: &SanitizedInputs, @@ -92,12 +93,14 @@ pub(super) fn emit_ec_scalar_mul_hint_and_sanitize( let hint_start = compiler.num_witnesses(); compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { output_start: hint_start, - px: san.px, - py: san.py, + px_limbs: vec![san.px], + py_limbs: vec![san.py], s_lo: san.s_lo, s_hi: san.s_hi, curve_a: curve.curve_a, field_modulus_p: curve.field_modulus_p, + num_limbs: 1, + limb_bits: 0, }); let rx = select_witness(compiler, san.is_skip, hint_start, gen_x_witness); let ry = select_witness(compiler, san.is_skip, hint_start + 1, gen_y_witness); @@ -121,61 +124,90 @@ pub(super) fn emit_fakeglv_hint( (glv_start, glv_start + 1, glv_start + 2, glv_start + 3) } -/// Signed-bit decomposition for wNAF scalar multiplication. -/// -/// Decomposes `scalar` into `num_bits` sign-bits b_i ∈ {0,1} and a skew ∈ {0,1} -/// such that the signed digits d_i = 2*b_i - 1 ∈ {-1, +1} satisfy: -/// scalar = Σ d_i * 2^i - skew -/// -/// Reconstruction constraint (1 linear R1CS): -/// scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} +/// Multi-limb sanitized inputs for non-native MSM. +pub(super) struct SanitizedInputsMultiLimb { + pub px_limbs: Limbs, + pub py_limbs: Limbs, + pub s_lo: usize, + pub s_hi: usize, + pub is_skip: usize, +} + +/// Sanitize a non-native point-scalar pair at the limb level. /// -/// All bits and skew are boolean-constrained. +/// Detects degenerate cases and replaces the point with the generator +/// (as limbs) and scalar with 1 when degenerate. This avoids truncating +/// generator coordinates that exceed BN254_Fr. +pub(super) fn sanitize_point_scalar_multi_limb( + compiler: &mut NoirToR1CSCompiler, + px_limbs: Limbs, + py_limbs: Limbs, + s_lo: usize, + s_hi: usize, + inf_flag: usize, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + zero: usize, + one: usize, +) -> SanitizedInputsMultiLimb { + let n = px_limbs.len(); + let is_skip = detect_skip(compiler, s_lo, s_hi, inf_flag); + + let mut san_px = Limbs::new(n); + let mut san_py = Limbs::new(n); + for i in 0..n { + san_px[i] = select_witness(compiler, is_skip, px_limbs[i], gen_x_limb_wits[i]); + san_py[i] = select_witness(compiler, is_skip, py_limbs[i], gen_y_limb_wits[i]); + } + + SanitizedInputsMultiLimb { + px_limbs: san_px, + py_limbs: san_py, + s_lo: select_witness(compiler, is_skip, s_lo, one), + s_hi: select_witness(compiler, is_skip, s_hi, zero), + is_skip, + } +} + +/// Emit an `EcScalarMulHint` with multi-limb inputs/outputs and sanitize. /// -/// # Limitation -/// The prover's `SignedBitHint` solver reads the scalar as a `u128` (lower -/// 128 bits of the field element). This is correct for FakeGLV half-scalars -/// (≤128 bits for 256-bit curves) but would silently truncate if `num_bits` -/// exceeds 128. The R1CS reconstruction constraint would then fail. -pub(super) fn decompose_signed_bits( +/// When `is_skip=1`, each output limb is replaced with the corresponding +/// generator limb. Returns `(rx_limbs, ry_limbs)` as `Limbs`. +pub(super) fn emit_ec_scalar_mul_hint_and_sanitize_multi_limb( compiler: &mut NoirToR1CSCompiler, - scalar: usize, - num_bits: usize, -) -> (Vec, usize) { - let start = compiler.num_witnesses(); - compiler.add_witness_builder(WitnessBuilder::SignedBitHint { - output_start: start, - scalar, - num_bits, + san: &SanitizedInputsMultiLimb, + gen_x_limb_wits: &[usize], + gen_y_limb_wits: &[usize], + num_limbs: usize, + limb_bits: u32, + range_checks: &mut std::collections::BTreeMap>, + curve: &CurveParams, +) -> (Limbs, Limbs) { + let hint_start = compiler.num_witnesses(); + compiler.add_witness_builder(WitnessBuilder::EcScalarMulHint { + output_start: hint_start, + px_limbs: san.px_limbs.as_slice().to_vec(), + py_limbs: san.py_limbs.as_slice().to_vec(), + s_lo: san.s_lo, + s_hi: san.s_hi, + curve_a: curve.curve_a, + field_modulus_p: curve.field_modulus_p, + num_limbs: num_limbs as u32, + limb_bits, }); - let bits: Vec = (start..start + num_bits).collect(); - let skew = start + num_bits; - // Boolean-constrain each bit and skew - for &b in &bits { - constrain_boolean(compiler, b); + let mut rx = Limbs::new(num_limbs); + let mut ry = Limbs::new(num_limbs); + for i in 0..num_limbs { + let rx_hint = hint_start + i; + let ry_hint = hint_start + num_limbs + i; + // Range-check hint output limbs + range_checks.entry(limb_bits).or_default().push(rx_hint); + range_checks.entry(limb_bits).or_default().push(ry_hint); + // Sanitize: select between hint output and generator + rx[i] = select_witness(compiler, san.is_skip, rx_hint, gen_x_limb_wits[i]); + ry[i] = select_witness(compiler, san.is_skip, ry_hint, gen_y_limb_wits[i]); } - constrain_boolean(compiler, skew); - - // Reconstruction: scalar + skew + (2^n - 1) = Σ b_i * 2^{i+1} - // Rearranged as: scalar + skew + (2^n - 1) - Σ b_i * 2^{i+1} = 0 - let one = compiler.witness_one(); - let two = FieldElement::from(2u64); - let constant = two.pow([num_bits as u64]) - FieldElement::ONE; - let mut b_terms: Vec<(FieldElement, usize)> = bits - .iter() - .enumerate() - .map(|(i, &b)| (-two.pow([(i + 1) as u64]), b)) - .collect(); - b_terms.push((FieldElement::ONE, scalar)); - b_terms.push((FieldElement::ONE, skew)); - b_terms.push((constant, one)); - compiler - .r1cs - .add_constraint(&[(FieldElement::ONE, one)], &b_terms, &[( - FieldElement::ZERO, - one, - )]); - (bits, skew) + (rx, ry) } diff --git a/provekit/r1cs-compiler/src/msm/scalar_relation.rs b/provekit/r1cs-compiler/src/msm/scalar_relation.rs index 7cca0221c..0b6c1e29b 100644 --- a/provekit/r1cs-compiler/src/msm/scalar_relation.rs +++ b/provekit/r1cs-compiler/src/msm/scalar_relation.rs @@ -6,6 +6,7 @@ use { super::{ cost_model, curve, + multi_limb_arith::compute_is_zero, multi_limb_ops::{MultiLimbOps, MultiLimbParams}, Limbs, }, @@ -89,6 +90,12 @@ pub(super) fn verify_scalar_relation( for i in 0..num_limbs { constrain_zero(ops.compiler, effective[i]); } + + // Soundness: s2 must be non-zero. If s2=0 the relation degenerates to + // s1≡0 (mod n) which is trivially satisfiable with s1=0, leaving the + // hint-supplied result point R unconstrained. + let s2_is_zero = compute_is_zero(ops.compiler, s2_witness); + constrain_zero(ops.compiler, s2_is_zero); } /// Decompose a 256-bit scalar from two 128-bit halves into `num_limbs` limbs. diff --git a/provekit/r1cs-compiler/src/noir_to_r1cs.rs b/provekit/r1cs-compiler/src/noir_to_r1cs.rs index 144730437..961edb248 100644 --- a/provekit/r1cs-compiler/src/noir_to_r1cs.rs +++ b/provekit/r1cs-compiler/src/noir_to_r1cs.rs @@ -98,8 +98,8 @@ pub struct R1CSBreakdown { /// Compiles an ACIR circuit into an [R1CS] instance, comprising of the A, B, /// and C R1CS matrices, along with the witness vector. -pub(crate) struct NoirToR1CSCompiler { - pub(crate) r1cs: R1CS, +pub struct NoirToR1CSCompiler { + pub r1cs: R1CS, /// Indicates how to solve for each R1CS witness pub witness_builders: Vec, @@ -142,7 +142,7 @@ pub fn noir_to_r1cs_with_breakdown( } impl NoirToR1CSCompiler { - pub(crate) fn new() -> Self { + pub fn new() -> Self { let mut r1cs = R1CS::new(); // Grow the matrices to account for the constant one witness. r1cs.add_witnesses(1); diff --git a/provekit/r1cs-compiler/src/range_check.rs b/provekit/r1cs-compiler/src/range_check.rs index 936a33240..763576ac4 100644 --- a/provekit/r1cs-compiler/src/range_check.rs +++ b/provekit/r1cs-compiler/src/range_check.rs @@ -178,7 +178,7 @@ pub(crate) fn estimate_range_check_cost(checks: &BTreeMap) -> usize /// Uses dynamic base width optimization: all range check requests are /// collected, and the optimal decomposition base width is determined by /// minimizing the total witness count (memory cost). The search evaluates -/// every base width from [MIN_BASE_WIDTH] to [MAX_BASE_WIDTH]. For each +/// every base width from \[MIN_BASE_WIDTH\] to \[MAX_BASE_WIDTH\]. For each /// candidate, the cost model picks the cheaper of LogUp and naive for /// every atomic bucket. /// @@ -189,7 +189,7 @@ pub(crate) fn estimate_range_check_cost(checks: &BTreeMap) -> usize /// /// `range_checks` is a map from the number of bits k to the vector of /// witness indices that are to be constrained within the range [0..2^k]. -pub(crate) fn add_range_checks( +pub fn add_range_checks( r1cs: &mut NoirToR1CSCompiler, range_checks: BTreeMap>, ) -> Option { diff --git a/tooling/provekit-bench/Cargo.toml b/tooling/provekit-bench/Cargo.toml index 3b1993b2a..8791ea151 100644 --- a/tooling/provekit-bench/Cargo.toml +++ b/tooling/provekit-bench/Cargo.toml @@ -16,6 +16,7 @@ provekit-r1cs-compiler.workspace = true provekit-verifier.workspace = true # Noir language +acir.workspace = true nargo.workspace = true nargo_cli.workspace = true nargo_toml.workspace = true diff --git a/tooling/provekit-bench/tests/msm_witness_solving.rs b/tooling/provekit-bench/tests/msm_witness_solving.rs new file mode 100644 index 000000000..f8787c0f5 --- /dev/null +++ b/tooling/provekit-bench/tests/msm_witness_solving.rs @@ -0,0 +1,512 @@ +//! End-to-end MSM witness solving tests for non-native curves (secp256r1). +//! +//! These tests verify that the full pipeline works correctly: +//! 1. Compile MSM circuit (R1CS + witness builders) +//! 2. Set initial witness values (point coordinates as limbs + scalar) +//! 3. Solve all derived witnesses via the witness builder layer scheduler +//! 4. Check R1CS satisfaction: A·w ⊙ B·w = C·w for all constraints +//! +//! All tests use the **limbed API** (`add_msm_with_curve_limbed`) where +//! point coordinates are multi-limb witnesses, supporting arbitrary +//! secp256r1 coordinates (including those exceeding BN254 Fr). + +use { + acir::native_types::WitnessMap, + ark_ff::{PrimeField, Zero}, + provekit_common::{ + witness::{ConstantOrR1CSWitness, LayerScheduler, WitnessBuilder}, + FieldElement, NoirElement, TranscriptSponge, + }, + provekit_prover::{bigint_mod::ec_scalar_mul, r1cs::solve_witness_vec}, + provekit_r1cs_compiler::{ + msm::{ + add_msm_with_curve_limbed, + cost_model::get_optimal_msm_params, + curve::{decompose_to_limbs, secp256r1_params}, + MsmLimbedOutputs, + }, + noir_to_r1cs::NoirToR1CSCompiler, + range_check::add_range_checks, + }, + std::collections::BTreeMap, + whir::transcript::{codecs::Empty, DomainSeparator, ProverState}, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Convert a [u64; 4] to a FieldElement. Panics if value exceeds BN254 Fr. +/// Only used for scalars (128-bit halves that always fit). +fn u256_to_fe(v: &[u64; 4]) -> FieldElement { + FieldElement::from_bigint(ark_ff::BigInt(*v)) + .unwrap_or_else(|| panic!("Value exceeds BN254 Fr: {v:?}")) +} + +/// Split a 256-bit scalar into (lo_128, hi_128) as [u64; 4] values. +fn split_scalar(s: &[u64; 4]) -> ([u64; 4], [u64; 4]) { + let lo = [s[0], s[1], 0, 0]; + let hi = [s[2], s[3], 0, 0]; + (lo, hi) +} + +/// Verify R1CS satisfaction: for each constraint row, A·w * B·w == C·w. +fn check_r1cs_satisfaction( + r1cs: &provekit_common::R1CS, + witness: &[FieldElement], +) -> anyhow::Result<()> { + use anyhow::ensure; + + ensure!( + witness.len() == r1cs.num_witnesses(), + "witness size {} != expected {}", + witness.len(), + r1cs.num_witnesses() + ); + + let a = r1cs.a() * witness; + let b = r1cs.b() * witness; + let c = r1cs.c() * witness; + for (row, ((a_val, b_val), c_val)) in a.into_iter().zip(b).zip(c).enumerate() { + ensure!( + a_val * b_val == c_val, + "Constraint {row} failed: a={a_val:?}, b={b_val:?}, a*b={:?}, c={c_val:?}", + a_val * b_val + ); + } + Ok(()) +} + +/// Create a dummy transcript for witness solving (no challenges needed). +fn dummy_transcript() -> ProverState { + let ds = DomainSeparator::protocol(&()).instance(&Empty); + ProverState::new(&ds, TranscriptSponge::default()) +} + +/// Solve all witness builders given initial witness values. +fn solve_witnesses( + builders: &[WitnessBuilder], + num_witnesses: usize, + initial_values: &[(usize, FieldElement)], +) -> Vec { + let layers = LayerScheduler::new(builders).build_layers(); + let mut witness: Vec> = vec![None; num_witnesses]; + + for &(idx, val) in initial_values { + witness[idx] = Some(val); + } + + let acir_map = WitnessMap::::new(); + let mut transcript = dummy_transcript(); + solve_witness_vec(&mut witness, layers, &acir_map, &mut transcript); + + witness + .into_iter() + .enumerate() + .map(|(i, w)| w.unwrap_or_else(|| panic!("Witness {i} was not solved"))) + .collect() +} + +/// Compute the (num_limbs, limb_bits) that the compiler will use for this +/// curve, so the test can decompose coordinates the same way. +fn msm_params_for_curve( + curve: &provekit_r1cs_compiler::msm::curve::CurveParams, + n_points: usize, +) -> (usize, u32) { + let native_bits = FieldElement::MODULUS_BIT_SIZE; + let curve_bits = curve.modulus_bits(); + let (limb_bits, _window_size) = + get_optimal_msm_params(native_bits, curve_bits, n_points, 256, false); + let num_limbs = (curve_bits as usize + limb_bits as usize - 1) / limb_bits as usize; + (num_limbs, limb_bits) +} + +/// Decompose a [u64; 4] value into field-element limbs. +fn u256_to_limb_fes(v: &[u64; 4], limb_bits: u32, num_limbs: usize) -> Vec { + decompose_to_limbs(v, limb_bits, num_limbs) +} + +// --------------------------------------------------------------------------- +// Single-point limbed MSM test runner +// --------------------------------------------------------------------------- + +/// Compile and solve a single-point MSM circuit using the limbed API. +/// +/// When `expected_inf` is true, the expected output is point at infinity +/// (all output limbs zero, out_inf = 1). +fn run_single_point_msm_test_limbed( + px: &[u64; 4], + py: &[u64; 4], + inf: bool, + scalar: &[u64; 4], + expected_x: &[u64; 4], + expected_y: &[u64; 4], + expected_inf: bool, +) { + let curve = secp256r1_params(); + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 1); + let (s_lo, s_hi) = split_scalar(scalar); + let stride = 2 * num_limbs + 1; + + let px_fes = u256_to_limb_fes(px, limb_bits, num_limbs); + let py_fes = u256_to_limb_fes(py, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(expected_x, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(expected_y, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total_input_wits = stride + 2 + stride; + compiler.r1cs.add_witnesses(total_input_wits); + + let points: Vec = (0..stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let slo_w = base + stride; + let shi_w = base + stride + 1; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(slo_w), + ConstantOrR1CSWitness::Witness(shi_w), + ]; + + let out_base = base + stride + 2; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + // Set initial witness values + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in px_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in py_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + let inf_fe = if inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((base + 2 * num_limbs, inf_fe)); + initial_values.push((slo_w, u256_to_fe(&s_lo))); + initial_values.push((shi_w, u256_to_fe(&s_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + let out_inf_fe = if expected_inf { + FieldElement::from(1u64) + } else { + FieldElement::zero() + }; + initial_values.push((out_inf, out_inf_fe)); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed (limbed)"); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Single-point MSM using the secp256r1 generator directly. +/// The generator's x-coordinate exceeds BN254 Fr. +#[test] +fn test_single_point_generator() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [7, 0, 0, 0]; + let (ex, ey) = ec_scalar_mul(&gx, &gy, &scalar, &curve.curve_a, &curve.field_modulus_p); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Scalar = 1: result should equal the input point. +#[test] +fn test_scalar_one() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [1, 0, 0, 0]; + + // 1·G = G + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &gx, &gy, false); +} + +/// Large scalar spanning both lo and hi halves of the 256-bit representation. +#[test] +fn test_large_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [0xcafebabe, 0x12345678, 0x42, 0]; + let (ex, ey) = ec_scalar_mul(&gx, &gy, &scalar, &curve.curve_a, &curve.field_modulus_p); + + run_single_point_msm_test_limbed(&gx, &gy, false, &scalar, &ex, &ey, false); +} + +/// Zero scalar: result should be point at infinity. +#[test] +fn test_zero_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let zero_scalar: [u64; 4] = [0, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed( + &gx, + &gy, + false, + &zero_scalar, + &zero_point, + &zero_point, + true, + ); +} + +/// Point at infinity as input: result should be point at infinity regardless +/// of scalar. +#[test] +fn test_point_at_infinity_input() { + let curve = secp256r1_params(); + // Use generator coords as placeholder (they're ignored due to inf=1 select) + let gx = curve.generator.0; + let gy = curve.generator.1; + let scalar: [u64; 4] = [42, 0, 0, 0]; + let zero_point: [u64; 4] = [0, 0, 0, 0]; + + run_single_point_msm_test_limbed(&gx, &gy, true, &scalar, &zero_point, &zero_point, true); +} + +/// Non-trivial point (2·G) with a moderate scalar, verifying the full +/// wNAF + FakeGLV pipeline. +#[test] +fn test_arbitrary_point_and_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + + // P = 2·G + let (px, py) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let scalar: [u64; 4] = [17, 0, 0, 0]; + // Expected: 17·(2G) = 34G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[34, 0, 0, 0], a, p); + + run_single_point_msm_test_limbed(&px, &py, false, &scalar, &ex, &ey, false); +} + +/// Two-point MSM: s1·P1 + s2·P2 with arbitrary coordinates. +#[test] +fn test_two_point_msm() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = 3·G, P2 = 5·G + let (p1x, p1y) = ec_scalar_mul(&gx, &gy, &[3, 0, 0, 0], a, p); + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + let s1: [u64; 4] = [2, 0, 0, 0]; + let s2: [u64; 4] = [3, 0, 0, 0]; + // Expected: 2·(3G) + 3·(5G) = 6G + 15G = 21G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[21, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&p1x, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&p1y, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM"); +} + +/// Two-point MSM where one scalar is zero — only the non-zero point +/// should contribute. +#[test] +fn test_two_point_one_zero_scalar() { + let curve = secp256r1_params(); + let gx = curve.generator.0; + let gy = curve.generator.1; + let a = &curve.curve_a; + let p = &curve.field_modulus_p; + let (num_limbs, limb_bits) = msm_params_for_curve(&curve, 2); + let stride = 2 * num_limbs + 1; + + // P1 = G (scalar=5), P2 = 2G (scalar=0) + let (p2x, p2y) = ec_scalar_mul(&gx, &gy, &[2, 0, 0, 0], a, p); + let s1: [u64; 4] = [5, 0, 0, 0]; + let s2: [u64; 4] = [0, 0, 0, 0]; + // Expected: 5·G + 0·(2G) = 5G + let (ex, ey) = ec_scalar_mul(&gx, &gy, &[5, 0, 0, 0], a, p); + + let (s1_lo, s1_hi) = split_scalar(&s1); + let (s2_lo, s2_hi) = split_scalar(&s2); + + let p1x_fes = u256_to_limb_fes(&gx, limb_bits, num_limbs); + let p1y_fes = u256_to_limb_fes(&gy, limb_bits, num_limbs); + let p2x_fes = u256_to_limb_fes(&p2x, limb_bits, num_limbs); + let p2y_fes = u256_to_limb_fes(&p2y, limb_bits, num_limbs); + let ex_fes = u256_to_limb_fes(&ex, limb_bits, num_limbs); + let ey_fes = u256_to_limb_fes(&ey, limb_bits, num_limbs); + + let mut compiler = NoirToR1CSCompiler::new(); + let mut range_checks: BTreeMap> = BTreeMap::new(); + + let base = compiler.num_witnesses(); + let total = 2 * stride + 4 + stride; + compiler.r1cs.add_witnesses(total); + + let points: Vec = (0..2 * stride) + .map(|j| ConstantOrR1CSWitness::Witness(base + j)) + .collect(); + let scalar_base = base + 2 * stride; + let scalars = vec![ + ConstantOrR1CSWitness::Witness(scalar_base), + ConstantOrR1CSWitness::Witness(scalar_base + 1), + ConstantOrR1CSWitness::Witness(scalar_base + 2), + ConstantOrR1CSWitness::Witness(scalar_base + 3), + ]; + let out_base = scalar_base + 4; + let out_x_limbs: Vec = (0..num_limbs).map(|j| out_base + j).collect(); + let out_y_limbs: Vec = (0..num_limbs).map(|j| out_base + num_limbs + j).collect(); + let out_inf = out_base + 2 * num_limbs; + + let outputs = MsmLimbedOutputs { + out_x_limbs: out_x_limbs.clone(), + out_y_limbs: out_y_limbs.clone(), + out_inf, + }; + let msm_ops = vec![(points, scalars, outputs)]; + add_msm_with_curve_limbed(&mut compiler, msm_ops, &mut range_checks, &curve, num_limbs); + add_range_checks(&mut compiler, range_checks); + + let num_witnesses = compiler.num_witnesses(); + + let mut initial_values = vec![(0, FieldElement::from(1u64))]; + // P1 limbs (generator) + for (j, fe) in p1x_fes.iter().enumerate() { + initial_values.push((base + j, *fe)); + } + for (j, fe) in p1y_fes.iter().enumerate() { + initial_values.push((base + num_limbs + j, *fe)); + } + initial_values.push((base + 2 * num_limbs, FieldElement::zero())); + // P2 limbs + let p2_base = base + stride; + for (j, fe) in p2x_fes.iter().enumerate() { + initial_values.push((p2_base + j, *fe)); + } + for (j, fe) in p2y_fes.iter().enumerate() { + initial_values.push((p2_base + num_limbs + j, *fe)); + } + initial_values.push((p2_base + 2 * num_limbs, FieldElement::zero())); + // Scalars + initial_values.push((scalar_base, u256_to_fe(&s1_lo))); + initial_values.push((scalar_base + 1, u256_to_fe(&s1_hi))); + initial_values.push((scalar_base + 2, u256_to_fe(&s2_lo))); + initial_values.push((scalar_base + 3, u256_to_fe(&s2_hi))); + // Expected output limbs + for (j, fe) in ex_fes.iter().enumerate() { + initial_values.push((out_x_limbs[j], *fe)); + } + for (j, fe) in ey_fes.iter().enumerate() { + initial_values.push((out_y_limbs[j], *fe)); + } + initial_values.push((out_inf, FieldElement::zero())); + + let witness = solve_witnesses(&compiler.witness_builders, num_witnesses, &initial_values); + + check_r1cs_satisfaction(&compiler.r1cs, &witness) + .expect("R1CS satisfaction check failed for two-point MSM with one zero scalar"); +}