Remove peer from cryptkey router on drop

This commit is contained in:
Mathias Hall-Andersen
2019-08-17 16:31:08 +02:00
parent 5aeea9b619
commit 78ab1a93e6
9 changed files with 242 additions and 100 deletions

11
src/constants.rs Normal file
View File

@@ -0,0 +1,11 @@
use std::time::Duration;
use std::u64;
pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16);
pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);

View File

@@ -1,7 +1,7 @@
#![feature(test)] #![feature(test)]
mod constants;
mod handshake; mod handshake;
mod platform;
mod router; mod router;
mod types; mod types;

View File

@@ -1,10 +0,0 @@
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
pub trait Tun: Send + Sync {
type Error;
fn new(mtu: Arc<AtomicUsize>) -> Self;
fn read(&self, dst: &mut [u8]) -> Result<usize, Self::Error>;
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
}

View File

@@ -1,11 +0,0 @@
/* Often times an a file descriptor in an atomic might suffice.
*/
pub trait Bind<Endpoint>: Send + Sync {
type Error;
fn new() -> Self;
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
fn get_port(&self) -> u16;
fn recv(&self, dst: &mut [u8]) -> Endpoint;
fn send(&self, src: &[u8], dst: &Endpoint);
}

View File

@@ -1,37 +1,38 @@
use arraydeque::{ArrayDeque, Wrapping}; use arraydeque::{ArrayDeque, Wrapping};
use treebitmap::address::Address;
use treebitmap::IpLookupTable; use treebitmap::IpLookupTable;
use crossbeam_deque::{Injector, Steal}; use crossbeam_deque::{Injector, Steal};
use std::collections::HashMap; use std::collections::HashMap;
use std::mem;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::mpsc::SyncSender;
use std::sync::{Arc, Mutex, Weak}; use std::sync::{Arc, Mutex, Weak};
use std::thread; use std::thread;
use std::time::{Duration, Instant}; use std::time::Instant;
use spin; use spin;
use super::super::constants::*;
use super::super::types::KeyPair; use super::super::types::KeyPair;
use super::anti_replay::AntiReplay; use super::anti_replay::AntiReplay;
use std::u64; use std::u64;
const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
const MAX_STAGED_PACKETS: usize = 128; const MAX_STAGED_PACKETS: usize = 128;
struct DeviceInner { struct DeviceInner {
stopped: AtomicBool, stopped: AtomicBool,
injector: Injector<()>, // parallel enc/dec task injector injector: Injector<()>, // parallel enc/dec task injector
threads: Vec<thread::JoinHandle<()>>, threads: Vec<thread::JoinHandle<()>>, // join handles of worker threads
recv: spin::RwLock<HashMap<u32, DecryptionState>>, recv: spin::RwLock<HashMap<u32, DecryptionState>>, // receiver id -> decryption state
ipv4: IpLookupTable<Ipv4Addr, Weak<PeerInner>>, ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner>>>, // ipv4 cryptkey routing
ipv6: IpLookupTable<Ipv6Addr, Weak<PeerInner>>, ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner>>>, // ipv6 cryptkey routing
} }
struct PeerInner { struct PeerInner {
stopped: AtomicBool, stopped: AtomicBool,
device: Arc<DeviceInner>,
thread_outbound: spin::Mutex<thread::JoinHandle<()>>, thread_outbound: spin::Mutex<thread::JoinHandle<()>>,
thread_inbound: spin::Mutex<thread::JoinHandle<()>>, thread_inbound: spin::Mutex<thread::JoinHandle<()>>,
inorder_outbound: SyncSender<()>, inorder_outbound: SyncSender<()>,
@@ -40,7 +41,7 @@ struct PeerInner {
rx_bytes: AtomicU64, // received bytes rx_bytes: AtomicU64, // received bytes
tx_bytes: AtomicU64, // transmitted bytes tx_bytes: AtomicU64, // transmitted bytes
keys: spin::Mutex<KeyWheel>, // key-wheel keys: spin::Mutex<KeyWheel>, // key-wheel
ekey: spin::Mutex<EncryptionState>, // encryption state ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
endpoint: spin::Mutex<Option<Arc<SocketAddr>>>, endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
} }
@@ -68,26 +69,104 @@ struct KeyWheel {
pub struct Peer(Arc<PeerInner>); pub struct Peer(Arc<PeerInner>);
pub struct Device(DeviceInner); pub struct Device(DeviceInner);
fn treebit_list<A, R>(
peer: &Peer,
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>,
callback: Box<dyn Fn(A, u32) -> R>,
) -> Vec<R>
where
A: Address,
{
let mut res = Vec::new();
for subnet in table.read().iter() {
let (ip, masklen, p) = subnet;
if let Some(p) = p.upgrade() {
if Arc::ptr_eq(&p, &peer.0) {
res.push(callback(ip, masklen))
}
}
}
res
}
fn treebit_remove<A>(peer: &Peer, table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>)
where
A: Address,
{
let mut m = table.write();
// collect keys for value
let mut subnets = vec![];
for subnet in m.iter() {
let (ip, masklen, p) = subnet;
if let Some(p) = p.upgrade() {
if Arc::ptr_eq(&p, &peer.0) {
subnets.push((ip, masklen))
}
}
}
// remove all key mappings
for subnet in subnets {
let r = m.remove(subnet.0, subnet.1);
debug_assert!(r.is_some());
}
}
impl Drop for Peer { impl Drop for Peer {
fn drop(&mut self) { fn drop(&mut self) {
// mark peer as stopped // mark peer as stopped
let inner = &self.0; let peer = &self.0;
inner.stopped.store(true, Ordering::SeqCst); peer.stopped.store(true, Ordering::SeqCst);
// unpark threads to stop // remove from cryptkey router
inner.thread_inbound.lock().thread().unpark(); treebit_remove(self, &peer.device.ipv4);
inner.thread_outbound.lock().thread().unpark(); treebit_remove(self, &peer.device.ipv6);
// unpark threads
peer.thread_inbound.lock().thread().unpark();
peer.thread_outbound.lock().thread().unpark();
// collect ids to release
let mut keys = peer.keys.lock();
let mut release = Vec::with_capacity(3);
keys.next.map(|k| release.push(k.recv.id));
keys.current.map(|k| release.push(k.recv.id));
keys.previous.map(|k| release.push(k.recv.id));
// remove from receive id map
if release.len() > 0 {
let mut recv = peer.device.recv.write();
for id in &release {
recv.remove(id);
}
}
// null key-material (TODO: extend)
keys.next = None;
keys.current = None;
keys.previous = None;
*peer.ekey.lock() = None;
*peer.endpoint.lock() = None;
} }
} }
impl Drop for Device { impl Drop for Device {
fn drop(&mut self) { fn drop(&mut self) {
// mark device as stopped // mark device as stopped
let inner = &self.0; let device = &self.0;
inner.stopped.store(true, Ordering::SeqCst); device.stopped.store(true, Ordering::SeqCst);
// eat all parallel jobs // eat all parallel jobs
while inner.injector.steal() != Steal::Empty {} while device.injector.steal() != Steal::Empty {}
// unpark all threads
for handle in &device.threads {
handle.thread().unpark();
}
} }
} }
@@ -97,12 +176,12 @@ impl Peer {
} }
pub fn keypair_confirm(&self, ks: Arc<KeyPair>) { pub fn keypair_confirm(&self, ks: Arc<KeyPair>) {
*self.0.ekey.lock() = EncryptionState { *self.0.ekey.lock() = Some(EncryptionState {
id: ks.send.id, id: ks.send.id,
key: ks.send.key, key: ks.send.key,
nonce: 0, nonce: 0,
death: ks.birth + Duration::from_millis(1337), // todo death: ks.birth + REJECT_AFTER_TIME,
}; });
} }
fn keypair_add(&self, new: KeyPair) -> Option<u32> { fn keypair_add(&self, new: KeyPair) -> Option<u32> {
@@ -112,12 +191,12 @@ impl Peer {
// update key-wheel // update key-wheel
if new.confirmed { if new.confirmed {
// start using key for encryption // start using key for encryption
*self.0.ekey.lock() = EncryptionState { *self.0.ekey.lock() = Some(EncryptionState {
id: new.send.id, id: new.send.id,
key: new.send.key, key: new.send.key,
nonce: 0, nonce: 0,
death: new.birth + Duration::from_millis(1337), // todo death: new.birth + REJECT_AFTER_TIME,
}; });
// move current into previous // move current into previous
keys.previous = keys.current; keys.previous = keys.current;
@@ -148,42 +227,39 @@ impl Device {
stopped: AtomicBool::new(false), stopped: AtomicBool::new(false),
injector: Injector::new(), injector: Injector::new(),
recv: spin::RwLock::new(HashMap::new()), recv: spin::RwLock::new(HashMap::new()),
ipv4: IpLookupTable::new(), ipv4: spin::RwLock::new(IpLookupTable::new()),
ipv6: IpLookupTable::new(), ipv6: spin::RwLock::new(IpLookupTable::new()),
}) })
} }
pub fn add_subnet(&mut self, ip: IpAddr, masklen: u32, peer: Peer) { pub fn add_subnet(&mut self, ip: IpAddr, masklen: u32, peer: Peer) {
match ip { match ip {
IpAddr::V4(v4) => self.0.ipv4.insert(v4, masklen, Arc::downgrade(&peer.0)), IpAddr::V4(v4) => self
IpAddr::V6(v6) => self.0.ipv6.insert(v6, masklen, Arc::downgrade(&peer.0)), .0
.ipv4
.write()
.insert(v4, masklen, Arc::downgrade(&peer.0)),
IpAddr::V6(v6) => self
.0
.ipv6
.write()
.insert(v6, masklen, Arc::downgrade(&peer.0)),
}; };
} }
pub fn subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> { pub fn list_subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> {
let mut subnets = Vec::new(); let mut res = Vec::new();
res.append(&mut treebit_list(
// extract ipv4 entries &peer,
for subnet in self.0.ipv4.iter() { &self.0.ipv4,
let (ip, masklen, p) = subnet; Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
if let Some(p) = p.upgrade() { ));
if Arc::ptr_eq(&p, &peer.0) { res.append(&mut treebit_list(
subnets.push((IpAddr::V4(ip), masklen)) &peer,
} &self.0.ipv6,
} Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
} ));
res
// extract ipv6 entries
for subnet in self.0.ipv6.iter() {
let (ip, masklen, p) = subnet;
if let Some(p) = p.upgrade() {
if Arc::ptr_eq(&p, &peer.0) {
subnets.push((IpAddr::V6(ip), masklen))
}
}
}
subnets
} }
pub fn keypair_add(&self, peer: Peer, new: KeyPair) -> Option<u32> { pub fn keypair_add(&self, peer: Peer, new: KeyPair) -> Option<u32> {
@@ -208,7 +284,7 @@ impl Device {
key: new.recv.key, key: new.recv.key,
protector: Arc::new(spin::Mutex::new(AntiReplay::new())), protector: Arc::new(spin::Mutex::new(AntiReplay::new())),
peer: Arc::downgrade(&peer.0), peer: Arc::downgrade(&peer.0),
death: new.birth + Duration::from_millis(2600), // todo death: new.birth + REJECT_AFTER_TIME,
}, },
); );

26
src/types/keys.rs Normal file
View File

@@ -0,0 +1,26 @@
use std::time::Instant;
/* This file holds types passed between components.
* Whenever a type cannot be held local to a single module.
*/
#[derive(Debug, Clone, Copy)]
pub struct Key {
pub key: [u8; 32],
pub id: u32,
}
#[cfg(test)]
impl PartialEq for Key {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.key[..] == other.key[..]
}
}
#[derive(Debug, Clone, Copy)]
pub struct KeyPair {
pub birth: Instant, // when was the key-pair created
pub confirmed: bool, // has the key-pair been confirmed?
pub send: Key, // key for outbound messages
pub recv: Key, // key for inbound messages
}

View File

@@ -1,26 +1,7 @@
use std::time::Instant; mod keys;
mod tun;
mod udp;
/* This file holds types passed between components. pub use keys::{Key, KeyPair};
* Whenever a type cannot be held local to a single module. pub use tun::Tun;
*/ pub use udp::Bind;
#[derive(Debug, Clone, Copy)]
pub struct Key {
pub key: [u8; 32],
pub id: u32,
}
#[cfg(test)]
impl PartialEq for Key {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.key[..] == other.key[..]
}
}
#[derive(Debug, Clone, Copy)]
pub struct KeyPair {
pub birth: Instant, // when was the key-pair created
pub confirmed: bool, // has the key-pair been confirmed?
pub send: Key, // key for outbound messages
pub recv: Key, // key for inbound messages
}

43
src/types/tun.rs Normal file
View File

@@ -0,0 +1,43 @@
use std::error;
pub trait Tun: Send + Sync {
type Error: error::Error;
/// Returns the MTU of the device
///
/// This function needs to be efficient (called for every read).
/// The goto implementation stragtegy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a seperate thread.
///
/// # Returns
///
/// The MTU of the interface in bytes
fn mtu(&self) -> usize;
/// Reads an IP packet into dst[offset:] from the tunnel device
///
/// The reason for providing space for a prefix
/// is to efficiently accommodate platforms on which the packet is prefaced by a header.
/// This space is later used to construct the transport message inplace.
///
/// # Arguments
///
/// - dst: Destination buffer (enough space for MTU bytes + header)
/// - offset: Offset for the beginning of the IP packet
///
/// # Returns
///
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
fn read(&self, dst: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
/// Writes an IP packet to the tunnel device
///
/// # Arguments
///
/// - src: Buffer containing the IP packet to be written
///
/// # Returns
///
/// Unit type or an error
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
}

26
src/types/udp.rs Normal file
View File

@@ -0,0 +1,26 @@
use std::error;
/* Often times an a file descriptor in an atomic might suffice.
*/
pub trait Bind<Endpoint>: Send + Sync {
type Error : error::Error;
fn new() -> Self;
/// Updates the port of the Bind
///
/// # Arguments
///
/// - port, The new port to bind to. 0 means any available port.
///
/// # Returns
///
/// The unit type or an error, if binding fails
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
/// Returns the current port of the bind
fn get_port(&self) -> u16;
fn recv(&self, dst: &mut [u8]) -> Endpoint;
fn send(&self, src: &[u8], dst: &Endpoint);
}