WIP: TUN IO worker

Also removed the type parameters from the handshake device.
This commit is contained in:
Mathias Hall-Andersen
2019-09-18 15:31:10 +02:00
parent dfe4a22920
commit 6311aa3402
6 changed files with 217 additions and 134 deletions

View File

@@ -21,11 +21,11 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
pub struct Device<T> {
pub struct Device {
pub sk: StaticSecret, // static secret key
pub pk: PublicKey, // static public key
macs: macs::Validator, // validator for the mac fields
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
limiter: Mutex<RateLimiter>,
}
@@ -33,16 +33,13 @@ pub struct Device<T> {
/* A mutable reference to the device needs to be held during configuration.
* Wrapping the device in a RwLock enables peer config after "configuration time"
*/
impl<T> Device<T>
where
T: Clone,
{
impl Device {
/// Initialize a new handshake state machine
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
pub fn new(sk: StaticSecret) -> Device<T> {
pub fn new(sk: StaticSecret) -> Device {
let pk = PublicKey::from(&sk);
Device {
pk,
@@ -54,6 +51,25 @@ where
}
}
/// Update the secret key of the device
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
pub fn set_sk(&mut self, sk: StaticSecret) {
// update secret and public key
let pk = PublicKey::from(&sk);
self.sk = sk;
self.pk = pk;
self.macs = macs::Validator::new(pk);
// recalculate the shared secrets for every peer
for &mut peer in self.pk_map.values_mut() {
peer.reset_state().map(|id| self.release(id));
peer.ss = self.sk.diffie_hellman(&peer.pk)
}
}
/// Add a new public key to the state machine
/// To remove public keys, you must create a new machine instance
///
@@ -61,7 +77,7 @@ where
///
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
// check that the pk is not added twice
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
return Err(ConfigError::new("Duplicate public key"));
@@ -80,10 +96,8 @@ where
}
// map the public key to the peer state
self.pk_map.insert(
*pk.as_bytes(),
Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
);
self.pk_map
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
Ok(())
}
@@ -204,7 +218,7 @@ where
rng: &mut R, // rng instance to sample randomness from
msg: &[u8], // message buffer
src: Option<&'a S>, // optional source endpoint, set when "under load"
) -> Result<Output<T>, HandshakeError>
) -> Result<Output, HandshakeError>
where
&'a S: Into<&'a SocketAddr>,
{
@@ -269,11 +283,7 @@ where
.generate(resp.noise.as_bytes(), &mut resp.macs);
// return unconfirmed keypair and the response as vector
Ok((
Some(peer.identifier.clone()),
Some(resp.as_bytes().to_owned()),
Some(keys),
))
Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
}
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
@@ -328,7 +338,7 @@ where
// Internal function
//
// Return the peer associated with the public key
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<T>, HandshakeError> {
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
self.pk_map
.get(pk.as_bytes())
.ok_or(HandshakeError::UnknownPublicKey)
@@ -337,7 +347,7 @@ where
// Internal function
//
// Return the peer currently associated with the receiver identifier
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer<T>, HandshakeError> {
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
let im = self.id_map.read();
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
match self.pk_map.get(pk) {
@@ -349,7 +359,7 @@ where
// Internal function
//
// Allocated a new receiver identifier for the peer
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer<T>) -> u32 {
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
loop {
let id = rng.gen();
@@ -380,7 +390,7 @@ mod tests {
fn setup_devices<R: RngCore + CryptoRng>(
rng: &mut R,
) -> (PublicKey, Device<usize>, PublicKey, Device<usize>) {
) -> (PublicKey, Device, PublicKey, Device) {
// generate new keypairs
let sk1 = StaticSecret::new(rng);
@@ -399,8 +409,8 @@ mod tests {
let mut dev1 = Device::new(sk1);
let mut dev2 = Device::new(sk2);
dev1.add(pk2, 1337).unwrap();
dev2.add(pk1, 2600).unwrap();
dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap();
dev1.set_psk(pk2, Some(psk)).unwrap();
dev2.set_psk(pk1, Some(psk)).unwrap();

View File

@@ -215,10 +215,10 @@ mod tests {
}
}
pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
pub fn create_initiation<R: RngCore + CryptoRng>(
rng: &mut R,
device: &Device<T>,
peer: &Peer<T>,
device: &Device,
peer: &Peer,
sender: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
@@ -296,10 +296,10 @@ pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
})
}
pub fn consume_initiation<'a, T: Clone>(
device: &'a Device<T>,
pub fn consume_initiation<'a>(
device: &'a Device,
msg: &NoiseInitiation,
) -> Result<(&'a Peer<T>, TemporaryState), HandshakeError> {
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state
@@ -370,9 +370,9 @@ pub fn consume_initiation<'a, T: Clone>(
})
}
pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
pub fn create_response<R: RngCore + CryptoRng>(
rng: &mut R,
peer: &Peer<T>,
peer: &Peer,
sender: u32, // sending identifier
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
@@ -456,10 +456,7 @@ pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
* allow concurrent processing of potential responses to the initiation,
* in order to better mitigate DoS from malformed response messages.
*/
pub fn consume_response<T: Clone>(
device: &Device<T>,
msg: &NoiseResponse,
) -> Result<Output<T>, HandshakeError> {
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?;
@@ -530,7 +527,7 @@ pub fn consume_response<T: Clone>(
// return confirmed key-pair
Ok((
Some(peer.identifier.clone()),
Some(peer.pk),
None,
Some(KeyPair {
birth,

View File

@@ -1,5 +1,7 @@
use lazy_static::lazy_static;
use spin::Mutex;
use std::mem;
use std::time::{Duration, Instant};
use generic_array::typenum::U32;
@@ -24,10 +26,7 @@ lazy_static! {
*
* This type is only for internal use and not exposed.
*/
pub struct Peer<T> {
// external identifier
pub(crate) identifier: T,
pub struct Peer {
// mutable state
pub(crate) state: Mutex<State>,
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
@@ -65,18 +64,13 @@ impl Drop for State {
}
}
impl<T> Peer<T>
where
T: Clone,
{
impl Peer {
pub fn new(
identifier: T, // external identifier
pk: PublicKey, // public key of peer
ss: SharedSecret, // precomputed DH(static, static)
) -> Self {
Self {
macs: Mutex::new(macs::Generator::new(pk)),
identifier: identifier,
state: Mutex::new(State::Reset),
timestamp: Mutex::new(None),
last_initiation_consumption: Mutex::new(None),
@@ -94,6 +88,13 @@ where
*self.state.lock() = state_new;
}
pub fn reset_state(&self) -> Option<u32> {
match mem::replace(&mut *self.state.lock(), State::Reset) {
State::InitiationSent { sender, .. } => Some(sender),
_ => None,
}
}
/// Set the mutable state of the peer conditioned on the timestamp being newer
///
/// # Arguments
@@ -102,7 +103,7 @@ where
/// * ts_new - The associated timestamp
pub fn check_replay_flood(
&self,
device: &Device<T>,
device: &Device,
timestamp_new: &timestamp::TAI64N,
) -> Result<(), HandshakeError> {
let mut state = self.state.lock();

View File

@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;
use x25519_dalek::PublicKey;
use crate::types::KeyPair;
/* Internal types for the noise IKpsk2 implementation */
@@ -77,10 +79,10 @@ impl Error for HandshakeError {
}
}
pub type Output<T> = (
Option<T>, // external identifier associated with peer
Option<Vec<u8>>, // message to send
Option<KeyPair>, // resulting key-pair of successful handshake
pub type Output = (
Option<PublicKey>, // external identifier associated with peer
Option<Vec<u8>>, // message to send
Option<KeyPair>, // resulting key-pair of successful handshake
);
// preshared key

View File

@@ -121,7 +121,6 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
@@ -149,6 +148,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
}
}
/// A new secret key has been set for the device.
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
pub fn new_sk(&self) {}
/// Adds a new peer to the device
///
/// # Returns

View File

@@ -2,17 +2,20 @@ use crate::handshake;
use crate::router;
use crate::types::{Bind, Endpoint, Tun};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use std::collections::HashMap;
use log::debug;
use rand::rngs::OsRng;
use spin::{Mutex, RwLock};
use byteorder::{ByteOrder, LittleEndian};
use crossbeam_channel::bounded;
use x25519_dalek::StaticSecret;
use crossbeam_channel::{bounded, Sender};
use x25519_dalek::{PublicKey, StaticSecret};
const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
@@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
pub struct PeerInner<T: Tun, B: Bind> {
peer: router::Peer<Events, T, B>,
router: router::Peer<Events, T, B>,
timers: Timers,
rx: AtomicU64,
tx: AtomicU64,
}
pub struct Timers {}
@@ -40,39 +45,137 @@ impl router::Callbacks for Events {
fn need_key(t: &Timers) {}
}
struct Handshake {
device: handshake::Device,
active: bool,
}
struct WireguardInner<T: Tun, B: Bind> {
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptkey routing
router: router::Device<Events, T, B>,
// handshake related state
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
// IO
bind: B,
}
pub struct Wireguard<T: Tun, B: Bind> {
router: Arc<router::Device<Events, T, B>>,
handshake: Option<Arc<handshake::Device<()>>>,
state: Arc<WireguardInner<T, B>>,
}
impl<T: Tun, B: Bind> Wireguard<T, B> {
fn start(&self) {}
fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write();
match sk {
None => {
let mut rng = OsRng::new().unwrap();
handshake.device.set_sk(StaticSecret::new(&mut rng));
handshake.active = false;
}
Some(sk) => {
handshake.device.set_sk(sk);
handshake.active = true;
}
}
}
fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> {
let router = Arc::new(router::Device::new(
num_cpus::get(),
tun.clone(),
bind.clone(),
));
fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
peers: RwLock::new(HashMap::new()),
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
bind: bind.clone(),
queue: Mutex::new(tx),
});
let handshake_staged = Arc::new(AtomicUsize::new(0));
let handshake_device: Arc<handshake::Device<Peer<T, B>>> =
Arc::new(handshake::Device::new(sk));
// start handshake workers
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
let bind = bind.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
for (msg, src) in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
let state = wg.handshake.read();
if !state.active {
continue;
}
// process message
match state.device.process(
&mut rng,
&msg[..],
if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
) {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
});
}
// update timers
if let Some(pk) = pk {
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
for id in peer.0.router.add_keypair(keypair) {
state.device.release(id);
}
}
}
}
}
Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
});
}
// start UDP read IO thread
let (handshake_tx, handshake_rx) = bounded(128);
{
let wg = wg.clone();
let tun = tun.clone();
let bind = bind.clone();
thread::spawn(move || {
let mut under_load =
let mut last_under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
// read UDP packet into vector
let size = tun.mtu() + 148; // maximum message size
let mut msg: Vec<u8> =
Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
@@ -86,20 +189,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => {
// detect if under load
if handshake_staged.fetch_add(1, Ordering::SeqCst)
> THRESHOLD_UNDER_LOAD
{
under_load = Instant::now()
// update under_load flag
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
wg.under_load.store(false, Ordering::SeqCst);
}
// pass source address along if under load
handshake_tx
.send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD))
.unwrap();
wg.queue.lock().send((msg, src)).unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
// pad the message
let _ = wg.router.recv(src, msg);
}
_ => (),
}
@@ -107,62 +212,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
// start handshake workers
for _ in 0..num_cpus::get() {
let bind = bind.clone();
let handshake_rx = handshake_rx.clone();
let handshake_device = handshake_device.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// start TUN read IO thread
{
let wg = wg.clone();
thread::spawn(move || loop {
// read a new IP packet
let mtu = tun.mtu();
let size = mtu + 148;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
msg.truncate(size);
// process elements from the handshake queue
for (msg, src, under_load) in handshake_rx {
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
match handshake_device.process(
&mut rng,
&msg[..],
if under_load {
Some(&src_validate)
} else {
None
},
) {
Ok((identity, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
});
}
// update timers
if let Some(identity) = identity {
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
for id in identity.0.peer.add_keypair(keypair) {
handshake_device.release(id);
}
}
}
}
Err(e) => debug!("handshake worker, error = {:?}", e),
}
// pad message to multiple of 16
while msg.len() < mtu && msg.len() % 16 != 0 {
msg.push(0);
}
// crypt-key route
let _ = wg.router.send(msg);
});
}
// start TUN read IO thread
thread::spawn(move || {});
Wireguard {
router,
handshake: None,
}
Wireguard { state: wg }
}
}