Feature-complete for implementing real crypto

This commit is contained in:
Pascal Engélibert 2025-08-29 22:58:23 +02:00
commit f34f931030
13 changed files with 2684 additions and 306 deletions

1
.gitignore vendored
View file

@ -1,2 +1 @@
/target
*.sage.py

184
Cargo.lock generated
View file

@ -8,13 +8,71 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "cfg-if"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "getrandom"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "gwrizienn"
version = "0.1.0"
dependencies = [
"num-traits",
"rand",
"rand_core",
"rayon",
"zeroize",
]
[[package]]
name = "libc"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "num-traits"
version = "0.2.19"
@ -23,3 +81,129 @@ checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rayon"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "syn"
version = "2.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "wasi"
version = "0.11.1+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
[[package]]
name = "zerocopy"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"

View file

@ -4,4 +4,16 @@ version = "0.1.0"
edition = "2024"
[dependencies]
num-traits = "0.2.19"
num-traits = "0.2"
rand = { version = "0.8", optional = true }
rand_core = { version = "0.6", optional = true }
zeroize = { version = "1", optional = true, default-features = false }
[dev-dependencies]
rayon = "1.8"
[features]
default = ["rand", "zeroize"]
rand = ["dep:rand", "dep:rand_core"]
zeroize = ["dep:zeroize"]

View file

@ -12,10 +12,17 @@ Goals:
Non-goals:
* Generic (it's generic but only for primitive types)
* Complete (it's simple because it's not a complete implementation of modern algebra)
* Dynamic (vector dimensions and moduli are strongly typed)
Supported:
* ring Zq
* ring Zq/(x^N+1) with additive operations
* ring Zq/(x^N+1) with multiplicative operations if q=p or q=2p with p prime and 2N divides p-1
* vectors and matrices of the above rings
## Name
It's Breton for "root", because we use roots of unity to compute the NTT for faster polynomial multiplication.
Pronounce _grizienn_. It's Breton for "root", because we use roots of unity to compute the NTT for faster O(N log N) polynomial multiplication.
## License

77
examples/dilithium.rs Normal file
View file

@ -0,0 +1,77 @@
//! Example of simplified Dilithium with q=8380417
use gwrizienn::{
matrix::Matrix,
ntt::{Ntt, NttInv},
ring::Ring,
vector::Vector,
*,
};
// Implement Zq
ring!(Zq, u32, u64, i64, 8380417);
// Implement Rq = Zq/(x^256+1)
// zeta=1753 is the first 512-th root of unity mod q
poly!(Rq, 256, Zq, u32, u64, u32, u64, 8380417, 1753);
fn high_bits<const N: usize>(mut v: Vector<Rq, N>) -> Vector<Rq, N> {
for vi in v.0.iter_mut() {
for vij in vi.0.iter_mut() {
vij.0 -= vij.0 % 190464;
}
}
v
}
fn main() {
let mut rng = rand::thread_rng();
let uniform = Zq::uniform();
let ball_c = Zq::uniform_ball(1);
let ball_s = Zq::uniform_ball(2);
let ball_y = Zq::uniform_ball(131071);
// generate secret key
let a = Matrix::<Rq, 4, 4>::random(uniform, &mut rng);
let s1 = Vector::<Rq, 4>::random(ball_s, &mut rng);
let s2 = Vector::<Rq, 4>::random(ball_s, &mut rng);
// random value for signing
let y = Vector::<Rq, 4>::random(ball_y, &mut rng);
// challenge
let c = Rq::random(ball_c, &mut rng);
// use NTT for fast multiplication
let a = a.ntt();
let s1 = s1.ntt();
let s2 = s2.ntt();
let y = y.ntt();
let c = c.ntt();
// generate public key
let t = &a * &s1 + s2;
// commitment
let w = &a * &y;
// proof
let z = y + s1 * &c;
// verify
assert_eq!(
high_bits((&a * &z - t * &c).ntt_inv()),
high_bits(w.ntt_inv())
);
// let uniform = Zq::uniform();
// let ball_c = Zq::uniform_ball(1);
// let ball_s = Zq::uniform_ball(2);
// let ball_y = Zq::uniform_ball(131071);
//
// let a = Matrix::<Rq, 4, 4>::random(uniform, &mut rng).ntt();
// let s1 = Vector::<Rq, 4>::random(ball_s, &mut rng).ntt();
// let s2 = Vector::<Rq, 4>::random(ball_s, &mut rng).ntt();
// let y = Vector::<Rq, 4>::random(ball_y, &mut rng).ntt();
// let c = Rq::random(ball_c, &mut rng).ntt();
//
// let t = &a * &s1 + s2;
// let w = &a * &y;
// let z = y + s1 * &c;
}

101
examples/ntwe.rs Normal file
View file

@ -0,0 +1,101 @@
//! Example of simplified NTWE scheme (Gärtner 2024)
use gwrizienn::{
matrix::Matrix,
ntt::{Ntt, NttInv},
ring::{Lift, Ring},
vector::{Vector, VectorRef},
*,
};
use num_traits::{Inv, One, Zero};
// Implement Zq
ring!(Zq, u32, u32, i32, 50177);
// Implement Rq = Zq/(x^256+1)
// zeta=66 is the first 512-th root of unity mod q
poly!(Rq, 256, Zq, u32, u32, u32, u64, 50177, 66);
// Implement Zb and Rb with a big modulus for NTT in Z2q
// chosen because > N*q^2
ring!(Zb, u64, u128, i128, 644539222529);
poly!(
Rb,
256,
Zb,
u64,
u128,
u64,
u128,
644539222529,
483489047161
);
//Implement Z2q
ring!(Z2q, u32, u64, i64, 100354);
// Implement R2q
// A different macro must be used because 2q is not prime.
poly2!(
Rq,
Rb,
R2q,
256,
Z2q,
u32,
u64,
u32,
u64,
i64,
i64,
50177,
644539222529
);
const L: usize = 3;
const M: usize = 2;
fn main() {
let mut rng = rand::thread_rng();
let uniform = Zq::uniform();
let ball_c = Z2q::uniform_positive_semiball(1);
let ball_s = Zq::uniform_ball(1);
let ball_y = Zb::uniform_ball(55);
// generate secret key s = [f s0 e]
let mut s = Vector::<Rq, { L + M + 1 }>::random(ball_s, &mut rng);
let f0 = &mut s.0[0];
*f0 *= Zq(2);
*f0 += Zq::one();
let s_ntt = s.clone().ntt();
let f = &s_ntt.0[0];
let s0: VectorRef<_, L> = s_ntt.get_sub(1);
let e: VectorRef<_, M> = s_ntt.get_sub(1 + L);
// generate public key
let a0 = Matrix::<Rq, M, L>::random(uniform, &mut rng);
let b = (&a0.clone().ntt() * s0 + e) * &f.clone().inv();
let mut a = Matrix::<R2q, M, { L + M + 1 }>::zero();
a.set_column(0, (b.ntt_inv().lift() * Z2q(100352)).get_ref());
a[0][0] += Z2q(50177);
a.set_columns(1, &(a0.lift() * Z2q(2)));
for i in 0..M {
a[i][1 + L + i] += Z2q(2);
}
let a: Matrix<Rb, M, { L + M + 1 }> = a.lift();
let a = a.ntt();
// random value for signing
let y = Vector::<Rb, { L + M + 1 }>::random(ball_y, &mut rng);
let y = y.ntt();
// commitment
let w: Vector<R2q, M> = (&a * &y).ntt_inv().lift();
// challenge
let c = R2q::random(ball_c, &mut rng);
let cb: Rb = c.lift();
// proof
let s: Vector<Rb, { L + M + 1 }> = s.lift().lift();
let z = y + s.ntt() * &cb.ntt();
// verify
let mut w2: Vector<R2q, M> = (&a * &z).ntt_inv().lift();
w2[0] -= c * Z2q(50177);
assert_eq!(w, w2);
}

View file

@ -1,7 +1,33 @@
#![no_std]
#![warn(missing_docs)]
#![deny(non_ascii_idents)]
#![deny(unnameable_types)]
#![deny(unreachable_pub)]
#![deny(unstable_features)]
#![warn(unused_qualifications)]
#![allow(clippy::tabs_in_doc_comments)]
//! Modular and polynomial arithmetic.
pub mod matrix;
pub mod ntt;
pub mod poly;
//pub mod r7681;
pub mod ring;
pub mod tuple;
pub mod vector;
/// Something that can be sampled using a distribution D over T
#[cfg(feature = "rand")]
pub trait Random<D: rand::distributions::Distribution<T>, T> {
/// Sample an element of T from distribution `distr`
fn random<R: rand::Rng>(distr: D, rng: &mut R) -> Self;
}
/// Type having a unity, or multiplicative neutral element
///
/// This trait is useful because `num_traits::One` requires `Mul<Self>`, which is not implemented for polynomials.
pub trait One {
/// the unity
fn one() -> Self;
/// is it the unity?
fn is_one(&self) -> bool;
}

View file

@ -1 +1,401 @@
//! Matrices of ring elements or polynomials.
use crate::vector::{Dot, Vector, VectorRef};
use core::{
mem::MaybeUninit,
ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
};
use num_traits::Zero;
/// Matrix of E of M rows and N columns
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Matrix<E, const M: usize, const N: usize>(pub [Vector<E, N>; M]);
/// Referenced matrix of E of M rows and N columns
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct MatrixRef<'a, E, const M: usize, const N: usize>(pub &'a [Vector<E, N>; M]);
/// Mutable referenced matrix of E of M rows and N columns
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct MatrixMut<'a, E, const M: usize, const N: usize>(pub &'a mut [Vector<E, N>; M]);
impl<'a, E, const M: usize, const N: usize> Clone for MatrixRef<'a, E, M, N> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, E, const M: usize, const N: usize> Copy for MatrixRef<'a, E, M, N> {}
impl<E: Clone, const M: usize, const N: usize> Matrix<E, M, N> {
/// Set column `idx`
pub fn set_column(&mut self, idx: usize, v: VectorRef<E, M>) {
for (row, x) in self.0.iter_mut().zip(v.0.iter()) {
row.0[idx] = x.clone();
}
}
/// Set `S` columns starting at column `idx` to values from the matrix `v`
pub fn set_columns<const S: usize>(&mut self, idx: usize, v: &Matrix<E, M, S>) {
for (s_row, v_row) in self.0.iter_mut().zip(v.0.iter()) {
for (s_x, v_x) in s_row.0[idx..idx + S].iter_mut().zip(v_row.0.iter()) {
*s_x = v_x.clone();
}
}
}
}
impl<E, const M: usize, const N: usize> Zero for Matrix<E, M, N>
where
Vector<E, N>: Zero,
Self: Add<Self, Output = Self>,
{
fn zero() -> Self {
let mut z = MaybeUninit::<[Vector<E, N>; M]>::uninit();
unsafe {
for i in z.assume_init_mut() {
*i = Vector::zero();
}
}
Self(unsafe { z.assume_init() })
}
fn is_zero(&self) -> bool {
self.0.iter().all(Zero::is_zero)
}
}
impl<E, const N: usize> crate::One for Matrix<E, N, N>
where
E: crate::One + Zero,
Self: Zero,
{
fn one() -> Self {
let mut id = Self::zero();
for (i, a) in id.0.iter_mut().enumerate() {
a.0[i] = E::one();
}
id
}
fn is_one(&self) -> bool {
self.0.iter().enumerate().all(|(i, row)| {
row.0
.iter()
.enumerate()
.all(|(j, v)| if i == j { v.is_one() } else { v.is_zero() })
})
}
}
impl<E, const M: usize, const N: usize> Add<Self> for Matrix<E, M, N>
where
Vector<E, N>: AddAssign<Vector<E, N>>,
{
type Output = Self;
fn add(mut self, rhs: Self) -> Self {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x += y);
self
}
}
impl<'a, E, const M: usize, const N: usize> Add<&'a Self> for Matrix<E, M, N>
where
Vector<E, N>: 'a + AddAssign<&'a Vector<E, N>>,
{
type Output = Self;
fn add(mut self, rhs: &'a Self) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
self
}
}
impl<'a, E, const M: usize, const N: usize> Add<MatrixRef<'a, E, M, N>> for Matrix<E, M, N>
where
Vector<E, N>: 'a + AddAssign<&'a Vector<E, N>>,
{
type Output = Self;
fn add(mut self, rhs: MatrixRef<'a, E, M, N>) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
self
}
}
impl<E, const M: usize, const N: usize> AddAssign<Self> for Matrix<E, M, N>
where
Vector<E, N>: AddAssign<Vector<E, N>>,
{
fn add_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x += y);
}
}
impl<'a, E, const M: usize, const N: usize> AddAssign<&'a Self> for Matrix<E, M, N>
where
Vector<E, N>: 'a + AddAssign<&'a Vector<E, N>>,
{
fn add_assign(&mut self, rhs: &'a Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
}
}
impl<'a, E, const M: usize, const N: usize> AddAssign<MatrixRef<'a, E, M, N>> for Matrix<E, M, N>
where
Vector<E, N>: 'a + AddAssign<&'a Vector<E, N>>,
{
fn add_assign(&mut self, rhs: MatrixRef<'a, E, M, N>) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
}
}
impl<E, const M: usize, const N: usize> Sub<Self> for Matrix<E, M, N>
where
Vector<E, N>: SubAssign<Vector<E, N>>,
{
type Output = Self;
fn sub(mut self, rhs: Self) -> Self {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x -= y);
self
}
}
impl<E, const M: usize, const N: usize> SubAssign<Self> for Matrix<E, M, N>
where
Vector<E, N>: SubAssign<Vector<E, N>>,
{
fn sub_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x -= y);
}
}
impl<'a, E, const M: usize, const N: usize> Sub<&'a Self> for Matrix<E, M, N>
where
Vector<E, N>: 'a + SubAssign<&'a Vector<E, N>>,
{
type Output = Self;
fn sub(mut self, rhs: &'a Self) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
self
}
}
impl<'a, E, const M: usize, const N: usize> SubAssign<&'a Self> for Matrix<E, M, N>
where
Vector<E, N>: 'a + SubAssign<&'a Vector<E, N>>,
{
fn sub_assign(&mut self, rhs: &'a Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
}
}
impl<'a, E, const M: usize, const N: usize> Sub<MatrixRef<'a, E, M, N>> for Matrix<E, M, N>
where
Vector<E, N>: 'a + SubAssign<&'a Vector<E, N>>,
{
type Output = Self;
fn sub(mut self, rhs: MatrixRef<'a, E, M, N>) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
self
}
}
impl<'a, E, const M: usize, const N: usize> SubAssign<MatrixRef<'a, E, M, N>> for Matrix<E, M, N>
where
Vector<E, N>: 'a + SubAssign<&'a Vector<E, N>>,
{
fn sub_assign(&mut self, rhs: MatrixRef<'a, E, M, N>) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
}
}
impl<E, O, const M: usize, const N: usize> crate::ntt::Ntt for Matrix<E, M, N>
where
E: crate::ntt::Ntt<Output = O>,
{
type Output = Matrix<O, M, N>;
fn ntt(self) -> Matrix<O, M, N> {
let mut z = MaybeUninit::<[Vector<O, N>; M]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.ntt();
}
}
Matrix(unsafe { z.assume_init() })
}
}
impl<E, O, const M: usize, const N: usize> crate::ntt::NttInv for Matrix<E, M, N>
where
E: crate::ntt::NttInv<Output = O>,
{
type Output = Matrix<O, M, N>;
fn ntt_inv(self) -> Matrix<O, M, N> {
let mut z = MaybeUninit::<[Vector<O, N>; M]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.ntt_inv();
}
}
Matrix(unsafe { z.assume_init() })
}
}
// TODO only for safe in-place liftable elements
impl<Fr, To, T, U, const M: usize, const N: usize> crate::ring::Lift<Matrix<To, M, N>, T, U>
for Matrix<Fr, M, N>
where
Fr: crate::ring::Lift<To, T, U>,
{
fn lift(self) -> Matrix<To, M, N> {
let mut z = MaybeUninit::<[Vector<To, N>; M]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.lift();
}
}
Matrix(unsafe { z.assume_init() })
}
}
impl<E, F, const M: usize, const N: usize> Mul<F> for Matrix<E, M, N>
where
E: MulAssign<F>,
F: Copy,
{
type Output = Self;
fn mul(mut self, rhs: F) -> Self {
for row in self.0.iter_mut() {
for v in row.0.iter_mut() {
*v *= rhs;
}
}
self
}
}
impl<'a, E, const M: usize, const N: usize> Mul<&'a Vector<E, N>> for &'a Matrix<E, M, N>
where
&'a Vector<E, N>: Dot<&'a Vector<E, N>, Output = E>,
Vector<E, M>: Zero,
{
type Output = Vector<E, M>;
fn mul(self, rhs: &'a Vector<E, N>) -> Vector<E, M> {
let mut res = Self::Output::zero();
res.0
.iter_mut()
.zip(self.0.iter())
.for_each(|(r, s)| *r = s.dot(rhs));
res
}
}
impl<'a, E, const M: usize, const N: usize> Mul<VectorRef<'a, E, N>> for &'a Matrix<E, M, N>
where
&'a Vector<E, N>: Dot<VectorRef<'a, E, N>, Output = E>,
Vector<E, M>: Zero,
{
type Output = Vector<E, M>;
fn mul(self, rhs: VectorRef<'a, E, N>) -> Vector<E, M> {
let mut res = Self::Output::zero();
res.0
.iter_mut()
.zip(self.0.iter())
.for_each(|(r, s)| *r = s.dot(rhs));
res
}
}
#[cfg(feature = "rand")]
impl<T, D, E, const M: usize, const N: usize> crate::Random<D, T> for Matrix<E, M, N>
where
D: Clone + rand::distributions::Distribution<T>,
E: crate::Random<D, T>,
{
fn random<R: rand::Rng>(distr: D, rng: &mut R) -> Self {
let mut z = MaybeUninit::<[Vector<E, N>; M]>::uninit();
unsafe {
for i in z.assume_init_mut() {
*i = Vector::random(distr.clone(), rng);
}
}
Self(unsafe { z.assume_init() })
}
}
impl<E, const M: usize, const N: usize> core::ops::Index<usize> for Matrix<E, M, N> {
type Output = Vector<E, N>;
fn index(&self, idx: usize) -> &Vector<E, N> {
&self.0[idx]
}
}
impl<E, const M: usize, const N: usize> core::ops::IndexMut<usize> for Matrix<E, M, N> {
fn index_mut(&mut self, idx: usize) -> &mut Vector<E, N> {
&mut self.0[idx]
}
}
#[cfg(feature = "zeroize")]
impl<E: zeroize::Zeroize, const M: usize, const N: usize> zeroize::Zeroize for Matrix<E, M, N> {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl<E: core::fmt::Display, const M: usize, const N: usize> core::fmt::Display for Matrix<E, M, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "[")?;
for i in self.0.iter() {
write!(f, "{},", i)?;
}
write!(f, "]")
}
}
#[cfg(test)]
#[allow(unreachable_pub)]
#[allow(unnameable_types)]
mod test {
use super::*;
use crate::{ntt::Ntt, poly, ring};
ring!(Z50177, u32, u32, i64, 50177);
poly!(R50177, 256, Z50177, u32, u32, u32, u64, 50177, 66);
#[test]
fn test_basic() {
Matrix::<R50177, 2, 2>::zero();
}
#[test]
fn test_mul() {
let a = Matrix::<R50177, 2, 2>::zero();
let b = Vector::<R50177, 2>::zero();
let x = R50177::one();
let _ = a.clone().ntt() * &x.ntt();
let _ = &a.clone().ntt() * &b.ntt();
}
}

View file

@ -1,72 +1,84 @@
use num_traits::Zero;
use std::ops::{Add, AddAssign, Sub, SubAssign};
//! Number-Theoretic Transform (NTT) is kind of a discrete Fourier Transform that transforms a polynomial to the NTT domain, where multiplication can be done efficiently in O(N) instead of O(N^2). Addition can also be performed efficiently in the NTT domain. The `NttDomain` struct ensures mathematical soundness and domains are not used together.
use core::{
marker::PhantomData,
ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign},
};
use num_traits::{Inv, Zero};
/// Element E in the NTT domain
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct NttDomain<E>(pub E);
pub struct NttDomain<E, EE, const N: usize>(pub E, pub PhantomData<[EE; N]>);
/// An element that can be transformed by the NTT
pub trait Ntt: Sized {
/// Type in the NTT domain
type Output: Sized;
/// Apply the NTT
fn ntt(self) -> Self::Output;
}
/// An element in the NTT domain that can be transformed back by the inverse NTT
pub trait NttInv: Sized {
/// Type out of the NTT domain
type Output: Sized;
/// Apply the inverse NTT
fn ntt_inv(self) -> Self::Output;
}
impl<E> Add<NttDomain<E>> for NttDomain<E>
impl<E, EE, const N: usize> Add<NttDomain<E, EE, N>> for NttDomain<E, EE, N>
where
E: Add<E, Output = E>,
{
type Output = NttDomain<E>;
fn add(self, rhs: NttDomain<E>) -> Self::Output {
NttDomain(self.0 + rhs.0)
type Output = NttDomain<E, EE, N>;
fn add(self, rhs: NttDomain<E, EE, N>) -> Self::Output {
NttDomain(self.0 + rhs.0, Default::default())
}
}
impl<'a, E> Add<&'a NttDomain<E>> for NttDomain<E>
impl<'a, E, EE, const N: usize> Add<&'a NttDomain<E, EE, N>> for NttDomain<E, EE, N>
where
E: 'a + Add<&'a E, Output = E>,
{
type Output = NttDomain<E>;
fn add(self, rhs: &'a NttDomain<E>) -> Self::Output {
NttDomain(self.0 + &rhs.0)
type Output = NttDomain<E, EE, N>;
fn add(self, rhs: &'a NttDomain<E, EE, N>) -> Self::Output {
NttDomain(self.0 + &rhs.0, Default::default())
}
}
impl<E> Sub<NttDomain<E>> for NttDomain<E>
impl<E, EE, const N: usize> Sub<NttDomain<E, EE, N>> for NttDomain<E, EE, N>
where
E: Sub<E, Output = E>,
{
type Output = NttDomain<E>;
fn sub(self, rhs: NttDomain<E>) -> Self::Output {
NttDomain(self.0 - rhs.0)
type Output = NttDomain<E, EE, N>;
fn sub(self, rhs: NttDomain<E, EE, N>) -> Self::Output {
NttDomain(self.0 - rhs.0, Default::default())
}
}
impl<'a, E> Sub<&'a NttDomain<E>> for NttDomain<E>
impl<'a, E, EE, const N: usize> Sub<&'a NttDomain<E, EE, N>> for NttDomain<E, EE, N>
where
E: 'a + Sub<&'a E, Output = E>,
{
type Output = NttDomain<E>;
fn sub(self, rhs: &'a NttDomain<E>) -> Self::Output {
NttDomain(self.0 - &rhs.0)
type Output = NttDomain<E, EE, N>;
fn sub(self, rhs: &'a NttDomain<E, EE, N>) -> Self::Output {
NttDomain(self.0 - &rhs.0, Default::default())
}
}
impl<E> Zero for NttDomain<E>
impl<E, EE, const N: usize> Zero for NttDomain<E, EE, N>
where
E: Zero,
{
fn zero() -> Self {
NttDomain(E::zero())
NttDomain(E::zero(), Default::default())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl<E> AddAssign for NttDomain<E>
impl<E, EE, const N: usize> AddAssign for NttDomain<E, EE, N>
where
E: AddAssign,
{
@ -75,7 +87,7 @@ where
}
}
impl<'a, E> AddAssign<&'a Self> for NttDomain<E>
impl<'a, E, EE, const N: usize> AddAssign<&'a Self> for NttDomain<E, EE, N>
where
E: 'a + AddAssign<&'a E>,
{
@ -84,7 +96,7 @@ where
}
}
impl<E> SubAssign for NttDomain<E>
impl<E, EE, const N: usize> SubAssign for NttDomain<E, EE, N>
where
E: SubAssign,
{
@ -93,7 +105,7 @@ where
}
}
impl<'a, E> SubAssign<&'a Self> for NttDomain<E>
impl<'a, E, EE, const N: usize> SubAssign<&'a Self> for NttDomain<E, EE, N>
where
E: 'a + SubAssign<&'a E>,
{
@ -101,3 +113,207 @@ where
self.0 -= &rhs.0;
}
}
impl<E, EE, const N: usize> Mul for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
type Output = Self;
fn mul(mut self, rhs: Self) -> Self {
self.0
.coefficients_mut()
.iter_mut()
.zip(rhs.0.coefficients().iter())
.for_each(|(x, y)| *x *= *y);
self
}
}
impl<E, EE, const N: usize> Mul<&Self> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
type Output = Self;
fn mul(mut self, rhs: &Self) -> Self {
self.0
.coefficients_mut()
.iter_mut()
.zip(rhs.0.coefficients().iter())
.for_each(|(x, y)| *x *= *y);
self
}
}
impl<E, EE, const N: usize> MulAssign for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
fn mul_assign(&mut self, rhs: Self) {
self.0
.coefficients_mut()
.iter_mut()
.zip(rhs.0.coefficients().iter())
.for_each(|(x, y)| *x *= *y);
}
}
impl<E, EE, const N: usize> MulAssign<&Self> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
fn mul_assign(&mut self, rhs: &Self) {
self.0
.coefficients_mut()
.iter_mut()
.zip(rhs.0.coefficients().iter())
.for_each(|(x, y)| *x *= *y);
}
}
impl<E, EE, const N: usize> Mul<EE> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
type Output = Self;
fn mul(mut self, rhs: EE) -> Self {
self.0.coefficients_mut().iter_mut().for_each(|x| *x *= rhs);
self
}
}
impl<E, EE, const N: usize> MulAssign<EE> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + MulAssign,
{
fn mul_assign(&mut self, rhs: EE) {
self.0.coefficients_mut().iter_mut().for_each(|x| *x *= rhs);
}
}
impl<E, EE, const N: usize> Inv for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + Inv<Output = EE>,
{
type Output = Self;
fn inv(mut self) -> Self {
self.0
.coefficients_mut()
.iter_mut()
.for_each(|x| *x = x.inv());
self
}
}
impl<E, EE, const N: usize> NttInv for NttDomain<E, EE, N>
where
E: crate::poly::InternalNttInv<Input = E, Output = E>,
{
type Output = E;
fn ntt_inv(self) -> E {
E::__ntt_inv(self.0)
}
}
impl<E, EE, const N: usize> crate::poly::Poly<N> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: crate::ring::Ring<<E as crate::poly::Poly<N>>::T> + Copy,
<E as crate::poly::Poly<N>>::T: Clone,
{
type T = <E as crate::poly::Poly<N>>::T;
type Element = EE;
fn coefficients(&self) -> &[EE; N] {
self.0.coefficients()
}
fn coefficients_mut(&mut self) -> &mut [EE; N] {
self.0.coefficients_mut()
}
fn from_scalar(e: EE) -> Self {
Self(E::from_coefficients([e; N]), Default::default())
}
fn to_coefficients(self) -> [EE; N] {
self.0.to_coefficients()
}
fn from_coefficients(v: [EE; N]) -> Self {
Self(E::from_coefficients(v), Default::default())
}
}
#[cfg(feature = "rand")]
impl<T, E, EE, const N: usize> crate::Random<crate::ring::UniformZq<T, EE>, T>
for NttDomain<E, EE, N>
where
E: crate::Random<crate::ring::UniformZq<T, EE>, T>,
crate::ring::UniformZq<T, EE>: rand::distributions::Distribution<T>,
{
fn random<R: rand::Rng>(distr: crate::ring::UniformZq<T, EE>, rng: &mut R) -> Self {
Self(E::random(distr, rng), Default::default())
}
}
impl<E, EE, const N: usize> Index<usize> for NttDomain<E, EE, N>
where
E: Index<usize, Output = EE>,
{
type Output = EE;
fn index(&self, idx: usize) -> &EE {
&self.0[idx]
}
}
impl<E, EE, const N: usize> IndexMut<usize> for NttDomain<E, EE, N>
where
E: Index<usize, Output = EE> + IndexMut<usize>,
{
fn index_mut(&mut self, idx: usize) -> &mut EE {
&mut self.0[idx]
}
}
#[cfg(feature = "zeroize")]
impl<E: zeroize::Zeroize, EE, const N: usize> zeroize::Zeroize for NttDomain<E, EE, N> {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl<E: core::fmt::Display, EE, const N: usize> core::fmt::Display for NttDomain<E, EE, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.0.fmt(f)
}
}
/// Compute self + a * b in a way that may be faster
pub trait Fma<A, B> {
/// Compute self + a * b in a way that may be faster
///
/// You are encouraged to benchmark this against the naive operations.
/// Compiler optimizations may render this function uninteresting.
fn fma(&mut self, a: A, b: B);
}
impl<E, EE, const N: usize> Fma<&Self, &Self> for NttDomain<E, EE, N>
where
E: crate::poly::Poly<N, Element = EE>,
EE: Copy + Mul<Output = EE> + AddAssign,
{
fn fma(&mut self, a: &Self, b: &Self) {
self.0
.coefficients_mut()
.iter_mut()
.zip(a.0.coefficients().iter().zip(b.0.coefficients().iter()))
.for_each(|(si, (ai, bi))| *si += *ai * *bi);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,28 +1,162 @@
pub trait Ring {
type T;
//! Macros and traits for modular rings Z/qZ.
use num_traits::cast::AsPrimitive;
/// An element of ring Z/qZ represented by type `T`
pub trait Ring<T>: AsRef<T> + AsMut<T> + Sized {
/// q
fn modulus() -> T;
/// q/2 rounded down
fn modulus_half() -> T;
/// New element from a scalar, safely reducing modulo q
fn from_scalar(v: T) -> Self;
/// New element from a scalar, without reduction modulo q
///
/// # Safety
/// If v >= q, following operations may produce wrong results or panic.
unsafe fn from_scalar_unchecked(v: T) -> Self;
/// New element from a scalar, assuming 0 <= v < 2q
fn from_scalar_small_reduce(v: T) -> Self;
/// positive representative in [0..q-1]
fn to_scalar(self) -> T;
/// Sample a uniform element in the ring
fn sample_uniform(rng: impl rand::Rng) -> Self;
/// Uniform distribution in the ring
fn uniform() -> UniformZq<T, Self>;
/// Uniform distribution in a ball of maximum infinity norm `max`
fn uniform_ball(max_norm: T) -> UniformZq<T, Self>;
/// Uniform distribution in a positive semiball of maximum infinity norm `max`
fn uniform_positive_semiball(max_norm: T) -> UniformZq<T, Self>;
/// Centered representative in [-q/2, q/2]
fn center<TI>(self) -> TI
where
Self: Copy,
T: AsPrimitive<TI> + Copy + Ord,
TI: 'static + Copy + core::ops::Sub<TI, Output = TI>,
{
if self.to_scalar() <= Self::modulus_half() {
self.to_scalar().as_()
} else {
self.to_scalar().as_() - Self::modulus().as_()
}
}
}
/// Implement a modular ring
///
/// - `$Zq`: the struct that will be created
/// - `$T`: (defined) scalar type big enough to represent 2q
/// - `$TT`: (defined) scalar type big enough to represent (q-1)^2
/// - `$TI`: (defined) signed scalar type big enough to represent (q-1)^2
/// - `$q`: prime modulus
///
/// ```
/// gwrizienn::ring!(Z50177, u32, u64, i64, 50177);
/// assert_eq!(Z50177(50176) + Z50177(43), Z50177(42));
/// assert_eq!(Z50177(123) * Z50177(123).inverse(), Z50177(1));
/// ```
#[macro_export]
macro_rules! ring {
($Zq:ident, $T:ident, $TT:ident, $TI:ident, $q:expr) => {
/// Element of Z/qZ
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct $Zq(pub $T);
impl $crate::ring::Ring for $Zq {
type T = $T;
impl AsRef<$T> for $Zq {
fn as_ref(&self) -> &$T {
&self.0
}
}
impl AsMut<$T> for $Zq {
fn as_mut(&mut self) -> &mut $T {
&mut self.0
}
}
impl $crate::ring::Ring<$T> for $Zq {
fn modulus() -> $T {
$q
}
fn modulus_half() -> $T {
$q / 2
}
fn from_scalar(v: $T) -> Self {
Self(v % $q)
}
fn from_scalar_small_reduce(v: $T) -> Self {
Self(Self::small_reduce(v))
}
fn to_scalar(self) -> $T {
self.0
}
unsafe fn from_scalar_unchecked(v: $T) -> Self {
Self(v)
}
//#[cfg(feature = "rand")]
//#[allow(dead_code)]
fn sample_uniform(mut rng: impl rand::Rng) -> Self {
Self(rng.gen_range(0..$q))
}
//#[cfg(feature = "rand")]
//#[allow(dead_code)]
fn uniform() -> $crate::ring::UniformZq<$T, Self> {
use rand::distributions::uniform::UniformSampler;
$crate::ring::UniformZq::new_inclusive(Self(0), Self($q - 1))
}
/// Uniform distribution in a ball of maximum infinity norm `max`
//#[cfg(feature = "rand")]
//#[allow(dead_code)]
fn uniform_ball(max_norm: $T) -> $crate::ring::UniformZq<$T, Self> {
use rand::distributions::uniform::UniformSampler;
$crate::ring::UniformZq::new_inclusive(Self($q - max_norm), Self($q + max_norm))
}
/// Uniform distribution in a positive semiball of maximum infinity norm `max`
//#[cfg(feature = "rand")]
//#[allow(dead_code)]
fn uniform_positive_semiball(max_norm: $T) -> $crate::ring::UniformZq<$T, Self> {
use rand::distributions::uniform::UniformSampler;
$crate::ring::UniformZq::new_inclusive(Self(0), Self(max_norm))
}
}
impl $Zq {
pub const fn modulus() -> $T {
$q
}
/// Create an element
#[allow(dead_code)]
pub const fn new(v: $T) -> Self {
Self(v % $q)
}
/// Reduce v modulo q, assuming that 0 <= v < 2q
#[allow(dead_code)]
const fn small_reduce(v: $T) -> $T {
if v < $q { v } else { v - $q }
}
/// Compute the inverse modulo q
///
/// May output garbage if `self` is not invertible.
#[allow(dead_code)]
const fn inverse(self) -> Self {
let (mut ro, mut r) = ($q, self.0);
let (mut to, mut t) = (0, 1);
@ -38,36 +172,48 @@ macro_rules! ring {
}
}
impl std::ops::Add for $Zq {
impl num_traits::AsPrimitive<$T> for $Zq {
fn as_(self) -> $T {
self.0
}
}
impl num_traits::AsPrimitive<$Zq> for $T {
fn as_(self) -> $Zq {
$Zq(self)
}
}
impl core::ops::Add for $Zq {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(Self::small_reduce(self.0 + rhs.0))
}
}
impl std::ops::AddAssign for $Zq {
impl core::ops::AddAssign for $Zq {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl std::ops::Sub for $Zq {
impl core::ops::Sub for $Zq {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(Self::small_reduce($q + self.0 - rhs.0))
}
}
impl std::ops::SubAssign for $Zq {
impl core::ops::SubAssign for $Zq {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl std::ops::Neg for $Zq {
impl core::ops::Neg for $Zq {
type Output = Self;
fn neg(self) -> Self {
Self(Self::small_reduce($q - self.0))
if self.0 == 0 { self } else { Self($q - self.0) }
}
}
@ -80,15 +226,26 @@ macro_rules! ring {
}
}
impl std::ops::Mul for $Zq {
impl core::ops::Mul for $Zq {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let product = self.0 as $TT * rhs.0 as $TT;
Self((product % $q as $TT) as $T)
// Barrett
// let a = self.0 as $TT;
// let b = rhs.0 as $TT;
// const r: $TT = ($q as $TT).next_power_of_two() as _;
// let res = (a * b - a * (b * r / $q as $TT) / r * $q as $TT) as $T;
// Self(if res >= $q {
// res - $q
// } else {
// res
// })
}
}
impl std::ops::MulAssign for $Zq {
impl core::ops::MulAssign for $Zq {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
@ -109,12 +266,168 @@ macro_rules! ring {
self.inverse()
}
}
// TODO how to check our feature from where the macro is called?
//#[cfg(feature = "rand")]
impl rand::distributions::uniform::SampleUniform for $Zq {
type Sampler = $crate::ring::UniformZq<$T, $Zq>;
}
//#[cfg(feature = "rand")]
impl<D: rand::distributions::Distribution<Self>> $crate::Random<D, Self> for $Zq {
fn random<R: rand::Rng>(distr: D, rng: &mut R) -> Self {
distr.sample(rng)
}
}
impl<D: rand::distributions::Distribution<$TI>> $crate::Random<D, $TI> for $Zq {
fn random<R: rand::Rng>(distr: D, rng: &mut R) -> Self {
let r = distr.sample(rng);
if r < 0 {
Self(($q + r) as $T)
} else {
Self(r as $T)
}
}
}
impl $crate::ring::Norm2 for $Zq {
type Output = $T;
type OutputSquared = $TT;
fn norm2(&self) -> $T {
use $crate::ring::Ring;
self.center::<$TI>().unsigned_abs() as $T
}
fn norm2_squared(&self) -> $TT {
use $crate::ring::Ring;
let v: $TI = self.center::<$TI>();
(v * v) as $TT
}
}
impl $crate::ring::NormInf for $Zq {
type Output = $T;
fn norm_inf(&self) -> $T {
use $crate::ring::Ring;
self.center::<$TI>().abs() as _
}
}
impl core::fmt::Display for $Zq {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.0.fmt(f)
}
}
};
}
crate::ring!(Z2, u8, u8, i8, 2);
/// Uniform distribution over a ring
#[cfg(feature = "rand")]
#[derive(Clone, Copy, Debug)]
pub struct UniformZq<T, Zq>(
pub(crate) rand::distributions::uniform::UniformInt<T>,
pub(crate) core::marker::PhantomData<Zq>,
);
#[cfg(feature = "rand")]
impl<T: rand::distributions::uniform::SampleUniform, Zq: Copy + Ring<T>>
rand::distributions::uniform::UniformSampler for UniformZq<T, Zq>
where
rand::distributions::uniform::UniformInt<T>:
rand::distributions::uniform::UniformSampler<X = T>,
{
type X = Zq;
fn new<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
UniformZq(
rand::distributions::uniform::UniformInt::<T>::new(
low.borrow().to_scalar(),
high.borrow().to_scalar(),
),
Default::default(),
)
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
UniformZq(
rand::distributions::uniform::UniformInt::<T>::new_inclusive(
low.borrow().to_scalar(),
high.borrow().to_scalar(),
),
Default::default(),
)
}
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Zq::from_scalar_small_reduce(self.0.sample(rng))
}
}
impl<T, Zq> rand::distributions::Distribution<Zq> for UniformZq<T, Zq>
where
Self: rand::distributions::uniform::UniformSampler<X = Zq>,
Zq: rand::distributions::uniform::SampleUniform,
{
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Zq {
rand::distributions::uniform::UniformSampler::sample(self, rng)
}
}
//crate::ring!(Z2, u8, u8, i8, 2);
/// Lifting between types
pub trait Lift<To, T, U> {
/// Lift to another modular ring
///
/// Injective only if the destination ring is bigger.
fn lift(self) -> To;
//unsafe fn lift_unchecked(self) -> To;
}
/*impl<T, Fr: Ring<T>, To: Ring<T>> Lift<To, T, T> for Fr {
fn lift(self) -> To {
To::from_scalar(self.to_scalar())
}
unsafe fn lift_unchecked(self) -> To {
To::from_scalar_unchecked(self.to_scalar())
}
}*/
/*impl<T: Into<U>, U, Fr: Ring<T>, To: Ring<U>> Lift<To, T, U> for Fr {
fn lift(self) -> To {
To::from_scalar(self.to_scalar().into())
}
/*unsafe fn lift_unchecked(self) -> To {
To::from_scalar_unchecked(self.to_scalar().into())
}*/
}*/
/// Euclidean norm
pub trait Norm2 {
/// Output type
type Output;
/// Output type
type OutputSquared;
/// Euclidean norm
fn norm2(&self) -> Self::Output;
/// Square of the euclidean norm (may be faster than the norm, depending on implementation)
fn norm2_squared(&self) -> Self::OutputSquared;
}
/// Infinity norm
pub trait NormInf {
/// Output type
type Output;
/// Infinity norm
fn norm_inf(&self) -> Self::Output;
}
#[cfg(test)]
#[allow(unreachable_pub, unnameable_types)]
mod test {
use super::*;
@ -129,5 +442,10 @@ mod test {
assert_eq!(Z50177::new(1), Z50177::new(50178));
assert_eq!(Z50177::new(123) * Z50177::new(123).inv(), Z50177::one());
assert_eq!(Z50177::new(1).norm2_squared(), 1);
assert_eq!(Z50177::new(50176).norm2_squared(), 1);
assert_eq!(Z50177::new(2).norm2_squared(), 4);
assert_eq!(Z50177::new(50175).norm2_squared(), 4);
}
}

View file

@ -1,22 +1,82 @@
use std::{
//! Vector types that can contain modular scalars (`Ring`) or polynomials (`Poly`).
//! Matrices must not be implemented as vectors of vectors; use the `matrix` module instead.
//use crate::ntt::Fma;
use core::{
mem::MaybeUninit,
ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign},
};
use num_traits::Zero;
use num_traits::{One, Zero};
/// Vector of E
pub trait VectorTrait<E>: Index<usize, Output = E> + IndexMut<usize> {}
impl<E, const N: usize> VectorTrait<E> for Vector<E, N> {}
/// Vector of dimension N
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Vector<E, const N: usize>(pub [E; N]);
/// Reference to a vector
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct VectorRef<'a, E, const N: usize>(pub &'a [E; N]);
/// Mutable reference to a vector
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct VectorMut<'a, E, const N: usize>(pub &'a mut [E; N]);
impl<'a, E, const N: usize> Clone for VectorRef<'a, E, N> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, E, const N: usize> Copy for VectorRef<'a, E, N> {}
/*impl<const N: usize> Vector<Poly, N> {
pub fn norm_inf(&self) -> Zq {
Zq(self.0.iter().map(|i| i.norm_inf().0).max().unwrap())
}
}*/
impl<E, const N: usize> Vector<E, N> {
/// Concatenate two vectors
pub fn concat<const M: usize, const MN: usize>(self, other: Vector<E, M>) -> Vector<E, MN> {
assert_eq!(M + N, MN);
let mut z = MaybeUninit::<[E; MN]>::uninit();
unsafe {
for (new, old) in z
.assume_init_mut()
.iter_mut()
.zip(self.0.into_iter().chain(other.0.into_iter()))
{
*new = old;
}
}
Vector(unsafe { z.assume_init() })
}
/// Get a reference to a subvector
pub fn get_sub<'a, const M: usize>(&'a self, idx: usize) -> VectorRef<'a, E, M>
where
&'a [E]: TryInto<&'a [E; M]>,
{
VectorRef(if let Ok(sub) = &self.0[idx..idx + M].try_into() {
sub
} else {
panic!("Subvector out of bounds")
})
}
/// Get a reference to the vector
pub fn get_ref(&self) -> VectorRef<'_, E, N> {
VectorRef(&self.0)
}
}
impl<E, const N: usize> Zero for Vector<E, N>
where
E: AddAssign + Zero,
E: Zero,
Self: Add<Self, Output = Self>,
{
fn zero() -> Self {
let mut z = MaybeUninit::<[E; N]>::uninit();
@ -57,6 +117,53 @@ where
}
}
impl<'a, E, const N: usize> Add<VectorRef<'a, E, N>> for Vector<E, N>
where
E: 'a + AddAssign<&'a E>,
{
type Output = Self;
fn add(mut self, rhs: VectorRef<'a, E, N>) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
self
}
}
impl<E, const N: usize> AddAssign<Self> for Vector<E, N>
where
E: AddAssign<E>,
{
fn add_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x += y);
}
}
impl<'a, E, const N: usize> AddAssign<&'a Self> for Vector<E, N>
where
E: 'a + AddAssign<&'a E>,
{
fn add_assign(&mut self, rhs: &'a Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
}
}
impl<'a, E, const N: usize> AddAssign<VectorRef<'a, E, N>> for Vector<E, N>
where
E: 'a + AddAssign<&'a E>,
{
fn add_assign(&mut self, rhs: VectorRef<'a, E, N>) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x += y);
}
}
impl<E, const N: usize> Sub<Self> for Vector<E, N>
where
E: SubAssign<E>,
@ -68,6 +175,15 @@ where
}
}
impl<E, const N: usize> SubAssign<Self> for Vector<E, N>
where
E: SubAssign<E>,
{
fn sub_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x -= y);
}
}
impl<'a, E, const N: usize> Sub<&'a Self> for Vector<E, N>
where
E: 'a + SubAssign<&'a E>,
@ -82,26 +198,98 @@ where
}
}
/// Vector-scalar product
impl<'a, E, const N: usize> Mul<&'a E> for Vector<E, N>
impl<'a, E, const N: usize> SubAssign<&'a Self> for Vector<E, N>
where
E: 'a + Clone + std::fmt::Debug + Zero + MulAssign<&'a E>,
E: 'a + SubAssign<&'a E>,
{
fn sub_assign(&mut self, rhs: &'a Self) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
}
}
impl<'a, E, const N: usize> Sub<VectorRef<'a, E, N>> for Vector<E, N>
where
E: 'a + SubAssign<&'a E>,
{
type Output = Self;
fn sub(mut self, rhs: VectorRef<'a, E, N>) -> Self {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
self
}
}
impl<'a, E, const N: usize> SubAssign<VectorRef<'a, E, N>> for Vector<E, N>
where
E: 'a + SubAssign<&'a E>,
{
fn sub_assign(&mut self, rhs: VectorRef<'a, E, N>) {
self.0
.iter_mut()
.zip(rhs.0.iter())
.for_each(|(x, y)| *x -= y);
}
}
/// Vector-scalar product
impl<E, F: Copy, const N: usize> Mul<F> for Vector<E, N>
where
E: MulAssign<F>,
{
type Output = Vector<E, N>;
fn mul(mut self, rhs: &'a E) -> Vector<E, N> {
fn mul(mut self, rhs: F) -> Vector<E, N> {
self.0.iter_mut().for_each(|s| *s *= rhs);
self
}
}
/// Inner product
// TODO optimize (reduce and copy less often)
impl<'a, E, const N: usize> Mul<&'a Self> for Vector<E, N>
/// Vector-scalar product
impl<E, F: Copy, const N: usize> MulAssign<F> for Vector<E, N>
where
E: 'a + Clone + Zero + MulAssign<&'a E> + Add<E>,
E: MulAssign<F>,
{
fn mul_assign(&mut self, rhs: F) {
self.0.iter_mut().for_each(|s| *s *= rhs);
}
}
/// Inner product
pub trait Dot<Rhs> {
/// Result type (should be the vector's element type)
type Output;
/// Inner product
fn dot(self, rhs: Rhs) -> Self::Output;
}
impl<E, const N: usize> Dot<Self> for Vector<E, N>
where
E: Zero + MulAssign<E> + Add<E>,
{
type Output = E;
fn mul(self, rhs: &'a Self) -> E {
// TODO optimize (reduce and copy less often)
fn dot(self, rhs: Self) -> E {
self.0
.into_iter()
.zip(rhs.0)
.fold(E::zero(), |e, (mut x, y)| {
x *= y;
e + x
})
}
}
impl<'a, E, const N: usize> Dot<&'a Self> for Vector<E, N>
where
E: 'a + Zero + MulAssign<&'a E> + Add<E>,
{
type Output = E;
// TODO optimize (reduce and copy less often)
fn dot(self, rhs: &'a Self) -> E {
self.0
.into_iter()
.zip(rhs.0.iter())
@ -112,74 +300,222 @@ where
}
}
/// Inner product
// TODO optimize (reduce and copy less often)
impl<'a, E, const N: usize> Mul<Self> for &'a Vector<E, N>
impl<'a, E, const N: usize> Dot<VectorRef<'a, E, N>> for Vector<E, N>
where
E: 'a + Clone + Zero + Mul<&'a E, Output = E> + Add<E>,
E: 'a + Zero + MulAssign<&'a E> + Add<E>,
{
type Output = E;
fn mul(self, rhs: Self) -> E {
// TODO optimize (reduce and copy less often)
fn dot(self, rhs: VectorRef<'a, E, N>) -> E {
self.0
.into_iter()
.zip(rhs.0.iter())
.fold(E::zero(), |e, (mut x, y)| {
x *= y;
e + x
})
}
}
impl<'a, E, const N: usize> Dot<Self> for &'a Vector<E, N>
where
E: 'a + Clone + Mul<&'a E, Output = E> + Add<E, Output = E> + Zero,
//E: Fma<&'a E, &'a E>
{
type Output = E;
fn dot(self, rhs: Self) -> E {
self.0
.iter()
.zip(rhs.0.iter())
.fold(E::zero(), |e, (x, y)| e + x.clone() * y)
//.fold(E::zero(), |mut e, (x, y)| {
// e.fma(x, y);
// e
//})
}
}
/// Matrix-vector product
impl<'a, E, const M: usize, const N: usize> Mul<&'a Vector<E, M>> for Vector<Vector<E, M>, N>
impl<'a, E, const N: usize> Dot<VectorRef<'a, E, N>> for &'a Vector<E, N>
where
E: 'a + Clone + std::fmt::Debug + Zero + MulAssign<&'a E> + Add<E> + AddAssign<E>,
E: 'a + Clone + Mul<&'a E, Output = E> + Add<E, Output = E> + Zero,
//E: Fma<&'a E, &'a E>
{
type Output = Vector<E, N>;
fn mul(self, rhs: &'a Vector<E, M>) -> Vector<E, N> {
let mut res = Self::Output::zero();
res.0.iter_mut().zip(self.0).for_each(|(r, s)| *r = s * rhs);
res
type Output = E;
fn dot(self, rhs: VectorRef<'a, E, N>) -> E {
self.0
.iter()
.zip(rhs.0.iter())
.fold(E::zero(), |e, (x, y)| e + x.clone() * y)
//.fold(E::zero(), |mut e, (x, y)| {
// e.fma(x, y);
// e
//})
}
}
/// Matrix-vector product
impl<'a, E, const M: usize, const N: usize> Mul<&'a Vector<E, M>> for &'a Vector<Vector<E, M>, N>
where
E: 'a + Clone + std::fmt::Debug + Zero + Mul<&'a E, Output = E> + Add<E> + AddAssign<E>,
{
type Output = Vector<E, N>;
fn mul(self, rhs: &'a Vector<E, M>) -> Vector<E, N> {
let mut res = Self::Output::zero();
res.0
.iter_mut()
.zip(self.0.iter())
.for_each(|(r, s)| *r = s * rhs);
res
impl<E, const N: usize> Index<usize> for Vector<E, N> {
type Output = E;
fn index(&self, idx: usize) -> &E {
&self.0[idx]
}
}
impl<E, const N: usize> Vector<Vector<E, N>, N>
impl<E, const N: usize> IndexMut<usize> for Vector<E, N> {
fn index_mut(&mut self, idx: usize) -> &mut E {
&mut self.0[idx]
}
}
#[cfg(feature = "rand")]
impl<T, D, E, const N: usize> crate::Random<D, T> for Vector<E, N>
where
E: Zero + One,
Vector<E, N>: AddAssign + Zero,
D: Clone + rand::distributions::Distribution<T>,
E: crate::Random<D, T>,
{
pub fn id() -> Self {
let mut id = Self::zero();
for (i, a) in id.0.iter_mut().enumerate() {
a.0[i] = E::one();
fn random<R: rand::Rng>(distr: D, rng: &mut R) -> Self {
let mut z = MaybeUninit::<[E; N]>::uninit();
unsafe {
for i in z.assume_init_mut() {
*i = E::random(distr.clone(), rng);
}
}
id
Self(unsafe { z.assume_init() })
}
}
impl<E, O, const N: usize> crate::ntt::Ntt for Vector<E, N>
where
E: crate::ntt::Ntt<Output = O>,
{
type Output = Vector<O, N>;
fn ntt(self) -> Vector<O, N> {
let mut z = MaybeUninit::<[O; N]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.ntt();
}
}
Vector(unsafe { z.assume_init() })
}
}
impl<E, O, const N: usize> crate::ntt::NttInv for Vector<E, N>
where
E: crate::ntt::NttInv<Output = O>,
{
type Output = Vector<O, N>;
fn ntt_inv(self) -> Vector<O, N> {
let mut z = MaybeUninit::<[O; N]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.ntt_inv();
}
}
Vector(unsafe { z.assume_init() })
}
}
impl<Fr, To, T, U, const N: usize> crate::ring::Lift<Vector<To, N>, T, U> for Vector<Fr, N>
where
Fr: crate::ring::Lift<To, T, U>,
{
fn lift(self) -> Vector<To, N> {
let mut z = MaybeUninit::<[To; N]>::uninit();
unsafe {
for (i, j) in z.assume_init_mut().iter_mut().zip(self.0) {
*i = j.lift();
}
}
Vector(unsafe { z.assume_init() })
}
/*unsafe fn lift_unchecked(self) -> $R2q {
let mut ret: $R2q = core::mem::transmute(self);
ret.0.iter_mut().for_each(|x| if x.0 > $q / 2 {
x.0 += $q;
});
ret
}*/
}
impl<T, TT, E, const N: usize> crate::ring::Norm2 for Vector<E, N>
where
TT: Zero + AddAssign<TT>,
E: crate::ring::Norm2<Output = T, OutputSquared = TT>,
{
type Output = T;
type OutputSquared = TT;
fn norm2(&self) -> T {
todo!()
}
fn norm2_squared(&self) -> TT {
let mut ret = TT::zero();
for i in self.0.iter() {
ret += i.norm2_squared();
}
ret
}
}
impl<T, E, const N: usize> crate::ring::NormInf for Vector<E, N>
where
T: Ord,
E: crate::ring::NormInf<Output = T>,
{
type Output = T;
fn norm_inf(&self) -> T {
self.0
.iter()
.map(crate::ring::NormInf::norm_inf)
.max()
.unwrap()
}
}
impl<E, const N: usize> core::ops::Neg for Vector<E, N>
where
Self: Zero + Sub<Self, Output = Self>,
{
type Output = Self;
fn neg(self) -> Self {
Self::zero() - self
}
}
#[cfg(feature = "zeroize")]
impl<E: zeroize::Zeroize, const N: usize> zeroize::Zeroize for Vector<E, N> {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl<E: core::fmt::Display, const N: usize> core::fmt::Display for Vector<E, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "[")?;
for i in self.0.iter() {
write!(f, "{},", i)?;
}
write!(f, "]")
}
}
impl<'a, E, const N: usize> From<VectorRef<'a, E, N>> for Vector<E, N>
where
&'a [E; N]: Into<[E; N]>,
{
fn from(s: VectorRef<'a, E, N>) -> Vector<E, N> {
Vector(s.0.into())
}
}
#[cfg(test)]
#[allow(unreachable_pub, unnameable_types)]
mod test {
use super::*;
use crate::{poly, ring};
use num_traits::Inv;
crate::ring!(Z50177, u32, u64, i64, 50177);
poly!(R50177, 256, Z50177, u32, u64, 50177, 66);
ring!(Z50177, u32, u32, i64, 50177);
poly!(R50177, 256, Z50177, u32, u32, u32, u64, 50177, 66);
#[test]
fn test_basic() {

View file

@ -1,39 +0,0 @@
q = 50177
Fq.<yq> = PolynomialRing(ZZ.quotient(q))
Rq = Fq.quotient(yq**256+1, "xq")
xq = Rq.gen()
F2q.<y2q> = PolynomialRing(ZZ.quotient(q*2))
R2q = F2q.quotient(y2q**256+1, "x2q")
x2q = R2q.gen()
F2.<y2> = PolynomialRing(ZZ.quotient(2))
R2 = F2.quotient(y2**256+1, "x2")
x2 = R2.gen()
def center(x, m):
x = x % m
if x < m/2:
return x
else:
return x - m
def f(x):
return (R2([int(i)%2 for i in x.list()]), Rq([int(i)%q for i in x.list()]))
def g(a, b):
al = a.list()
bl = b.list()
return R2q([int(al[i])*q-int(bl[i])*2*(q//2) for i in range(len(al))])
def add(ab, cd):
return (ab[0]+cd[0], ab[1]+cd[1])
def mul(ab, cd):
return (ab[0]*cd[0], ab[1]*cd[1])
a = -x2q**2 + 3*x2q + 2
b = 4*x2q + x2q**4
print(g(*mul(f(a), f(b))))
print(a*b)