diff --git a/skyscraper/bn254-multiplier/gen_multiples.py b/skyscraper/bn254-multiplier/gen_multiples.py new file mode 100644 index 000000000..67c86127e --- /dev/null +++ b/skyscraper/bn254-multiplier/gen_multiples.py @@ -0,0 +1,35 @@ +"""Generate Rust const lookup tables for multiples of the BN254 scalar field prime.""" + +p = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 + + +def int_to_limbs(size, n, count): + mask = 2**size - 1 + limbs = [] + for _ in range(count): + limbs.append(n & mask) + n >>= size + return limbs + + +def main(): + multiples_64 = [int_to_limbs(64, k * p, 4) for k in range(0, 6)] + multiples_51 = [int_to_limbs(51, k * p, 5) for k in range(0, 6)] + + # Print 64-bit table (for constants.rs) + print("pub const U64_P_MULTIPLES: [[u64; 4]; 6] = [") + for k, limbs in enumerate(multiples_64): + fmt = ", ".join(f"0x{l:016x}" for l in limbs) + print(f" [{fmt}], // {k}P") + print("];") + + # Print 51-bit table (for rne/constants.rs) + print("\npub const U51_P_MULTIPLES: [[u64; 5]; 6] = [") + for k, limbs in enumerate(multiples_51): + fmt = ", ".join(f"0x{l:013x}" for l in limbs) + print(f" [{fmt}], // {k}P") + print("];") + + +if __name__ == "__main__": + main() diff --git a/skyscraper/bn254-multiplier/src/constants.rs b/skyscraper/bn254-multiplier/src/constants.rs index b49971136..b831593f3 100644 --- a/skyscraper/bn254-multiplier/src/constants.rs +++ b/skyscraper/bn254-multiplier/src/constants.rs @@ -1,17 +1,47 @@ -pub const U64_NP0: u64 = 0xc2e1f593efffffff; +pub const U64_P: [u64; 4] = U64_P_MULTIPLES[1]; -pub const U64_P: [u64; 4] = [ - 0x43e1f593f0000001, - 0x2833e84879b97091, - 0xb85045b68181585d, - 0x30644e72e131a029, -]; +pub const U64_2P: [u64; 4] = U64_P_MULTIPLES[2]; -pub const U64_2P: [u64; 4] = [ - 0x87c3eb27e0000002, - 0x5067d090f372e122, - 0x70a08b6d0302b0ba, - 0x60c89ce5c2634053, +/// Lookup table: `U64_P_MULTIPLES[k]` = `k * P` for k in 0..=5. +/// Index 0 is all-zeros; use as `x - U64_P_MULTIPLES[k]` to subtract k copies +/// of P. +pub const U64_P_MULTIPLES: [[u64; 4]; 6] = [ + [ + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], // 0P + [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, + ], // 1P + [ + 0x87c3eb27e0000002, + 0x5067d090f372e122, + 0x70a08b6d0302b0ba, + 0x60c89ce5c2634053, + ], // 2P + [ + 0xcba5e0bbd0000003, + 0x789bb8d96d2c51b3, + 0x28f0d12384840917, + 0x912ceb58a394e07d, + ], // 3P + [ + 0x0f87d64fc0000004, + 0xa0cfa121e6e5c245, + 0xe14116da06056174, + 0xc19139cb84c680a6, + ], // 4P + [ + 0x5369cbe3b0000005, + 0xc903896a609f32d6, + 0x99915c908786b9d1, + 0xf1f5883e65f820d0, + ], // 5P ]; // R mod P diff --git a/skyscraper/bn254-multiplier/src/lib.rs b/skyscraper/bn254-multiplier/src/lib.rs index 454d01945..3b4d99bfa 100644 --- a/skyscraper/bn254-multiplier/src/lib.rs +++ b/skyscraper/bn254-multiplier/src/lib.rs @@ -13,7 +13,7 @@ pub mod rtz; pub mod constants; pub mod rne; mod scalar; -mod utils; +pub mod utils; #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; diff --git a/skyscraper/bn254-multiplier/src/rne/constants.rs b/skyscraper/bn254-multiplier/src/rne/constants.rs index 77862defc..462e8938e 100644 --- a/skyscraper/bn254-multiplier/src/rne/constants.rs +++ b/skyscraper/bn254-multiplier/src/rne/constants.rs @@ -6,12 +6,53 @@ use crate::pow_2; pub const U51_NP0: u64 = 0x1f593efffffff; /// BN254 scalar field prime -pub const U51_P: [u64; 5] = [ - 0x1f593f0000001, - 0x10f372e12287c, - 0x6056174a0cfa1, - 0x014dc2822db40, - 0x30644e72e131a, +pub const U51_P: [u64; 5] = U51_P_MULTIPLES[1]; + +/// Lookup table: `U51_P_MULTIPLES[k]` = `k * P` for k in 0..=5, in 51-bit +/// limbs. +pub const U51_P_MULTIPLES: [[u64; 5]; 6] = [ + [ + 0x0000000000000, + 0x0000000000000, + 0x0000000000000, + 0x0000000000000, + 0x0000000000000, + ], // 0P + [ + 0x1f593f0000001, + 0x10f372e12287c, + 0x6056174a0cfa1, + 0x014dc2822db40, + 0x30644e72e131a, + ], // 1P + [ + 0x3eb27e0000002, + 0x21e6e5c2450f8, + 0x40ac2e9419f42, + 0x029b85045b681, + 0x60c89ce5c2634, + ], // 2P + [ + 0x5e0bbd0000003, + 0x32da58a367974, + 0x210245de26ee3, + 0x03e94786891c2, + 0x112ceb58a394e, + ], // 3P + [ + 0x7d64fc0000004, + 0x43cdcb848a1f0, + 0x01585d2833e84, + 0x05370a08b6d03, + 0x419139cb84c68, + ], // 4P + [ + 0x1cbe3b0000005, + 0x54c13e65aca6d, + 0x61ae747240e25, + 0x0684cc8ae4843, + 0x71f5883e65f82, + ], // 5P ]; /// Bit mask for 51-bit limbs. diff --git a/skyscraper/bn254-multiplier/src/utils.rs b/skyscraper/bn254-multiplier/src/utils.rs index ee3ac57b7..c8cc68ef1 100644 --- a/skyscraper/bn254-multiplier/src/utils.rs +++ b/skyscraper/bn254-multiplier/src/utils.rs @@ -1,4 +1,4 @@ -use crate::constants::U64_2P; +use crate::constants::{self, U64_2P}; /// Macro to extract a subarray from an array. /// @@ -97,3 +97,136 @@ pub const fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64 let c: u128 = widening_mul(a, b) + carry as u128 + add as u128; (c as u64, (c >> 64) as u64) } + +/// Precomputed magic constants for fast approximate floor division by the +/// BN254 prime P, following Warren's "Hacker's Delight" (integer division by +/// constants). Guarantees `div_p(val) ≤ ⌊val / P⌋` for all inputs within the +/// declared bit range. +pub struct MulShift { + mul: u64, + shift: u32, +} + +impl MulShift { + /// Computes magic multiplication and shift constants for approximate floor + /// division by P, such that `div_p(val) ≤ ⌊val / P⌋` for all `val` within + /// `max_bit_size` bits. `precision` controls how many of the top bits of + /// `val` are used — higher precision yields a tighter approximation. + pub const fn new(max_bit_size: u32, precision: u32) -> Self { + // Generate magic numbers for division by bn254's prime for a range of top-bit + // widths. + + // Based on Warren's "Hacker's Delight" (integer + // division by constants) to find a magic multiplier for each bit-width. + + use crate::constants::U64_P; + + let d = (U64_P[3] >> (max_bit_size - 192 - precision)) + 1; // d = divisor = ceil(p / 2^(max_bit_size-w)) + let nc = 2_u64.pow(precision) - 1 - (2_u64.pow(precision) % d); // nc = largest value s.t. nc mod d == d-1 + let mut s = precision; // start at precision; s < precision values are skipped + let m; + loop { + // s = shift exponent + if 2_u64.pow(s) > nc * (d - 1 - (2_u64.pow(s) - 1) % d) { + m = (2_u64.pow(s) + d - 1 - (2_u64.pow(s) - 1) % d) / d; // m = magic multiplier + break; + } + s += 1; + assert!(s < 64, "no magic multiplier found"); + } + + MulShift { mul: m, shift: s } + } + + #[inline(always)] + /// Returns an under-approximation of ⌊val / P⌋ (result ≤ true quotient). + /// `val` must fit within the `max_bit_size` passed to [`MulShift::new`]. + pub const fn div_p(&self, val: u64) -> u64 { + // assumes systems can handle multiplication by 64 bits without performance + // penalty. + (val * self.mul) >> self.shift + } +} + +/// Approximate floor division by P using the upper 6 bits of `x`. +/// `x` must be the upper limb of a u256 in 64-bit radix. +/// +/// Returns a value ≤ ⌊x / P⌋. This is the most precise +/// approximation achievable without multiplication on ARM64 and x86. +/// +/// Tradeoff: due the limited range of this division \[0,4\] (instead of \[0,5\] +/// for u256) will lead to a larger value after subtraction reduction. +/// subtraction reduction output: [0, 1+ε] with ε < 0.3. +#[inline(always)] +pub fn div_p_6b(x: u64) -> u64 { + let upper_bits = x >> (64 - 6); + // const to force compile time evaluation + const MULSHIFT: MulShift = MulShift::new(256, 6); + MULSHIFT.div_p(upper_bits) +} + +/// Approximate floor division by P using the upper 32 bits of `x`. +/// `x` must be the upper limb of a u256 in 64-bit radix. +/// +/// Returns a value ≤ ⌊x / P⌋. This is the most precise +/// approximation achievable with a 32bx32b->64b multiplier. +#[inline(always)] +pub fn div_p_32b(x: u64) -> u64 { + let upper_bits = x >> (64 - 32); + // const to force compile time evaluation + const MULSHIFT: MulShift = MulShift::new(256, 32); + MULSHIFT.div_p(upper_bits) +} + +/// Subtracts an approximate multiple of P from `x` using `div_p` on the high +/// limb. +/// +/// The result is not fully reduced; the output range depends on the precision +/// of the supplied `div_p` — see [`div_p_6b`] and [`div_p_32b`]. +#[inline(always)] +pub fn subtraction_reduce u64>(div_p: F, x: [u64; 4]) -> [u64; 4] { + // No clamping as the max value of x can't go past 5. Which is the maximum of + // the table. + let q = div_p(x[3]) as usize; + sub(x, constants::U64_P_MULTIPLES[q]) +} + +#[cfg(kani)] +mod proofs { + use { + super::{constants::U64_2P, div_p_32b, div_p_6b}, + crate::constants::U64_P_MULTIPLES, + }; + + /// Lexicographic ≤ on little-endian 256-bit integers. + fn le256(a: [u64; 4], b: [u64; 4]) -> bool { + for i in (0..4).rev() { + if a[i] != b[i] { + return a[i] < b[i]; + } + } + true + } + + /// TODO: tighter bounds + #[kani::proof] + fn div_p_32b_underapprox() { + let x: u64 = kani::any(); + let q = div_p_32b(x); + + let r = U64_P_MULTIPLES[q as usize][3]; + assert!(x >= r); + assert!(le256([0, 0, 0, x - r], U64_2P)); + } + + #[kani::proof] + // TODO tighter bounds + fn div_p_6b_underapprox() { + let x: u64 = kani::any(); + let q = div_p_6b(x); + + let r = U64_P_MULTIPLES[q as usize][3]; + assert!(x >= r); + assert!(le256([0, 0, 0, x - r], U64_2P)); + } +} diff --git a/skyscraper/bn254-multiplier/sub_reduce.py b/skyscraper/bn254-multiplier/sub_reduce.py new file mode 100644 index 000000000..fddc95ffc --- /dev/null +++ b/skyscraper/bn254-multiplier/sub_reduce.py @@ -0,0 +1,126 @@ +"""Sub-reduction strategies for bn254 modular arithmetic.""" + +import argparse + +p = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 + + +def shift_sum(): + """ + Generate truth table for top 3 bit reduction strategy. + + The formula ((i >> 2) & i) + (i >> 1) computes this + approximation using only shifts and adds — no multiplication needed. + """ + # Keep track of potential erroneous conditions + neq = 0 + neg_max = 0 + neg_min = 0 + + power = 3 + + max_pval = 0 + truth_table = list() + + for i in range(0, 2**power): + val_min = i << (256 - power) + subp = val_min // p + shift_subp = ((i >> 2) & i) + (i >> 1) + + val_max = val_min + (1 << (256 - power)) - 1 + rem_max = val_max - (shift_subp) * p + + rem_min = val_min - (shift_subp) * p + + # Validation: track the maximum remainder relative to p + max_pval = max(max_pval, rem_max) + + truth_table.append((bin(i)[2:].zfill(power), shift_subp)) + + # Check for erroneous situations + if subp != shift_subp: + print(hex(val_min), subp, shift_subp) + neq += 1 + + if rem_max < 0: + neg_max += 1 + + if rem_min < 0: + neg_min += 1 + + print(f"{'bits':>5} {'subtractions':>12}") + for e, r in truth_table: + print(f"{e:>5} {r:>12}") + + print( + f"\nmax_remainder/p={max_pval / p:.4f} mismatches={neq} neg_max={neg_max} neg_min={neg_min}" + ) + + +def warren_magic(max_bit_size, apply_sub: bool = False): + """ + Generate magic numbers for division by bn254's prime for a range of top-bit widths. + + Based on Warren's "Hacker's Delight" (integer + division by constants) to find a magic multiplier for each bit-width. + + Returns a list of tuples (w, m_bits, sub, shift, m) where: + - w: number of top bits of the value + - m_bits: number of bits in the magic multiplier + - sub: whether the "subtract and shift" variant is used (m exceeded 2^w, + so we store m - 2^w and compensate at runtime) + - shift: the number of bits to right-shift the product (called 's' in Warren) + - m: the magic multiplier + """ + res = list() + for w in range(0, 65): + d = (p >> (max_bit_size - w)) + 1 # d = divisor = ceil(p / 2^(max_bit_size-w)) + nc = 2**w - 1 - (2**w % d) # nc = largest value s.t. nc mod d == d-1 + for s in range(0, 128): # s = shift exponent + if 2**s > nc * (d - 1 - (2**s - 1) % d): + if s < w: + print(w, s, d.bit_length()) + m = (2**s + d - 1 - (2**s - 1) % d) // d # m = magic multiplier + sub = False + + # "Subtract and shift" variant: when m >= 2^w it won't fit in w bits, + # so subtract 2^w and compensate at runtime. + # Skip this part via apply_sub if the register can hold more bits than w. + if apply_sub and m >= 2**w: + sub = True + m = m - 2**w + res.append( + ( + w, + m.bit_length(), + sub, + s, + m, + ) + ) + break + return res + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sub-reduction strategies for bn254.") + parser.add_argument( + "--sub", + action="store_true", + help="Apply the 'subtract and shift' variant when m >= 2^w.", + ) + args = parser.parse_args() + + print("shift sum") + shift_sum() + + print("\n warren") + print(f"{'w':>3} {'m_bits':>6} {'sub':>5} {'shift':>5} {'m':>20}") + for w, m_bits, sub, shift, m in warren_magic(257, apply_sub=args.sub): + print(f"{w:>3} {m_bits:>6} {str(sub):>5} {shift:>5} {hex(m):>20}") + + for w, m_bits, sub, shift, m in warren_magic(256, apply_sub=args.sub): + print(f"{w:>3} {m_bits:>6} {str(sub):>5} {shift:>5} {hex(m):>20}") + + for w, m_bits, sub, shift, m in warren_magic(255, apply_sub=args.sub): + print(f"{w:>3} {m_bits:>6} {str(sub):>5} {shift:>5} {hex(m):>20}")