Change router job to accommodate keep_key_fresh

This commit is contained in:
Mathias Hall-Andersen
2019-10-30 12:01:12 +01:00
parent e04a11a8ca
commit afc96611a5
5 changed files with 139 additions and 140 deletions

View File

@@ -19,7 +19,7 @@ use super::constants::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner}; use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callbacks, RouterError}; use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation}; use super::workers::{worker_parallel, JobParallel};
use super::SIZE_MESSAGE_PREFIX; use super::SIZE_MESSAGE_PREFIX;
use super::route::get_route; use super::route::get_route;
@@ -44,10 +44,9 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
} }
pub struct EncryptionState { pub struct EncryptionState {
pub key: [u8; 32], // encryption key pub keypair: Arc<KeyPair>, // keypair
pub id: u32, // receiver id pub nonce: u64, // next available nonce
pub nonce: u64, // next available nonce pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
} }
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
@@ -143,8 +142,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
// schedule for encryption and transmission to peer // schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) { if let Some(job) = peer.send_job(msg, true) {
debug_assert_eq!(job.1.op, Operation::Encryption);
// add job to worker queue // add job to worker queue
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
let queues = self.state.queues.lock(); let queues = self.state.queues.lock();
@@ -186,8 +183,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
// schedule for decryption and TUN write // schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
debug_assert_eq!(job.1.op, Operation::Decryption);
// add job to worker queue // add job to worker queue
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
let queues = self.state.queues.lock(); let queues = self.state.queues.lock();

View File

@@ -14,6 +14,9 @@ mod tests;
use messages::TransportHeader; use messages::TransportHeader;
use std::mem; use std::mem;
use super::constants::REJECT_AFTER_MESSAGES;
use super::constants::REKEY_AFTER_MESSAGES;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG;

View File

@@ -24,9 +24,8 @@ use super::messages::TransportHeader;
use futures::*; use futures::*;
use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound}; use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel};
use super::SIZE_MESSAGE_PREFIX; use super::SIZE_MESSAGE_PREFIX;
use super::constants::*; use super::constants::*;
@@ -99,9 +98,8 @@ fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind
impl EncryptionState { impl EncryptionState {
fn new(keypair: &Arc<KeyPair>) -> EncryptionState { fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
EncryptionState { EncryptionState {
id: keypair.send.id,
key: keypair.send.key,
nonce: 0, nonce: 0,
keypair: keypair.clone(),
death: keypair.birth + REJECT_AFTER_TIME, death: keypair.birth + REJECT_AFTER_TIME,
} }
} }
@@ -294,22 +292,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E,
msg: Vec<u8>, msg: Vec<u8>,
) -> Option<JobParallel> { ) -> Option<JobParallel> {
let (tx, rx) = oneshot(); let (tx, rx) = oneshot();
let key = dec.keypair.recv.key; let keypair = dec.keypair.clone();
match self.inbound.lock().try_send((dec, src, rx)) { match self.inbound.lock().try_send((dec, src, rx)) {
Ok(_) => Some(( Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })),
tx,
JobBuffer {
msg,
key: key,
okay: false,
op: Operation::Decryption,
},
)),
Err(_) => None, Err(_) => None,
} }
} }
pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> { pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
debug!("peer.send_job"); debug!("peer.send_job");
debug_assert!( debug_assert!(
msg.len() >= mem::size_of::<TransportHeader>(), msg.len() >= mem::size_of::<TransportHeader>(),
@@ -317,29 +307,24 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E,
msg.len() msg.len()
); );
// parse / cast
let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
let mut header: LayoutVerified<&mut [u8], TransportHeader> = header;
// check if has key // check if has key
let key = { let (keypair, counter) = {
let mut ekey = self.ekey.lock(); let keypair = {
let key = match ekey.as_mut() { // TODO: consider using atomic ptr for ekey state
None => None, let mut ekey = self.ekey.lock();
Some(mut state) => { match ekey.as_mut() {
// avoid integer overflow in nonce None => None,
if state.nonce >= REJECT_AFTER_MESSAGES - 1 { Some(mut state) => {
*ekey = None; // avoid integer overflow in nonce
None if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
} else { *ekey = None;
// there should be no stacked packets lingering around None
debug!("encryption state available, nonce = {}", state.nonce); } else {
debug!("encryption state available, nonce = {}", state.nonce);
// set transport message fields let counter = state.nonce;
header.f_counter.set(state.nonce); state.nonce += 1;
header.f_receiver.set(state.id); Some((state.keypair.clone(), counter))
state.nonce += 1; }
Some(state.key)
} }
} }
}; };
@@ -347,25 +332,24 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E,
// If not suitable key was found: // If not suitable key was found:
// 1. Stage packet for later transmission // 1. Stage packet for later transmission
// 2. Request new key // 2. Request new key
if key.is_none() && stage { if keypair.is_none() && stage {
self.staged_packets.lock().push_back(msg); self.staged_packets.lock().push_back(msg);
C::need_key(&self.opaque); C::need_key(&self.opaque);
return None; return None;
}; };
key keypair
}?; }?;
// add job to in-order queue and return sendeer to device for inclusion in worker pool // add job to in-order queue and return sender to device for inclusion in worker pool
let (tx, rx) = oneshot(); let (tx, rx) = oneshot();
match self.outbound.lock().try_send(rx) { match self.outbound.lock().try_send(rx) {
Ok(_) => Some(( Ok(_) => Some(JobParallel::Encryption(
tx, tx,
JobBuffer { JobEncryption {
msg, msg,
key, counter,
okay: false, keypair,
op: Operation::Encryption,
}, },
)), )),
Err(_) => None, Err(_) => None,

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use futures::sync::oneshot; use futures::sync::oneshot;
use futures::*; use futures::*;
use log::debug; use log::{debug, trace};
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
@@ -16,34 +16,40 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner; use super::peer::PeerInner;
use super::route::check_route; use super::route::check_route;
use super::types::Callbacks; use super::types::Callbacks;
use super::REJECT_AFTER_MESSAGES;
use super::super::types::KeyPair;
use super::super::{bind, tun, Endpoint}; use super::super::{bind, tun, Endpoint};
pub const SIZE_TAG: usize = 16; pub const SIZE_TAG: usize = 16;
#[derive(PartialEq, Debug)] #[derive(Debug)]
pub enum Operation { pub struct JobEncryption {
Encryption, pub msg: Vec<u8>,
Decryption, pub keypair: Arc<KeyPair>,
pub counter: u64,
} }
pub struct JobBuffer { #[derive(Debug)]
pub msg: Vec<u8>, // message buffer (nonce and receiver id set) pub struct JobDecryption {
pub key: [u8; 32], // chacha20poly1305 key pub msg: Vec<u8>,
pub okay: bool, // state of the job pub keypair: Arc<KeyPair>,
pub op: Operation, // should be buffer be encrypted / decrypted?
} }
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer); #[derive(Debug)]
pub enum JobParallel {
Encryption(oneshot::Sender<JobEncryption>, JobEncryption),
Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption),
}
#[allow(type_alias_bounds)] #[allow(type_alias_bounds)]
pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
Arc<DecryptionState<E, C, T, B>>, Arc<DecryptionState<E, C, T, B>>,
E, E,
oneshot::Receiver<JobBuffer>, oneshot::Receiver<Option<JobDecryption>>,
); );
pub type JobOutbound = oneshot::Receiver<JobBuffer>; pub type JobOutbound = oneshot::Receiver<JobEncryption>;
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device device: Arc<DeviceInner<E, C, T, B>>, // related device
@@ -64,7 +70,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer
let _ = rx let _ = rx
.map(|buf| { .map(|buf| {
debug!("inbound worker: job complete"); debug!("inbound worker: job complete");
if buf.okay { if let Some(buf) = buf {
// cast transport header // cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
match LayoutVerified::new_from_prefix(&buf.msg[..]) { match LayoutVerified::new_from_prefix(&buf.msg[..]) {
@@ -134,6 +140,10 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
peer: Arc<PeerInner<E, C, T, B>>, // related peer peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>, receiver: Receiver<JobOutbound>,
) { ) {
fn keep_key_fresh(keypair: &KeyPair, counter: u64) -> bool {
false
}
loop { loop {
// fetch job // fetch job
let rx = match receiver.recv() { let rx = match receiver.recv() {
@@ -148,27 +158,30 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
let _ = rx let _ = rx
.map(|buf| { .map(|buf| {
debug!("outbound worker: job complete"); debug!("outbound worker: job complete");
if buf.okay { // write to UDP bind
// write to UDP bind let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { let send: &Option<B> = &*device.outbound.read();
let send: &Option<B> = &*device.outbound.read(); if let Some(writer) = send.as_ref() {
if let Some(writer) = send.as_ref() { match writer.write(&buf.msg[..], dst) {
match writer.write(&buf.msg[..], dst) { Err(e) => {
Err(e) => { debug!("failed to send outbound packet: {:?}", e);
debug!("failed to send outbound packet: {:?}", e); false
false
}
Ok(_) => true,
} }
} else { Ok(_) => true,
false
} }
} else { } else {
false false
}; }
} else {
false
};
// trigger callback // trigger callback
C::send(&peer.opaque, buf.msg.len(), xmit); C::send(&peer.opaque, buf.msg.len(), xmit);
// keep_key_fresh semantics
if keep_key_fresh(&buf.keypair, buf.counter) {
C::need_key(&peer.opaque);
} }
}) })
.wait(); .wait();
@@ -178,76 +191,85 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
pub fn worker_parallel(receiver: Receiver<JobParallel>) { pub fn worker_parallel(receiver: Receiver<JobParallel>) {
loop { loop {
// fetch next job // fetch next job
let (tx, mut buf) = match receiver.recv() { let job = match receiver.recv() {
Err(_) => { Err(_) => {
return; return;
} }
Ok(val) => val, Ok(val) => val,
}; };
debug!("parallel worker: obtained job"); trace!("parallel worker: obtained job");
// make space for tag (TODO: consider moving this out) // handle job
if buf.op == Operation::Encryption { match job {
buf.msg.extend([0u8; SIZE_TAG].iter()); JobParallel::Encryption(tx, mut job) => {
} job.msg.extend([0u8; SIZE_TAG].iter());
// cast and check size of packet // cast to header (should never fail)
let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut buf.msg[..]) { LayoutVerified::new_from_prefix(&mut job.msg[..])
Some(v) => v, .expect("earlier code should ensure that there is ample space");
None => {
debug_assert!(
false,
"parallel worker: failed to parse message (insufficient size)"
);
continue;
}
};
debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
// do the weird ring AEAD dance // set header fields
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
match buf.op {
Operation::Encryption => {
debug!("parallel worker: process encryption");
// set the type field
header.f_type.set(TYPE_TRANSPORT); header.f_type.set(TYPE_TRANSPORT);
header.f_receiver.set(job.keypair.send.id);
header.f_counter.set(job.counter);
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
);
// encrypt content of transport message in-place // encrypt content of transport message in-place
let end = packet.len() - SIZE_TAG; let end = body.len() - SIZE_TAG;
let tag = key let tag = key
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) .seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end])
.unwrap(); .unwrap();
// append tag // append tag
packet[end..].copy_from_slice(tag.as_ref()); body[end..].copy_from_slice(tag.as_ref());
buf.okay = true; // pass ownership
let _ = tx.send(job);
} }
Operation::Decryption => { JobParallel::Decryption(tx, mut job) => {
debug!("parallel worker: process decryption"); // cast to header (could fail)
let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> =
LayoutVerified::new_from_prefix(&mut job.msg[..]);
// opening failure is signaled by fault state let _ = tx.send(match layout {
buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) { Some((header, body)) => {
Ok(_) => true, debug_assert_eq!(header.f_type.get(), TYPE_TRANSPORT);
Err(_) => false, if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
}; None
} else {
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..])
.unwrap(),
);
// attempt to open (and authenticate) the body
match key.open_in_place(nonce, Aad::empty(), body) {
Ok(_) => Some(job),
Err(_) => None,
}
}
}
None => None,
});
} }
} }
// pass ownership to consumer
let okay = tx.send(buf);
debug!(
"parallel worker: passing ownership to sequential worker: {}",
okay.is_ok()
);
} }
} }

View File

@@ -77,7 +77,6 @@ fn wait() {
} }
/* Create and configure two matching pure instances of WireGuard /* Create and configure two matching pure instances of WireGuard
*
*/ */
#[test] #[test]
fn test_pure_wireguard() { fn test_pure_wireguard() {
@@ -166,8 +165,6 @@ fn test_pure_wireguard() {
fake1.write(p); fake1.write(p);
} }
wait();
while let Some(p) = backup.pop() { while let Some(p) = backup.pop() {
assert_eq!( assert_eq!(
hex::encode(fake2.read()), hex::encode(fake2.read()),
@@ -197,8 +194,6 @@ fn test_pure_wireguard() {
fake2.write(p); fake2.write(p);
} }
wait();
while let Some(p) = backup.pop() { while let Some(p) = backup.pop() {
assert_eq!( assert_eq!(
hex::encode(fake1.read()), hex::encode(fake1.read()),