dualring-rs/src/lib.rs
2025-03-23 09:09:02 +01:00

623 lines
16 KiB
Rust

#![feature(array_chunks)]
#![feature(split_array)]
// To be more explicit about all the variable names
#![allow(clippy::let_and_return)]
// Have you seen how many indices the vectors have?
#![allow(clippy::type_complexity)]
mod util;
use bytemuck::{bytes_of, bytes_of_mut, cast_slice, cast_slice_mut};
use num_modular::{ModularCoreOps, ModularPow, ModularUnaryOps};
use rand::seq::SliceRandom;
use rand_core::CryptoRngCore;
use sha3::{
Shake128,
digest::{ExtendableOutput, Update, XofReader},
};
use std::marker::PhantomData;
const P: u128 = 0x7fff_ffff_ffff_ffff_ffff_ffff_ffff_ffff;
/// L_K(a) Legendre PRF
fn legendre_prf(a: u128) -> u128 {
1.subm(a.powm(0x3fffffffffffffffffffffffffffffff, &P), &P) / 2
}
fn random_fp(mut rng: impl CryptoRngCore) -> u128 {
let mut bytes = [0; 16];
rng.fill_bytes(&mut bytes);
// approximate modulo to avoid complex multi-register arithmetic
bytes[0] &= 0x7f;
u128::from_be_bytes(bytes)
// Generated number is in [0..P]
// Proba(0 = x mod P) = 2^-126
// so we are in in the ring [0..P-1] with overwhelming probability
}
pub struct DualRing {
/// B
b: usize,
/// M
m: usize,
/// N
n: usize,
/// tau
tau: usize,
}
impl DualRing {
pub fn new_fast() -> Self {
Self {
b: 10,
m: 127,
n: 16,
tau: 74,
}
}
pub fn keygen(rng: impl CryptoRngCore) -> (u128, [u128; 2 * 254]) {
let sk = random_fp(rng);
let pk = Self::power_residue_prf_l(sk);
(sk, pk)
}
pub fn sign(
&self,
ring: &[[u128; 2 * 254]],
sk: u128,
sk_index: usize,
m: &[u8],
mut rng: impl CryptoRngCore,
) {
// Phase 1
let mut salt = [0; 32];
rng.fill_bytes(&mut salt);
/*let challenges: Vec<([u8; L], Vec<u32>)> = ring.iter().enumerate().map(|(i, pk_i)| {
if i == sk_index {
([0; L], Vec::new())
} else {
let mut challenge_seed = [0; L];
rng.fill_bytes(&mut challenge_seed);
let mut challenge = vec![0; self.tau as usize * self.b as usize];
Self::expand_indices(&challenge_seed, L as _, &mut challenge);
(challenge_seed, challenge)
}
}).collect();*/
let challenges: Vec<[u8; 16]> = (0..ring.len())
.map(|i| {
let mut challenge = [0; 16];
if i != sk_index {
rng.fill_bytes(&mut challenge);
}
challenge
})
.collect();
struct Share {
k: u128,
a: u128,
b: u128,
c: u128,
r: Vec<u128>,
c_mask: Vec<[u8; 32]>,
c_mpc: [u8; 32],
}
struct Execution {
sd_mask: Vec<u128>,
mask: Vec<Vec<u128>>,
/// [b]
r_sum: Vec<u128>,
t: Vec<u128>,
dk: u128,
dc: u128,
shares: Vec<Share>,
indices: Vec<Vec<u32>>,
/// [k][j][b]
d_indices: Vec<Vec<Vec<u128>>>,
epsilon: u128,
lambda: Vec<u128>,
}
let mut executions: Vec<Execution> = (0..self.tau)
.map(|e| {
let mut sd_mask_e = vec![0u128; self.m];
rng.fill_bytes(bytes_of_mut(&mut sd_mask_e[0]));
Self::expand_tree(&mut sd_mask_e);
let mask_e: Vec<Vec<u128>> = (0..self.m)
.map(|k| {
let mut mask_e_k = vec![0u128; self.n];
mask_e_k[0] = sd_mask_e[k];
Self::expand_tree(&mut mask_e_k);
mask_e_k
})
.collect();
let mut sd_mpc_e = vec![0u128; self.n];
rng.fill_bytes(bytes_of_mut(&mut sd_mpc_e[0]));
Self::expand_tree(&mut sd_mpc_e);
let mut shares: Vec<Share> = (0..self.n)
.map(|i| {
let mut buf = vec![0; (4 + self.b) * size_of::<u128>()];
Self::expand(&sd_mpc_e[i].to_be_bytes(), &mut buf);
let mut buf_iter = buf.array_chunks();
let k_e_i = u128::from_be_bytes(*buf_iter.next().unwrap());
let a_e_i = u128::from_be_bytes(*buf_iter.next().unwrap());
let b_e_i = u128::from_be_bytes(*buf_iter.next().unwrap());
let c_e_i = u128::from_be_bytes(*buf_iter.next().unwrap());
let r_e_i: Vec<u128> = (0..self.b)
.zip(buf_iter)
.map(|(_bi, r_e_i_bi)| u128::from_be_bytes(*r_e_i_bi))
.collect();
let c_mpc_e_i = Self::hash([
&salt[..],
&e.to_be_bytes(),
&i.to_be_bytes(),
&sd_mpc_e[i].to_be_bytes(),
]);
let c_mask_e_i: Vec<[u8; 32]> = mask_e
.iter()
.enumerate()
.map(|(k, mask_e_k)| {
Self::hash([
&salt[..],
&e.to_be_bytes(),
&k.to_be_bytes(),
&i.to_be_bytes(),
&mask_e_k[i].to_be_bytes(),
])
})
.collect();
Share {
k: k_e_i,
a: a_e_i,
b: b_e_i,
c: c_e_i,
r: r_e_i,
c_mask: c_mask_e_i,
c_mpc: c_mpc_e_i,
}
})
.collect();
let dk_e: u128 = sk.subm(
shares
.iter()
.map(|share| share.k)
.fold(0, |x, y| x.addm(y, &P)),
&P,
);
shares[0].k = shares[0].k.addm(dk_e, &P);
let a_e_sum = shares
.iter()
.map(|share| share.a)
.fold(0, |x, y| x.addm(y, &P));
let b_e_sum = shares
.iter()
.map(|share| share.b)
.fold(0, |x, y| x.addm(y, &P));
let c_e_sum = shares
.iter()
.map(|share| share.c)
.fold(0, |x, y| x.addm(y, &P));
let dc_e = a_e_sum.mulm(b_e_sum, &P).subm(c_e_sum, &P);
shares[0].c = shares[0].c.addm(dc_e, &P);
let indices_e: Vec<Vec<u32>> = challenges
.iter()
.enumerate()
.map(|(j, challenge)| {
let mut indices_e_j = vec![0; self.b];
if j != sk_index {
Self::expand_indices(challenge, 2 * 254, &mut indices_e_j);
}
indices_e_j
})
.collect();
let (r_sum_e, t_e): (Vec<u128>, Vec<u128>) = (0..self.b)
.map(|b| {
let r_e_b = shares
.iter()
.map(|share| share.r[b])
.fold(0, |x, y| x.addm(y, &P));
let t_e_b = legendre_prf(r_e_b).subm(
indices_e
.iter()
.map(|indices_e_j| {
if indices_e_j.is_empty() {
0
} else {
ring[sk_index][indices_e_j[b] as usize]
}
})
.fold(0, |x, y| x.addm(y, &P)),
&P,
);
(r_e_b, t_e_b)
})
.collect();
Execution {
sd_mask: sd_mask_e,
mask: mask_e,
r_sum: r_sum_e,
t: t_e,
dk: dk_e,
dc: dc_e,
shares,
indices: indices_e,
d_indices: Vec::new(),
epsilon: 0,
lambda: Vec::new(),
}
})
.collect();
// Concatenate lots of stuff to be hashed!
let s1 = executions
.iter()
.flat_map(|e| {
[cast_slice(&e.t), bytes_of(&e.dk), bytes_of(&e.dc)]
.into_iter()
.chain(
e.shares
.iter()
.flat_map(|share| [cast_slice(&share.c_mask), &share.c_mpc].into_iter()),
)
});
// Phase 2
let h1 = Self::hash(
[&salt[..], m]
.into_iter()
.chain(ring.iter().map(|pk| cast_slice(pk)))
.chain(s1),
);
let signer_challenge = u128::from_be_bytes(*h1.split_array_ref::<16>().0)
^ challenges
.iter()
.map(|c| u128::from_be_bytes(*c))
.fold(0, |x, y| x ^ y);
// TODO optimize: hold only 1 copy of indices_e_pi for all e
// (they are all the same if I get it right from part 2, line 4)
for e in executions.iter_mut() {
Self::expand_indices(
&signer_challenge.to_be_bytes(),
2 * 254,
&mut e.indices[sk_index],
);
}
let s2: Vec<[u8; 32]> = executions
.iter_mut()
.enumerate()
.map(|(e_index, e)| {
let (c_d_e, d_indices_e, accs_e): (
Vec<Vec<[u8; 32]>>,
Vec<Vec<Vec<u128>>>,
Vec<[u8; 32]>,
) = e.mask
.iter()
.enumerate()
.map(|(k, mask_e_k)| {
let mask_e_k_sum = mask_e_k.iter().fold(0, |x, y| x.addm(y, &P));
let (c_d_e_k, d_indices_e_k): (Vec<[u8; 32]>, Vec<Vec<u128>>) = e
.indices
.iter()
.map(|indices_e_j| {
let d_indices_e_k_j: Vec<u128> = indices_e_j
.iter()
.map(|index_e_j_b| {
(*index_e_j_b as u128).subm(mask_e_k_sum, &P)
})
.collect();
let c_d_e_k_j = Self::hash([
&salt,
cast_slice(&[e_index as u32, k as u32]),
cast_slice(&d_indices_e_k_j),
]);
(c_d_e_k_j, d_indices_e_k_j)
})
.collect();
let mut phi: Vec<usize> = (0..ring.len()).collect();
phi.shuffle(&mut rng);
let acc_e_k =
Self::compute_tree_root(phi.iter().map(|i| c_d_e_k[*i]).collect());
(c_d_e_k, d_indices_e_k, acc_e_k)
})
.collect();
e.d_indices = d_indices_e;
let acc_e = Self::hash(
[salt.as_slice(), &e_index.to_be_bytes()]
.into_iter()
.chain(accs_e.iter().map(|acc_e_k| acc_e_k.as_slice())),
);
acc_e
})
.collect();
// Phase 3
let h2 = Self::hash(
[h1.as_slice()]
.into_iter()
.chain(s2.iter().map(|i| i.as_slice())),
);
let mut kbar = vec![0; self.tau];
Self::expand_indices(&h2, self.m as u32, &mut kbar);
// Phase 4
let s3: Vec<Vec<u128>> = executions
.iter()
.map(|e| {
let o_e: Vec<u128> = e.indices[sk_index]
.iter()
.zip(e.r_sum.iter())
.map(|(indices_e_pi_b, r_e_b)| {
let o_e_b = r_e_b.mulm(sk.addm(*indices_e_pi_b as u128, &P), &P);
o_e_b
})
.collect();
o_e
})
.collect();
// Phase 5
let h3 = Self::hash(
[h2.as_slice()]
.into_iter()
.chain(s3.iter().map(|o_e| cast_slice(o_e))),
);
let mut buf = vec![0; self.tau * (self.b + 1) * size_of::<u128>()];
Self::expand(&h3, &mut buf);
let mut buf_iter = buf.array_chunks();
for e in executions.iter_mut() {
e.epsilon = u128::from_be_bytes(*buf_iter.next().unwrap()) & P;
e.lambda = (0..self.b)
.map(|_b| u128::from_be_bytes(*buf_iter.next().unwrap()) & P)
.collect();
}
// Phase 6
let mut h4_hasher = Shake128::default();
h4_hasher.update(&h3);
for ((e, kbar_e), o_e) in executions.iter().zip(kbar.iter()).zip(s3.iter()) {
let (alpha_e, beta_e): (Vec<u128>, Vec<u128>) = e
.shares
.iter()
.map(|share| {
let alpha_e_i = share.a.addm(share.k.mulm(e.epsilon, &P), &P);
let beta_e_i = share.b.addm(
share
.r
.iter()
.zip(e.lambda.iter())
.map(|(r_e_i_b, lambda_e_b)| r_e_i_b.mulm(lambda_e_b, &P))
.fold(0, |x, y| x.addm(y, &P)),
&P,
);
(alpha_e_i, beta_e_i)
})
.collect();
let alpha_e_sum = alpha_e.iter().fold(0, |x, y| x.addm(y, &P));
let beta_e_sum = beta_e.iter().fold(0, |x, y| x.addm(y, &P));
let (z_e, mut zp_e): (Vec<u128>, Vec<u128>) = e.shares.iter().zip(e.mask[*kbar_e as usize].iter()).enumerate().map(|(i, (share_e_i, mask_e_kbar_e_i))| {
let (z_e_i, zp_e_i) = e
.lambda
.iter()
.zip(share_e_i.r.iter())
.zip(e.d_indices[*kbar_e as usize][sk_index].iter())
.map(|((lambda_e_b, r_e_i_b), d_index_e_kbar_pi_b)| {
let factor = lambda_e_b.mulm(r_e_i_b, &P);
(factor.mulm(d_index_e_kbar_pi_b, &P), factor.mulm(if i == 0 {
mask_e_kbar_e_i.addm(d_index_e_kbar_pi_b, &P)
} else {
*mask_e_kbar_e_i
}, &P))
})
.fold((0, 0), |x, y| (x.0.subm(y.0, &P), x.1.subm(y.1, &P)));
(z_e_i, zp_e_i)
}).collect();
let z_e_sum = z_e.iter().fold(0, |x, y| x.addm(y, &P));
let d_z_e = z_e_sum.subm(zp_e.iter().fold(0, |x, y| x.addm(y, &P)), &P);
zp_e[0] = e.lambda.iter().zip(o_e.iter()).map(|(lambda_e_b, o_e_b)| lambda_e_b.mulm(o_e_b, &P)).fold(zp_e[0].addm(d_z_e, &P), |x, y| x.addm(y, &P));
let gamma_e: Vec<u128> = e.shares.iter().zip(zp_e.iter()).map(|(share_e_i, zp_e_i)| {
let gamma_e_i = alpha_e_sum.mulm(share_e_i.b, &P)
.addm(beta_e_sum.mulm(share_e_i.a, &P), &P)
.subm(share_e_i.c, &P)
.addm(e.epsilon.mulm(zp_e_i, &P), &P);
gamma_e_i
}).collect();
// We also hash the corresponding vectors, so what use is to hash their sums?
//h4_hasher.update(bytes_of(&alpha_e_sum));
//h4_hasher.update(bytes_of(&beta_e_sum));
h4_hasher.update(cast_slice(&alpha_e));
h4_hasher.update(cast_slice(&beta_e));
h4_hasher.update(cast_slice(&gamma_e));
}
// Phase 7
let mut h4_reader = h4_hasher.finalize_xof();
let mut h4 = [0; 32];
h4_reader.read(&mut h4);
let mut ibar = vec![0; self.tau];
Self::expand_indices(&h4, self.n as u32, &mut ibar);
// Phase 8
}
/// L_K^k(a)
fn power_residue_prf(a: u128) -> u128 {
// for k=2, this is the Legendre PRF
legendre_prf(a)
}
fn power_residue_prf_l(a: u128) -> [u128; 2 * 254] {
// TODO generate uniform list
#[rustfmt::skip]
const LIST: [u128; 2*254] = [0; 2*254];
let mut res = [0; 2 * 254];
for (res_i, l_i) in res.iter_mut().zip(LIST.iter()) {
*res_i = Self::power_residue_prf(a + l_i);
}
res
}
fn expand(input: &[u8], output: &mut [u8]) {
// TODO generic
let mut hasher = Shake128::default();
hasher.update(input);
let mut reader = hasher.finalize_xof();
reader.read(output);
}
fn expand_indices(input: &[u8], modulus: u32, output: &mut [u32]) {
// TODO generic
let mut hasher = Shake128::default();
hasher.update(input);
let mut reader = hasher.finalize_xof();
reader.read(cast_slice_mut(output));
for i in output.iter_mut() {
*i %= modulus;
}
}
fn hash<'a>(input: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
// TODO generic
let mut output = [0; 32];
let mut hasher = Shake128::default();
for item in input {
hasher.update(item);
}
let mut reader = hasher.finalize_xof();
reader.read(&mut output);
output
}
/// Seed is the first element
fn expand_tree(tree: &mut [u128]) {
for i in 0..tree.len() / 2 - 1 {
let (tree1, tree2) = tree.split_at_mut(2 * i + 1);
Self::expand(bytes_of(&tree1[i]), cast_slice_mut(&mut tree2[0..2]));
}
}
/// Compute Merkle tree root
// TODO salt?
fn compute_tree_root(mut input: Vec<[u8; 32]>) -> [u8; 32] {
while input.len() > 1 {
if input.len() % 2 == 1 {
input.push(*input.last().unwrap());
}
for i in (0..input.len()).step_by(2) {
let mut hasher = Shake128::default();
hasher.update(&input[i]);
hasher.update(&input[i + 1]);
let mut reader = hasher.finalize_xof();
reader.read(&mut input[i / 2]);
}
input.truncate(input.len() / 2);
}
input[0]
}
/// Compute authentication path from given input's index to root
/// Returns the siblings of the path, from leaf to root
fn compute_tree_proof(mut input: Vec<[u8; 32]>, mut index: usize) -> Vec<[u8; 32]> {
let mut path = Vec::new();
// TODO do not compute useless branches
// input.len().next_power_of_two()-1
while input.len() > 1 {
if input.len() % 2 == 1 {
input.push(*input.last().unwrap());
}
path.push(input[index ^ 1]);
for i in (0..input.len()).step_by(2) {
let mut hasher = Shake128::default();
hasher.update(&input[i]);
hasher.update(&input[i + 1]);
let mut reader = hasher.finalize_xof();
reader.read(&mut input[i / 2]);
}
input.truncate(input.len() / 2);
index /= 2;
}
path
}
fn tree_root_from_path(mut leaf: [u8; 32], mut index: usize, path: &[[u8; 32]]) -> [u8; 32] {
for sibling in path {
let mut hasher = Shake128::default();
if index % 2 == 0 {
hasher.update(&leaf);
hasher.update(sibling);
} else {
hasher.update(sibling);
hasher.update(&leaf);
}
let mut reader = hasher.finalize_xof();
reader.read(&mut leaf);
index /= 2;
}
leaf
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_merkle_tree() {
let mut rng = rand::thread_rng();
for len in 1usize..32 {
let input: Vec<[u8; 32]> = (0..len).map(|_i| rng.r#gen()).collect();
let root = DualRing::compute_tree_root(input.clone());
for index in 0..len {
let path = DualRing::compute_tree_proof(input.clone(), index);
assert_eq!(root, DualRing::tree_root_from_path(input[index], index, &path));
}
}
}
#[test]
fn test_sign() {
let mut rng = rand::thread_rng();
let signer = DualRing::new_fast();
let m = b"Hello world!";
let l = 10;
let (sks, pks): (Vec<u128>, Vec<[u128; 2 * 254]>) = (0..l).map(|_j| DualRing::keygen(&mut rng)).collect();
let sk_index = 4;
signer.sign(&pks, sks[sk_index], sk_index, m, &mut rng);
}
}