diff --git a/crates/seec-channel/src/multi.rs b/crates/seec-channel/src/multi.rs index 268d075..3b39a43 100644 --- a/crates/seec-channel/src/multi.rs +++ b/crates/seec-channel/src/multi.rs @@ -1,4 +1,6 @@ -use crate::{multi, sub_channel, tcp, CommunicationError, Receiver, ReceiverT, Sender, SenderT}; +use crate::{ + channel, multi, sub_channel, tcp, CommunicationError, Receiver, ReceiverT, Sender, SenderT, +}; use async_trait::async_trait; use futures::future::join; use futures::stream::FuturesUnordered; @@ -54,14 +56,14 @@ pub struct MultiReceiver { } impl MultiSender { - pub async fn send_to(&self, to: impl IntoIterator, msg: T) -> Result<(), Error> { + pub async fn send_to(&self, to: impl IntoIterator, msg: T) -> Result<(), Error> { let mut fu = FuturesUnordered::new(); for to in to { debug!(to, "Sending"); let sender = self .senders - .get(to) - .ok_or_else(|| Error::UnknownParty(*to))?; + .get(&to) + .ok_or_else(|| Error::UnknownParty(to))?; fu.push(sender.send(msg.clone())); } let mut errors = vec![]; @@ -81,7 +83,7 @@ impl MultiSender { #[instrument(level = "debug", skip(self, msg), ret)] pub async fn send_all(&self, msg: T) -> Result<(), Error> { - self.send_to(self.senders.keys(), msg).await + self.send_to(self.senders.keys().copied(), msg).await } pub fn sender(&self, to: u32) -> Option<&Sender> { @@ -100,6 +102,14 @@ pub struct MsgFrom { } impl MultiReceiver { + pub async fn recv_from_single(&mut self, from: u32) -> Result { + let receiver = self + .receivers + .get_mut(&from) + .ok_or(Error::UnknownParty(from))?; + Ok(map_recv_fut((&from, receiver)).await?.into_msg()) + } + pub fn recv_from( &mut self, from: &HashSet, @@ -337,6 +347,44 @@ where Ok(MultiReceiver { receivers }) } +impl MsgFrom { + pub fn into_msg(self) -> T { + self.msg + } +} + +pub fn new_local(parties: usize) -> Vec<(MultiSender, MultiReceiver)> { + let mut res: Vec<(MultiSender, MultiReceiver)> = + (0..parties).map(|_| Default::default()).collect(); + for party in 0..parties { + for other in 0..parties { + if party == other { + continue; + } + let (sender, receiver) = channel(128); + res[party].0.senders.insert(other as u32, sender); + res[other].1.receivers.insert(party as u32, receiver); + } + } + res +} + +impl Default for MultiSender { + fn default() -> Self { + Self { + senders: Default::default(), + } + } +} + +impl Default for MultiReceiver { + fn default() -> Self { + Self { + receivers: Default::default(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/seec/src/executor.rs b/crates/seec/src/executor.rs index 56bf30e..44ca939 100644 --- a/crates/seec/src/executor.rs +++ b/crates/seec/src/executor.rs @@ -42,6 +42,7 @@ pub type DynFDSetup<'c, P, Idx> = Box< + 'c, >; +#[derive(Debug, Clone)] pub struct GateOutputs { data: Vec>, // Used as a sanity check in debug builds. Stores for which gates we have set the output, @@ -502,6 +503,10 @@ impl GateOutputs { pub fn iter(&self) -> impl Iterator> { self.data.iter() } + + pub fn into_iter(self) -> impl Iterator> { + self.data.into_iter() + } } impl Input { @@ -625,6 +630,24 @@ impl GateOutputs { } } +impl Default for GateOutputs { + fn default() -> Self { + Self { + data: vec![], + output_set: Default::default(), + } + } +} + +impl FromIterator> for GateOutputs { + fn from_iter>>(iter: T) -> Self { + Self { + data: iter.into_iter().collect(), + output_set: Default::default(), + } + } +} + #[cfg(test)] mod tests { use crate::circuit::base_circuit::BaseGate; diff --git a/crates/seec/src/protocols/aby2.rs b/crates/seec/src/protocols/aby2.rs index 3febfa2..f1a0d9b 100644 --- a/crates/seec/src/protocols/aby2.rs +++ b/crates/seec/src/protocols/aby2.rs @@ -1,3 +1,4 @@ +use crate::bristol::circuit; use crate::circuit::base_circuit::BaseGate; use crate::circuit::{ExecutableCircuit, GateIdx}; use crate::common::BitVec; @@ -9,11 +10,15 @@ use crate::protocols::{ boolean_gmw, FunctionDependentSetup, Gate, Protocol, ScalarDim, SetupStorage, ShareStorage, }; use crate::secret::Secret; +use crate::utils::take_arr; use crate::{bristol, executor, CircuitBuilder}; use ahash::AHashMap; use async_trait::async_trait; +use itertools::Itertools; use rand::{Rng, SeedableRng}; use rand_chacha::ChaChaRng; +use seec_channel::multi::{MultiReceiver, MultiSender}; +use seec_channel::ReceiverT; use serde::{Deserialize, Serialize}; use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -21,7 +26,6 @@ use std::convert::Infallible; use std::error::Error; use std::fmt::Debug; use std::ops::Not; -use itertools::Itertools; pub struct BooleanAby2 { delta_sharing_state: DeltaSharing, @@ -62,9 +66,7 @@ pub enum Msg { #[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug)] pub enum BooleanGate { Base(BaseGate), - And { - n: u8 - }, + And { n: u8 }, Xor, Inv, } @@ -99,6 +101,33 @@ pub struct AbySetupProvider { setup_data: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AstraSetupMsg(BitVec); + +#[derive(Debug, Copy, Clone)] +pub enum InputBy { + P0, + P1, +} +pub struct AstraSetupHelper { + sender: MultiSender, + receiver: MultiReceiver, + // shared rng with p0 + priv_seed_p0: [u8; 32], + // shared rng with p1 + priv_seed_p1: [u8; 32], + joint_seed_0: [u8; 32], + joint_seed_1: [u8; 32], +} +pub struct AstraSetupProvider { + // The normal parties have party id 0 and 1. For the helper, there is a dedicated struct + party_id: usize, + sender: MultiSender, + receiver: MultiReceiver, + rng: ChaChaRng, + setup_data: Option, +} + impl Protocol for BooleanAby2 { const SIMD_SUPPORT: bool = false; type Msg = Msg; @@ -119,7 +148,7 @@ impl Protocol for BooleanAby2 { let delta: BitVec = interactive_gates .zip(gate_outputs) .map(|(gate, output)| { - assert!(matches!(gate, BooleanGate::And {n: 2})); + assert!(matches!(gate, BooleanGate::And { n: 2 })); let inputs = inputs.by_ref().take(gate.input_size()); gate.compute_delta_share(party_id, inputs, preprocessing_data, output) }) @@ -180,6 +209,7 @@ impl Protocol for BooleanAby2 { let output = gate.setup_output_share(gate_input_iter, rng); storage.set(sc_gate_id, output); } + println!("{_party_id}: {storage:?}"); storage } @@ -196,7 +226,7 @@ impl BooleanGate { output_share: Share, ) -> bool { assert!(matches!(party_id, 0 | 1)); - assert!(matches!(self, BooleanGate::And {n: 2})); + assert!(matches!(self, BooleanGate::And { n: 2 })); let a = inputs.next().expect("Empty input"); let b = inputs.next().expect("Insufficient input"); let plain_ab = a.public & b.public; @@ -238,14 +268,18 @@ impl BooleanGate { | BaseGate::Debug | BaseGate::Identity => inputs.next().expect("Empty input"), BaseGate::Constant(c) => { - // TODO is it correct to just output the stored constant here? + assert_eq!( + c.private, false, + "Private part of constant gate share must be 0" + ); + // return constant as the public part is simply the constant c.clone() - }, + } BaseGate::ConnectToMainFromSimd(_) => { unimplemented!("SIMD currently not supported for ABY2") } }, - BooleanGate::And {..} => { + BooleanGate::And { .. } => { // input is not actually needed at this stage Share { private: rng.gen(), @@ -260,7 +294,7 @@ impl BooleanGate { a } BooleanGate::Inv => { - // TODO correctness? + // private share part does not change for Inv gates inputs.next().expect("Empty input") } } @@ -271,7 +305,7 @@ impl BooleanGate { input_shares: impl Iterator>, setup_sub_circ_cache: &mut AHashMap>, Secret>, ) -> Vec> { - let &BooleanGate::And {n} = self else { + let &BooleanGate::And { n } = self else { assert!(self.is_non_interactive(), "Unhandled interactive gate"); panic!("Called setup_data_circ on non_interactive gate") }; @@ -315,10 +349,7 @@ impl Gate for BooleanGate { type DimTy = ScalarDim; fn is_interactive(&self) -> bool { - matches!( - self, - BooleanGate::And {..} - ) + matches!(self, BooleanGate::And { .. }) } fn input_size(&self) -> usize { @@ -326,7 +357,7 @@ impl Gate for BooleanGate { BooleanGate::Base(base_gate) => base_gate.input_size(), BooleanGate::Inv => 1, BooleanGate::Xor => 2, - BooleanGate::And {n} => *n as usize, + BooleanGate::And { n } => *n as usize, } } @@ -354,7 +385,7 @@ impl Gate for BooleanGate { c.clone() } BooleanGate::Base(base) => base.evaluate_non_interactive(party_id, inputs.by_ref()), - BooleanGate::And {..} => { + BooleanGate::And { .. } => { panic!("Called evaluate_non_interactive on Gate::And") } BooleanGate::Xor => { @@ -390,7 +421,7 @@ impl From> for BooleanGate { impl From<&bristol::Gate> for BooleanGate { fn from(gate: &bristol::Gate) -> Self { match gate { - bristol::Gate::And(_) => Self::And { n: 2}, + bristol::Gate::And(_) => Self::And { n: 2 }, bristol::Gate::Xor(_) => Self::Xor, bristol::Gate::Inv(_) => Self::Inv, } @@ -495,14 +526,14 @@ impl SetupStorage for SetupData { self.eval_shares.len() } - fn reserve(&mut self, additional: usize) { - self.eval_shares.reserve(additional); - } fn split_off_last(&mut self, count: usize) -> Self { Self { eval_shares: self.eval_shares.split_off(self.len() - count), } } + fn reserve(&mut self, additional: usize) { + self.eval_shares.reserve(additional); + } fn append(&mut self, mut other: Self) { self.eval_shares.append(&mut other.eval_shares); @@ -699,7 +730,7 @@ where .interactive_iter() .zip(setup_outputs) .map(|((gate, _gate_id), setup_out)| match gate { - BooleanGate::And {..} => { + BooleanGate::And { .. } => { let shares = setup_out .into_iter() .map(|out_id| executor_gate_outputs.get(out_id.as_usize())) @@ -722,13 +753,191 @@ where } } +impl AstraSetupHelper { + pub fn new( + sender: MultiSender, + receiver: MultiReceiver, + priv_seed_p0: [u8; 32], + priv_seed_p1: [u8; 32], + joint_seed_0: [u8; 32], + joint_seed_1: [u8; 32], + ) -> Self { + Self { + sender, + receiver, + priv_seed_p0, + priv_seed_p1, + joint_seed_0, + joint_seed_1, + } + } + + pub async fn setup( + self, + circuit: &ExecutableCircuit, + input_map: HashMap, + ) { + let p0_gate_outputs = + self.setup_gate_outputs(0, circuit, self.priv_seed_p0, self.joint_seed_1, &input_map); + let p1_gate_outputs = + self.setup_gate_outputs(1, circuit, self.priv_seed_p1, self.joint_seed_0, &input_map); + + let mut rng_p0 = ChaChaRng::from_seed(self.priv_seed_p0); + // synchronized with the AstraSetupProvider but different than the stream used for the gate + // outputs before + rng_p0.set_stream(1); + + // TODO this could potentially be optimized as it reconstructs all lambda values + // but we only need those that are an input to an interactive gate + let rec_gate_outputs: Vec<_> = p0_gate_outputs + .into_iter() + .zip(p1_gate_outputs.into_iter()) + .map(|(p0_out, p1_out)| { + let p0_storage = p0_out.into_scalar().expect("SIMD unsupported"); + let p1_storage = p1_out.into_scalar().expect("SIMD unsupported"); + // we only need to reconstruct the private parts which were initialized + p0_storage.private ^ p1_storage.private + }) + .collect(); + + let mut msg = BitVec::with_capacity(circuit.interactive_count()); + + for (gate, _gate_id, parents) in circuit.interactive_with_parents_iter() { + match gate { + BooleanGate::And { n } => { + assert_eq!(2, n, "Astra setup currently supports 2 input ANDs"); + let inputs: [bool; 2] = take_arr(&mut parents.take(2).map(|scg| { + rec_gate_outputs[scg.circuit_id as usize][scg.gate_id.as_usize()] + })); + let lambda_xy = inputs[0] & inputs[1]; + let lambda_xy_0: bool = rng_p0.gen(); + let lambda_xy_1 = lambda_xy ^ lambda_xy_0; + msg.push(lambda_xy_1); + } + ni => unreachable!("non interactive gate {ni:?}"), + } + } + self.sender + .send_to([1], AstraSetupMsg(msg)) + .await + .expect("failed to send setup message") + } + + fn setup_gate_outputs( + &self, + party_id: usize, + circuit: &ExecutableCircuit, + local_seed: [u8; 32], + remote_seed: [u8; 32], + input_map: &HashMap, + ) -> GateOutputs { + // The idea is to reuse the `BooleanAby2` setup_gate_outputs method with the correct + // rngs to generate the correct values for the helper + + let input_position_share_type_map = input_map + .iter() + .map(|(&pos, by)| { + let st = match (party_id, by) { + (0, InputBy::P0) | (1, InputBy::P1) => ShareType::Local, + (0, InputBy::P1) | (1, InputBy::P0) => ShareType::Remote, + (id, _) => panic!("Unsupported party id {id}"), + }; + (pos, st) + }) + .collect(); + + let mut p = BooleanAby2 { + delta_sharing_state: DeltaSharing { + private_rng: ChaChaRng::from_seed(local_seed), + // not used + local_joint_rng: ChaChaRng::from_seed(local_seed), + remote_joint_rng: ChaChaRng::from_seed(remote_seed), + input_position_share_type_map, + }, + }; + p.setup_gate_outputs(party_id, circuit) + } +} + +impl AstraSetupProvider { + pub fn new( + party_id: usize, + sender: MultiSender, + receiver: MultiReceiver, + seed: [u8; 32], + ) -> Self { + let mut rng = ChaChaRng::from_seed(seed); + // We use the next stream of this RNG so that it is synchronized with the helper + rng.set_stream(1); + Self { + party_id, + sender, + receiver, + rng, + setup_data: None, + } + } +} + +#[async_trait] +impl FunctionDependentSetup for AstraSetupProvider +where + Idx: GateIdx, +{ + type Output = SetupData; + type Error = Infallible; + + async fn setup( + &mut self, + _shares: &GateOutputs, + circuit: &ExecutableCircuit, + ) -> Result<(), Self::Error> { + if self.party_id == 0 { + let lambda_values: Vec<_> = (0..circuit.interactive_count()) + .map(|_| EvalShares { + shares: BitVec::repeat(self.rng.gen(), 1), + }) + .collect(); + self.setup_data = Some(SetupData::from_raw(lambda_values)); + } else if self.party_id == 1 { + let msg = self + .receiver + .recv_from_single(2) + .await + .expect("Recv message from helper"); + let setup_data = msg + .0 + .into_iter() + .map(|eval_share| EvalShares { + shares: BitVec::repeat(eval_share, 1), + }) + .collect(); + self.setup_data = Some(SetupData::from_raw(setup_data)); + } else { + panic!("Illegal party id {}", self.party_id) + } + Ok(()) + } + + async fn request_setup_output(&mut self, count: usize) -> Result { + Ok(self + .setup_data + .as_mut() + .expect("setup must be called before request_setup_output") + .split_off_last(count)) + } +} #[cfg(test)] mod tests { + use super::BooleanGate as BG; + use super::*; use crate::circuit::BaseCircuit; use crate::mul_triple::boolean::InsecureMTProvider; - use super::*; - use super::BooleanGate as BG; + use crate::private_test_utils::init_tracing; + use crate::Circuit; + use rand::thread_rng; + use seec_channel::multi; // #[tokio::test] // async fn multi_and() { @@ -745,4 +954,4 @@ mod tests { // let setup0 = AbySetupProvider::new(0, InsecureMTProvider::default(), ch0.0, ch0.1); // let setup1 = AbySetupProvider::new(1, InsecureMTProvider::default(), ch1.0, ch1.1); // } -} \ No newline at end of file +} diff --git a/crates/seec/src/utils.rs b/crates/seec/src/utils.rs index 4be873e..ca2b627 100644 --- a/crates/seec/src/utils.rs +++ b/crates/seec/src/utils.rs @@ -159,3 +159,8 @@ impl From> for BoxError { Self(value) } } + +#[allow(unused)] +pub(crate) fn take_arr(iter: &mut I) -> [I::Item; N] { + array::from_fn(|_| iter.next().expect("Input array has insufficient elements")) +} diff --git a/crates/seec/tests/boolean_aby2.rs b/crates/seec/tests/boolean_aby2.rs index dc11b72..1dfdb15 100644 --- a/crates/seec/tests/boolean_aby2.rs +++ b/crates/seec/tests/boolean_aby2.rs @@ -1,11 +1,16 @@ -use rand::{thread_rng, Rng}; +use rand::{thread_rng, Rng, SeedableRng}; +use rand_chacha::ChaChaRng; use seec::circuit::ExecutableCircuit; use seec::common::BitVec; -use seec::executor::{Executor, Input}; +use seec::executor::{Executor, GateOutputs, Input}; use seec::mul_triple::boolean::insecure_provider::InsecureMTProvider; use seec::private_test_utils::init_tracing; -use seec::protocols::aby2::{AbySetupProvider, BooleanAby2, DeltaSharing, ShareType}; +use seec::protocols::aby2::{ + AbySetupProvider, AstraSetupHelper, AstraSetupProvider, BooleanAby2, DeltaSharing, InputBy, + ShareType, +}; use seec::Circuit; +use seec_channel::multi; #[tokio::test(flavor = "multi_thread")] async fn eval_8_bit_adder() -> anyhow::Result<()> { @@ -81,3 +86,97 @@ async fn eval_8_bit_adder() -> anyhow::Result<()> { Ok(()) } + +#[tokio::test] +async fn astra_setup() -> anyhow::Result<()> { + let _g = init_tracing(); + let mut channels = multi::new_local(3); + let helper_ch = channels.pop().unwrap(); + let p1_ch = channels.pop().unwrap(); + let p0_ch = channels.pop().unwrap(); + let priv_seed_p0: [u8; 32] = thread_rng().gen(); + let priv_seed_p1: [u8; 32] = thread_rng().gen(); + let seed_p0: [u8; 32] = thread_rng().gen(); + let seed_p1: [u8; 32] = thread_rng().gen(); + let helper = AstraSetupHelper::new( + helper_ch.0, + helper_ch.1, + priv_seed_p0, + priv_seed_p1, + seed_p0, + seed_p1, + ); + + let astra_setup0 = AstraSetupProvider::new(0, p0_ch.0, p0_ch.1, priv_seed_p0); + let astra_setup1 = AstraSetupProvider::new(1, p1_ch.0, p1_ch.1, priv_seed_p1); + + let circ = ExecutableCircuit::DynLayers( + Circuit::load_bristol("test_resources/bristol-circuits/int_add8_depth.bristol").unwrap(), + ); + + let input_map = (0..8) + .map(|i| (i, InputBy::P0)) + .chain((8..16).map(|i| (i, InputBy::P1))) + .collect(); + let circ_clone = circ.clone(); + let jh = tokio::spawn(async move { helper.setup(&circ_clone, input_map).await }); + + let share_map1 = (0..8) + .map(|pos| (pos, ShareType::Local)) + .chain((8..16).map(|pos| (pos, ShareType::Remote))) + .collect(); + let share_map2 = (0..8) + .map(|pos| (pos, ShareType::Remote)) + .chain((8..16).map(|pos| (pos, ShareType::Local))) + .collect(); + + let mut sharing_state1 = DeltaSharing::new(priv_seed_p0, seed_p0, seed_p1, share_map1); + let mut sharing_state2 = DeltaSharing::new(priv_seed_p1, seed_p1, seed_p0, share_map2); + let state1 = BooleanAby2::new(sharing_state1.clone()); + let state2 = BooleanAby2::new(sharing_state2.clone()); + + let (mut ex1, mut ex2): (Executor, Executor) = + tokio::try_join!( + Executor::new_with_state(state1, &circ, 0, astra_setup0), + Executor::new_with_state(state2, &circ, 1, astra_setup1) + ) + .unwrap(); + let (shared_30, plain_delta_30) = sharing_state1.share(BitVec::from_element(30_u8)); + let (shared_12, plain_delta_12) = sharing_state2.share(BitVec::from_element(12_u8)); + + let inp1 = { + let mut inp = shared_30; + inp.extend(sharing_state1.plain_delta_to_share(plain_delta_12)); + inp + }; + let inp2 = { + let mut inp = sharing_state2.plain_delta_to_share(plain_delta_30); + inp.extend(shared_12); + inp + }; + + let reconstruct: BitVec = inp1 + .clone() + .into_iter() + .zip(inp2.clone()) + .map(|(sh1, sh2)| { + assert_eq!(sh1.get_public(), sh2.get_public()); + sh1.get_public() ^ sh1.get_private() ^ sh2.get_private() + }) + .collect(); + assert_eq!(BitVec::from_slice(&[30_u8, 12]), reconstruct); + + let (mut ch1, mut ch2) = seec_channel::in_memory::new_pair(16); + + let (out0, out1) = tokio::try_join!( + ex1.execute(Input::Scalar(inp1), &mut ch1.0, &mut ch1.1), + ex2.execute(Input::Scalar(inp2), &mut ch2.0, &mut ch2.1), + )?; + + let out0 = out0.into_scalar().unwrap(); + let out1 = out1.into_scalar().unwrap(); + let out_bits: BitVec = DeltaSharing::reconstruct(out0, out1); + assert_eq!(BitVec::from_element(42_u8), out_bits); + + Ok(()) +}