diff --git a/Cargo.lock b/Cargo.lock index 740b4bd..3c5092e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2091,6 +2091,7 @@ dependencies = [ "tokio-serde", "tokio-util", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/crates/seec-channel/Cargo.toml b/crates/seec-channel/Cargo.toml index 2fba45c..25816aa 100644 --- a/crates/seec-channel/Cargo.toml +++ b/crates/seec-channel/Cargo.toml @@ -12,7 +12,7 @@ async-stream = "0.3.5" bytes = "1.6.0" futures = "0.3.30" pin-project = "1.1.5" -serde = { version = "1.0.197" } +serde = { version = "1.0.197" , features = ["derive"]} erased-serde = "0.4.4" thiserror = "1.0.58" tokio = { version = "1.36.0", features = ["macros", "net"] } @@ -36,6 +36,7 @@ criterion = { version = "0.5.1", features = ["async_tokio"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"]} [[bench]] diff --git a/crates/seec-channel/src/lib.rs b/crates/seec-channel/src/lib.rs index 98e950b..7467135 100644 --- a/crates/seec-channel/src/lib.rs +++ b/crates/seec-channel/src/lib.rs @@ -1,15 +1,19 @@ //! Channel abstraction for communication use crate::util::{Counter, TrackingReader, TrackingWriter}; use async_trait::async_trait; + use remoc::rch::{base, mpsc}; use remoc::{codec, ConnectError, RemoteSend}; use serde::{Deserialize, Serialize}; + use tokio::io; use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::debug; pub use seec_channel_macros::sub_channels_for; pub mod in_memory; +pub mod multi; pub mod tcp; pub mod tls; pub mod util; @@ -27,29 +31,35 @@ pub type Channel = (Sender, Receiver); pub struct SyncMsg; #[async_trait] -pub trait SenderT { - async fn send(&mut self, item: T) -> Result<(), E>; +pub trait SenderT { + type Error; + async fn send(&mut self, item: T) -> Result<(), Self::Error>; } #[async_trait] -pub trait ReceiverT { - async fn recv(&mut self) -> Result, E>; +pub trait ReceiverT { + type Error; + async fn recv(&mut self) -> Result, Self::Error>; } #[derive(thiserror::Error, Debug)] pub enum CommunicationError { #[error("Error sending initial value")] - BaseSend(base::SendErrorKind), + BaseSend(#[source] base::SendError<()>), #[error("Error receiving value on base channel")] BaseRecv(#[from] base::RecvError), #[error("Error sending value on mpsc channel")] - Send(mpsc::SendError<()>), + Send(#[source] mpsc::SendError<()>), #[error("Error receiving value on mpsc channel")] Recv(#[from] mpsc::RecvError), + #[error("Error in Multi-Sender/Receiver")] + Multi(#[from] multi::Error), #[error("Unexpected termination. Remote is closed.")] RemoteClosed, #[error("Received out of order message")] UnexpectedMessage, + #[error("Unabel to establish multi-sub-channel with party {0}")] + MultiSubChannel(u32, #[source] Box), } pub fn channel( @@ -65,62 +75,65 @@ pub fn channel( } #[tracing::instrument(skip_all)] -pub async fn sub_channel( - sender: &mut impl SenderT, - receiver: &mut impl ReceiverT, +pub async fn sub_channel( + sender: &mut S, + receiver: &mut R, local_buffer: usize, ) -> Result<(Sender, Receiver), CommunicationError> where - Receiver: Into, - Msg: Into>> + RemoteSend, + S: SenderT, + R: ReceiverT, + Sender: Into, + Msg: Into>> + RemoteSend, SubMsg: RemoteSend, - CommunicationError: From + From, + CommunicationError: From + From, { - tracing::debug!("Establishing new sub_channel"); - let (sub_sender, remote_sub_receiver) = channel(local_buffer); - sender.send(remote_sub_receiver.into()).await?; - tracing::debug!("Sent remote_sub_receiver"); + debug!("Establishing new sub_channel"); + let (remote_sub_sender, sub_receiver) = channel(local_buffer); + sender.send(remote_sub_sender.into()).await?; + debug!("Sent remote_sub_receiver"); let msg = receiver .recv() .await? .ok_or(CommunicationError::RemoteClosed)?; - let sub_receiver = msg.into().ok_or(CommunicationError::UnexpectedMessage)?; - tracing::debug!("Received sub_receiver"); + let sub_sender = msg.into().ok_or(CommunicationError::UnexpectedMessage)?; + debug!("Received sub_receiver"); Ok((sub_sender, sub_receiver)) } #[tracing::instrument(skip_all)] -pub async fn sub_channel_with( - sender: &mut impl SenderT, - receiver: &mut impl ReceiverT, +pub async fn sub_channel_with( + sender: &mut S, + receiver: &mut R, local_buffer: usize, - wrap_fn: impl FnOnce(Receiver) -> Msg, - extract_fn: impl FnOnce(Msg) -> Option>, + wrap_fn: impl FnOnce(Sender) -> Msg, + extract_fn: impl FnOnce(Msg) -> Option>, ) -> Result<(Sender, Receiver), CommunicationError> where + S: SenderT, + R: ReceiverT, Msg: RemoteSend, SubMsg: RemoteSend, - CommunicationError: From + From, + CommunicationError: From + From, { - tracing::debug!("Establishing new sub_channel"); - let (sub_sender, remote_sub_receiver) = channel(local_buffer); - sender.send(wrap_fn(remote_sub_receiver)).await?; - tracing::debug!("Sent remote_sub_receiver"); + debug!("Establishing new sub_channel"); + let (remote_sub_sender, sub_receiver) = channel(local_buffer); + sender.send(wrap_fn(remote_sub_sender)).await?; + debug!("Sent remote_sub_receiver"); let msg = receiver .recv() .await? .ok_or(CommunicationError::RemoteClosed)?; - let sub_receiver = extract_fn(msg).ok_or(CommunicationError::UnexpectedMessage)?; - tracing::debug!("Received sub_receiver"); + let sub_sender = extract_fn(msg).ok_or(CommunicationError::UnexpectedMessage)?; + debug!("Received sub_receiver"); Ok((sub_sender, sub_receiver)) } -pub async fn sync( - sender: &mut impl SenderT, - receiver: &mut impl ReceiverT, -) -> Result<(), CommunicationError> +pub async fn sync(sender: &mut S, receiver: &mut R) -> Result<(), CommunicationError> where - CommunicationError: From + From, + S: SenderT, + R: ReceiverT, + CommunicationError: From + From, { sender.send(SyncMsg).await?; // ignore receiving a None @@ -132,54 +145,56 @@ where } #[async_trait] -impl SenderT> for base::Sender +impl SenderT for base::Sender where T: RemoteSend, Codec: codec::Codec, { - async fn send(&mut self, item: T) -> Result<(), base::SendError> { + type Error = base::SendError; + async fn send(&mut self, item: T) -> Result<(), Self::Error> { base::Sender::send(self, item).await } } #[async_trait] -impl ReceiverT for base::Receiver +impl ReceiverT for base::Receiver where T: RemoteSend, Codec: codec::Codec, { - async fn recv(&mut self) -> Result, base::RecvError> { + type Error = base::RecvError; + async fn recv(&mut self) -> Result, Self::Error> { base::Receiver::recv(self).await } } #[async_trait] -impl SenderT> - for mpsc::Sender +impl SenderT for mpsc::Sender where T: RemoteSend, Codec: codec::Codec, { - async fn send(&mut self, item: T) -> Result<(), mpsc::SendError> { + type Error = mpsc::SendError; + async fn send(&mut self, item: T) -> Result<(), Self::Error> { mpsc::Sender::send(self, item).await } } #[async_trait] -impl ReceiverT - for mpsc::Receiver +impl ReceiverT for mpsc::Receiver where T: RemoteSend, Codec: codec::Codec, { - async fn recv(&mut self) -> Result, mpsc::RecvError> { + type Error = mpsc::RecvError; + async fn recv(&mut self) -> Result, Self::Error> { mpsc::Receiver::recv(self).await } } impl From> for CommunicationError { fn from(err: base::SendError) -> Self { - CommunicationError::BaseSend(err.kind) + CommunicationError::BaseSend(err.without_item()) } } @@ -216,7 +231,10 @@ where 8096, ) .await?; + tokio::spawn(conn); + debug!("Established remoc connection"); + Ok((tx, bytes_written, rx, bytes_read)) } diff --git a/crates/seec-channel/src/multi.rs b/crates/seec-channel/src/multi.rs new file mode 100644 index 0000000..268d075 --- /dev/null +++ b/crates/seec-channel/src/multi.rs @@ -0,0 +1,519 @@ +use crate::{multi, sub_channel, tcp, CommunicationError, Receiver, ReceiverT, Sender, SenderT}; +use async_trait::async_trait; +use futures::future::join; +use futures::stream::FuturesUnordered; +use futures::Stream; +use futures::StreamExt; +use remoc::rch::{base, mpsc}; +use remoc::RemoteSend; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::hash::Hash; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio::task::JoinSet; +use tracing::{debug, error, instrument, Instrument}; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("unable to establish TCP remoc connection")] + Tcp(#[from] tcp::Error), + #[error("internal error in multi-party connection establishment")] + Internal(#[from] tokio::task::JoinError), + #[error("error when sending initial message")] + InitialMessageFailed(#[from] base::SendError<()>), + #[error("missing initial message")] + MissingInitialMsg, + #[error("received multiple initial message with equal party id")] + DuplicateInitialMsg, + #[error("unable to multi-send message")] + MultiSend(Vec>), + #[error("unable to multi-recv message")] + MultiRecv(Option), + #[error("unknown party id")] + UnknownParty(u32), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(bound = "T: RemoteSend")] +pub struct InitialMsg { + party_id: u32, + sender: Sender, +} + +#[derive(Debug)] +pub struct MultiSender { + senders: HashMap>, +} + +#[derive(Debug)] +pub struct MultiReceiver { + receivers: HashMap>, +} + +impl MultiSender { + pub async fn send_to(&self, to: impl IntoIterator, msg: T) -> Result<(), Error> { + let mut fu = FuturesUnordered::new(); + for to in to { + debug!(to, "Sending"); + let sender = self + .senders + .get(to) + .ok_or_else(|| Error::UnknownParty(*to))?; + fu.push(sender.send(msg.clone())); + } + let mut errors = vec![]; + loop { + match fu.next().await { + None => break, + Some(Ok(())) => continue, + Some(Err(err)) => errors.push(err.without_item()), + } + } + if errors.is_empty() { + Ok(()) + } else { + Err(Error::MultiSend(errors)) + } + } + + #[instrument(level = "debug", skip(self, msg), ret)] + pub async fn send_all(&self, msg: T) -> Result<(), Error> { + self.send_to(self.senders.keys(), msg).await + } + + pub fn sender(&self, to: u32) -> Option<&Sender> { + self.senders.get(&to) + } + + pub fn senders(&self) -> impl Iterator)> { + self.senders.iter() + } +} + +#[derive(Debug, Eq, PartialEq, Hash)] +pub struct MsgFrom { + from: u32, + msg: T, +} + +impl MultiReceiver { + pub fn recv_from( + &mut self, + from: &HashSet, + ) -> impl Stream, Error>> + '_ { + // this is unfortunately O(|receivers|) instead of O(|from|), but I doubt, + // that this has a noticeable perf impact + self.receivers + .iter_mut() + .filter(|(id, _)| from.contains(*id)) + .map(map_recv_fut) + .collect::>() + } + + pub fn recv_all(&mut self) -> impl Stream, Error>> + '_ { + self.receivers + .iter_mut() + .map(map_recv_fut) + .collect::>() + } + + pub fn receiver(&mut self, from: u32) -> Option<&mut Receiver> { + self.receivers.get_mut(&from) + } + + pub fn receivers(&mut self) -> impl Iterator)> { + self.receivers.iter_mut() + } +} + +#[inline] +async fn map_recv_fut( + (from, receiver): (&u32, &mut Receiver), +) -> Result, Error> { + debug!(from); + match receiver.recv().await { + Ok(Some(msg)) => { + debug!(from, "Received msg"); + Ok(MsgFrom { from: *from, msg }) + } + Ok(None) => Err(Error::MultiRecv(None)), + Err(err) => Err(Error::MultiRecv(Some(err))), + } +} + +#[tracing::instrument(skip_all)] +pub async fn multi_sub_channel( + sender: &MultiSender, + receiver: &mut MultiReceiver, + local_buffer: usize, +) -> Result<(MultiSender, MultiReceiver), CommunicationError> +where + Sender: Into, + Msg: Into>> + RemoteSend + Clone, + SubMsg: RemoteSend, + CommunicationError: std::convert::From + std::convert::From, +{ + struct SenderMutWrapper<'a, T>(&'a Sender); + #[async_trait] + impl<'a, T: RemoteSend> SenderT for SenderMutWrapper<'a, T> { + type Error = as SenderT>::Error; + + async fn send(&mut self, item: T) -> Result<(), Self::Error> { + self.0.send(item).await + } + } + let mut fu: FuturesUnordered<_> = receiver + .receivers() + .map(|(from, receiver)| { + let sender = sender + .sender(*from) + .expect("has receiver for {from} but no sender"); + let mut sender = SenderMutWrapper(sender); + async move { + let ch = sub_channel(&mut sender, receiver, local_buffer).await; + (*from, ch) + } + }) + .collect(); + let mut senders = HashMap::new(); + let mut receivers = HashMap::new(); + while let Some((remote_id, res)) = fu.next().await { + match res { + Ok((sender, receiver)) => { + senders.insert(remote_id, sender); + receivers.insert(remote_id, receiver); + } + Err(err) => { + return Err(CommunicationError::MultiSubChannel( + remote_id, + Box::new(err), + )) + } + } + } + let multi_sender = MultiSender { senders }; + let multi_receiver = MultiReceiver { receivers }; + Ok((multi_sender, multi_receiver)) +} + +#[async_trait] +impl SenderT for MultiSender { + type Error = Error; + + async fn send(&mut self, item: T) -> Result<(), Self::Error> { + self.send_all(item).await + } +} + +#[async_trait] +impl ReceiverT for MultiReceiver { + type Error = Error; + + async fn recv(&mut self) -> Result, Self::Error> { + todo!() + } +} + +/// remotes must include local_addr +#[instrument] +pub async fn connect( + local_addr: SocketAddr, + remotes: &[SocketAddr], + timeout: Duration, +) -> Result<(MultiSender, MultiReceiver), Error> { + let listener = TcpListener::bind(local_addr) + .await + .map_err(tcp::Error::Io)?; + let mut my_party_id = None; + let remotes: Vec<_> = remotes + .iter() + .cloned() + .enumerate() + .filter(|(id, addr)| { + let is_local = addr == &local_addr; + if is_local { + my_party_id = Some(*id as u32); + } + // filter out local address + !is_local + }) + .collect(); + let Some(my_party_id) = my_party_id else { + panic!("remotes must contain local_addr to ensure correct party ids"); + }; + let (listen, connect) = join( + listen_for_remotes::(listener, remotes.len()), + connect_to_remotes(my_party_id, remotes, timeout), + ) + .await; + Ok((listen?, connect?)) +} + +#[instrument(level = "debug", skip_all)] +async fn listen_for_remotes( + listener: TcpListener, + mut num_remotes: usize, +) -> Result, Error> { + let mut senders = HashMap::new(); + loop { + // local addr is part of remotes + if num_remotes == 0 { + break; + } + match listener.accept().await { + Ok((stream, addr)) => { + // we're not interested in the base channel, as we use the ones from + // connecting to the remotes. We only need to spawn the connectio + // which is done internally in this method + let (_, _, mut base_receiver, _) = + tcp::establish_remoc_connection_tcp::>(stream).await?; + debug!(%addr, "Established connection to remote"); + match base_receiver.recv().await { + Ok(Some(InitialMsg { party_id, sender })) => { + if senders.insert(party_id, sender).is_some() { + return Err(Error::DuplicateInitialMsg); + } + } + _ => return Err(Error::MissingInitialMsg), + } + num_remotes -= 1; + } + Err(err) => { + error!("Error during TCP connection establishment. {err:#?}") + } + } + } + debug!("Listened for all remotes"); + Ok(MultiSender { senders }) +} + +#[instrument(level = "debug", skip(remotes, timeout))] +async fn connect_to_remotes( + my_party_id: u32, + remotes: impl IntoIterator, + timeout: Duration, +) -> Result, Error> +where + T: RemoteSend, + S: ToSocketAddrs + Debug + Sync + Send + 'static, +{ + let mut join_set = JoinSet::new(); + for (id, remote) in remotes { + join_set.spawn( + async move { + let ch = tcp::connect_with_timeout(remote, timeout).await; + ch.map(|ch| (id, ch)) + } + .in_current_span(), + ); + } + let mut receivers = HashMap::new(); + while let Some(conn_res) = join_set.join_next().await { + match conn_res { + Ok(Ok((id, (mut base_sender, _, _, _)))) => { + let id = id as u32; + let (sender, receiver) = super::channel(128); + receivers.insert(id, receiver); + // we send our party id over the channel, so the other side knows with + // which party it communicates + base_sender + .send(InitialMsg { + party_id: my_party_id, + sender, + }) + .await + .map_err(|err| err.without_item())?; + } + Ok(Err(err)) => { + return Err(err.into()); + } + Err(err) => return Err(err.into()), + } + } + debug!("Connected to all remotes"); + Ok(MultiReceiver { receivers }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::util::init_tracing; + use futures::stream::FuturesOrdered; + use futures::TryStreamExt; + use std::net::Ipv4Addr; + + #[tokio::test] + async fn create_multi_channel() { + let _g = init_tracing(); + let base_port: u16 = 7712; + let parties = 10; + let parties_addrs: Vec<_> = (0..parties) + .map(|p| SocketAddr::from((Ipv4Addr::LOCALHOST, base_port + p))) + .collect(); + let mut join_set = JoinSet::new(); + for party in 0..parties { + let parties_addrs = parties_addrs.clone(); + join_set.spawn(async move { + connect::<()>( + parties_addrs[party as usize], + &parties_addrs, + Duration::from_millis(200), + ) + .await + }); + } + let mut cnt = 0; + loop { + match join_set.join_next().await { + None => break, + Some(Ok(Ok(_))) => { + cnt += 1; + } + Some(Ok(Err(err))) => { + panic!("{err:?}"); + } + Some(Err(err)) => { + panic!("{err:?}"); + } + } + } + assert_eq!(cnt, parties); + } + + #[tokio::test] + async fn send_receive_via_multi_channel() { + let _g = init_tracing(); + let base_port: u16 = 7612; + let parties = 5; + let parties_addrs: Vec<_> = (0..parties) + .map(|p| SocketAddr::from((Ipv4Addr::LOCALHOST, base_port + p))) + .collect(); + let mut join_set = JoinSet::new(); + for party in 0..parties { + let parties_addrs = parties_addrs.clone(); + join_set.spawn(async move { + let ch = connect::( + parties_addrs[party as usize], + &parties_addrs, + Duration::from_millis(200), + ) + .await; + (party, ch) + }); + } + let mut multi_channels = HashMap::new(); + loop { + match join_set.join_next().await { + Some(Ok((id, Ok(ch)))) => { + multi_channels.insert(id, ch); + } + None => break, + Some(err) => { + panic!("{err:#?}") + } + } + } + + multi_channels + .get(&0) + .unwrap() + .0 + .send_all("hello there".to_string()) + .await + .unwrap(); + + for (id, (_, mreceiver)) in multi_channels.iter_mut().filter(|(id, _)| **id != 0) { + debug!(id, "Listening on"); + let res = mreceiver + .recv_from(&FromIterator::from_iter([0])) + .next() + .await + .unwrap() + .unwrap(); + assert_eq!( + MsgFrom { + from: 0, + msg: String::from("hello there") + }, + res + ); + } + } + + #[tokio::test] + async fn test_multi_sub_channel() { + type SubMsg = u8; + #[derive(Clone, Serialize, Deserialize)] + struct Msg { + sender: Sender, + } + impl From for Option> { + fn from(value: Msg) -> Self { + Some(value.sender) + } + } + + impl From> for Msg { + fn from(sender: Sender) -> Msg { + Msg { sender } + } + } + + let _g = init_tracing(); + + let base_port: u16 = 7512; + let parties = 10; + let parties_addrs: Vec<_> = (0..parties) + .map(|p| SocketAddr::from((Ipv4Addr::LOCALHOST, base_port + p))) + .collect(); + + let fu: FuturesOrdered<_> = (0..parties) + .map(|party| { + let parties_addrs = &parties_addrs[..]; + async move { + connect::( + parties_addrs[party as usize], + parties_addrs, + Duration::from_millis(100), + ) + .await + .unwrap() + } + }) + .collect(); + + let mut channels: Vec<_> = fu.collect().await; + + let mut sub_chs: Vec<_> = channels + .iter_mut() + .map(|(sender, receiver)| async move { + multi_sub_channel::(sender, receiver, 128) + .await + .unwrap() + }) + .collect::>() + .collect() + .await; + + for (id, (sender, _)) in sub_chs.iter().enumerate() { + sender.send_all(id as u8).await.unwrap(); + } + + for (id, (_, receiver)) in sub_chs.iter_mut().enumerate() { + let v: HashSet<_> = receiver.recv_all().try_collect().await.unwrap(); + + let expected: HashSet<_> = (0..parties) + .filter(|p| *p as usize != id) + .map(|id| MsgFrom { + from: id as u32, + msg: id as u8, + }) + .collect(); + + assert_eq!(expected, v); + } + } +} diff --git a/crates/seec-channel/src/tcp.rs b/crates/seec-channel/src/tcp.rs index 80a7b92..6d951da 100644 --- a/crates/seec-channel/src/tcp.rs +++ b/crates/seec-channel/src/tcp.rs @@ -120,7 +120,7 @@ pub async fn new_local_pair( Ok((ch1, ch2)) } -async fn establish_remoc_connection_tcp( +pub(crate) async fn establish_remoc_connection_tcp( socket: TcpStream, ) -> Result, Error> { // send data ASAP diff --git a/crates/seec-channel/src/tls.rs b/crates/seec-channel/src/tls.rs index f16a5ac..578e813 100644 --- a/crates/seec-channel/src/tls.rs +++ b/crates/seec-channel/src/tls.rs @@ -1,5 +1,5 @@ -use crate::util::TrackingReadWrite; -use crate::TrackingChannel; +use crate::util::{Counter, TrackingReadWrite}; +use crate::{BaseReceiver, BaseSender, TrackingChannel}; use remoc::{ConnectError, RemoteSend}; use rustls::pki_types::{CertificateDer, InvalidDnsNameError, PrivateKeyDer, ServerName}; use rustls::version::TLS13; @@ -11,7 +11,7 @@ use std::io; use std::io::BufReader; use std::path::Path; use std::sync::Arc; -use tokio::io::split; +use tokio::io::{split, AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_rustls::{TlsAcceptor, TlsConnector}; use tracing::info; @@ -45,25 +45,12 @@ pub async fn listen( certificate_chain_file: impl AsRef + Debug, ) -> Result, Error> { info!("Listening for connections"); - let certs = load_certs(certificate_chain_file.as_ref())?; - let key = load_key(private_key_file.as_ref())?; - let config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13]) - .with_no_client_auth() - .with_single_cert(certs, key)?; - let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind(addr).await?; - let (socket, remote_addr) = listener.accept().await?; + let (stream, remote_addr) = listener.accept().await?; info!(?remote_addr, "Accepted TCP connection to remote"); - socket.set_nodelay(true)?; - let (socket_read, socket_write) = socket.into_split(); - let tracking_channel = TrackingReadWrite::new(socket_read, socket_write); - let write_counter = tracking_channel.bytes_written(); - let read_counter = tracking_channel.bytes_read(); - let tls_stream = acceptor.accept(tracking_channel).await?; - info!(?remote_addr, "Established TLS connection to remote"); - let (tls_reader, tls_writer) = split(tls_stream); - let (sender, _, receiver, _) = - super::establish_remoc_connection(tls_reader, tls_writer).await?; + let (tracking_stream, write_counter, read_counter) = tracking_stream(stream)?; + let (sender, receiver) = + tls_accept(tracking_stream, private_key_file, certificate_chain_file).await?; // return the counters that include tls overhead // TODO it might be nice to have both counters Ok((sender, write_counter, receiver, read_counter)) @@ -74,6 +61,63 @@ pub async fn connect( domain: &str, remote_addr: impl ToSocketAddrs + Debug, ) -> Result, Error> { + info!("Connecting to remote"); + let stream = TcpStream::connect(remote_addr).await?; + info!("Established TCP connection to server"); + let (tracking_stream, write_counter, read_counter) = tracking_stream(stream)?; + let (sender, receiver) = tls_connect(domain, tracking_stream).await?; + Ok((sender, write_counter, receiver, read_counter)) +} + +fn tracking_stream( + tcp_stream: TcpStream, +) -> Result< + ( + impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + Counter, + Counter, + ), + Error, +> { + tcp_stream.set_nodelay(true)?; + let (socket_read, socket_write) = tcp_stream.into_split(); + let tracking_channel = TrackingReadWrite::new(socket_read, socket_write); + let write_counter = tracking_channel.bytes_written(); + let read_counter = tracking_channel.bytes_read(); + Ok((tracking_channel, write_counter, read_counter)) +} + +async fn tls_accept( + tcp_stream: IO, + private_key_file: impl AsRef + Debug, + certificate_chain_file: impl AsRef + Debug, +) -> Result<(BaseSender, BaseReceiver), Error> +where + T: RemoteSend, + IO: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + let certs = load_certs(certificate_chain_file.as_ref())?; + let key = load_key(private_key_file.as_ref())?; + let config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13]) + .with_no_client_auth() + .with_single_cert(certs, key)?; + let acceptor = TlsAcceptor::from(Arc::new(config)); + let tls_stream = acceptor.accept(tcp_stream).await?; + info!("Established TLS connection to remote"); + let (tls_reader, tls_writer) = split(tls_stream); + let (sender, _, receiver, _) = + super::establish_remoc_connection(tls_reader, tls_writer).await?; + Ok((sender, receiver)) +} + +async fn tls_connect( + domain: &str, + tcp_stream: IO, +) -> Result<(BaseSender, BaseReceiver), Error> +where + T: RemoteSend, + IO: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ let domain = ServerName::try_from(domain.to_string())?; let mut root_cert_store = rustls::RootCertStore::empty(); let (added, ignored) = root_cert_store.add_parsable_certificates(load_native_certs()?); @@ -83,35 +127,10 @@ pub async fn connect( .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); - - info!("Connecting to remote"); - let stream = TcpStream::connect(remote_addr).await?; - stream.set_nodelay(true)?; - let (socket_read, socket_write) = stream.into_split(); - let tracking_channel = TrackingReadWrite::new(socket_read, socket_write); - let write_counter = tracking_channel.bytes_written(); - let read_counter = tracking_channel.bytes_read(); - info!("Established TCP connection to server"); - let tls_stream = connector.connect(domain, tracking_channel).await?; + let tls_stream = connector.connect(domain, tcp_stream).await?; info!("Established TLS connection to server"); let (tls_reader, tls_writer) = split(tls_stream); let (sender, _, receiver, _) = super::establish_remoc_connection(tls_reader, tls_writer).await?; - Ok((sender, write_counter, receiver, read_counter)) + Ok((sender, receiver)) } - -// #[tracing::instrument(err)] -// pub async fn server( -// addr: impl ToSocketAddrs + Debug, -// ) -> Result, crate::tcp::Error>>, io::Error> { -// info!("Starting Tcp Server"); -// let listener = TcpListener::bind(addr).await?; -// let s = stream! { -// loop { -// let (socket, _) = listener.accept().await?; -// yield establish_remoc_connection_tls(socket).await; -// -// } -// }; -// Ok(s) -// } diff --git a/crates/seec-channel/src/util.rs b/crates/seec-channel/src/util.rs index 99efd7e..e35b40d 100644 --- a/crates/seec-channel/src/util.rs +++ b/crates/seec-channel/src/util.rs @@ -618,3 +618,14 @@ impl DivAssign for CountPair { self.rcvd /= rhs; } } + +#[cfg(test)] +pub(crate) fn init_tracing() -> tracing::dispatcher::DefaultGuard { + use tracing_subscriber::fmt::format::FmtSpan; + use tracing_subscriber::util::SubscriberInitExt; + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_test_writer() + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .set_default() +} diff --git a/crates/seec/examples/aes_cbc.rs b/crates/seec/examples/aes_cbc.rs index bb2e77c..a3ea9eb 100644 --- a/crates/seec/examples/aes_cbc.rs +++ b/crates/seec/examples/aes_cbc.rs @@ -105,7 +105,7 @@ enum Msg { iv: [usize; 2], key: [usize; 2], }, - OtChannel(seec_channel::Receiver), + OtChannel(seec_channel::Sender), Ack, } @@ -159,7 +159,7 @@ async fn execute(args: &ExecuteArgs) -> Result<()> { 128, Msg::OtChannel, |msg| match msg { - Msg::OtChannel(receiver) => Some(receiver), + Msg::OtChannel(sender) => Some(sender), _ => None, }, ) diff --git a/crates/seec/examples/privmail_sc.rs b/crates/seec/examples/privmail_sc.rs index 804af37..b1da2b3 100644 --- a/crates/seec/examples/privmail_sc.rs +++ b/crates/seec/examples/privmail_sc.rs @@ -225,7 +225,7 @@ async fn main() -> anyhow::Result<()> { &mut sender, &mut receiver, 64, - seec_channel::Receiver, + seec_channel::Sender, Message ) .await?; diff --git a/crates/seec/src/bench.rs b/crates/seec/src/bench.rs index 289fa0a..d5612d3 100644 --- a/crates/seec/src/bench.rs +++ b/crates/seec/src/bench.rs @@ -20,7 +20,7 @@ use rand::rngs::OsRng; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; use seec_channel::util::{Phase, RunResult, Statistics}; -use seec_channel::{sub_channels_for, Channel, Receiver}; +use seec_channel::{sub_channels_for, Channel, Sender}; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::fs::File; @@ -36,7 +36,7 @@ type DynMTP

= pub trait BenchProtocol: Protocol + Default + Debug { fn insecure_setup() -> DynMTP; - fn ot_setup(ch: Channel>) -> DynMTP; + fn ot_setup(ch: Channel>) -> DynMTP; fn stored(path: &Path) -> DynMTP; } @@ -45,7 +45,7 @@ impl BenchProtocol for BooleanGmw { Box::new(ErasedError(boolean::InsecureMTProvider::default())) } - fn ot_setup(ch: Channel>) -> DynMTP { + fn ot_setup(ch: Channel>) -> DynMTP { let ot_sender = zappot::ot_ext::Sender::default(); let ot_recv = zappot::ot_ext::Receiver::default(); let mtp = boolean::OtMTProvider::new(OsRng, ot_sender, ot_recv, ch.0, ch.1); @@ -68,7 +68,7 @@ where mixed_gmw::InsecureMixedSetup::default().into_dyn() } - fn ot_setup(_ch: Channel>) -> DynMTP { + fn ot_setup(_ch: Channel>) -> DynMTP { todo!() } @@ -238,7 +238,7 @@ where &mut sender, &mut receiver, 128, - Receiver, + Sender, Message

) .await diff --git a/crates/seec/src/mul_triple/arithmetic/ot_ext.rs b/crates/seec/src/mul_triple/arithmetic/ot_ext.rs index ecc777f..02abbac 100644 --- a/crates/seec/src/mul_triple/arithmetic/ot_ext.rs +++ b/crates/seec/src/mul_triple/arithmetic/ot_ext.rs @@ -19,8 +19,8 @@ pub struct OtMTProvider { rng: RNG, ot_sender: OtS, ot_receiver: OtR, - ch_sender: seec_channel::Sender>, - ch_receiver: seec_channel::Receiver>, + ch_sender: seec_channel::Sender>, + ch_receiver: seec_channel::Receiver>, precomputed_mts: Option>, } @@ -31,8 +31,8 @@ impl>, - ch_receiver: seec_channel::Receiver>, + ch_sender: seec_channel::Sender>, + ch_receiver: seec_channel::Receiver>, ) -> Self { Self { rng, diff --git a/crates/seec/src/mul_triple/boolean/ot_ext.rs b/crates/seec/src/mul_triple/boolean/ot_ext.rs index e3e39e2..2fdd4ef 100644 --- a/crates/seec/src/mul_triple/boolean/ot_ext.rs +++ b/crates/seec/src/mul_triple/boolean/ot_ext.rs @@ -13,7 +13,7 @@ use thiserror::Error; use zappot::traits::{ExtROTReceiver, ExtROTSender}; use zappot::util::aes_rng::AesRng; -pub type Msg = seec_channel::Receiver; +pub type Msg = seec_channel::Sender; /// Message for default ot ext pub type DefaultMsg = Msg<::Msg>; diff --git a/crates/seec/src/private_test_utils.rs b/crates/seec/src/private_test_utils.rs index d4eba27..5c168ee 100644 --- a/crates/seec/src/private_test_utils.rs +++ b/crates/seec/src/private_test_utils.rs @@ -335,7 +335,7 @@ where } TestChannel::Tcp => { let (mut t1, mut t2) = - seec_channel::tcp::new_local_pair::>(None).await?; + seec_channel::tcp::new_local_pair::>(None).await?; let (mut sub_t1, mut sub_t2) = tokio::try_join!( sub_channel(&mut t1.0, &mut t1.2, 2), sub_channel(&mut t2.0, &mut t2.2, 2) diff --git a/crates/seec/tests/mt_providers.rs b/crates/seec/tests/mt_providers.rs index 8407e59..7c77fd2 100644 --- a/crates/seec/tests/mt_providers.rs +++ b/crates/seec/tests/mt_providers.rs @@ -36,7 +36,7 @@ async fn trusted_mt_provider() -> anyhow::Result<()> { let mut ex2 = Executor::::new(&circuit, 1, mt_provider_2).await?; let input_a = BitVec::repeat(false, 256); let input_b = BitVec::repeat(false, 256); - let (mut t1, mut t2) = tcp::new_local_pair::>(None).await?; + let (mut t1, mut t2) = tcp::new_local_pair::>(None).await?; let (mut t1, mut t2) = tokio::try_join!( sub_channel(&mut t1.0, &mut t1.2, 8), sub_channel(&mut t2.0, &mut t2.2, 8) @@ -92,7 +92,7 @@ async fn trusted_seed_mt_provider() -> anyhow::Result<()> { let mut ex2 = Executor::::new(&circuit, 1, mt_provider_2).await?; let input_a = BitVec::repeat(false, 256); let input_b = BitVec::repeat(false, 256); - let (mut t1, mut t2) = tcp::new_local_pair::>(None).await?; + let (mut t1, mut t2) = tcp::new_local_pair::>(None).await?; let (mut t1, mut t2) = tokio::try_join!( sub_channel(&mut t1.0, &mut t1.2, 8), sub_channel(&mut t2.0, &mut t2.2, 8) diff --git a/crates/zappot/examples/alsz_ot_extension.rs b/crates/zappot/examples/alsz_ot_extension.rs index ae4035e..1ac7671 100644 --- a/crates/zappot/examples/alsz_ot_extension.rs +++ b/crates/zappot/examples/alsz_ot_extension.rs @@ -35,7 +35,7 @@ async fn sender(args: Args) -> (Vec<[Block; 2]>, usize, usize) { // Create a channel by listening on a socket address. Once another party connect, this // returns the channel let (mut base_sender, send_cnt, mut base_receiver, recv_cnt) = - seec_channel::tcp::listen::>(("127.0.0.1", args.port)) + seec_channel::tcp::listen::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 128) @@ -58,7 +58,7 @@ async fn receiver(args: Args) -> (Vec, BitVec) { // to create the base_ots let mut receiver = Receiver::new(base_ot::Sender); let (mut base_sender, _, mut base_receiver, _) = - seec_channel::tcp::connect::>(("127.0.0.1", args.port)) + seec_channel::tcp::connect::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 128) diff --git a/crates/zappot/examples/co_base_ot.rs b/crates/zappot/examples/co_base_ot.rs index 6022036..3cd67e1 100644 --- a/crates/zappot/examples/co_base_ot.rs +++ b/crates/zappot/examples/co_base_ot.rs @@ -30,7 +30,7 @@ async fn sender(args: Args) -> Vec<[Block; 2]> { // Create a channel by listening on a socket address. Once another party connect, this // returns the channel let (mut base_sender, _, mut base_receiver, _) = - seec_channel::tcp::listen::>(("127.0.0.1", args.port)) + seec_channel::tcp::listen::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 8) @@ -51,7 +51,7 @@ async fn receiver(args: Args) -> (Vec, BitVec) { let mut receiver = Receiver::new(); // Connect to the sender on the listened on port let (mut base_sender, _, mut base_receiver, _) = - seec_channel::tcp::connect::>(("127.0.0.1", args.port)) + seec_channel::tcp::connect::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); let (ch_sender, mut ch_receiver) = sub_channel(&mut base_sender, &mut base_receiver, 8) diff --git a/crates/zappot/examples/silent_ot.rs b/crates/zappot/examples/silent_ot.rs index c42d598..226b6b0 100644 --- a/crates/zappot/examples/silent_ot.rs +++ b/crates/zappot/examples/silent_ot.rs @@ -38,7 +38,7 @@ async fn sender(args: Args) -> (Vec<[Block; 2]>, usize, usize) { // Create a channel by listening on a socket address. Once another party connect, this // returns the channel let (mut base_sender, bytes_sent, mut base_receiver, bytes_rcv) = - seec_channel::tcp::listen::>(("127.0.0.1", args.port)) + seec_channel::tcp::listen::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); tracing::debug!("Before sub channel"); @@ -71,7 +71,7 @@ async fn receiver(args: Args) -> (Vec, BitVec) { // Create a secure RNG to use in the protocol let mut rng = OsRng; let (mut base_sender, _, mut base_receiver, _) = - seec_channel::tcp::connect::>(("127.0.0.1", args.port)) + seec_channel::tcp::connect::>(("127.0.0.1", args.port)) .await .expect("Error listening for channel connection"); tracing::debug!("Before sub channel"); diff --git a/crates/zappot/src/silent_ot/mod.rs b/crates/zappot/src/silent_ot/mod.rs index 68b8bd4..a29adfd 100644 --- a/crates/zappot/src/silent_ot/mod.rs +++ b/crates/zappot/src/silent_ot/mod.rs @@ -121,8 +121,8 @@ pub enum MultType { /// Message sent during SilentOT evaluation. pub enum Msg { #[serde(bound = "")] - BaseOTChannel(seec_channel::Receiver), - Pprf(seec_channel::Receiver), + BaseOTChannel(seec_channel::Sender), + Pprf(seec_channel::Sender), GapValues(Vec), }