Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions benches/const_monty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ fn bench_montgomery_sqrt<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
const_prime_monty_params!(
P256Field,
U256,
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
6
);
assert_eq!(P256Field::PRIME_PARAMS.s().get(), 1);
type ConstForm = crypto_bigint::modular::ConstMontyForm<P256Field, { U256::LIMBS }>;
Expand All @@ -227,7 +228,8 @@ fn bench_montgomery_sqrt<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
const_prime_monty_params!(
P256Scalar,
U256,
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
7
);
assert_eq!(P256Scalar::PRIME_PARAMS.s().get(), 4);
type ConstForm = crypto_bigint::modular::ConstMontyForm<P256Scalar, { U256::LIMBS }>;
Expand All @@ -254,7 +256,8 @@ fn bench_montgomery_sqrt<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
const_prime_monty_params!(
K256Scalar,
U256,
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
7
);
assert_eq!(K256Scalar::PRIME_PARAMS.s().get(), 6);
type ConstForm = crypto_bigint::modular::ConstMontyForm<K256Scalar, { U256::LIMBS }>;
Expand Down
21 changes: 9 additions & 12 deletions src/modular/const_monty_form/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ macro_rules! const_monty_params {
}

/// Create a type representing a prime modulus which impls the [`ConstPrimeMontyParams`]
/// trait with the given name, type, value (in big endian hex), and optional documentation
/// string.
/// trait with the given name, type, value (in big endian hex), multiplicative generator,
/// and optional documentation string.
///
/// # Usage
///
Expand All @@ -58,35 +58,32 @@ macro_rules! const_monty_params {
/// MyModulus,
/// U256,
/// "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001",
/// 7,
/// "Docs for my modulus"
/// );
/// ```
///
/// The modulus _must_ be odd and prime, or this will panic.
#[macro_export]
macro_rules! const_prime_monty_params {
($name:ident, $uint_type:ty, $value:expr) => {
($name:ident, $uint_type:ty, $value:expr, $generator:literal) => {
$crate::const_prime_monty_params!(
$name,
$uint_type,
$value,
$generator,
"Modulus which impls `ConstPrimeMontyParams`"
);
};
($name:ident, $uint_type:ty, $value:expr, $doc:expr) => {
$crate::const_monty_params!(
$name,
$uint_type,
$value,
"Modulus which impls `ConstPrimeMontyParams`"
);
($name:ident, $uint_type:ty, $value:expr, $generator:literal, $doc:expr) => {
$crate::const_monty_params!($name, $uint_type, $value, $doc);

impl $crate::modular::ConstPrimeMontyParams<{ <$uint_type>::LIMBS }> for $name {
const PRIME_PARAMS: $crate::modular::PrimeParams<{ <$uint_type>::LIMBS }> =
$crate::modular::PrimeParams::new_vartime(
&<$name as $crate::modular::ConstMontyParams<{ <$uint_type>::LIMBS }>>::PARAMS,
)
.expect("cannot derive prime parameters");
$generator,
);
}
};
}
Expand Down
3 changes: 2 additions & 1 deletion src/modular/const_monty_form/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ mod tests {
const_prime_monty_params!(
P256Field,
U256,
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
6
);
assert_eq!(P256Field::PRIME_PARAMS.s().get(), 1);
type ConstForm = ConstMontyForm<P256Field, { U256::LIMBS }>;
Expand Down
95 changes: 40 additions & 55 deletions src/modular/prime_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::num::NonZeroU32;
use ctutils::{CtAssignSlice, CtEqSlice, CtSelectUsingCtAssign};

use super::{FixedMontyForm, FixedMontyParams};
use crate::{Choice, CtAssign, CtEq, Odd, Uint};
use crate::{Choice, CtAssign, CtEq, OddUint, Uint};

#[cfg(feature = "subtle")]
use crate::CtSelect;
Expand All @@ -13,8 +13,10 @@ use crate::CtSelect;
/// with a prime modulus.
#[derive(Debug, Copy, Clone)]
pub struct PrimeParams<const LIMBS: usize> {
/// A constant such that the modulus `p = t•2^s+1` for `s > 0` and some odd `t`.
/// The largest power of two that divides `(modulus - 1)`.
pub(super) s: NonZeroU32,
/// The result of dividing `modulus - 1` by `2^s`.
pub(super) t: OddUint<LIMBS>,
/// The smallest primitive root of the modulus.
pub(super) generator: NonZeroU32,
/// The exponent to use in computing a modular square root.
Expand All @@ -28,18 +30,21 @@ pub struct PrimeParams<const LIMBS: usize> {
impl<const LIMBS: usize> PrimeParams<LIMBS> {
/// Instantiates a new set of [`PrimeParams`] given [`FixedMontyParams`] for a prime modulus.
///
/// This method will return `None` if the modulus is determined to be non-prime, however
/// this is not an exhaustive check and non-prime values can be accepted.
/// The value `generator` must be a multiplicative generator (ie. primitive element) of the
/// finite field, having order `modulus-1`. Its powers generate all nonzero elements of the
/// field: `generator^0, generator^1, ..., generator^(modulus-2)` enumerate `[1, modulus-1]`.
#[must_use]
#[allow(clippy::unwrap_in_result, clippy::missing_panics_doc)]
pub const fn new_vartime(params: &FixedMontyParams<LIMBS>) -> Option<Self> {
pub const fn new_vartime(params: &FixedMontyParams<LIMBS>, generator: u32) -> Self {
let p = params.modulus();
let p_minus_one = p.as_ref().set_bit_vartime(0, false);
let generator = NonZeroU32::new(generator).expect("invalid generator");
let s = NonZeroU32::new(p_minus_one.trailing_zeros_vartime()).expect("ensured non-zero");

let Some(generator) = find_primitive_root(p) else {
return None;
};
let t = p
.as_ref()
.shr(s.get())
.to_odd()
.expect_copied("ensured odd");

// if s=1 and p is a power of a prime then -1 is always a root of unity
let (exp, root) = if s.get() == 1 {
Expand All @@ -57,19 +62,21 @@ impl<const LIMBS: usize> PrimeParams<LIMBS> {
FixedMontyForm::new(&Uint::from_u32(generator.get()), params).pow_vartime(&t);
// root^(2^(s-1)) must be equal to -1
let check = root.square_repeat_vartime(s.get() - 1);
if !Uint::eq(&check.retrieve(), &p_minus_one).to_bool_vartime() {
return None;
}
assert!(
Uint::eq(&check.retrieve(), &p_minus_one).to_bool_vartime(),
"error calculating root of unity: invalid generator"
);
(exp, root)
};

Some(Self {
Self {
s,
t,
generator,
sqrt_exp: exp,
monty_root_unity: root.to_montgomery(),
monty_root_unity_p2: root.square().to_montgomery(),
})
}
}

/// Get the constant 'generator' used in modular square root calculation.
Expand All @@ -83,6 +90,12 @@ impl<const LIMBS: usize> PrimeParams<LIMBS> {
pub const fn s(&self) -> NonZeroU32 {
self.s
}

/// Get the constant 't' used in modular square root calculation.
#[must_use]
pub const fn t(&self) -> OddUint<LIMBS> {
self.t
}
}

impl<const LIMBS: usize> CtAssign for PrimeParams<LIMBS> {
Expand Down Expand Up @@ -128,41 +141,6 @@ impl<const LIMBS: usize> subtle::ConditionallySelectable for PrimeParams<LIMBS>
}
}

#[allow(clippy::unwrap_in_result)]
const fn find_primitive_root<const LIMBS: usize>(p: &Odd<Uint<LIMBS>>) -> Option<NonZeroU32> {
// A primitive root exists iff p is 1, 2, 4, q^k or 2q^k, k > 0, q is an odd prime.
// Find a quadratic non-residue (primitive roots are non-residue for powers of a prime)
let mut g = NonZeroU32::new(2u32).expect("ensured non-zero");
let (mut skip_root, mut skip_square) = (2u32, 4u32);
loop {
// Either the modulus is prime and g is quadratic non-residue, or
// the modulus is composite.
let g_uint = Uint::<1>::from_u32(g.get());
match g_uint.jacobi_symbol_vartime(p) as i8 {
-1 => {
break Some(g);
}
0 => {
// Modulus is composite
return None;
}
_ => loop {
let Some(g2) = g.checked_add(1) else {
return None;
};
g = g2;
if g.get() == skip_square {
// Skip obviously square values (4, 9, 16..)
skip_root += 1;
skip_square = skip_root.saturating_pow(2);
} else {
break;
}
},
}
}
}

#[cfg(test)]
mod tests {
use super::PrimeParams;
Expand All @@ -172,29 +150,36 @@ mod tests {
fn check_expected() {
let monty_params =
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc47"));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 5);
assert_eq!(prime_params.s.get(), 1);
assert_eq!(prime_params.generator.get(), 5);
}

#[should_panic]
#[test]
fn check_invalid_generator() {
let monty_params =
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc47"));
let _ = PrimeParams::new_vartime(&monty_params, 0);
}

#[should_panic]
#[test]
fn check_non_prime() {
let monty_params =
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc01"));
assert!(PrimeParams::new_vartime(&monty_params).is_none());
let _ = PrimeParams::new_vartime(&monty_params, 5);
}

#[test]
fn check_equality() {
let monty_params_1 =
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc47"));
let prime_params_1 =
PrimeParams::new_vartime(&monty_params_1).expect("failed creating params");
let prime_params_1 = PrimeParams::new_vartime(&monty_params_1, 5);

let monty_params_2 =
MontyParams::new_vartime(Odd::<U128>::from_be_hex("f2799d643ab7ff983437c3a86cdb1beb"));
let prime_params_2 =
PrimeParams::new_vartime(&monty_params_2).expect("failed creating params");
let prime_params_2 = PrimeParams::new_vartime(&monty_params_2, 5);

assert!(CtEq::ct_eq(&prime_params_1, &prime_params_1).to_bool_vartime());
#[cfg(feature = "subtle")]
Expand Down
24 changes: 12 additions & 12 deletions src/modular/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,21 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
let neg_zeta_b = monty_eq(&neg_zeta, &ru_2);
let zeta_d = monty_eq(&zeta, &ru_6);

// m = B if -zeta in (B, C), else 1
// m = B if -zeta in {B, C}, else 1
let mut m = monty_select(
&FixedMontyForm::one(ru.params()),
&ru_2,
neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)),
);
// m = C if zeta in (-1, D)
// m = C if zeta in {-1, D}
m = monty_select(
&m,
&ru_4,
Uint::eq(neg_zeta.as_montgomery(), monty_params.one()).or(zeta_d),
);
// m = D if zeta in (B, C)
// m = D if zeta in {B, C}
m = monty_select(&m, &ru_6, zeta_b.or(monty_eq(&zeta, &ru_4)));
// m = m•ru if zeta or -zeta in (B, D)
// m = m•ru if zeta or -zeta in {B, D}
m = monty_select(
&m,
&m.mul(&ru),
Expand Down Expand Up @@ -208,9 +208,9 @@ mod tests {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 6);
assert_eq!(prime_params.s.get(), 1);
assert_eq!(prime_params.generator.get(), 3);
assert_eq!(prime_params.generator.get(), 6);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE")
Expand All @@ -226,7 +226,7 @@ mod tests {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 2);
assert_eq!(prime_params.s.get(), 2);
assert_eq!(prime_params.generator.get(), 2);
assert_eq!(
Expand All @@ -244,7 +244,7 @@ mod tests {
let monty_params = FixedMontyParams::new_vartime(Odd::<U576>::from_be_hex(
"00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 3);
assert_eq!(prime_params.s.get(), 3);
assert_eq!(prime_params.generator.get(), 3);
assert_eq!(
Expand All @@ -264,7 +264,7 @@ mod tests {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 7);
assert_eq!(prime_params.s.get(), 4);
assert_eq!(prime_params.generator.get(), 7);
assert_eq!(
Expand All @@ -282,12 +282,12 @@ mod tests {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
));
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
let prime_params = PrimeParams::new_vartime(&monty_params, 7);
assert_eq!(prime_params.s.get(), 6);
assert_eq!(prime_params.generator.get(), 5);
assert_eq!(prime_params.generator.get(), 7);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("0D1F8EAB98DCD1ACA7DC810E065710CBB96E9ABEBBE451FA15B4F83D2D2AD232")
U256::from_be_hex("0C1DC060E7A91986DF9879A3FBC483A898BDEAB680756045992F4B5402B052F2")
);

test_monty_sqrt(monty_params, prime_params);
Expand Down