diff --git a/benches/const_monty.rs b/benches/const_monty.rs index 7405896bd..700a79ad7 100644 --- a/benches/const_monty.rs +++ b/benches/const_monty.rs @@ -202,7 +202,8 @@ fn bench_montgomery_sqrt(group: &mut BenchmarkGroup<'_, M>) { const_prime_monty_params!( P256Field, U256, - "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff", + 6 ); assert_eq!(P256Field::PRIME_PARAMS.s().get(), 1); type ConstForm = crypto_bigint::modular::ConstMontyForm; @@ -227,7 +228,8 @@ fn bench_montgomery_sqrt(group: &mut BenchmarkGroup<'_, M>) { const_prime_monty_params!( P256Scalar, U256, - "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", + 7 ); assert_eq!(P256Scalar::PRIME_PARAMS.s().get(), 4); type ConstForm = crypto_bigint::modular::ConstMontyForm; @@ -254,7 +256,8 @@ fn bench_montgomery_sqrt(group: &mut BenchmarkGroup<'_, M>) { const_prime_monty_params!( K256Scalar, U256, - "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141" + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", + 7 ); assert_eq!(K256Scalar::PRIME_PARAMS.s().get(), 6); type ConstForm = crypto_bigint::modular::ConstMontyForm; diff --git a/src/modular/const_monty_form/macros.rs b/src/modular/const_monty_form/macros.rs index d62b9e4a6..47036979c 100644 --- a/src/modular/const_monty_form/macros.rs +++ b/src/modular/const_monty_form/macros.rs @@ -46,8 +46,8 @@ macro_rules! const_monty_params { } /// Create a type representing a prime modulus which impls the [`ConstPrimeMontyParams`] -/// trait with the given name, type, value (in big endian hex), and optional documentation -/// string. +/// trait with the given name, type, value (in big endian hex), multiplicative generator, +/// and optional documentation string. /// /// # Usage /// @@ -58,6 +58,7 @@ macro_rules! const_monty_params { /// MyModulus, /// U256, /// "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", +/// 7, /// "Docs for my modulus" /// ); /// ``` @@ -65,28 +66,24 @@ macro_rules! const_monty_params { /// The modulus _must_ be odd and prime, or this will panic. #[macro_export] macro_rules! const_prime_monty_params { - ($name:ident, $uint_type:ty, $value:expr) => { + ($name:ident, $uint_type:ty, $value:expr, $generator:literal) => { $crate::const_prime_monty_params!( $name, $uint_type, $value, + $generator, "Modulus which impls `ConstPrimeMontyParams`" ); }; - ($name:ident, $uint_type:ty, $value:expr, $doc:expr) => { - $crate::const_monty_params!( - $name, - $uint_type, - $value, - "Modulus which impls `ConstPrimeMontyParams`" - ); + ($name:ident, $uint_type:ty, $value:expr, $generator:literal, $doc:expr) => { + $crate::const_monty_params!($name, $uint_type, $value, $doc); impl $crate::modular::ConstPrimeMontyParams<{ <$uint_type>::LIMBS }> for $name { const PRIME_PARAMS: $crate::modular::PrimeParams<{ <$uint_type>::LIMBS }> = $crate::modular::PrimeParams::new_vartime( &<$name as $crate::modular::ConstMontyParams<{ <$uint_type>::LIMBS }>>::PARAMS, - ) - .expect("cannot derive prime parameters"); + $generator, + ); } }; } diff --git a/src/modular/const_monty_form/sqrt.rs b/src/modular/const_monty_form/sqrt.rs index 49f2116cc..2368961da 100644 --- a/src/modular/const_monty_form/sqrt.rs +++ b/src/modular/const_monty_form/sqrt.rs @@ -39,7 +39,8 @@ mod tests { const_prime_monty_params!( P256Field, U256, - "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff", + 6 ); assert_eq!(P256Field::PRIME_PARAMS.s().get(), 1); type ConstForm = ConstMontyForm; diff --git a/src/modular/prime_params.rs b/src/modular/prime_params.rs index e33b39f07..83f8ee0f9 100644 --- a/src/modular/prime_params.rs +++ b/src/modular/prime_params.rs @@ -4,7 +4,7 @@ use core::num::NonZeroU32; use ctutils::{CtAssignSlice, CtEqSlice, CtSelectUsingCtAssign}; use super::{FixedMontyForm, FixedMontyParams}; -use crate::{Choice, CtAssign, CtEq, Odd, Uint}; +use crate::{Choice, CtAssign, CtEq, OddUint, Uint}; #[cfg(feature = "subtle")] use crate::CtSelect; @@ -13,8 +13,10 @@ use crate::CtSelect; /// with a prime modulus. #[derive(Debug, Copy, Clone)] pub struct PrimeParams { - /// A constant such that the modulus `p = t•2^s+1` for `s > 0` and some odd `t`. + /// The largest power of two that divides `(modulus - 1)`. pub(super) s: NonZeroU32, + /// The result of dividing `modulus - 1` by `2^s`. + pub(super) t: OddUint, /// The smallest primitive root of the modulus. pub(super) generator: NonZeroU32, /// The exponent to use in computing a modular square root. @@ -28,18 +30,21 @@ pub struct PrimeParams { impl PrimeParams { /// Instantiates a new set of [`PrimeParams`] given [`FixedMontyParams`] for a prime modulus. /// - /// This method will return `None` if the modulus is determined to be non-prime, however - /// this is not an exhaustive check and non-prime values can be accepted. + /// The value `generator` must be a multiplicative generator (ie. primitive element) of the + /// finite field, having order `modulus-1`. Its powers generate all nonzero elements of the + /// field: `generator^0, generator^1, ..., generator^(modulus-2)` enumerate `[1, modulus-1]`. #[must_use] #[allow(clippy::unwrap_in_result, clippy::missing_panics_doc)] - pub const fn new_vartime(params: &FixedMontyParams) -> Option { + pub const fn new_vartime(params: &FixedMontyParams, generator: u32) -> Self { let p = params.modulus(); let p_minus_one = p.as_ref().set_bit_vartime(0, false); + let generator = NonZeroU32::new(generator).expect("invalid generator"); let s = NonZeroU32::new(p_minus_one.trailing_zeros_vartime()).expect("ensured non-zero"); - - let Some(generator) = find_primitive_root(p) else { - return None; - }; + let t = p + .as_ref() + .shr(s.get()) + .to_odd() + .expect_copied("ensured odd"); // if s=1 and p is a power of a prime then -1 is always a root of unity let (exp, root) = if s.get() == 1 { @@ -57,19 +62,21 @@ impl PrimeParams { FixedMontyForm::new(&Uint::from_u32(generator.get()), params).pow_vartime(&t); // root^(2^(s-1)) must be equal to -1 let check = root.square_repeat_vartime(s.get() - 1); - if !Uint::eq(&check.retrieve(), &p_minus_one).to_bool_vartime() { - return None; - } + assert!( + Uint::eq(&check.retrieve(), &p_minus_one).to_bool_vartime(), + "error calculating root of unity: invalid generator" + ); (exp, root) }; - Some(Self { + Self { s, + t, generator, sqrt_exp: exp, monty_root_unity: root.to_montgomery(), monty_root_unity_p2: root.square().to_montgomery(), - }) + } } /// Get the constant 'generator' used in modular square root calculation. @@ -83,6 +90,12 @@ impl PrimeParams { pub const fn s(&self) -> NonZeroU32 { self.s } + + /// Get the constant 't' used in modular square root calculation. + #[must_use] + pub const fn t(&self) -> OddUint { + self.t + } } impl CtAssign for PrimeParams { @@ -128,41 +141,6 @@ impl subtle::ConditionallySelectable for PrimeParams } } -#[allow(clippy::unwrap_in_result)] -const fn find_primitive_root(p: &Odd>) -> Option { - // A primitive root exists iff p is 1, 2, 4, q^k or 2q^k, k > 0, q is an odd prime. - // Find a quadratic non-residue (primitive roots are non-residue for powers of a prime) - let mut g = NonZeroU32::new(2u32).expect("ensured non-zero"); - let (mut skip_root, mut skip_square) = (2u32, 4u32); - loop { - // Either the modulus is prime and g is quadratic non-residue, or - // the modulus is composite. - let g_uint = Uint::<1>::from_u32(g.get()); - match g_uint.jacobi_symbol_vartime(p) as i8 { - -1 => { - break Some(g); - } - 0 => { - // Modulus is composite - return None; - } - _ => loop { - let Some(g2) = g.checked_add(1) else { - return None; - }; - g = g2; - if g.get() == skip_square { - // Skip obviously square values (4, 9, 16..) - skip_root += 1; - skip_square = skip_root.saturating_pow(2); - } else { - break; - } - }, - } - } -} - #[cfg(test)] mod tests { use super::PrimeParams; @@ -172,29 +150,36 @@ mod tests { fn check_expected() { let monty_params = MontyParams::new_vartime(Odd::::from_be_hex("e38af050d74b8567f73c8713cbc7bc47")); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 5); assert_eq!(prime_params.s.get(), 1); assert_eq!(prime_params.generator.get(), 5); } + #[should_panic] + #[test] + fn check_invalid_generator() { + let monty_params = + MontyParams::new_vartime(Odd::::from_be_hex("e38af050d74b8567f73c8713cbc7bc47")); + let _ = PrimeParams::new_vartime(&monty_params, 0); + } + + #[should_panic] #[test] fn check_non_prime() { let monty_params = MontyParams::new_vartime(Odd::::from_be_hex("e38af050d74b8567f73c8713cbc7bc01")); - assert!(PrimeParams::new_vartime(&monty_params).is_none()); + let _ = PrimeParams::new_vartime(&monty_params, 5); } #[test] fn check_equality() { let monty_params_1 = MontyParams::new_vartime(Odd::::from_be_hex("e38af050d74b8567f73c8713cbc7bc47")); - let prime_params_1 = - PrimeParams::new_vartime(&monty_params_1).expect("failed creating params"); + let prime_params_1 = PrimeParams::new_vartime(&monty_params_1, 5); let monty_params_2 = MontyParams::new_vartime(Odd::::from_be_hex("f2799d643ab7ff983437c3a86cdb1beb")); - let prime_params_2 = - PrimeParams::new_vartime(&monty_params_2).expect("failed creating params"); + let prime_params_2 = PrimeParams::new_vartime(&monty_params_2, 5); assert!(CtEq::ct_eq(&prime_params_1, &prime_params_1).to_bool_vartime()); #[cfg(feature = "subtle")] diff --git a/src/modular/sqrt.rs b/src/modular/sqrt.rs index 3e40eaa35..4977c4d78 100644 --- a/src/modular/sqrt.rs +++ b/src/modular/sqrt.rs @@ -71,21 +71,21 @@ pub const fn sqrt_montgomery_form( let neg_zeta_b = monty_eq(&neg_zeta, &ru_2); let zeta_d = monty_eq(&zeta, &ru_6); - // m = B if -zeta in (B, C), else 1 + // m = B if -zeta in {B, C}, else 1 let mut m = monty_select( &FixedMontyForm::one(ru.params()), &ru_2, neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)), ); - // m = C if zeta in (-1, D) + // m = C if zeta in {-1, D} m = monty_select( &m, &ru_4, Uint::eq(neg_zeta.as_montgomery(), monty_params.one()).or(zeta_d), ); - // m = D if zeta in (B, C) + // m = D if zeta in {B, C} m = monty_select(&m, &ru_6, zeta_b.or(monty_eq(&zeta, &ru_4))); - // m = m•ru if zeta or -zeta in (B, D) + // m = m•ru if zeta or -zeta in {B, D} m = monty_select( &m, &m.mul(&ru), @@ -208,9 +208,9 @@ mod tests { let monty_params = FixedMontyParams::new_vartime(Odd::::from_be_hex( "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff", )); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 6); assert_eq!(prime_params.s.get(), 1); - assert_eq!(prime_params.generator.get(), 3); + assert_eq!(prime_params.generator.get(), 6); assert_eq!( root_of_unity(&monty_params, &prime_params), U256::from_be_hex("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE") @@ -226,7 +226,7 @@ mod tests { let monty_params = FixedMontyParams::new_vartime(Odd::::from_be_hex( "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", )); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 2); assert_eq!(prime_params.s.get(), 2); assert_eq!(prime_params.generator.get(), 2); assert_eq!( @@ -244,7 +244,7 @@ mod tests { let monty_params = FixedMontyParams::new_vartime(Odd::::from_be_hex( "00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409", )); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 3); assert_eq!(prime_params.s.get(), 3); assert_eq!(prime_params.generator.get(), 3); assert_eq!( @@ -264,7 +264,7 @@ mod tests { let monty_params = FixedMontyParams::new_vartime(Odd::::from_be_hex( "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", )); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 7); assert_eq!(prime_params.s.get(), 4); assert_eq!(prime_params.generator.get(), 7); assert_eq!( @@ -282,12 +282,12 @@ mod tests { let monty_params = FixedMontyParams::new_vartime(Odd::::from_be_hex( "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", )); - let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params"); + let prime_params = PrimeParams::new_vartime(&monty_params, 7); assert_eq!(prime_params.s.get(), 6); - assert_eq!(prime_params.generator.get(), 5); + assert_eq!(prime_params.generator.get(), 7); assert_eq!( root_of_unity(&monty_params, &prime_params), - U256::from_be_hex("0D1F8EAB98DCD1ACA7DC810E065710CBB96E9ABEBBE451FA15B4F83D2D2AD232") + U256::from_be_hex("0C1DC060E7A91986DF9879A3FBC483A898BDEAB680756045992F4B5402B052F2") ); test_monty_sqrt(monty_params, prime_params);