Added zero_key to peer

This commit is contained in:
Mathias Hall-Andersen
2019-09-21 17:22:03 +02:00
parent 6311aa3402
commit 5cc1083499
6 changed files with 190 additions and 66 deletions

View File

@@ -64,10 +64,16 @@ impl Device {
self.macs = macs::Validator::new(pk); self.macs = macs::Validator::new(pk);
// recalculate the shared secrets for every peer // recalculate the shared secrets for every peer
for &mut peer in self.pk_map.values_mut() { let mut ids = vec![];
peer.reset_state().map(|id| self.release(id)); for mut peer in self.pk_map.values_mut() {
peer.reset_state().map(|id| ids.push(id));
peer.ss = self.sk.diffie_hellman(&peer.pk) peer.ss = self.sk.diffie_hellman(&peer.pk)
} }
// release ids from aborted handshakes
for id in ids {
self.release(id)
}
} }
/// Add a new public key to the state machine /// Add a new public key to the state machine

View File

@@ -8,6 +8,7 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
mod constants; mod constants;
mod handshake; mod handshake;
mod router; mod router;
mod timers;
mod types; mod types;
mod wireguard; mod wireguard;

View File

@@ -36,7 +36,7 @@ pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed) next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption) current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption) previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
retired: Option<u32>, // retired id (previous id, after confirming key-pair) retired: Vec<u32>, // retired ids
} }
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> { pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
@@ -188,7 +188,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
next: None, next: None,
current: None, current: None,
previous: None, previous: None,
retired: None, retired: vec![],
}), }),
staged_packets: spin::Mutex::new(ArrayDeque::new()), staged_packets: spin::Mutex::new(ArrayDeque::new()),
}) })
@@ -375,6 +375,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
} }
/// Returns the current endpoint of the peer (for configuration)
///
/// # Note
///
/// Does not convey potential "sticky socket" information
pub fn get_endpoint(&self) -> Option<SocketAddr> { pub fn get_endpoint(&self) -> Option<SocketAddr> {
self.state self.state
.endpoint .endpoint
@@ -383,6 +388,30 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
.map(|e| e.into_address()) .map(|e| e.into_address())
} }
/// Zero all key-material related to the peer
pub fn zero_keys(&self) {
let mut release: Vec<u32> = Vec::with_capacity(3);
let mut keys = self.state.keys.lock();
// update key-wheel
mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
keys.retired.extend(&release[..]);
// update inbound "recv" map
{
let mut recv = self.state.device.recv.write();
for id in release {
recv.remove(&id);
}
}
// clear encryption state
*self.state.ekey.lock() = None;
}
/// Add a new keypair /// Add a new keypair
/// ///
/// # Arguments /// # Arguments
@@ -393,14 +422,16 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// ///
/// A vector of ids which has been released. /// A vector of ids which has been released.
/// These should be released in the handshake module. /// These should be released in the handshake module.
///
/// # Note
///
/// The number of ids to be released can be at most 3,
/// since the only way to add additional keys to the peer is by using this method
/// and a peer can have at most 3 keys allocated in the router at any time.
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> { pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
let mut keys = self.state.keys.lock();
let mut release = Vec::with_capacity(2);
let new = Arc::new(new); let new = Arc::new(new);
let mut keys = self.state.keys.lock();
// collect ids to be released let mut release = mem::replace(&mut keys.retired, vec![]);
keys.retired.map(|v| release.push(v));
keys.previous.as_ref().map(|k| release.push(k.recv.id));
// update key-wheel // update key-wheel
if new.initiator { if new.initiator {
@@ -420,10 +451,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
{ {
let mut recv = self.state.device.recv.write(); let mut recv = self.state.device.recv.write();
// purge recv map of released ids // purge recv map of previous id
for id in &release { keys.previous.as_ref().map(|k| {
recv.remove(&id); recv.remove(&k.local_id());
} release.push(k.local_id());
});
// map new id to decryption state // map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id)); debug_assert!(!recv.contains_key(&new.recv.id));
@@ -442,7 +474,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
} }
} }
// return the released id (for handshake state machine) debug_assert!(release.len() <= 3);
release release
} }

65
src/timers.rs Normal file
View File

@@ -0,0 +1,65 @@
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use hjul::{Runner, Timer};
use crate::router::Callbacks;
const ZERO_DURATION: Duration = Duration::from_micros(0);
pub struct TimersInner {
handshake_pending: AtomicBool,
handshake_attempts: AtomicUsize,
retransmit_handshake: Timer,
send_keepalive: Timer,
zero_key_material: Timer,
new_handshake: Timer,
// stats
rx_bytes: AtomicU64,
tx_bytes: AtomicU64,
}
impl TimersInner {
pub fn new(runner: &Runner) -> Timers {
Arc::new(TimersInner {
handshake_pending: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: runner.timer(|| {}),
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {}),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
})
}
pub fn handshake_sent(&self) {
self.send_keepalive.stop();
}
}
pub type Timers = Arc<TimersInner>;
pub struct Events();
impl Callbacks for Events {
type Opaque = Timers;
fn send(t: &Timers, size: usize, data: bool, sent: bool) {
t.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
}
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {
t.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
}
fn need_key(t: &Timers) {
if !t.handshake_pending.swap(true, Ordering::SeqCst) {
t.handshake_attempts.store(0, Ordering::SeqCst);
t.new_handshake.reset(ZERO_DURATION);
}
}
}

View File

@@ -28,3 +28,9 @@ pub struct KeyPair {
pub send: Key, // key for outbound messages pub send: Key, // key for outbound messages
pub recv: Key, // key for inbound messages pub recv: Key, // key for inbound messages
} }
impl KeyPair {
pub fn local_id(&self) -> u32 {
self.recv.id
}
}

View File

@@ -1,5 +1,6 @@
use crate::handshake; use crate::handshake;
use crate::router; use crate::router;
use crate::timers::{Events, Timers};
use crate::types::{Bind, Endpoint, Tun}; use crate::types::{Bind, Endpoint, Tun};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
@@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
#[derive(Clone)] type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>;
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
pub struct PeerInner<T: Tun, B: Bind> { pub struct PeerInner<T: Tun, B: Bind> {
router: router::Peer<Events, T, B>, queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
timers: Timers, router: router::Peer<Events, T, B>, // router peer
rx: AtomicU64, timers: Option<Timers>, //
tx: AtomicU64,
} }
pub struct Timers {} impl<T: Tun, B: Bind> PeerInner<T, B> {
#[inline(always)]
pub struct Events(); fn timers(&self) -> &Timers {
self.timers.as_ref().unwrap()
impl router::Callbacks for Events { }
type Opaque = Timers;
fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
fn need_key(t: &Timers) {}
} }
struct Handshake { struct Handshake {
@@ -50,6 +42,11 @@ struct Handshake {
active: bool, active: bool,
} }
enum HandshakeJob<E> {
Message(Vec<u8>, E),
New(PublicKey),
}
struct WireguardInner<T: Tun, B: Bind> { struct WireguardInner<T: Tun, B: Bind> {
// identify and configuration map // identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
@@ -61,7 +58,7 @@ struct WireguardInner<T: Tun, B: Bind> {
handshake: RwLock<Handshake>, handshake: RwLock<Handshake>,
under_load: AtomicBool, under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>, queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
// IO // IO
bind: B, bind: B,
@@ -90,7 +87,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
fn new(tun: T, bind: B) -> Wireguard<T, B> { fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state // create device state
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner { let wg = Arc::new(WireguardInner {
peers: RwLock::new(HashMap::new()), peers: RwLock::new(HashMap::new()),
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
@@ -114,50 +111,64 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue // process elements from the handshake queue
for (msg, src) in rx { for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst); wg.pending.fetch_sub(1, Ordering::SeqCst);
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
let state = wg.handshake.read(); let state = wg.handshake.read();
if !state.active { if !state.active {
continue; continue;
} }
// process message match job {
match state.device.process( HandshakeJob::Message(msg, src) => {
&mut rng, // feed message to handshake device
&msg[..], let src_validate = (&src).into_address(); // TODO avoid
if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate) // process message
} else { match state.device.process(
None &mut rng,
}, &msg[..],
) { if wg.under_load.load(Ordering::Relaxed) {
Ok((pk, msg, keypair)) => { Some(&src_validate)
// send response } else {
if let Some(msg) = msg { None
let _ = bind.send(&msg[..], &src).map_err(|e| { },
debug!( ) {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {:?}", "handshake worker, failed to send response, error = {:?}",
e e
) )
}); });
} }
// update timers // update timers
if let Some(pk) = pk { if let Some(pk) = pk {
// add keypair to peer and free any unused ids if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
if let Some(keypair) = keypair { // update endpoint (DISCUSS: right semantics?)
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { peer.router.set_endpoint(src_validate);
for id in peer.0.router.add_keypair(keypair) {
state.device.release(id); // add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
for id in peer.router.add_keypair(keypair) {
state.device.release(id);
}
}
} }
} }
} }
Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
HandshakeJob::New(pk) => {
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
peer.router.send(&msg[..]);
peer.timers().handshake_sent();
} }
} }
Err(e) => debug!("handshake worker, error = {:?}", e),
} }
} }
}); });
@@ -197,7 +208,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
wg.under_load.store(false, Ordering::SeqCst); wg.under_load.store(false, Ordering::SeqCst);
} }
wg.queue.lock().send((msg, src)).unwrap(); wg.queue
.lock()
.send(HandshakeJob::Message(msg, src))
.unwrap();
} }
router::TYPE_TRANSPORT => { router::TYPE_TRANSPORT => {
// transport message // transport message
@@ -223,7 +237,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
msg.truncate(size); msg.truncate(size);
// pad message to multiple of 16 // pad message to multiple of 16 bytes
while msg.len() < mtu && msg.len() % 16 != 0 { while msg.len() < mtu && msg.len() % 16 != 0 {
msg.push(0); msg.push(0);
} }