Work on Linux platform code
This commit is contained in:
186
src/wireguard/config.rs
Normal file
186
src/wireguard/config.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use super::wireguard::Wireguard;
|
||||
use super::types::bind::Bind;
|
||||
use super::types::tun::Tun;
|
||||
|
||||
/// The goal of the configuration interface is, among others,
|
||||
/// to hide the IO implementations (over which the WG device is generic),
|
||||
/// from the configuration and UAPI code.
|
||||
|
||||
/// Describes a snapshot of the state of a peer
|
||||
pub struct PeerState {
|
||||
rx_bytes: u64,
|
||||
tx_bytes: u64,
|
||||
last_handshake_time_sec: u64,
|
||||
last_handshake_time_nsec: u64,
|
||||
public_key: PublicKey,
|
||||
allowed_ips: Vec<(IpAddr, u32)>,
|
||||
}
|
||||
|
||||
pub enum ConfigError {
|
||||
NoSuchPeer
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
|
||||
fn errno(&self) -> i32 {
|
||||
match self {
|
||||
NoSuchPeer => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exposed configuration interface
|
||||
pub trait Configuration {
|
||||
/// Updates the private key of the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `sk`: The new private key (or None, if the private key should be cleared)
|
||||
fn set_private_key(&self, sk: Option<StaticSecret>);
|
||||
|
||||
/// Returns the private key of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The private if set, otherwise None.
|
||||
fn get_private_key(&self) -> Option<StaticSecret>;
|
||||
|
||||
/// Returns the protocol version of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer indicating the protocol version
|
||||
fn get_protocol_version(&self) -> usize;
|
||||
|
||||
fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
|
||||
|
||||
/// Set the firewall mark (or similar, depending on platform)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `mark`: The fwmark value
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if this operation is not supported by the underlying
|
||||
/// "bind" implementation.
|
||||
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
|
||||
|
||||
/// Removes all peers from the device
|
||||
fn replace_peers(&self);
|
||||
|
||||
/// Remove the peer from the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer to remove
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// If the peer does not exists this operation is a noop
|
||||
fn remove_peer(&self, peer: PublicKey);
|
||||
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer to add
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A bool indicating if the peer was added.
|
||||
///
|
||||
/// If the peer already exists this operation is a noop
|
||||
fn add_peer(&self, peer: PublicKey) -> bool;
|
||||
|
||||
/// Update the psk of a peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer
|
||||
/// - `psk`: The new psk or None if the psk should be unset
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if no such peer exists
|
||||
fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
|
||||
|
||||
/// Update the endpoint of the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
/// - `psk`
|
||||
fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
|
||||
|
||||
/// Update the endpoint of the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
/// - `psk`
|
||||
fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
|
||||
|
||||
/// Remove all allowed IPs from the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if no such peer exists
|
||||
fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
|
||||
|
||||
/// Add a new allowed subnet to the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer
|
||||
/// - `ip`: Subnet mask
|
||||
/// - `masklen`:
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if the peer does not exist
|
||||
///
|
||||
/// # Note:
|
||||
///
|
||||
/// The API must itself sanitize the (ip, masklen) set:
|
||||
/// The ip should be masked to remove any set bits right of the first "masklen" bits.
|
||||
fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
|
||||
|
||||
/// Returns the state of all peers
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of structures describing the state of each peer
|
||||
fn get_peers(&self) -> Vec<PeerState>;
|
||||
}
|
||||
|
||||
impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
|
||||
|
||||
fn set_private_key(&self, sk : Option<StaticSecret>) {
|
||||
self.set_key(sk)
|
||||
}
|
||||
|
||||
fn get_private_key(&self) -> Option<StaticSecret> {
|
||||
self.get_sk()
|
||||
}
|
||||
|
||||
fn get_protocol_version(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
|
||||
None
|
||||
}
|
||||
|
||||
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> {
|
||||
None
|
||||
}
|
||||
|
||||
}
|
||||
20
src/wireguard/constants.rs
Normal file
20
src/wireguard/constants.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use std::time::Duration;
|
||||
use std::u64;
|
||||
|
||||
pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16);
|
||||
pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
|
||||
|
||||
pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
|
||||
pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
|
||||
pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
|
||||
pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
pub const MAX_TIMER_HANDSHAKES: usize = 18;
|
||||
|
||||
pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
|
||||
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
|
||||
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
|
||||
pub const TIMERS_CAPACITY: usize = 1024;
|
||||
|
||||
pub const MESSAGE_PADDING_MULTIPLE: usize = 16;
|
||||
574
src/wireguard/handshake/device.rs
Normal file
574
src/wireguard/handshake/device.rs
Normal file
@@ -0,0 +1,574 @@
|
||||
use spin::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Mutex;
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use rand::prelude::*;
|
||||
|
||||
use x25519_dalek::PublicKey;
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
use super::macs;
|
||||
use super::messages::{CookieReply, Initiation, Response};
|
||||
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
||||
use super::noise;
|
||||
use super::peer::Peer;
|
||||
use super::ratelimiter::RateLimiter;
|
||||
use super::types::*;
|
||||
|
||||
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
||||
|
||||
pub struct Device {
|
||||
pub sk: StaticSecret, // static secret key
|
||||
pub pk: PublicKey, // static public key
|
||||
macs: macs::Validator, // validator for the mac fields
|
||||
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
|
||||
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
||||
limiter: Mutex<RateLimiter>,
|
||||
}
|
||||
|
||||
/* A mutable reference to the device needs to be held during configuration.
|
||||
* Wrapping the device in a RwLock enables peer config after "configuration time"
|
||||
*/
|
||||
impl Device {
|
||||
/// Initialize a new handshake state machine
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sk` - x25519 scalar representing the local private key
|
||||
pub fn new(sk: StaticSecret) -> Device {
|
||||
let pk = PublicKey::from(&sk);
|
||||
Device {
|
||||
pk,
|
||||
sk,
|
||||
macs: macs::Validator::new(pk),
|
||||
pk_map: HashMap::new(),
|
||||
id_map: RwLock::new(HashMap::new()),
|
||||
limiter: Mutex::new(RateLimiter::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the secret key of the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sk` - x25519 scalar representing the local private key
|
||||
pub fn set_sk(&mut self, sk: StaticSecret) {
|
||||
// update secret and public key
|
||||
let pk = PublicKey::from(&sk);
|
||||
self.sk = sk;
|
||||
self.pk = pk;
|
||||
self.macs = macs::Validator::new(pk);
|
||||
|
||||
// recalculate the shared secrets for every peer
|
||||
let mut ids = vec![];
|
||||
for mut peer in self.pk_map.values_mut() {
|
||||
peer.reset_state().map(|id| ids.push(id));
|
||||
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
||||
}
|
||||
|
||||
// release ids from aborted handshakes
|
||||
for id in ids {
|
||||
self.release(id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the secret key of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A secret key (x25519 scalar)
|
||||
pub fn get_sk(&self) -> StaticSecret {
|
||||
StaticSecret::from(self.sk.to_bytes())
|
||||
}
|
||||
|
||||
/// Add a new public key to the state machine
|
||||
/// To remove public keys, you must create a new machine instance
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pk` - The public key to add
|
||||
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
||||
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
|
||||
// check that the pk is not added twice
|
||||
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
||||
return Err(ConfigError::new("Duplicate public key"));
|
||||
};
|
||||
|
||||
// check that the pk is not that of the device
|
||||
if *self.pk.as_bytes() == *pk.as_bytes() {
|
||||
return Err(ConfigError::new(
|
||||
"Public key corresponds to secret key of interface",
|
||||
));
|
||||
}
|
||||
|
||||
// ensure less than 2^20 peers
|
||||
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
|
||||
return Err(ConfigError::new("Too many peers for device"));
|
||||
}
|
||||
|
||||
// map the public key to the peer state
|
||||
self.pk_map
|
||||
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a peer by public key
|
||||
/// To remove public keys, you must create a new machine instance
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pk` - The public key of the peer to remove
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The call might fail if the public key is not found
|
||||
pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
|
||||
// take write-lock on receive id table
|
||||
let mut id_map = self.id_map.write();
|
||||
|
||||
// remove the peer
|
||||
self.pk_map
|
||||
.remove(pk.as_bytes())
|
||||
.ok_or(ConfigError::new("Public key not in device"))?;
|
||||
|
||||
// pruge the id map (linear scan)
|
||||
id_map.retain(|_, v| v != pk.as_bytes());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a psk to the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pk` - The public key of the peer
|
||||
/// * `psk` - The psk to set / unset
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The call might fail if the public key is not found
|
||||
pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
|
||||
match self.pk_map.get_mut(pk.as_bytes()) {
|
||||
Some(mut peer) => {
|
||||
peer.psk = match psk {
|
||||
Some(v) => v,
|
||||
None => [0u8; 32],
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ConfigError::new("No such public key")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the psk for the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pk` - The public key of the peer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A 32 byte array holding the PSK
|
||||
///
|
||||
/// The call might fail if the public key is not found
|
||||
pub fn get_psk(&self, pk: PublicKey) -> Result<Psk, ConfigError> {
|
||||
match self.pk_map.get(pk.as_bytes()) {
|
||||
Some(peer) => Ok(peer.psk),
|
||||
_ => Err(ConfigError::new("No such public key")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Release an id back to the pool
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` - The (sender) id to release
|
||||
pub fn release(&self, id: u32) {
|
||||
let mut m = self.id_map.write();
|
||||
debug_assert!(m.contains_key(&id), "Releasing id not allocated");
|
||||
m.remove(&id);
|
||||
}
|
||||
|
||||
/// Begin a new handshake
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pk` - Public key of peer to initiate handshake for
|
||||
pub fn begin<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
rng: &mut R,
|
||||
pk: &PublicKey,
|
||||
) -> Result<Vec<u8>, HandshakeError> {
|
||||
match self.pk_map.get(pk.as_bytes()) {
|
||||
None => Err(HandshakeError::UnknownPublicKey),
|
||||
Some(peer) => {
|
||||
let sender = self.allocate(rng, peer);
|
||||
|
||||
let mut msg = Initiation::default();
|
||||
|
||||
noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
|
||||
|
||||
// add macs to initation
|
||||
|
||||
peer.macs
|
||||
.lock()
|
||||
.generate(msg.noise.as_bytes(), &mut msg.macs);
|
||||
|
||||
Ok(msg.as_bytes().to_owned())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a handshake message.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `msg` - Byte slice containing the message (untrusted input)
|
||||
pub fn process<'a, R: RngCore + CryptoRng, S>(
|
||||
&self,
|
||||
rng: &mut R, // rng instance to sample randomness from
|
||||
msg: &[u8], // message buffer
|
||||
src: Option<&'a S>, // optional source endpoint, set when "under load"
|
||||
) -> Result<Output, HandshakeError>
|
||||
where
|
||||
&'a S: Into<&'a SocketAddr>,
|
||||
{
|
||||
// ensure type read in-range
|
||||
if msg.len() < 4 {
|
||||
return Err(HandshakeError::InvalidMessageFormat);
|
||||
}
|
||||
|
||||
// de-multiplex the message type field
|
||||
match LittleEndian::read_u32(msg) {
|
||||
TYPE_INITIATION => {
|
||||
// parse message
|
||||
let msg = Initiation::parse(msg)?;
|
||||
|
||||
// check mac1 field
|
||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||
|
||||
// address validation & DoS mitigation
|
||||
if let Some(src) = src {
|
||||
// obtain ref to socket addr
|
||||
let src = src.into();
|
||||
|
||||
// check mac2 field
|
||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||
let mut reply = Default::default();
|
||||
self.macs.create_cookie_reply(
|
||||
rng,
|
||||
msg.noise.f_sender.get(),
|
||||
src,
|
||||
&msg.macs,
|
||||
&mut reply,
|
||||
);
|
||||
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
||||
}
|
||||
|
||||
// check ratelimiter
|
||||
if !self.limiter.lock().unwrap().allow(&src.ip()) {
|
||||
return Err(HandshakeError::RateLimited);
|
||||
}
|
||||
}
|
||||
|
||||
// consume the initiation
|
||||
let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
|
||||
|
||||
// allocate new index for response
|
||||
let sender = self.allocate(rng, peer);
|
||||
|
||||
// prepare memory for response, TODO: take slice for zero allocation
|
||||
let mut resp = Response::default();
|
||||
|
||||
// create response (release id on error)
|
||||
let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err(
|
||||
|e| {
|
||||
self.release(sender);
|
||||
e
|
||||
},
|
||||
)?;
|
||||
|
||||
// add macs to response
|
||||
peer.macs
|
||||
.lock()
|
||||
.generate(resp.noise.as_bytes(), &mut resp.macs);
|
||||
|
||||
// return unconfirmed keypair and the response as vector
|
||||
Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
|
||||
}
|
||||
TYPE_RESPONSE => {
|
||||
let msg = Response::parse(msg)?;
|
||||
|
||||
// check mac1 field
|
||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||
|
||||
// address validation & DoS mitigation
|
||||
if let Some(src) = src {
|
||||
// obtain ref to socket addr
|
||||
let src = src.into();
|
||||
|
||||
// check mac2 field
|
||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||
let mut reply = Default::default();
|
||||
self.macs.create_cookie_reply(
|
||||
rng,
|
||||
msg.noise.f_sender.get(),
|
||||
src,
|
||||
&msg.macs,
|
||||
&mut reply,
|
||||
);
|
||||
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
||||
}
|
||||
|
||||
// check ratelimiter
|
||||
if !self.limiter.lock().unwrap().allow(&src.ip()) {
|
||||
return Err(HandshakeError::RateLimited);
|
||||
}
|
||||
}
|
||||
|
||||
// consume inner playload
|
||||
noise::consume_response(self, &msg.noise)
|
||||
}
|
||||
TYPE_COOKIE_REPLY => {
|
||||
let msg = CookieReply::parse(msg)?;
|
||||
|
||||
// lookup peer
|
||||
let peer = self.lookup_id(msg.f_receiver.get())?;
|
||||
|
||||
// validate cookie reply
|
||||
peer.macs.lock().process(&msg)?;
|
||||
|
||||
// this prompts no new message and
|
||||
// DOES NOT cryptographically verify the peer
|
||||
Ok((None, None, None))
|
||||
}
|
||||
_ => Err(HandshakeError::InvalidMessageFormat),
|
||||
}
|
||||
}
|
||||
|
||||
// Internal function
|
||||
//
|
||||
// Return the peer associated with the public key
|
||||
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
|
||||
self.pk_map
|
||||
.get(pk.as_bytes())
|
||||
.ok_or(HandshakeError::UnknownPublicKey)
|
||||
}
|
||||
|
||||
// Internal function
|
||||
//
|
||||
// Return the peer currently associated with the receiver identifier
|
||||
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
|
||||
let im = self.id_map.read();
|
||||
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
|
||||
match self.pk_map.get(pk) {
|
||||
Some(peer) => Ok(peer),
|
||||
_ => unreachable!(), // if the id-lookup succeeded, the peer should exist
|
||||
}
|
||||
}
|
||||
|
||||
// Internal function
|
||||
//
|
||||
// Allocated a new receiver identifier for the peer
|
||||
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
|
||||
loop {
|
||||
let id = rng.gen();
|
||||
|
||||
// check membership with read lock
|
||||
if self.id_map.read().contains_key(&id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// take write lock and add index
|
||||
let mut m = self.id_map.write();
|
||||
if !m.contains_key(&id) {
|
||||
m.insert(id, *peer.pk.as_bytes());
|
||||
return id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::messages::*;
|
||||
use super::*;
|
||||
use hex;
|
||||
use rand::rngs::OsRng;
|
||||
use std::net::SocketAddr;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
fn setup_devices<R: RngCore + CryptoRng>(
|
||||
rng: &mut R,
|
||||
) -> (PublicKey, Device, PublicKey, Device) {
|
||||
// generate new keypairs
|
||||
|
||||
let sk1 = StaticSecret::new(rng);
|
||||
let pk1 = PublicKey::from(&sk1);
|
||||
|
||||
let sk2 = StaticSecret::new(rng);
|
||||
let pk2 = PublicKey::from(&sk2);
|
||||
|
||||
// pick random psk
|
||||
|
||||
let mut psk = [0u8; 32];
|
||||
rng.fill_bytes(&mut psk[..]);
|
||||
|
||||
// intialize devices on both ends
|
||||
|
||||
let mut dev1 = Device::new(sk1);
|
||||
let mut dev2 = Device::new(sk2);
|
||||
|
||||
dev1.add(pk2).unwrap();
|
||||
dev2.add(pk1).unwrap();
|
||||
|
||||
dev1.set_psk(pk2, Some(psk)).unwrap();
|
||||
dev2.set_psk(pk1, Some(psk)).unwrap();
|
||||
|
||||
(pk1, dev1, pk2, dev2)
|
||||
}
|
||||
|
||||
/* Test longest possible handshake interaction (7 messages):
|
||||
*
|
||||
* 1. I -> R (initation)
|
||||
* 2. I <- R (cookie reply)
|
||||
* 3. I -> R (initation)
|
||||
* 4. I <- R (response)
|
||||
* 5. I -> R (cookie reply)
|
||||
* 6. I -> R (initation)
|
||||
* 7. I <- R (response)
|
||||
*/
|
||||
#[test]
|
||||
fn handshake_under_load() {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
|
||||
|
||||
let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
|
||||
let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
|
||||
|
||||
// 1. device-1 : create first initation
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 2. device-2 : responds with CookieReply
|
||||
let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
(None, Some(msg), None) => msg,
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// device-1 : processes CookieReply (no response)
|
||||
match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
|
||||
(None, None, None) => (),
|
||||
_ => panic!("unexpected response"),
|
||||
}
|
||||
|
||||
// avoid initation flood
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
|
||||
// 3. device-1 : create second initation
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 4. device-2 : responds with noise response
|
||||
let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
(Some(_), Some(msg), Some(kp)) => {
|
||||
assert_eq!(kp.initiator, false);
|
||||
msg
|
||||
}
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// 5. device-1 : responds with CookieReply
|
||||
let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||
(None, Some(msg), None) => msg,
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// device-2 : processes CookieReply (no response)
|
||||
match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
|
||||
(None, None, None) => (),
|
||||
_ => panic!("unexpected response"),
|
||||
}
|
||||
|
||||
// avoid initation flood
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
|
||||
// 6. device-1 : create third initation
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 7. device-2 : responds with noise response
|
||||
let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
(Some(_), Some(msg), Some(kp)) => {
|
||||
assert_eq!(kp.initiator, false);
|
||||
(msg, kp)
|
||||
}
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// device-1 : process noise response
|
||||
let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||
(Some(_), None, Some(kp)) => {
|
||||
assert_eq!(kp.initiator, true);
|
||||
kp
|
||||
}
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
assert_eq!(kp1.send, kp2.recv);
|
||||
assert_eq!(kp1.recv, kp2.send);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_no_load() {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
|
||||
|
||||
// do a few handshakes (every handshake should succeed)
|
||||
|
||||
for i in 0..10 {
|
||||
println!("handshake : {}", i);
|
||||
|
||||
// create initiation
|
||||
|
||||
let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
|
||||
println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
|
||||
|
||||
// process initiation and create response
|
||||
|
||||
let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap();
|
||||
|
||||
let ks_r = ks_r.unwrap();
|
||||
let msg2 = msg2.unwrap();
|
||||
|
||||
println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
|
||||
println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
|
||||
|
||||
assert!(!ks_r.initiator, "Responders key-pair is confirmed");
|
||||
|
||||
// process response and obtain confirmed key-pair
|
||||
|
||||
let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap();
|
||||
let ks_i = ks_i.unwrap();
|
||||
|
||||
assert!(msg3.is_none(), "Returned message after response");
|
||||
assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
|
||||
|
||||
assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
|
||||
assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
|
||||
|
||||
dev1.release(ks_i.send.id);
|
||||
dev2.release(ks_r.send.id);
|
||||
|
||||
// to avoid flood detection
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
}
|
||||
|
||||
dev1.remove(pk2).unwrap();
|
||||
dev2.remove(pk1).unwrap();
|
||||
}
|
||||
}
|
||||
327
src/wireguard/handshake/macs.rs
Normal file
327
src/wireguard/handshake/macs.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use generic_array::GenericArray;
|
||||
use rand::{CryptoRng, RngCore};
|
||||
use spin::RwLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// types to coalesce into bytes
|
||||
use std::net::SocketAddr;
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
// AEAD
|
||||
use aead::{Aead, NewAead, Payload};
|
||||
use chacha20poly1305::XChaCha20Poly1305;
|
||||
|
||||
// MAC
|
||||
use blake2::Blake2s;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY};
|
||||
use super::types::HandshakeError;
|
||||
|
||||
const LABEL_MAC1: &[u8] = b"mac1----";
|
||||
const LABEL_COOKIE: &[u8] = b"cookie--";
|
||||
|
||||
const SIZE_COOKIE: usize = 16;
|
||||
const SIZE_SECRET: usize = 32;
|
||||
const SIZE_MAC: usize = 16; // blake2s-mac128
|
||||
const SIZE_TAG: usize = 16; // xchacha20poly1305 tag
|
||||
|
||||
const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120);
|
||||
|
||||
macro_rules! HASH {
|
||||
( $($input:expr),* ) => {{
|
||||
use blake2::Digest;
|
||||
let mut hsh = Blake2s::new();
|
||||
$(
|
||||
hsh.input($input);
|
||||
)*
|
||||
hsh.result()
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! MAC {
|
||||
( $key:expr, $($input:expr),* ) => {{
|
||||
use blake2::VarBlake2s;
|
||||
use digest::Input;
|
||||
use digest::VariableOutput;
|
||||
let mut tag = [0u8; SIZE_MAC];
|
||||
let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC);
|
||||
$(
|
||||
mac.input($input);
|
||||
)*
|
||||
mac.variable_result(|buf| tag.copy_from_slice(buf));
|
||||
tag
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! XSEAL {
|
||||
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
|
||||
let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key))
|
||||
.encrypt(
|
||||
GenericArray::from_slice($nonce),
|
||||
Payload { msg: $pt, aad: $ad },
|
||||
)
|
||||
.unwrap();
|
||||
debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG);
|
||||
$ct.copy_from_slice(&ct);
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! XOPEN {
|
||||
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
|
||||
debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG);
|
||||
XChaCha20Poly1305::new(*GenericArray::from_slice($key))
|
||||
.decrypt(
|
||||
GenericArray::from_slice($nonce),
|
||||
Payload { msg: $ct, aad: $ad },
|
||||
)
|
||||
.map_err(|_| HandshakeError::DecryptionFailure)
|
||||
.map(|pt| $pt.copy_from_slice(&pt))
|
||||
}};
|
||||
}
|
||||
|
||||
struct Cookie {
|
||||
value: [u8; 16],
|
||||
birth: Instant,
|
||||
}
|
||||
|
||||
pub struct Generator {
|
||||
mac1_key: [u8; 32],
|
||||
cookie_key: [u8; 32], // xchacha20poly key for opening cookie response
|
||||
last_mac1: Option<[u8; 16]>,
|
||||
cookie: Option<Cookie>,
|
||||
}
|
||||
|
||||
fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec<u8> {
|
||||
match addr {
|
||||
SocketAddr::V4(addr) => {
|
||||
let mut res = Vec::with_capacity(4 + 2);
|
||||
res.extend(&addr.ip().octets());
|
||||
res.extend(&addr.port().to_le_bytes());
|
||||
res
|
||||
}
|
||||
SocketAddr::V6(addr) => {
|
||||
let mut res = Vec::with_capacity(16 + 2);
|
||||
res.extend(&addr.ip().octets());
|
||||
res.extend(&addr.port().to_le_bytes());
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Generator {
|
||||
/// Initalize a new mac field generator
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - pk: The public key of the peer to which the generator is associated
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A freshly initated generator
|
||||
pub fn new(pk: PublicKey) -> Generator {
|
||||
Generator {
|
||||
mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
|
||||
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
|
||||
last_mac1: None,
|
||||
cookie: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a CookieReply message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - reply: CookieReply to process
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Can fail if the cookie reply fails to validate
|
||||
/// (either indicating that it is outdated or malformed)
|
||||
pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> {
|
||||
let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?;
|
||||
let mut tau = [0u8; SIZE_COOKIE];
|
||||
XOPEN!(
|
||||
&self.cookie_key, // key
|
||||
&reply.f_nonce, // nonce
|
||||
&mac1, // ad
|
||||
&mut tau, // pt
|
||||
&reply.f_cookie // ct || tag
|
||||
)?;
|
||||
self.cookie = Some(Cookie {
|
||||
birth: Instant::now(),
|
||||
value: tau,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate both mac fields for an inner message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - inner: A byteslice representing the inner message to be covered
|
||||
/// - macs: The destination mac footer for the resulting macs
|
||||
pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) {
|
||||
macs.f_mac1 = MAC!(&self.mac1_key, inner);
|
||||
macs.f_mac2 = match &self.cookie {
|
||||
Some(cookie) => {
|
||||
if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL {
|
||||
self.cookie = None;
|
||||
[0u8; SIZE_MAC]
|
||||
} else {
|
||||
MAC!(&cookie.value, inner, macs.f_mac1)
|
||||
}
|
||||
}
|
||||
None => [0u8; SIZE_MAC],
|
||||
};
|
||||
self.last_mac1 = Some(macs.f_mac1);
|
||||
}
|
||||
}
|
||||
|
||||
struct Secret {
|
||||
value: [u8; 32],
|
||||
birth: Instant,
|
||||
}
|
||||
|
||||
pub struct Validator {
|
||||
mac1_key: [u8; 32], // mac1 key, derived from device public key
|
||||
cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response
|
||||
secret: RwLock<Secret>,
|
||||
}
|
||||
|
||||
impl Validator {
|
||||
pub fn new(pk: PublicKey) -> Validator {
|
||||
Validator {
|
||||
mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
|
||||
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
|
||||
secret: RwLock::new(Secret {
|
||||
value: [0u8; SIZE_SECRET],
|
||||
birth: Instant::now() - Duration::new(86400, 0),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
|
||||
let secret = self.secret.read();
|
||||
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
|
||||
Some(MAC!(&secret.value, src))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_set_tau<R: RngCore + CryptoRng>(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] {
|
||||
// check if current value is still valid
|
||||
{
|
||||
let secret = self.secret.read();
|
||||
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
|
||||
return MAC!(&secret.value, src);
|
||||
};
|
||||
}
|
||||
|
||||
// take write lock, check again
|
||||
{
|
||||
let mut secret = self.secret.write();
|
||||
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
|
||||
return MAC!(&secret.value, src);
|
||||
};
|
||||
|
||||
// set new random cookie secret
|
||||
rng.fill_bytes(&mut secret.value);
|
||||
secret.birth = Instant::now();
|
||||
MAC!(&secret.value, src)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_cookie_reply<R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
rng: &mut R,
|
||||
receiver: u32, // receiver id of incoming message
|
||||
src: &SocketAddr, // source address of incoming message
|
||||
macs: &MacsFooter, // footer of incoming message
|
||||
msg: &mut CookieReply, // resulting cookie reply
|
||||
) {
|
||||
let src = addr_to_mac_bytes(src);
|
||||
msg.f_type.set(TYPE_COOKIE_REPLY as u32);
|
||||
msg.f_receiver.set(receiver);
|
||||
rng.fill_bytes(&mut msg.f_nonce);
|
||||
XSEAL!(
|
||||
&self.cookie_key, // key
|
||||
&msg.f_nonce, // nonce
|
||||
&macs.f_mac1, // ad
|
||||
&self.get_set_tau(rng, &src), // pt
|
||||
&mut msg.f_cookie // ct || tag
|
||||
);
|
||||
}
|
||||
|
||||
/// Check the mac1 field against the inner message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - inner: The inner message covered by the mac1 field
|
||||
/// - macs: The mac footer
|
||||
pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> {
|
||||
let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into();
|
||||
if !valid_mac1 {
|
||||
Err(HandshakeError::InvalidMac1)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool {
|
||||
let src = addr_to_mac_bytes(src);
|
||||
match self.get_tau(&src) {
|
||||
Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
use rand::rngs::OsRng;
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
fn new_validator_generator() -> (Validator, Generator) {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let sk = StaticSecret::new(&mut rng);
|
||||
let pk = PublicKey::from(&sk);
|
||||
(Validator::new(pk), Generator::new(pk))
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) {
|
||||
let mut msg = CookieReply::default();
|
||||
let mut rng = OsRng::new().expect("failed to create rng");
|
||||
let mut macs = MacsFooter::default();
|
||||
let src = "192.0.2.16:8080".parse().unwrap();
|
||||
let (validator, mut generator) = new_validator_generator();
|
||||
|
||||
// generate mac1 for first message
|
||||
generator.generate(&inner1[..], &mut macs);
|
||||
assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
|
||||
assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set");
|
||||
|
||||
// check validity of mac1
|
||||
validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate");
|
||||
assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate");
|
||||
validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg);
|
||||
|
||||
// consume cookie reply
|
||||
generator.process(&msg).expect("failed to process CookieReply");
|
||||
|
||||
// generate mac2 & mac2 for second message
|
||||
generator.generate(&inner2[..], &mut macs);
|
||||
assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
|
||||
assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set");
|
||||
|
||||
// check validity of mac1 and mac2
|
||||
validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate");
|
||||
assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate");
|
||||
}
|
||||
}
|
||||
}
|
||||
363
src/wireguard/handshake/messages.rs
Normal file
363
src/wireguard/handshake/messages.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
#[cfg(test)]
|
||||
use hex;
|
||||
|
||||
#[cfg(test)]
|
||||
use std::fmt;
|
||||
|
||||
use std::mem;
|
||||
|
||||
use byteorder::LittleEndian;
|
||||
use zerocopy::byteorder::U32;
|
||||
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
|
||||
|
||||
use super::types::*;
|
||||
|
||||
const SIZE_MAC: usize = 16;
|
||||
const SIZE_TAG: usize = 16; // poly1305 tag
|
||||
const SIZE_XNONCE: usize = 24; // xchacha20 nonce
|
||||
const SIZE_COOKIE: usize = 16; //
|
||||
const SIZE_X25519_POINT: usize = 32; // x25519 public key
|
||||
const SIZE_TIMESTAMP: usize = 12;
|
||||
|
||||
pub const TYPE_INITIATION: u32 = 1;
|
||||
pub const TYPE_RESPONSE: u32 = 2;
|
||||
pub const TYPE_COOKIE_REPLY: u32 = 3;
|
||||
|
||||
const fn max(a: usize, b: usize) -> usize {
|
||||
let m: usize = (a > b) as usize;
|
||||
m * a + (1 - m) * b
|
||||
}
|
||||
|
||||
pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
|
||||
max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
|
||||
mem::size_of::<CookieReply>(),
|
||||
);
|
||||
|
||||
/* Handshake messsages */
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct Response {
|
||||
pub noise: NoiseResponse, // inner message covered by macs
|
||||
pub macs: MacsFooter,
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct Initiation {
|
||||
pub noise: NoiseInitiation, // inner message covered by macs
|
||||
pub macs: MacsFooter,
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct CookieReply {
|
||||
pub f_type: U32<LittleEndian>,
|
||||
pub f_receiver: U32<LittleEndian>,
|
||||
pub f_nonce: [u8; SIZE_XNONCE],
|
||||
pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG],
|
||||
}
|
||||
|
||||
/* Inner sub-messages */
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct MacsFooter {
|
||||
pub f_mac1: [u8; SIZE_MAC],
|
||||
pub f_mac2: [u8; SIZE_MAC],
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct NoiseInitiation {
|
||||
pub f_type: U32<LittleEndian>,
|
||||
pub f_sender: U32<LittleEndian>,
|
||||
pub f_ephemeral: [u8; SIZE_X25519_POINT],
|
||||
pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG],
|
||||
pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG],
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct NoiseResponse {
|
||||
pub f_type: U32<LittleEndian>,
|
||||
pub f_sender: U32<LittleEndian>,
|
||||
pub f_receiver: U32<LittleEndian>,
|
||||
pub f_ephemeral: [u8; SIZE_X25519_POINT],
|
||||
pub f_empty: [u8; SIZE_TAG],
|
||||
}
|
||||
|
||||
/* Zero copy parsing of handshake messages */
|
||||
|
||||
impl Initiation {
|
||||
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
|
||||
let msg: LayoutVerified<B, Self> =
|
||||
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
|
||||
|
||||
if msg.noise.f_type.get() != (TYPE_INITIATION as u32) {
|
||||
return Err(HandshakeError::InvalidMessageFormat);
|
||||
}
|
||||
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
|
||||
let msg: LayoutVerified<B, Self> =
|
||||
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
|
||||
|
||||
if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) {
|
||||
return Err(HandshakeError::InvalidMessageFormat);
|
||||
}
|
||||
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl CookieReply {
|
||||
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
|
||||
let msg: LayoutVerified<B, Self> =
|
||||
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
|
||||
|
||||
if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) {
|
||||
return Err(HandshakeError::InvalidMessageFormat);
|
||||
}
|
||||
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
||||
|
||||
/* Default values */
|
||||
|
||||
impl Default for Response {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
noise: Default::default(),
|
||||
macs: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Initiation {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
noise: Default::default(),
|
||||
macs: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CookieReply {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
f_type: <U32<LittleEndian>>::new(TYPE_COOKIE_REPLY as u32),
|
||||
f_receiver: <U32<LittleEndian>>::ZERO,
|
||||
f_nonce: [0u8; SIZE_XNONCE],
|
||||
f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MacsFooter {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
f_mac1: [0u8; SIZE_MAC],
|
||||
f_mac2: [0u8; SIZE_MAC],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoiseInitiation {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
f_type: <U32<LittleEndian>>::new(TYPE_INITIATION as u32),
|
||||
f_sender: <U32<LittleEndian>>::ZERO,
|
||||
f_ephemeral: [0u8; SIZE_X25519_POINT],
|
||||
f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG],
|
||||
f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoiseResponse {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
f_type: <U32<LittleEndian>>::new(TYPE_RESPONSE as u32),
|
||||
f_sender: <U32<LittleEndian>>::ZERO,
|
||||
f_receiver: <U32<LittleEndian>>::ZERO,
|
||||
f_ephemeral: [0u8; SIZE_X25519_POINT],
|
||||
f_empty: [0u8; SIZE_TAG],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Debug formatting (for testing purposes) */
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for Initiation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for Response {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for CookieReply {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}",
|
||||
self.f_type,
|
||||
self.f_receiver,
|
||||
hex::encode(&self.f_nonce[..]),
|
||||
hex::encode(&self.f_cookie[..]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for NoiseInitiation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f,
|
||||
"NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}",
|
||||
self.f_type.get(),
|
||||
self.f_sender.get(),
|
||||
hex::encode(&self.f_ephemeral[..]),
|
||||
hex::encode(&self.f_static[..]),
|
||||
hex::encode(&self.f_timestamp[..]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for NoiseResponse {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f,
|
||||
"NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}",
|
||||
self.f_type,
|
||||
self.f_sender,
|
||||
self.f_receiver,
|
||||
hex::encode(&self.f_ephemeral[..]),
|
||||
hex::encode(&self.f_empty[..])
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl fmt::Debug for MacsFooter {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Macs {{ mac1 = {}, mac2 = {} }}",
|
||||
hex::encode(&self.f_mac1[..]),
|
||||
hex::encode(&self.f_mac2[..])
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/* Equality (for testing purposes) */
|
||||
|
||||
#[cfg(test)]
|
||||
macro_rules! eq_as_bytes {
|
||||
($type:path) => {
|
||||
impl PartialEq for $type {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.as_bytes() == other.as_bytes()
|
||||
}
|
||||
}
|
||||
impl Eq for $type {}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(Initiation);
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(Response);
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(CookieReply);
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(MacsFooter);
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(NoiseInitiation);
|
||||
|
||||
#[cfg(test)]
|
||||
eq_as_bytes!(NoiseResponse);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn message_response_identity() {
|
||||
let mut msg: Response = Default::default();
|
||||
|
||||
msg.noise.f_sender.set(146252);
|
||||
msg.noise.f_receiver.set(554442);
|
||||
msg.noise.f_ephemeral = [
|
||||
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
|
||||
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
|
||||
0x13, 0x56, 0x52, 0x1f,
|
||||
];
|
||||
msg.noise.f_empty = [
|
||||
0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05,
|
||||
0x2f, 0xde,
|
||||
];
|
||||
msg.macs.f_mac1 = [
|
||||
0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
|
||||
0x69, 0x29,
|
||||
];
|
||||
msg.macs.f_mac2 = [
|
||||
0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
|
||||
0xdf, 0xbb,
|
||||
];
|
||||
|
||||
let buf: Vec<u8> = msg.as_bytes().to_vec();
|
||||
let msg_p = Response::parse(&buf[..]).unwrap();
|
||||
assert_eq!(msg, *msg_p.into_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_initiate_identity() {
|
||||
let mut msg: Initiation = Default::default();
|
||||
|
||||
msg.noise.f_sender.set(575757);
|
||||
msg.noise.f_ephemeral = [
|
||||
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
|
||||
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
|
||||
0x13, 0x56, 0x52, 0x1f,
|
||||
];
|
||||
msg.noise.f_static = [
|
||||
0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c,
|
||||
0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3,
|
||||
0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d,
|
||||
0x86, 0x37, 0xee, 0x83, 0x6b, 0x42,
|
||||
];
|
||||
msg.noise.f_timestamp = [
|
||||
0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e,
|
||||
0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde,
|
||||
];
|
||||
msg.macs.f_mac1 = [
|
||||
0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
|
||||
0x69, 0x29,
|
||||
];
|
||||
msg.macs.f_mac2 = [
|
||||
0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
|
||||
0xdf, 0xbb,
|
||||
];
|
||||
|
||||
let buf: Vec<u8> = msg.as_bytes().to_vec();
|
||||
let msg_p = Initiation::parse(&buf[..]).unwrap();
|
||||
assert_eq!(msg, *msg_p.into_ref());
|
||||
}
|
||||
}
|
||||
21
src/wireguard/handshake/mod.rs
Normal file
21
src/wireguard/handshake/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
/* Implementation of the:
|
||||
*
|
||||
* Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s
|
||||
*
|
||||
* Protocol pattern, see: http://www.noiseprotocol.org/noise.html.
|
||||
* For documentation.
|
||||
*/
|
||||
|
||||
mod device;
|
||||
mod macs;
|
||||
mod messages;
|
||||
mod noise;
|
||||
mod peer;
|
||||
mod ratelimiter;
|
||||
mod timestamp;
|
||||
mod types;
|
||||
|
||||
// publicly exposed interface
|
||||
|
||||
pub use device::Device;
|
||||
pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
||||
549
src/wireguard/handshake/noise.rs
Normal file
549
src/wireguard/handshake/noise.rs
Normal file
@@ -0,0 +1,549 @@
|
||||
// DH
|
||||
use x25519_dalek::PublicKey;
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
// HASH & MAC
|
||||
use blake2::Blake2s;
|
||||
use hmac::Hmac;
|
||||
|
||||
// AEAD
|
||||
use aead::{Aead, NewAead, Payload};
|
||||
use chacha20poly1305::ChaCha20Poly1305;
|
||||
|
||||
use rand::{CryptoRng, RngCore};
|
||||
|
||||
use generic_array::typenum::*;
|
||||
use generic_array::*;
|
||||
|
||||
use clear_on_drop::clear::Clear;
|
||||
use clear_on_drop::clear_stack_on_return;
|
||||
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use super::device::Device;
|
||||
use super::messages::{NoiseInitiation, NoiseResponse};
|
||||
use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
|
||||
use super::peer::{Peer, State};
|
||||
use super::timestamp;
|
||||
use super::types::*;
|
||||
|
||||
use super::super::types::{KeyPair, Key};
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
// HMAC hasher (generic construction)
|
||||
|
||||
type HMACBlake2s = Hmac<Blake2s>;
|
||||
|
||||
// convenient alias to pass state temporarily into device.rs and back
|
||||
|
||||
type TemporaryState = (u32, PublicKey, GenericArray<u8, U32>, GenericArray<u8, U32>);
|
||||
|
||||
const SIZE_CK: usize = 32;
|
||||
const SIZE_HS: usize = 32;
|
||||
const SIZE_NONCE: usize = 8;
|
||||
const SIZE_TAG: usize = 16;
|
||||
|
||||
// number of pages to clear after sensitive call
|
||||
const CLEAR_PAGES: usize = 1;
|
||||
|
||||
// C := Hash(Construction)
|
||||
const INITIAL_CK: [u8; SIZE_CK] = [
|
||||
0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0,
|
||||
0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36,
|
||||
];
|
||||
|
||||
// H := Hash(C || Identifier)
|
||||
const INITIAL_HS: [u8; SIZE_HS] = [
|
||||
0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32,
|
||||
0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3,
|
||||
];
|
||||
|
||||
const ZERO_NONCE: [u8; 12] = [0u8; 12];
|
||||
|
||||
macro_rules! HASH {
|
||||
( $($input:expr),* ) => {{
|
||||
use blake2::Digest;
|
||||
let mut hsh = Blake2s::new();
|
||||
$(
|
||||
hsh.input($input);
|
||||
)*
|
||||
hsh.result()
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! HMAC {
|
||||
($key:expr, $($input:expr),*) => {{
|
||||
use hmac::Mac;
|
||||
let mut mac = HMACBlake2s::new_varkey($key).unwrap();
|
||||
$(
|
||||
mac.input($input);
|
||||
)*
|
||||
mac.result().code()
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! KDF1 {
|
||||
($ck:expr, $input:expr) => {{
|
||||
let mut t0 = HMAC!($ck, $input);
|
||||
let t1 = HMAC!(&t0, &[0x1]);
|
||||
t0.clear();
|
||||
t1
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! KDF2 {
|
||||
($ck:expr, $input:expr) => {{
|
||||
let mut t0 = HMAC!($ck, $input);
|
||||
let t1 = HMAC!(&t0, &[0x1]);
|
||||
let t2 = HMAC!(&t0, &t1, &[0x2]);
|
||||
t0.clear();
|
||||
(t1, t2)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! KDF3 {
|
||||
($ck:expr, $input:expr) => {{
|
||||
let mut t0 = HMAC!($ck, $input);
|
||||
let t1 = HMAC!(&t0, &[0x1]);
|
||||
let t2 = HMAC!(&t0, &t1, &[0x2]);
|
||||
let t3 = HMAC!(&t0, &t2, &[0x3]);
|
||||
t0.clear();
|
||||
(t1, t2, t3)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! SEAL {
|
||||
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
|
||||
ChaCha20Poly1305::new(*GenericArray::from_slice($key))
|
||||
.encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad })
|
||||
.map(|ct| $ct.copy_from_slice(&ct))
|
||||
.unwrap()
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! OPEN {
|
||||
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
|
||||
ChaCha20Poly1305::new(*GenericArray::from_slice($key))
|
||||
.decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad })
|
||||
.map_err(|_| HandshakeError::DecryptionFailure)
|
||||
.map(|pt| $pt.copy_from_slice(&pt))
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
|
||||
const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
/* Sanity check precomputed initial chain key
|
||||
*/
|
||||
#[test]
|
||||
fn precomputed_chain_key() {
|
||||
assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]);
|
||||
}
|
||||
|
||||
/* Sanity check precomputed initial hash transcript
|
||||
*/
|
||||
#[test]
|
||||
fn precomputed_hash() {
|
||||
assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]);
|
||||
}
|
||||
|
||||
/* Sanity check the HKDF macro
|
||||
*
|
||||
* Test vectors generated using WireGuard-Go
|
||||
*/
|
||||
#[test]
|
||||
fn hkdf() {
|
||||
let tests: Vec<(Vec<u8>, Vec<u8>, [u8; 32], [u8; 32], [u8; 32])> = vec![
|
||||
(
|
||||
vec![],
|
||||
vec![],
|
||||
[
|
||||
0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09,
|
||||
0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23,
|
||||
0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0,
|
||||
],
|
||||
[
|
||||
0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05,
|
||||
0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4,
|
||||
0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e,
|
||||
],
|
||||
[
|
||||
0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d,
|
||||
0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17,
|
||||
0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e,
|
||||
],
|
||||
),
|
||||
(
|
||||
vec![0xde, 0xad, 0xbe, 0xef],
|
||||
vec![],
|
||||
[
|
||||
0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08,
|
||||
0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2,
|
||||
0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6,
|
||||
],
|
||||
[
|
||||
0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac,
|
||||
0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0,
|
||||
0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee,
|
||||
],
|
||||
[
|
||||
0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a,
|
||||
0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1,
|
||||
0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4,
|
||||
],
|
||||
),
|
||||
];
|
||||
|
||||
for (key, input, t0, t1, t2) in &tests {
|
||||
let tt0 = KDF1!(key, input);
|
||||
debug_assert_eq!(tt0[..], t0[..]);
|
||||
|
||||
let (tt0, tt1) = KDF2!(key, input);
|
||||
debug_assert_eq!(tt0[..], t0[..]);
|
||||
debug_assert_eq!(tt1[..], t1[..]);
|
||||
|
||||
let (tt0, tt1, tt2) = KDF3!(key, input);
|
||||
debug_assert_eq!(tt0[..], t0[..]);
|
||||
debug_assert_eq!(tt1[..], t1[..]);
|
||||
debug_assert_eq!(tt2[..], t2[..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_initiation<R: RngCore + CryptoRng>(
|
||||
rng: &mut R,
|
||||
device: &Device,
|
||||
peer: &Peer,
|
||||
sender: u32,
|
||||
msg: &mut NoiseInitiation,
|
||||
) -> Result<(), HandshakeError> {
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// initialize state
|
||||
|
||||
let ck = INITIAL_CK;
|
||||
let hs = INITIAL_HS;
|
||||
let hs = HASH!(&hs, peer.pk.as_bytes());
|
||||
|
||||
msg.f_type.set(TYPE_INITIATION as u32);
|
||||
msg.f_sender.set(sender);
|
||||
|
||||
// (E_priv, E_pub) := DH-Generate()
|
||||
|
||||
let eph_sk = StaticSecret::new(rng);
|
||||
let eph_pk = PublicKey::from(&eph_sk);
|
||||
|
||||
// C := Kdf(C, E_pub)
|
||||
|
||||
let ck = KDF1!(&ck, eph_pk.as_bytes());
|
||||
|
||||
// msg.ephemeral := E_pub
|
||||
|
||||
msg.f_ephemeral = *eph_pk.as_bytes();
|
||||
|
||||
// H := HASH(H, msg.ephemeral)
|
||||
|
||||
let hs = HASH!(&hs, msg.f_ephemeral);
|
||||
|
||||
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
|
||||
|
||||
let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
|
||||
|
||||
// msg.static := Aead(k, 0, S_pub, H)
|
||||
|
||||
SEAL!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
device.pk.as_bytes(), // pt
|
||||
&mut msg.f_static // ct || tag
|
||||
);
|
||||
|
||||
// H := Hash(H || msg.static)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_static[..]);
|
||||
|
||||
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
||||
|
||||
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
|
||||
|
||||
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
||||
|
||||
SEAL!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
×tamp::now(), // pt
|
||||
&mut msg.f_timestamp // ct || tag
|
||||
);
|
||||
|
||||
// H := Hash(H || msg.timestamp)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_timestamp);
|
||||
|
||||
// update state of peer
|
||||
|
||||
*peer.state.lock() = State::InitiationSent {
|
||||
hs,
|
||||
ck,
|
||||
eph_sk,
|
||||
sender,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn consume_initiation<'a>(
|
||||
device: &'a Device,
|
||||
msg: &NoiseInitiation,
|
||||
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// initialize new state
|
||||
|
||||
let ck = INITIAL_CK;
|
||||
let hs = INITIAL_HS;
|
||||
let hs = HASH!(&hs, device.pk.as_bytes());
|
||||
|
||||
// C := Kdf(C, E_pub)
|
||||
|
||||
let ck = KDF1!(&ck, &msg.f_ephemeral);
|
||||
|
||||
// H := HASH(H, msg.ephemeral)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_ephemeral);
|
||||
|
||||
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
|
||||
|
||||
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
|
||||
let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||
|
||||
// msg.static := Aead(k, 0, S_pub, H)
|
||||
|
||||
let mut pk = [0u8; 32];
|
||||
|
||||
OPEN!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
&mut pk, // pt
|
||||
&msg.f_static // ct || tag
|
||||
)?;
|
||||
|
||||
let peer = device.lookup_pk(&PublicKey::from(pk))?;
|
||||
|
||||
// reset initiation state
|
||||
|
||||
*peer.state.lock() = State::Reset;
|
||||
|
||||
// H := Hash(H || msg.static)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_static[..]);
|
||||
|
||||
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
||||
|
||||
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
|
||||
|
||||
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
||||
|
||||
let mut ts = timestamp::ZERO;
|
||||
|
||||
OPEN!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
&mut ts, // pt
|
||||
&msg.f_timestamp // ct || tag
|
||||
)?;
|
||||
|
||||
// check and update timestamp
|
||||
|
||||
peer.check_replay_flood(device, &ts)?;
|
||||
|
||||
// H := Hash(H || msg.timestamp)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_timestamp);
|
||||
|
||||
// return state (to create response)
|
||||
|
||||
Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_response<R: RngCore + CryptoRng>(
|
||||
rng: &mut R,
|
||||
peer: &Peer,
|
||||
sender: u32, // sending identifier
|
||||
state: TemporaryState, // state from "consume_initiation"
|
||||
msg: &mut NoiseResponse, // resulting response
|
||||
) -> Result<KeyPair, HandshakeError> {
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// unpack state
|
||||
|
||||
let (receiver, eph_r_pk, hs, ck) = state;
|
||||
|
||||
msg.f_type.set(TYPE_RESPONSE as u32);
|
||||
msg.f_sender.set(sender);
|
||||
msg.f_receiver.set(receiver);
|
||||
|
||||
// (E_priv, E_pub) := DH-Generate()
|
||||
|
||||
let eph_sk = StaticSecret::new(rng);
|
||||
let eph_pk = PublicKey::from(&eph_sk);
|
||||
|
||||
// C := Kdf1(C, E_pub)
|
||||
|
||||
let ck = KDF1!(&ck, eph_pk.as_bytes());
|
||||
|
||||
// msg.ephemeral := E_pub
|
||||
|
||||
msg.f_ephemeral = *eph_pk.as_bytes();
|
||||
|
||||
// H := Hash(H || msg.ephemeral)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_ephemeral);
|
||||
|
||||
// C := Kdf1(C, DH(E_priv, E_pub))
|
||||
|
||||
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||
|
||||
// C := Kdf1(C, DH(E_priv, S_pub))
|
||||
|
||||
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
|
||||
|
||||
// (C, tau, k) := Kdf3(C, Q)
|
||||
|
||||
let (ck, tau, key) = KDF3!(&ck, &peer.psk);
|
||||
|
||||
// H := Hash(H || tau)
|
||||
|
||||
let hs = HASH!(&hs, tau);
|
||||
|
||||
// msg.empty := Aead(k, 0, [], H)
|
||||
|
||||
SEAL!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
&[], // pt
|
||||
&mut msg.f_empty // \epsilon || tag
|
||||
);
|
||||
|
||||
// Not strictly needed
|
||||
// let hs = HASH!(&hs, &msg.f_empty_tag);
|
||||
|
||||
// derive key-pair
|
||||
|
||||
let (key_recv, key_send) = KDF2!(&ck, &[]);
|
||||
|
||||
// return unconfirmed key-pair
|
||||
|
||||
Ok(KeyPair {
|
||||
birth: Instant::now(),
|
||||
initiator: false,
|
||||
send: Key {
|
||||
id: sender,
|
||||
key: key_send.into(),
|
||||
},
|
||||
recv: Key {
|
||||
id: receiver,
|
||||
key: key_recv.into(),
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/* The state lock is released while processing the message to
|
||||
* allow concurrent processing of potential responses to the initiation,
|
||||
* in order to better mitigate DoS from malformed response messages.
|
||||
*/
|
||||
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// retrieve peer and copy initiation state
|
||||
let peer = device.lookup_id(msg.f_receiver.get())?;
|
||||
|
||||
let (hs, ck, sender, eph_sk) = match *peer.state.lock() {
|
||||
State::InitiationSent {
|
||||
hs,
|
||||
ck,
|
||||
sender,
|
||||
ref eph_sk,
|
||||
} => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))),
|
||||
_ => Err(HandshakeError::InvalidState),
|
||||
}?;
|
||||
|
||||
// C := Kdf1(C, E_pub)
|
||||
|
||||
let ck = KDF1!(&ck, &msg.f_ephemeral);
|
||||
|
||||
// H := Hash(H || msg.ephemeral)
|
||||
|
||||
let hs = HASH!(&hs, &msg.f_ephemeral);
|
||||
|
||||
// C := Kdf1(C, DH(E_priv, E_pub))
|
||||
|
||||
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
|
||||
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||
|
||||
// C := Kdf1(C, DH(E_priv, S_pub))
|
||||
|
||||
let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||
|
||||
// (C, tau, k) := Kdf3(C, Q)
|
||||
|
||||
let (ck, tau, key) = KDF3!(&ck, &peer.psk);
|
||||
|
||||
// H := Hash(H || tau)
|
||||
|
||||
let hs = HASH!(&hs, tau);
|
||||
|
||||
// msg.empty := Aead(k, 0, [], H)
|
||||
|
||||
OPEN!(
|
||||
&key,
|
||||
&hs, // ad
|
||||
&mut [], // pt
|
||||
&msg.f_empty // \epsilon || tag
|
||||
)?;
|
||||
|
||||
// derive key-pair
|
||||
|
||||
let birth = Instant::now();
|
||||
let (key_send, key_recv) = KDF2!(&ck, &[]);
|
||||
|
||||
// check for new initiation sent while lock released
|
||||
|
||||
let mut state = peer.state.lock();
|
||||
let update = match *state {
|
||||
State::InitiationSent {
|
||||
eph_sk: ref old, ..
|
||||
} => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(),
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if update {
|
||||
// null the initiation state
|
||||
// (to avoid replay of this response message)
|
||||
*state = State::Reset;
|
||||
|
||||
// return confirmed key-pair
|
||||
Ok((
|
||||
Some(peer.pk),
|
||||
None,
|
||||
Some(KeyPair {
|
||||
birth,
|
||||
initiator: true,
|
||||
send: Key {
|
||||
id: sender,
|
||||
key: key_send.into(),
|
||||
},
|
||||
recv: Key {
|
||||
id: msg.f_sender.get(),
|
||||
key: key_recv.into(),
|
||||
},
|
||||
}),
|
||||
))
|
||||
} else {
|
||||
Err(HandshakeError::InvalidState)
|
||||
}
|
||||
})
|
||||
}
|
||||
142
src/wireguard/handshake/peer.rs
Normal file
142
src/wireguard/handshake/peer.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use spin::Mutex;
|
||||
|
||||
use std::mem;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use generic_array::typenum::U32;
|
||||
use generic_array::GenericArray;
|
||||
|
||||
use x25519_dalek::PublicKey;
|
||||
use x25519_dalek::SharedSecret;
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
use clear_on_drop::clear::Clear;
|
||||
|
||||
use super::device::Device;
|
||||
use super::macs;
|
||||
use super::timestamp;
|
||||
use super::types::*;
|
||||
|
||||
const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
|
||||
|
||||
/* Represents the recomputation and state of a peer.
|
||||
*
|
||||
* This type is only for internal use and not exposed.
|
||||
*/
|
||||
pub struct Peer {
|
||||
// mutable state
|
||||
pub(crate) state: Mutex<State>,
|
||||
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
|
||||
pub(crate) last_initiation_consumption: Mutex<Option<Instant>>,
|
||||
|
||||
// state related to DoS mitigation fields
|
||||
pub(crate) macs: Mutex<macs::Generator>,
|
||||
|
||||
// constant state
|
||||
pub(crate) pk: PublicKey, // public key of peer
|
||||
pub(crate) ss: SharedSecret, // precomputed DH(static, static)
|
||||
pub(crate) psk: Psk, // psk of peer
|
||||
}
|
||||
|
||||
pub enum State {
|
||||
Reset,
|
||||
InitiationSent {
|
||||
sender: u32, // assigned sender id
|
||||
eph_sk: StaticSecret,
|
||||
hs: GenericArray<u8, U32>,
|
||||
ck: GenericArray<u8, U32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Drop for State {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
State::InitiationSent { hs, ck, .. } => {
|
||||
// eph_sk already cleared by dalek-x25519
|
||||
hs.clear();
|
||||
ck.clear();
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
pk: PublicKey, // public key of peer
|
||||
ss: SharedSecret, // precomputed DH(static, static)
|
||||
) -> Self {
|
||||
Self {
|
||||
macs: Mutex::new(macs::Generator::new(pk)),
|
||||
state: Mutex::new(State::Reset),
|
||||
timestamp: Mutex::new(None),
|
||||
last_initiation_consumption: Mutex::new(None),
|
||||
pk: pk,
|
||||
ss: ss,
|
||||
psk: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the state of the peer unconditionally
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
pub fn set_state(&self, state_new: State) {
|
||||
*self.state.lock() = state_new;
|
||||
}
|
||||
|
||||
pub fn reset_state(&self) -> Option<u32> {
|
||||
match mem::replace(&mut *self.state.lock(), State::Reset) {
|
||||
State::InitiationSent { sender, .. } => Some(sender),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the mutable state of the peer conditioned on the timestamp being newer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * st_new - The updated state of the peer
|
||||
/// * ts_new - The associated timestamp
|
||||
pub fn check_replay_flood(
|
||||
&self,
|
||||
device: &Device,
|
||||
timestamp_new: ×tamp::TAI64N,
|
||||
) -> Result<(), HandshakeError> {
|
||||
let mut state = self.state.lock();
|
||||
let mut timestamp = self.timestamp.lock();
|
||||
let mut last_initiation_consumption = self.last_initiation_consumption.lock();
|
||||
|
||||
// check replay attack
|
||||
match *timestamp {
|
||||
Some(timestamp_old) => {
|
||||
if !timestamp::compare(×tamp_old, ×tamp_new) {
|
||||
return Err(HandshakeError::OldTimestamp);
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// check flood attack
|
||||
match *last_initiation_consumption {
|
||||
Some(last) => {
|
||||
if last.elapsed() < TIME_BETWEEN_INITIATIONS {
|
||||
return Err(HandshakeError::InitiationFlood);
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// reset state
|
||||
match *state {
|
||||
State::InitiationSent { sender, .. } => device.release(sender),
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// update replay & flood protection
|
||||
*state = State::Reset;
|
||||
*timestamp = Some(*timestamp_new);
|
||||
*last_initiation_consumption = Some(Instant::now());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
199
src/wireguard/handshake/ratelimiter.rs
Normal file
199
src/wireguard/handshake/ratelimiter.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use spin;
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const PACKETS_PER_SECOND: u64 = 20;
|
||||
const PACKETS_BURSTABLE: u64 = 5;
|
||||
const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
|
||||
const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
|
||||
|
||||
const GC_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
||||
struct Entry {
|
||||
pub last_time: Instant,
|
||||
pub tokens: u64,
|
||||
}
|
||||
|
||||
pub struct RateLimiter(Arc<RateLimiterInner>);
|
||||
|
||||
struct RateLimiterInner {
|
||||
gc_running: AtomicBool,
|
||||
gc_dropped: (Mutex<bool>, Condvar),
|
||||
table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
|
||||
}
|
||||
|
||||
impl Drop for RateLimiter {
|
||||
fn drop(&mut self) {
|
||||
// wake up & terminate any lingering GC thread
|
||||
let &(ref lock, ref cvar) = &self.0.gc_dropped;
|
||||
let mut dropped = lock.lock().unwrap();
|
||||
*dropped = true;
|
||||
cvar.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new() -> Self {
|
||||
RateLimiter(Arc::new(RateLimiterInner {
|
||||
gc_dropped: (Mutex::new(false), Condvar::new()),
|
||||
gc_running: AtomicBool::from(false),
|
||||
table: spin::RwLock::new(HashMap::new()),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn allow(&self, addr: &IpAddr) -> bool {
|
||||
// check if allowed
|
||||
let allowed = {
|
||||
// check for existing entry (only requires read lock)
|
||||
if let Some(entry) = self.0.table.read().get(addr) {
|
||||
// update existing entry
|
||||
let mut entry = entry.lock();
|
||||
|
||||
// add tokens earned since last time
|
||||
entry.tokens = MAX_TOKENS
|
||||
.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
|
||||
entry.last_time = Instant::now();
|
||||
|
||||
// subtract cost of packet
|
||||
if entry.tokens > PACKET_COST {
|
||||
entry.tokens -= PACKET_COST;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// add new entry (write lock)
|
||||
self.0.table.write().insert(
|
||||
*addr,
|
||||
spin::Mutex::new(Entry {
|
||||
last_time: Instant::now(),
|
||||
tokens: MAX_TOKENS - PACKET_COST,
|
||||
}),
|
||||
);
|
||||
true
|
||||
};
|
||||
|
||||
// check that GC thread is scheduled
|
||||
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
|
||||
let limiter = self.0.clone();
|
||||
thread::spawn(move || {
|
||||
let &(ref lock, ref cvar) = &limiter.gc_dropped;
|
||||
let mut dropped = lock.lock().unwrap();
|
||||
while !*dropped {
|
||||
// garbage collect
|
||||
{
|
||||
let mut tw = limiter.table.write();
|
||||
tw.retain(|_, ref mut entry| {
|
||||
entry.lock().last_time.elapsed() <= GC_INTERVAL
|
||||
});
|
||||
if tw.len() == 0 {
|
||||
limiter.gc_running.store(false, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// wait until stopped or new GC (~1 every sec)
|
||||
let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap();
|
||||
dropped = res.0;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
allowed
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std;
|
||||
|
||||
struct Result {
|
||||
allowed: bool,
|
||||
text: &'static str,
|
||||
wait: Duration,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ratelimiter() {
|
||||
let ratelimiter = RateLimiter::new();
|
||||
let mut expected = vec![];
|
||||
let ips = vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"192.168.1.1".parse().unwrap(),
|
||||
"172.167.2.3".parse().unwrap(),
|
||||
"97.231.252.215".parse().unwrap(),
|
||||
"248.97.91.167".parse().unwrap(),
|
||||
"188.208.233.47".parse().unwrap(),
|
||||
"104.2.183.179".parse().unwrap(),
|
||||
"72.129.46.120".parse().unwrap(),
|
||||
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
||||
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
||||
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
||||
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
||||
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
||||
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
||||
];
|
||||
|
||||
for _ in 0..PACKETS_BURSTABLE {
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "inital burst",
|
||||
});
|
||||
}
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "after burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, PACKET_COST as u32),
|
||||
text: "filling tokens for single packet",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "not having refilled enough",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
||||
text: "filling tokens for 2 * packet burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "second packet in 2 packet burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "packet following 2 packet burst",
|
||||
});
|
||||
|
||||
for item in expected {
|
||||
std::thread::sleep(item.wait);
|
||||
for ip in ips.iter() {
|
||||
if ratelimiter.allow(&ip) != item.allowed {
|
||||
panic!(
|
||||
"test failed for {} on {}. expected: {}, got: {}",
|
||||
ip, item.text, item.allowed, !item.allowed
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
32
src/wireguard/handshake/timestamp.rs
Normal file
32
src/wireguard/handshake/timestamp.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub type TAI64N = [u8; 12];
|
||||
|
||||
const TAI64_EPOCH: u64 = 0x400000000000000a;
|
||||
|
||||
pub const ZERO: TAI64N = [0u8; 12];
|
||||
|
||||
pub fn now() -> TAI64N {
|
||||
// get system time as duration
|
||||
let sysnow = SystemTime::now();
|
||||
let delta = sysnow.duration_since(UNIX_EPOCH).unwrap();
|
||||
|
||||
// convert to tai64n
|
||||
let tai64_secs = delta.as_secs() + TAI64_EPOCH;
|
||||
let tai64_nano = delta.subsec_nanos();
|
||||
|
||||
// serialize
|
||||
let mut res = [0u8; 12];
|
||||
res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]);
|
||||
res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn compare(old: &TAI64N, new: &TAI64N) -> bool {
|
||||
for i in 0..12 {
|
||||
if new[i] > old[i] {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
90
src/wireguard/handshake/types.rs
Normal file
90
src/wireguard/handshake/types.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
use super::super::types::KeyPair;
|
||||
|
||||
/* Internal types for the noise IKpsk2 implementation */
|
||||
|
||||
// config error
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConfigError(String);
|
||||
|
||||
impl ConfigError {
|
||||
pub fn new(s: &str) -> Self {
|
||||
ConfigError(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "ConfigError({})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ConfigError {
|
||||
fn description(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// handshake error
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeError {
|
||||
DecryptionFailure,
|
||||
UnknownPublicKey,
|
||||
UnknownReceiverId,
|
||||
InvalidMessageFormat,
|
||||
OldTimestamp,
|
||||
InvalidState,
|
||||
InvalidMac1,
|
||||
RateLimited,
|
||||
InitiationFlood,
|
||||
}
|
||||
|
||||
impl fmt::Display for HandshakeError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"),
|
||||
HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"),
|
||||
HandshakeError::UnknownReceiverId => {
|
||||
write!(f, "Receiver id not allocated to any handshake")
|
||||
}
|
||||
HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"),
|
||||
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
|
||||
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
|
||||
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
|
||||
HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
|
||||
HandshakeError::InitiationFlood => {
|
||||
write!(f, "Message was dropped because of initiation flood")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for HandshakeError {
|
||||
fn description(&self) -> &str {
|
||||
"Generic Handshake Error"
|
||||
}
|
||||
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub type Output = (
|
||||
Option<PublicKey>, // external identifier associated with peer
|
||||
Option<Vec<u8>>, // message to send
|
||||
Option<KeyPair>, // resulting key-pair of successful handshake
|
||||
);
|
||||
|
||||
// preshared key
|
||||
|
||||
pub type Psk = [u8; 32];
|
||||
23
src/wireguard/mod.rs
Normal file
23
src/wireguard/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
mod wireguard;
|
||||
// mod config;
|
||||
mod constants;
|
||||
mod timers;
|
||||
|
||||
mod handshake;
|
||||
mod router;
|
||||
mod types;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// The WireGuard sub-module contains a pure, configurable implementation of WireGuard.
|
||||
/// The implementation is generic over:
|
||||
///
|
||||
/// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface.
|
||||
/// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint"
|
||||
|
||||
pub use wireguard::{Wireguard, Peer};
|
||||
|
||||
pub use types::bind;
|
||||
pub use types::tun;
|
||||
pub use types::Endpoint;
|
||||
157
src/wireguard/router/anti_replay.rs
Normal file
157
src/wireguard/router/anti_replay.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use std::mem;
|
||||
|
||||
// Implementation of RFC 6479.
|
||||
// https://tools.ietf.org/html/rfc6479
|
||||
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
type Word = u64;
|
||||
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
const REDUNDANT_BIT_SHIFTS: usize = 6;
|
||||
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
type Word = u32;
|
||||
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
const REDUNDANT_BIT_SHIFTS: usize = 5;
|
||||
|
||||
const SIZE_OF_WORD: usize = mem::size_of::<Word>() * 8;
|
||||
|
||||
const BITMAP_BITLEN: usize = 2048;
|
||||
const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD);
|
||||
const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1;
|
||||
const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64;
|
||||
const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64;
|
||||
|
||||
pub struct AntiReplay {
|
||||
bitmap: [Word; BITMAP_LEN],
|
||||
last: u64,
|
||||
}
|
||||
|
||||
impl Default for AntiReplay {
|
||||
fn default() -> Self {
|
||||
AntiReplay::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AntiReplay {
|
||||
pub fn new() -> Self {
|
||||
debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD);
|
||||
debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0);
|
||||
AntiReplay {
|
||||
last: 0,
|
||||
bitmap: [0; BITMAP_LEN],
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if check is passed, i.e., not a replay or too old.
|
||||
//
|
||||
// Unlike RFC 6479, zero is allowed.
|
||||
fn check(&self, seq: u64) -> bool {
|
||||
// Larger is always good.
|
||||
if seq > self.last {
|
||||
return true;
|
||||
}
|
||||
|
||||
if self.last - seq > WINDOW_SIZE {
|
||||
return false;
|
||||
}
|
||||
|
||||
let bit_location = seq & BITMAP_LOC_MASK;
|
||||
let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK;
|
||||
|
||||
self.bitmap[index as usize] & (1 << bit_location) == 0
|
||||
}
|
||||
|
||||
// Should only be called if check returns true.
|
||||
fn update_store(&mut self, seq: u64) {
|
||||
debug_assert!(self.check(seq));
|
||||
|
||||
let index = seq >> REDUNDANT_BIT_SHIFTS;
|
||||
|
||||
if seq > self.last {
|
||||
let index_cur = self.last >> REDUNDANT_BIT_SHIFTS;
|
||||
let diff = index - index_cur;
|
||||
|
||||
if diff >= BITMAP_LEN as u64 {
|
||||
self.bitmap = [0; BITMAP_LEN];
|
||||
} else {
|
||||
for i in 0..diff {
|
||||
let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK;
|
||||
self.bitmap[real_index as usize] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
self.last = seq;
|
||||
}
|
||||
|
||||
let index = index & BITMAP_INDEX_MASK;
|
||||
let bit_location = seq & BITMAP_LOC_MASK;
|
||||
self.bitmap[index as usize] |= 1 << bit_location;
|
||||
}
|
||||
|
||||
/// Checks and marks a sequence number in the replay filter
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - seq: Sequence number check for replay and add to filter
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Ok(()) if sequence number is valid (not marked and not behind the moving window).
|
||||
/// Err if the sequence number is invalid (already marked or "too old").
|
||||
pub fn update(&mut self, seq: u64) -> bool {
|
||||
if self.check(seq) {
|
||||
self.update_store(seq);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn anti_replay() {
|
||||
let mut ar = AntiReplay::new();
|
||||
|
||||
for i in 0..20000 {
|
||||
assert!(ar.update(i));
|
||||
}
|
||||
|
||||
for i in (0..20000).rev() {
|
||||
assert!(!ar.check(i));
|
||||
}
|
||||
|
||||
assert!(ar.update(65536));
|
||||
for i in (65536 - WINDOW_SIZE)..65535 {
|
||||
assert!(ar.update(i));
|
||||
}
|
||||
|
||||
for i in (65536 - 10 * WINDOW_SIZE)..65535 {
|
||||
assert!(!ar.check(i));
|
||||
}
|
||||
|
||||
assert!(ar.update(66000));
|
||||
for i in 65537..66000 {
|
||||
assert!(ar.update(i));
|
||||
}
|
||||
for i in 65537..66000 {
|
||||
assert_eq!(ar.update(i), false);
|
||||
}
|
||||
|
||||
// Test max u64.
|
||||
let next = u64::max_value();
|
||||
assert!(ar.update(next));
|
||||
assert!(!ar.check(next));
|
||||
for i in (next - WINDOW_SIZE)..next {
|
||||
assert!(ar.update(i));
|
||||
}
|
||||
for i in (next - 20 * WINDOW_SIZE)..next {
|
||||
assert!(!ar.check(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
7
src/wireguard/router/constants.rs
Normal file
7
src/wireguard/router/constants.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// WireGuard semantics constants
|
||||
|
||||
pub const MAX_STAGED_PACKETS: usize = 128;
|
||||
|
||||
// performance constants
|
||||
|
||||
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
|
||||
243
src/wireguard/router/device.rs
Normal file
243
src/wireguard/router/device.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::sync_channel;
|
||||
use std::sync::mpsc::SyncSender;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use log::debug;
|
||||
use spin::{Mutex, RwLock};
|
||||
use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::constants::*;
|
||||
use super::ip::*;
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::{new_peer, Peer, PeerInner};
|
||||
use super::types::{Callbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel, Operation};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::super::types::{bind, tun, Endpoint, KeyPair};
|
||||
|
||||
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
// inbound writer (TUN)
|
||||
pub inbound: T,
|
||||
|
||||
// outbound writer (Bind)
|
||||
pub outbound: RwLock<Option<B>>,
|
||||
|
||||
// routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
|
||||
|
||||
// work queues
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
}
|
||||
|
||||
pub struct EncryptionState {
|
||||
pub key: [u8; 32], // encryption key
|
||||
pub id: u32, // receiver id
|
||||
pub nonce: u64, // next available nonce
|
||||
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
|
||||
}
|
||||
|
||||
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<E, C, T, B>>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
debug!("router: dropping device");
|
||||
|
||||
// drop all queues
|
||||
{
|
||||
let mut queues = self.state.queues.lock();
|
||||
while queues.pop().is_some() {}
|
||||
}
|
||||
|
||||
// join all worker threads
|
||||
while match self.handles.pop() {
|
||||
Some(handle) => {
|
||||
handle.thread().unpark();
|
||||
handle.join().unwrap();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
} {}
|
||||
|
||||
debug!("router: device dropped");
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: &Arc<DeviceInner<E, C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<Arc<PeerInner<E, C, T, B>>> {
|
||||
// ensure version access within bounds
|
||||
if packet.len() < 1 {
|
||||
return None;
|
||||
};
|
||||
|
||||
// cast to correct IP header
|
||||
match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
|
||||
LayoutVerified::new_from_prefix(packet)?;
|
||||
|
||||
// lookup destination address
|
||||
device
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_destination))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
|
||||
LayoutVerified::new_from_prefix(packet)?;
|
||||
|
||||
// lookup destination address
|
||||
device
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_destination))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
|
||||
// allocate shared device state
|
||||
let inner = DeviceInner {
|
||||
inbound: tun,
|
||||
outbound: RwLock::new(None),
|
||||
queues: Mutex::new(Vec::with_capacity(num_workers)),
|
||||
queue_next: AtomicUsize::new(0),
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
ipv4: RwLock::new(IpLookupTable::new()),
|
||||
ipv6: RwLock::new(IpLookupTable::new()),
|
||||
};
|
||||
|
||||
// start worker threads
|
||||
let mut threads = Vec::with_capacity(num_workers);
|
||||
for _ in 0..num_workers {
|
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||
inner.queues.lock().push(tx);
|
||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||
}
|
||||
|
||||
// return exported device handle
|
||||
Device {
|
||||
state: Arc::new(inner),
|
||||
handles: threads,
|
||||
}
|
||||
}
|
||||
|
||||
/// A new secret key has been set for the device.
|
||||
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
|
||||
pub fn new_sk(&self) {}
|
||||
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A atomic ref. counted peer (with liftime matching the device)
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
|
||||
new_peer(self.state.clone(), opaque)
|
||||
}
|
||||
|
||||
/// Cryptkey routes and sends a plaintext message (IP packet)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - msg: IP packet to crypt-key route
|
||||
///
|
||||
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// ignore header prefix (for in-place transport message construction)
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg, true) {
|
||||
debug_assert_eq!(job.1.op, Operation::Encryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive an encrypted transport message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Source address of the packet
|
||||
/// - msg: Encrypted transport message
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err(RouterError::MalformedTransportMessage);
|
||||
}
|
||||
};
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||
debug_assert!(
|
||||
header.f_type.get() == TYPE_TRANSPORT as u32,
|
||||
"this should be checked by the message type multiplexer"
|
||||
);
|
||||
|
||||
// lookup peer based on receiver id
|
||||
let dec = self.state.recv.read();
|
||||
let dec = dec
|
||||
.get(&header.f_receiver.get())
|
||||
.ok_or(RouterError::UnknownReceiverId)?;
|
||||
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||
debug_assert_eq!(job.1.op, Operation::Decryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set outbound writer
|
||||
///
|
||||
///
|
||||
pub fn set_outbound_writer(&self, new: B) {
|
||||
*self.state.outbound.write() = Some(new);
|
||||
}
|
||||
}
|
||||
26
src/wireguard/router/ip.rs
Normal file
26
src/wireguard/router/ip.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use byteorder::BigEndian;
|
||||
use zerocopy::byteorder::U16;
|
||||
use zerocopy::{AsBytes, FromBytes};
|
||||
|
||||
pub const VERSION_IP4: u8 = 4;
|
||||
pub const VERSION_IP6: u8 = 6;
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct IPv4Header {
|
||||
_f_space1: [u8; 2],
|
||||
pub f_total_len: U16<BigEndian>,
|
||||
_f_space2: [u8; 8],
|
||||
pub f_source: [u8; 4],
|
||||
pub f_destination: [u8; 4],
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct IPv6Header {
|
||||
_f_space1: [u8; 4],
|
||||
pub f_len: U16<BigEndian>,
|
||||
_f_space2: [u8; 2],
|
||||
pub f_source: [u8; 16],
|
||||
pub f_destination: [u8; 16],
|
||||
}
|
||||
13
src/wireguard/router/messages.rs
Normal file
13
src/wireguard/router/messages.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use byteorder::LittleEndian;
|
||||
use zerocopy::byteorder::{U32, U64};
|
||||
use zerocopy::{AsBytes, FromBytes};
|
||||
|
||||
pub const TYPE_TRANSPORT: u32 = 4;
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct TransportHeader {
|
||||
pub f_type: U32<LittleEndian>,
|
||||
pub f_receiver: U32<LittleEndian>,
|
||||
pub f_counter: U64<LittleEndian>,
|
||||
}
|
||||
22
src/wireguard/router/mod.rs
Normal file
22
src/wireguard/router/mod.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
mod anti_replay;
|
||||
mod constants;
|
||||
mod device;
|
||||
mod ip;
|
||||
mod messages;
|
||||
mod peer;
|
||||
mod types;
|
||||
mod workers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use messages::TransportHeader;
|
||||
use std::mem;
|
||||
|
||||
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
|
||||
pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
|
||||
|
||||
pub use messages::TYPE_TRANSPORT;
|
||||
pub use device::Device;
|
||||
pub use peer::Peer;
|
||||
pub use types::Callbacks;
|
||||
611
src/wireguard/router/peer.rs
Normal file
611
src/wireguard/router/peer.rs
Normal file
@@ -0,0 +1,611 @@
|
||||
use std::mem;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::{sync_channel, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use arraydeque::{ArrayDeque, Wrapping};
|
||||
use log::debug;
|
||||
use spin::Mutex;
|
||||
use treebitmap::address::Address;
|
||||
use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::super::constants::*;
|
||||
use super::super::types::{bind, tun, Endpoint, KeyPair};
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::device::DecryptionState;
|
||||
use super::device::DeviceInner;
|
||||
use super::device::EncryptionState;
|
||||
use super::messages::TransportHeader;
|
||||
|
||||
use futures::*;
|
||||
|
||||
use super::workers::Operation;
|
||||
use super::workers::{worker_inbound, worker_outbound};
|
||||
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::constants::*;
|
||||
use super::types::{Callbacks, RouterError};
|
||||
|
||||
pub struct KeyWheel {
|
||||
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
|
||||
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
|
||||
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
|
||||
retired: Vec<u32>, // retired ids
|
||||
}
|
||||
|
||||
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
pub device: Arc<DeviceInner<E, C, T, B>>,
|
||||
pub opaque: C::Opaque,
|
||||
pub outbound: Mutex<SyncSender<JobOutbound>>,
|
||||
pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
|
||||
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
|
||||
pub keys: Mutex<KeyWheel>,
|
||||
pub ekey: Mutex<Option<EncryptionState>>,
|
||||
pub endpoint: Mutex<Option<E>>,
|
||||
}
|
||||
|
||||
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
state: Arc<PeerInner<E, C, T, B>>,
|
||||
thread_outbound: Option<thread::JoinHandle<()>>,
|
||||
thread_inbound: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
|
||||
callback: Box<dyn Fn(A, u32) -> R>,
|
||||
) -> Vec<R>
|
||||
where
|
||||
A: Address,
|
||||
{
|
||||
let mut res = Vec::new();
|
||||
for subnet in table.read().iter() {
|
||||
let (ip, masklen, p) = subnet;
|
||||
if Arc::ptr_eq(&p, &peer) {
|
||||
res.push(callback(ip, masklen))
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
|
||||
) {
|
||||
let mut m = table.write();
|
||||
|
||||
// collect keys for value
|
||||
let mut subnets = vec![];
|
||||
for subnet in m.iter() {
|
||||
let (ip, masklen, p) = subnet;
|
||||
if Arc::ptr_eq(&p, &peer.state) {
|
||||
subnets.push((ip, masklen))
|
||||
}
|
||||
}
|
||||
|
||||
// remove all key mappings
|
||||
for (ip, masklen) in subnets {
|
||||
let r = m.remove(ip, masklen);
|
||||
debug_assert!(r.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptionState {
|
||||
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
|
||||
EncryptionState {
|
||||
id: keypair.send.id,
|
||||
key: keypair.send.key,
|
||||
nonce: 0,
|
||||
death: keypair.birth + REJECT_AFTER_TIME,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
|
||||
fn new(
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
keypair: &Arc<KeyPair>,
|
||||
) -> DecryptionState<E, C, T, B> {
|
||||
DecryptionState {
|
||||
confirmed: AtomicBool::new(keypair.initiator),
|
||||
keypair: keypair.clone(),
|
||||
protector: spin::Mutex::new(AntiReplay::new()),
|
||||
peer: peer.clone(),
|
||||
death: keypair.birth + REJECT_AFTER_TIME,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
let peer = &self.state;
|
||||
|
||||
// remove from cryptkey router
|
||||
|
||||
treebit_remove(self, &peer.device.ipv4);
|
||||
treebit_remove(self, &peer.device.ipv6);
|
||||
|
||||
// drop channels
|
||||
|
||||
mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
|
||||
mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
|
||||
|
||||
// join with workers
|
||||
|
||||
mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
|
||||
mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
|
||||
|
||||
// release ids from the receiver map
|
||||
|
||||
let mut keys = peer.keys.lock();
|
||||
let mut release = Vec::with_capacity(3);
|
||||
|
||||
keys.next.as_ref().map(|k| release.push(k.recv.id));
|
||||
keys.current.as_ref().map(|k| release.push(k.recv.id));
|
||||
keys.previous.as_ref().map(|k| release.push(k.recv.id));
|
||||
|
||||
if release.len() > 0 {
|
||||
let mut recv = peer.device.recv.write();
|
||||
for id in &release {
|
||||
recv.remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
// null key-material
|
||||
|
||||
keys.next = None;
|
||||
keys.current = None;
|
||||
keys.previous = None;
|
||||
|
||||
*peer.ekey.lock() = None;
|
||||
*peer.endpoint.lock() = None;
|
||||
|
||||
debug!("peer dropped & removed from device");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>,
|
||||
opaque: C::Opaque,
|
||||
) -> Peer<E, C, T, B> {
|
||||
let (out_tx, out_rx) = sync_channel(128);
|
||||
let (in_tx, in_rx) = sync_channel(128);
|
||||
|
||||
// allocate peer object
|
||||
let peer = {
|
||||
let device = device.clone();
|
||||
Arc::new(PeerInner {
|
||||
opaque,
|
||||
device,
|
||||
inbound: Mutex::new(in_tx),
|
||||
outbound: Mutex::new(out_tx),
|
||||
ekey: spin::Mutex::new(None),
|
||||
endpoint: spin::Mutex::new(None),
|
||||
keys: spin::Mutex::new(KeyWheel {
|
||||
next: None,
|
||||
current: None,
|
||||
previous: None,
|
||||
retired: vec![],
|
||||
}),
|
||||
staged_packets: spin::Mutex::new(ArrayDeque::new()),
|
||||
})
|
||||
};
|
||||
|
||||
// spawn outbound thread
|
||||
let thread_inbound = {
|
||||
let peer = peer.clone();
|
||||
let device = device.clone();
|
||||
thread::spawn(move || worker_outbound(device, peer, out_rx))
|
||||
};
|
||||
|
||||
// spawn inbound thread
|
||||
let thread_outbound = {
|
||||
let peer = peer.clone();
|
||||
let device = device.clone();
|
||||
thread::spawn(move || worker_inbound(device, peer, in_rx))
|
||||
};
|
||||
|
||||
Peer {
|
||||
state: peer,
|
||||
thread_inbound: Some(thread_inbound),
|
||||
thread_outbound: Some(thread_outbound),
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
|
||||
fn send_staged(&self) -> bool {
|
||||
debug!("peer.send_staged");
|
||||
let mut sent = false;
|
||||
let mut staged = self.staged_packets.lock();
|
||||
loop {
|
||||
match staged.pop_front() {
|
||||
Some(msg) => {
|
||||
sent = true;
|
||||
self.send_raw(msg);
|
||||
}
|
||||
None => break sent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Treat the msg as the payload of a transport message
|
||||
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
|
||||
fn send_raw(&self, msg: Vec<u8>) -> bool {
|
||||
debug!("peer.send_raw");
|
||||
match self.send_job(msg, false) {
|
||||
Some(job) => {
|
||||
debug!("send_raw: got obtained send_job");
|
||||
let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.device.queues.lock();
|
||||
match queues[index % queues.len()].send(job) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
|
||||
debug!("peer.confirm_key");
|
||||
{
|
||||
// take lock and check keypair = keys.next
|
||||
let mut keys = self.keys.lock();
|
||||
let next = match keys.next.as_ref() {
|
||||
Some(next) => next,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
if !Arc::ptr_eq(&next, keypair) {
|
||||
return;
|
||||
}
|
||||
|
||||
// allocate new encryption state
|
||||
let ekey = Some(EncryptionState::new(&next));
|
||||
|
||||
// rotate key-wheel
|
||||
let mut swap = None;
|
||||
mem::swap(&mut keys.next, &mut swap);
|
||||
mem::swap(&mut keys.current, &mut swap);
|
||||
mem::swap(&mut keys.previous, &mut swap);
|
||||
|
||||
// tell the world outside the router that a key was confirmed
|
||||
C::key_confirmed(&self.opaque);
|
||||
|
||||
// set new key for encryption
|
||||
*self.ekey.lock() = ekey;
|
||||
}
|
||||
|
||||
// start transmission of staged packets
|
||||
self.send_staged();
|
||||
}
|
||||
|
||||
pub fn recv_job(
|
||||
&self,
|
||||
src: E,
|
||||
dec: Arc<DecryptionState<E, C, T, B>>,
|
||||
msg: Vec<u8>,
|
||||
) -> Option<JobParallel> {
|
||||
let (tx, rx) = oneshot();
|
||||
let key = dec.keypair.recv.key;
|
||||
match self.inbound.lock().try_send((dec, src, rx)) {
|
||||
Ok(_) => Some((
|
||||
tx,
|
||||
JobBuffer {
|
||||
msg,
|
||||
key: key,
|
||||
okay: false,
|
||||
op: Operation::Decryption,
|
||||
},
|
||||
)),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
|
||||
debug!("peer.send_job");
|
||||
debug_assert!(
|
||||
msg.len() >= mem::size_of::<TransportHeader>(),
|
||||
"received message with size: {:}",
|
||||
msg.len()
|
||||
);
|
||||
|
||||
// parse / cast
|
||||
let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
|
||||
let mut header: LayoutVerified<&mut [u8], TransportHeader> = header;
|
||||
|
||||
// check if has key
|
||||
let key = {
|
||||
let mut ekey = self.ekey.lock();
|
||||
let key = match ekey.as_mut() {
|
||||
None => None,
|
||||
Some(mut state) => {
|
||||
// avoid integer overflow in nonce
|
||||
if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
|
||||
*ekey = None;
|
||||
None
|
||||
} else {
|
||||
// there should be no stacked packets lingering around
|
||||
debug!("encryption state available, nonce = {}", state.nonce);
|
||||
|
||||
// set transport message fields
|
||||
header.f_counter.set(state.nonce);
|
||||
header.f_receiver.set(state.id);
|
||||
state.nonce += 1;
|
||||
Some(state.key)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// If not suitable key was found:
|
||||
// 1. Stage packet for later transmission
|
||||
// 2. Request new key
|
||||
if key.is_none() && stage {
|
||||
self.staged_packets.lock().push_back(msg);
|
||||
C::need_key(&self.opaque);
|
||||
return None;
|
||||
};
|
||||
|
||||
key
|
||||
}?;
|
||||
|
||||
// add job to in-order queue and return sendeer to device for inclusion in worker pool
|
||||
let (tx, rx) = oneshot();
|
||||
match self.outbound.lock().try_send(rx) {
|
||||
Ok(_) => Some((
|
||||
tx,
|
||||
JobBuffer {
|
||||
msg,
|
||||
key,
|
||||
okay: false,
|
||||
op: Operation::Encryption,
|
||||
},
|
||||
)),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
|
||||
/// Set the endpoint of the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `endpoint`, socket address converted to bind endpoint
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This API still permits support for the "sticky socket" behavior,
|
||||
/// as sockets should be "unsticked" when manually updating the endpoint
|
||||
pub fn set_endpoint(&self, endpoint: E) {
|
||||
debug!("peer.set_endpoint");
|
||||
*self.state.endpoint.lock() = Some(endpoint);
|
||||
}
|
||||
|
||||
/// Returns the current endpoint of the peer (for configuration)
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// Does not convey potential "sticky socket" information
|
||||
pub fn get_endpoint(&self) -> Option<SocketAddr> {
|
||||
debug!("peer.get_endpoint");
|
||||
self.state
|
||||
.endpoint
|
||||
.lock()
|
||||
.as_ref()
|
||||
.map(|e| e.into_address())
|
||||
}
|
||||
|
||||
/// Zero all key-material related to the peer
|
||||
pub fn zero_keys(&self) {
|
||||
debug!("peer.zero_keys");
|
||||
|
||||
let mut release: Vec<u32> = Vec::with_capacity(3);
|
||||
let mut keys = self.state.keys.lock();
|
||||
|
||||
// update key-wheel
|
||||
|
||||
mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
|
||||
mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
|
||||
mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
|
||||
keys.retired.extend(&release[..]);
|
||||
|
||||
// update inbound "recv" map
|
||||
{
|
||||
let mut recv = self.state.device.recv.write();
|
||||
for id in release {
|
||||
recv.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
// clear encryption state
|
||||
*self.state.ekey.lock() = None;
|
||||
}
|
||||
|
||||
/// Add a new keypair
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - new: The new confirmed/unconfirmed key pair
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of ids which has been released.
|
||||
/// These should be released in the handshake module.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The number of ids to be released can be at most 3,
|
||||
/// since the only way to add additional keys to the peer is by using this method
|
||||
/// and a peer can have at most 3 keys allocated in the router at any time.
|
||||
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
|
||||
debug!("peer.add_keypair");
|
||||
|
||||
let initiator = new.initiator;
|
||||
let release = {
|
||||
let new = Arc::new(new);
|
||||
let mut keys = self.state.keys.lock();
|
||||
let mut release = mem::replace(&mut keys.retired, vec![]);
|
||||
|
||||
// update key-wheel
|
||||
if new.initiator {
|
||||
// start using key for encryption
|
||||
*self.state.ekey.lock() = Some(EncryptionState::new(&new));
|
||||
|
||||
// move current into previous
|
||||
keys.previous = keys.current.as_ref().map(|v| v.clone());
|
||||
keys.current = Some(new.clone());
|
||||
} else {
|
||||
// store the key and await confirmation
|
||||
keys.previous = keys.next.as_ref().map(|v| v.clone());
|
||||
keys.next = Some(new.clone());
|
||||
};
|
||||
|
||||
// update incoming packet id map
|
||||
{
|
||||
debug!("peer.add_keypair: updating inbound id map");
|
||||
let mut recv = self.state.device.recv.write();
|
||||
|
||||
// purge recv map of previous id
|
||||
keys.previous.as_ref().map(|k| {
|
||||
recv.remove(&k.local_id());
|
||||
release.push(k.local_id());
|
||||
});
|
||||
|
||||
// map new id to decryption state
|
||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||
recv.insert(
|
||||
new.recv.id,
|
||||
Arc::new(DecryptionState::new(&self.state, &new)),
|
||||
);
|
||||
}
|
||||
release
|
||||
};
|
||||
|
||||
// schedule confirmation
|
||||
if initiator {
|
||||
debug_assert!(self.state.ekey.lock().is_some());
|
||||
debug!("peer.add_keypair: is initiator, must confirm the key");
|
||||
// attempt to confirm using staged packets
|
||||
if !self.state.send_staged() {
|
||||
// fall back to keepalive packet
|
||||
let ok = self.send_keepalive();
|
||||
debug!(
|
||||
"peer.add_keypair: keepalive for confirmation, sent = {}",
|
||||
ok
|
||||
);
|
||||
}
|
||||
debug!("peer.add_keypair: key attempted confirmed");
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
release.len() <= 3,
|
||||
"since the key-wheel contains at most 3 keys"
|
||||
);
|
||||
release
|
||||
}
|
||||
|
||||
pub fn send_keepalive(&self) -> bool {
|
||||
debug!("peer.send_keepalive");
|
||||
self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
|
||||
}
|
||||
|
||||
/// Map a subnet to the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `ip`, the mask of the subnet
|
||||
/// - `masklen`, the length of the mask
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The `ip` must not have any bits set right of `masklen`.
|
||||
/// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not.
|
||||
///
|
||||
/// If an identical value already exists as part of a prior peer,
|
||||
/// the allowed IP entry will be removed from that peer and added to this peer.
|
||||
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
|
||||
debug!("peer.add_subnet");
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
self.state
|
||||
.device
|
||||
.ipv4
|
||||
.write()
|
||||
.insert(v4, masklen, self.state.clone())
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
self.state
|
||||
.device
|
||||
.ipv6
|
||||
.write()
|
||||
.insert(v6, masklen, self.state.clone())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// List subnets mapped to the peer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of subnets, represented by as mask/size
|
||||
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
|
||||
debug!("peer.list_subnets");
|
||||
let mut res = Vec::new();
|
||||
res.append(&mut treebit_list(
|
||||
&self.state,
|
||||
&self.state.device.ipv4,
|
||||
Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
|
||||
));
|
||||
res.append(&mut treebit_list(
|
||||
&self.state,
|
||||
&self.state.device.ipv6,
|
||||
Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
|
||||
));
|
||||
res
|
||||
}
|
||||
|
||||
/// Clear subnets mapped to the peer.
|
||||
/// After the call, no subnets will be cryptkey routed to the peer.
|
||||
/// Used for the UAPI command "replace_allowed_ips=true"
|
||||
pub fn remove_subnets(&self) {
|
||||
debug!("peer.remove_subnets");
|
||||
treebit_remove(self, &self.state.device.ipv4);
|
||||
treebit_remove(self, &self.state.device.ipv6);
|
||||
}
|
||||
|
||||
/// Send a raw message to the peer (used for handshake messages)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `msg`, message body to send to peer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Unit if packet was sent, or an error indicating why sending failed
|
||||
pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
|
||||
debug!("peer.send");
|
||||
let inner = &self.state;
|
||||
match inner.endpoint.lock().as_ref() {
|
||||
Some(endpoint) => inner
|
||||
.device
|
||||
.outbound
|
||||
.read()
|
||||
.as_ref()
|
||||
.ok_or(RouterError::SendError)
|
||||
.and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
|
||||
None => Err(RouterError::NoEndpoint),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn purge_staged_packets(&self) {
|
||||
self.state.staged_packets.lock().clear();
|
||||
}
|
||||
}
|
||||
432
src/wireguard/router/tests.rs
Normal file
432
src/wireguard/router/tests.rs
Normal file
@@ -0,0 +1,432 @@
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use num_cpus;
|
||||
use pnet::packet::ipv4::MutableIpv4Packet;
|
||||
use pnet::packet::ipv6::MutableIpv6Packet;
|
||||
|
||||
use super::super::types::bind::*;
|
||||
use super::super::types::*;
|
||||
|
||||
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
|
||||
|
||||
extern crate test;
|
||||
|
||||
const SIZE_KEEPALIVE: usize = 32;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use env_logger;
|
||||
use log::debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use test::Bencher;
|
||||
|
||||
// type for tracking events inside the router module
|
||||
struct Flags {
|
||||
send: Mutex<Vec<(usize, bool, bool)>>,
|
||||
recv: Mutex<Vec<(usize, bool, bool)>>,
|
||||
need_key: Mutex<Vec<()>>,
|
||||
key_confirmed: Mutex<Vec<()>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Opaque(Arc<Flags>);
|
||||
|
||||
struct TestCallbacks();
|
||||
|
||||
impl Opaque {
|
||||
fn new() -> Opaque {
|
||||
Opaque(Arc::new(Flags {
|
||||
send: Mutex::new(vec![]),
|
||||
recv: Mutex::new(vec![]),
|
||||
need_key: Mutex::new(vec![]),
|
||||
key_confirmed: Mutex::new(vec![]),
|
||||
}))
|
||||
}
|
||||
|
||||
fn reset(&self) {
|
||||
self.0.send.lock().unwrap().clear();
|
||||
self.0.recv.lock().unwrap().clear();
|
||||
self.0.need_key.lock().unwrap().clear();
|
||||
self.0.key_confirmed.lock().unwrap().clear();
|
||||
}
|
||||
|
||||
fn send(&self) -> Option<(usize, bool, bool)> {
|
||||
self.0.send.lock().unwrap().pop()
|
||||
}
|
||||
|
||||
fn recv(&self) -> Option<(usize, bool, bool)> {
|
||||
self.0.recv.lock().unwrap().pop()
|
||||
}
|
||||
|
||||
fn need_key(&self) -> Option<()> {
|
||||
self.0.need_key.lock().unwrap().pop()
|
||||
}
|
||||
|
||||
fn key_confirmed(&self) -> Option<()> {
|
||||
self.0.key_confirmed.lock().unwrap().pop()
|
||||
}
|
||||
|
||||
// has all events been accounted for by assertions?
|
||||
fn is_empty(&self) -> bool {
|
||||
let send = self.0.send.lock().unwrap();
|
||||
let recv = self.0.recv.lock().unwrap();
|
||||
let need_key = self.0.need_key.lock().unwrap();
|
||||
let key_confirmed = self.0.key_confirmed.lock().unwrap();
|
||||
send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Callbacks for TestCallbacks {
|
||||
type Opaque = Opaque;
|
||||
|
||||
fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
|
||||
t.0.send.lock().unwrap().push((size, data, sent))
|
||||
}
|
||||
|
||||
fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
|
||||
t.0.recv.lock().unwrap().push((size, data, sent))
|
||||
}
|
||||
|
||||
fn need_key(t: &Self::Opaque) {
|
||||
t.0.need_key.lock().unwrap().push(());
|
||||
}
|
||||
|
||||
fn key_confirmed(t: &Self::Opaque) {
|
||||
t.0.key_confirmed.lock().unwrap().push(());
|
||||
}
|
||||
}
|
||||
|
||||
// wait for scheduling
|
||||
fn wait() {
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
|
||||
fn init() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
}
|
||||
|
||||
fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> {
|
||||
// create "IP packet"
|
||||
let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16);
|
||||
msg.resize(SIZE_MESSAGE_PREFIX + size, 0);
|
||||
match ip {
|
||||
IpAddr::V4(ip) => {
|
||||
let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
|
||||
packet.set_destination(ip);
|
||||
packet.set_version(4);
|
||||
}
|
||||
IpAddr::V6(ip) => {
|
||||
let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
|
||||
packet.set_destination(ip);
|
||||
packet.set_version(6);
|
||||
}
|
||||
}
|
||||
msg
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_outbound(b: &mut Bencher) {
|
||||
struct BencherCallbacks {}
|
||||
impl Callbacks for BencherCallbacks {
|
||||
type Opaque = Arc<AtomicUsize>;
|
||||
fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
|
||||
t.fetch_add(size, Ordering::SeqCst);
|
||||
}
|
||||
fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
|
||||
fn need_key(_: &Self::Opaque) {}
|
||||
fn key_confirmed(_: &Self::Opaque) {}
|
||||
}
|
||||
|
||||
// create device
|
||||
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
|
||||
let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
|
||||
Device::new(num_cpus::get(), tun_writer);
|
||||
|
||||
// add new peer
|
||||
let opaque = Arc::new(AtomicUsize::new(0));
|
||||
let peer = router.new_peer(opaque.clone());
|
||||
peer.add_keypair(dummy::keypair(true));
|
||||
|
||||
// add subnet to peer
|
||||
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
let ip1: IpAddr = ip.parse().unwrap();
|
||||
peer.add_subnet(mask, len);
|
||||
|
||||
// every iteration sends 10 GB
|
||||
b.iter(|| {
|
||||
opaque.store(0, Ordering::SeqCst);
|
||||
let msg = make_packet(1024, ip1);
|
||||
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
|
||||
router.send(msg.to_vec()).unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outbound() {
|
||||
init();
|
||||
|
||||
// create device
|
||||
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
|
||||
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
|
||||
router.set_outbound_writer(dummy::VoidBind::new());
|
||||
|
||||
let tests = vec![
|
||||
("192.168.1.0", 24, "192.168.1.20", true),
|
||||
("172.133.133.133", 32, "172.133.133.133", true),
|
||||
("172.133.133.133", 32, "172.133.133.132", false),
|
||||
(
|
||||
"2001:db8::ff00:42:0000",
|
||||
112,
|
||||
"2001:db8::ff00:42:3242",
|
||||
true,
|
||||
),
|
||||
(
|
||||
"2001:db8::ff00:42:8000",
|
||||
113,
|
||||
"2001:db8::ff00:42:0660",
|
||||
false,
|
||||
),
|
||||
(
|
||||
"2001:db8::ff00:42:8000",
|
||||
113,
|
||||
"2001:db8::ff00:42:ffff",
|
||||
true,
|
||||
),
|
||||
];
|
||||
|
||||
for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
|
||||
for set_key in vec![true, false] {
|
||||
debug!("index = {}, set_key = {}", num, set_key);
|
||||
|
||||
// add new peer
|
||||
let opaque = Opaque::new();
|
||||
let peer = router.new_peer(opaque.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
if set_key {
|
||||
peer.add_keypair(dummy::keypair(true));
|
||||
}
|
||||
|
||||
// map subnet to peer
|
||||
peer.add_subnet(mask, *len);
|
||||
|
||||
// create "IP packet"
|
||||
let msg = make_packet(1024, ip.parse().unwrap());
|
||||
|
||||
// cryptkey route the IP packet
|
||||
let res = router.send(msg);
|
||||
|
||||
// allow some scheduling
|
||||
wait();
|
||||
|
||||
if *okay {
|
||||
// cryptkey routing succeeded
|
||||
assert!(res.is_ok(), "crypt-key routing should succeed");
|
||||
assert_eq!(
|
||||
opaque.need_key().is_some(),
|
||||
!set_key,
|
||||
"should have requested a new key, if no encryption state was set"
|
||||
);
|
||||
assert_eq!(
|
||||
opaque.send().is_some(),
|
||||
set_key,
|
||||
"transmission should have been attempted"
|
||||
);
|
||||
assert!(
|
||||
opaque.recv().is_none(),
|
||||
"no messages should have been marked as received"
|
||||
);
|
||||
} else {
|
||||
// no such cryptkey route
|
||||
assert!(res.is_err(), "crypt-key routing should fail");
|
||||
assert!(
|
||||
opaque.need_key().is_none(),
|
||||
"should not request a new-key if crypt-key routing failed"
|
||||
);
|
||||
assert_eq!(
|
||||
opaque.send(),
|
||||
if set_key {
|
||||
Some((SIZE_KEEPALIVE, false, false))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
"transmission should only happen if key was set (keepalive)",
|
||||
);
|
||||
assert!(
|
||||
opaque.recv().is_none(),
|
||||
"no messages should have been marked as received",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bidirectional() {
|
||||
init();
|
||||
|
||||
let tests = [
|
||||
(
|
||||
false, // confirm with keepalive
|
||||
("192.168.1.0", 24, "192.168.1.20", true),
|
||||
("172.133.133.133", 32, "172.133.133.133", true),
|
||||
),
|
||||
(
|
||||
true, // confirm with staged packet
|
||||
("192.168.1.0", 24, "192.168.1.20", true),
|
||||
("172.133.133.133", 32, "172.133.133.133", true),
|
||||
),
|
||||
(
|
||||
false, // confirm with keepalive
|
||||
(
|
||||
"2001:db8::ff00:42:8000",
|
||||
113,
|
||||
"2001:db8::ff00:42:ffff",
|
||||
true,
|
||||
),
|
||||
(
|
||||
"2001:db8::ff40:42:8000",
|
||||
113,
|
||||
"2001:db8::ff40:42:ffff",
|
||||
true,
|
||||
),
|
||||
),
|
||||
(
|
||||
false, // confirm with staged packet
|
||||
(
|
||||
"2001:db8::ff00:42:8000",
|
||||
113,
|
||||
"2001:db8::ff00:42:ffff",
|
||||
true,
|
||||
),
|
||||
(
|
||||
"2001:db8::ff40:42:8000",
|
||||
113,
|
||||
"2001:db8::ff40:42:ffff",
|
||||
true,
|
||||
),
|
||||
),
|
||||
];
|
||||
|
||||
for (stage, p1, p2) in tests.iter() {
|
||||
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
|
||||
dummy::PairBind::pair();
|
||||
|
||||
// create matching device
|
||||
let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
|
||||
let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
|
||||
|
||||
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
|
||||
router1.set_outbound_writer(bind_writer1);
|
||||
|
||||
let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
|
||||
router2.set_outbound_writer(bind_writer2);
|
||||
|
||||
// prepare opaque values for tracing callbacks
|
||||
|
||||
let opaq1 = Opaque::new();
|
||||
let opaq2 = Opaque::new();
|
||||
|
||||
// create peers with matching keypairs and assign subnets
|
||||
|
||||
let (mask, len, _ip, _okay) = p1;
|
||||
let peer1 = router1.new_peer(opaq1.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
peer1.add_subnet(mask, *len);
|
||||
peer1.add_keypair(dummy::keypair(false));
|
||||
|
||||
let (mask, len, _ip, _okay) = p2;
|
||||
let peer2 = router2.new_peer(opaq2.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
peer2.add_subnet(mask, *len);
|
||||
peer2.set_endpoint(dummy::UnitEndpoint::new());
|
||||
|
||||
if *stage {
|
||||
// stage a packet which can be used for confirmation (in place of a keepalive)
|
||||
let (_mask, _len, ip, _okay) = p2;
|
||||
let msg = make_packet(1024, ip.parse().unwrap());
|
||||
router2.send(msg).expect("failed to sent staged packet");
|
||||
|
||||
wait();
|
||||
assert!(opaq2.recv().is_none());
|
||||
assert!(
|
||||
opaq2.send().is_none(),
|
||||
"sending should fail as not key is set"
|
||||
);
|
||||
assert!(
|
||||
opaq2.need_key().is_some(),
|
||||
"a new key should be requested since a packet was attempted transmitted"
|
||||
);
|
||||
assert!(opaq2.is_empty(), "callbacks should only run once");
|
||||
}
|
||||
|
||||
// this should cause a key-confirmation packet (keepalive or staged packet)
|
||||
// this also causes peer1 to learn the "endpoint" for peer2
|
||||
assert!(peer1.get_endpoint().is_none());
|
||||
peer2.add_keypair(dummy::keypair(true));
|
||||
|
||||
wait();
|
||||
assert!(opaq2.send().is_some());
|
||||
assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
|
||||
assert!(opaq1.is_empty(), "nothing should happened on peer1");
|
||||
|
||||
// read confirming message received by the other end ("across the internet")
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let (len, from) = bind_reader1.read(&mut buf).unwrap();
|
||||
buf.truncate(len);
|
||||
router1.recv(from, buf).unwrap();
|
||||
|
||||
wait();
|
||||
assert!(opaq1.recv().is_some());
|
||||
assert!(opaq1.key_confirmed().is_some());
|
||||
assert!(
|
||||
opaq1.is_empty(),
|
||||
"events on peer1 should be 'recv' and 'key_confirmed'"
|
||||
);
|
||||
assert!(peer1.get_endpoint().is_some());
|
||||
assert!(opaq2.is_empty(), "nothing should happened on peer2");
|
||||
|
||||
// now that peer1 has an endpoint
|
||||
// route packets : peer1 -> peer2
|
||||
|
||||
for _ in 0..10 {
|
||||
assert!(
|
||||
opaq1.is_empty(),
|
||||
"we should have asserted a value for every callback on peer1"
|
||||
);
|
||||
assert!(
|
||||
opaq2.is_empty(),
|
||||
"we should have asserted a value for every callback on peer2"
|
||||
);
|
||||
|
||||
// pass IP packet to router
|
||||
let (_mask, _len, ip, _okay) = p1;
|
||||
let msg = make_packet(1024, ip.parse().unwrap());
|
||||
router1.send(msg).unwrap();
|
||||
|
||||
wait();
|
||||
assert!(opaq1.send().is_some());
|
||||
assert!(opaq1.recv().is_none());
|
||||
assert!(opaq1.need_key().is_none());
|
||||
|
||||
// receive ("across the internet") on the other end
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let (len, from) = bind_reader2.read(&mut buf).unwrap();
|
||||
buf.truncate(len);
|
||||
router2.recv(from, buf).unwrap();
|
||||
|
||||
wait();
|
||||
assert!(opaq2.send().is_none());
|
||||
assert!(opaq2.recv().is_some());
|
||||
assert!(opaq2.need_key().is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
65
src/wireguard/router/types.rs
Normal file
65
src/wireguard/router/types.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
pub trait Opaque: Send + Sync + 'static {}
|
||||
|
||||
impl<T> Opaque for T where T: Send + Sync + 'static {}
|
||||
|
||||
/// A send/recv callback takes 3 arguments:
|
||||
///
|
||||
/// * `0`, a reference to the opaque value assigned to the peer
|
||||
/// * `1`, a bool indicating whether the message contained data (not just keepalive)
|
||||
/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
|
||||
pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
|
||||
|
||||
impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
|
||||
|
||||
/// A key callback takes 1 argument
|
||||
///
|
||||
/// * `0`, a reference to the opaque value assigned to the peer
|
||||
pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
|
||||
|
||||
impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
|
||||
|
||||
pub trait Callbacks: Send + Sync + 'static {
|
||||
type Opaque: Opaque;
|
||||
fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
|
||||
fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
|
||||
fn need_key(opaque: &Self::Opaque);
|
||||
fn key_confirmed(opaque: &Self::Opaque);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RouterError {
|
||||
NoCryptKeyRoute,
|
||||
MalformedIPHeader,
|
||||
MalformedTransportMessage,
|
||||
UnknownReceiverId,
|
||||
NoEndpoint,
|
||||
SendError,
|
||||
}
|
||||
|
||||
impl fmt::Display for RouterError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
|
||||
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
|
||||
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
|
||||
RouterError::UnknownReceiverId => {
|
||||
write!(f, "No decryption state associated with receiver id")
|
||||
}
|
||||
RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
|
||||
RouterError::SendError => write!(f, "Failed to send packet on bind"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for RouterError {
|
||||
fn description(&self) -> &str {
|
||||
"Generic Handshake Error"
|
||||
}
|
||||
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
305
src/wireguard/router/workers.rs
Normal file
305
src/wireguard/router/workers.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
use std::mem;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::sync::oneshot;
|
||||
use futures::*;
|
||||
|
||||
use log::debug;
|
||||
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::atomic::Ordering;
|
||||
use zerocopy::{AsBytes, LayoutVerified};
|
||||
|
||||
use super::device::{DecryptionState, DeviceInner};
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::PeerInner;
|
||||
use super::types::Callbacks;
|
||||
|
||||
use super::super::types::{Endpoint, tun, bind};
|
||||
use super::ip::*;
|
||||
|
||||
const SIZE_TAG: usize = 16;
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub enum Operation {
|
||||
Encryption,
|
||||
Decryption,
|
||||
}
|
||||
|
||||
pub struct JobBuffer {
|
||||
pub msg: Vec<u8>, // message buffer (nonce and receiver id set)
|
||||
pub key: [u8; 32], // chacha20poly1305 key
|
||||
pub okay: bool, // state of the job
|
||||
pub op: Operation, // should be buffer be encrypted / decrypted?
|
||||
}
|
||||
|
||||
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
|
||||
|
||||
#[allow(type_alias_bounds)]
|
||||
pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
|
||||
Arc<DecryptionState<E, C, T, B>>,
|
||||
E,
|
||||
oneshot::Receiver<JobBuffer>,
|
||||
);
|
||||
|
||||
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
|
||||
|
||||
#[inline(always)]
|
||||
fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: &Arc<DeviceInner<E, C, T, B>>,
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<usize> {
|
||||
match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
|
||||
LayoutVerified::new_from_prefix(packet)?;
|
||||
|
||||
// check IPv4 source address
|
||||
device
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, &peer) {
|
||||
Some(header.f_total_len.get() as usize)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
// check length and cast to IPv6 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
|
||||
LayoutVerified::new_from_prefix(packet)?;
|
||||
|
||||
// check IPv6 source address
|
||||
device
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, &peer) {
|
||||
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<E, C, T, B>>, // related peer
|
||||
receiver: Receiver<JobInbound<E, C, T, B>>,
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
let (state, endpoint, rx) = match receiver.recv() {
|
||||
Ok(v) => v,
|
||||
_ => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("inbound worker: obtained job");
|
||||
|
||||
// wait for job to complete
|
||||
let _ = rx
|
||||
.map(|buf| {
|
||||
debug!("inbound worker: job complete");
|
||||
if buf.okay {
|
||||
// cast transport header
|
||||
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
|
||||
match LayoutVerified::new_from_prefix(&buf.msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug!("inbound worker: failed to parse message");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert!(
|
||||
packet.len() >= CHACHA20_POLY1305.tag_len(),
|
||||
"this should be checked earlier in the pipeline (decryption should fail)"
|
||||
);
|
||||
|
||||
// check for replay
|
||||
if !state.protector.lock().update(header.f_counter.get()) {
|
||||
debug!("inbound worker: replay detected");
|
||||
return;
|
||||
}
|
||||
|
||||
// check for confirms key
|
||||
if !state.confirmed.swap(true, Ordering::SeqCst) {
|
||||
debug!("inbound worker: message confirms key");
|
||||
peer.confirm_key(&state.keypair);
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
*peer.endpoint.lock() = Some(endpoint);
|
||||
|
||||
// calculate length of IP packet + padding
|
||||
let length = packet.len() - SIZE_TAG;
|
||||
debug!("inbound worker: plaintext length = {}", length);
|
||||
|
||||
// check if should be written to TUN
|
||||
let mut sent = false;
|
||||
if length > 0 {
|
||||
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
|
||||
debug_assert!(inner_len <= length, "should be validated");
|
||||
if inner_len <= length {
|
||||
sent = match device.inbound.write(&packet[..inner_len]) {
|
||||
Err(e) => {
|
||||
debug!("failed to write inbound packet to TUN: {:?}", e);
|
||||
false
|
||||
}
|
||||
Ok(_) => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("inbound worker: received keepalive")
|
||||
}
|
||||
|
||||
// trigger callback
|
||||
C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
|
||||
} else {
|
||||
debug!("inbound worker: authentication failure")
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<E, C, T, B>>, // related peer
|
||||
receiver: Receiver<JobOutbound>,
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
let rx = match receiver.recv() {
|
||||
Ok(v) => v,
|
||||
_ => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("outbound worker: obtained job");
|
||||
|
||||
// wait for job to complete
|
||||
let _ = rx
|
||||
.map(|buf| {
|
||||
debug!("outbound worker: job complete");
|
||||
if buf.okay {
|
||||
// write to UDP bind
|
||||
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
|
||||
let send : &Option<B> = &*device.outbound.read();
|
||||
if let Some(writer) = send.as_ref() {
|
||||
match writer.write(&buf.msg[..], dst) {
|
||||
Err(e) => {
|
||||
debug!("failed to send outbound packet: {:?}", e);
|
||||
false
|
||||
}
|
||||
Ok(_) => true,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// trigger callback
|
||||
C::send(
|
||||
&peer.opaque,
|
||||
buf.msg.len(),
|
||||
buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
|
||||
xmit,
|
||||
);
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_parallel(receiver: Receiver<JobParallel>) {
|
||||
loop {
|
||||
// fetch next job
|
||||
let (tx, mut buf) = match receiver.recv() {
|
||||
Err(_) => {
|
||||
return;
|
||||
}
|
||||
Ok(val) => val,
|
||||
};
|
||||
debug!("parallel worker: obtained job");
|
||||
|
||||
// make space for tag (TODO: consider moving this out)
|
||||
if buf.op == Operation::Encryption {
|
||||
buf.msg.extend([0u8; SIZE_TAG].iter());
|
||||
}
|
||||
|
||||
// cast and check size of packet
|
||||
let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
|
||||
match LayoutVerified::new_from_prefix(&mut buf.msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug_assert!(
|
||||
false,
|
||||
"parallel worker: failed to parse message (insufficient size)"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
|
||||
|
||||
// do the weird ring AEAD dance
|
||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
|
||||
|
||||
// create a nonce object
|
||||
let mut nonce = [0u8; 12];
|
||||
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
|
||||
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
|
||||
let nonce = Nonce::assume_unique_for_key(nonce);
|
||||
|
||||
match buf.op {
|
||||
Operation::Encryption => {
|
||||
debug!("parallel worker: process encryption");
|
||||
|
||||
// set the type field
|
||||
header.f_type.set(TYPE_TRANSPORT);
|
||||
|
||||
// encrypt content of transport message in-place
|
||||
let end = packet.len() - SIZE_TAG;
|
||||
let tag = key
|
||||
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
|
||||
.unwrap();
|
||||
|
||||
// append tag
|
||||
packet[end..].copy_from_slice(tag.as_ref());
|
||||
|
||||
buf.okay = true;
|
||||
}
|
||||
Operation::Decryption => {
|
||||
debug!("parallel worker: process decryption");
|
||||
|
||||
// opening failure is signaled by fault state
|
||||
buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// pass ownership to consumer
|
||||
let okay = tx.send(buf);
|
||||
debug!(
|
||||
"parallel worker: passing ownership to sequential worker: {}",
|
||||
okay.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
46
src/wireguard/tests.rs
Normal file
46
src/wireguard/tests.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use super::types::tun::Tun;
|
||||
use super::types::{bind, dummy, tun};
|
||||
use super::wireguard::Wireguard;
|
||||
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
fn init() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
}
|
||||
|
||||
/* Create and configure two matching pure instances of WireGuard
|
||||
*
|
||||
*/
|
||||
#[test]
|
||||
fn test_pure_wireguard() {
|
||||
init();
|
||||
|
||||
// create WG instances for fake TUN devices
|
||||
|
||||
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
|
||||
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
|
||||
Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
|
||||
|
||||
let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
|
||||
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
|
||||
Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
|
||||
|
||||
// create pair bind to connect the interfaces "over the internet"
|
||||
|
||||
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
|
||||
|
||||
wg1.set_writer(bind_writer1);
|
||||
wg2.set_writer(bind_writer2);
|
||||
|
||||
wg1.add_reader(bind_reader1);
|
||||
wg2.add_reader(bind_reader2);
|
||||
|
||||
// generate (public, pivate) key pairs
|
||||
|
||||
// configure cryptkey router
|
||||
|
||||
// create IP packets
|
||||
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
234
src/wireguard/timers.rs
Normal file
234
src/wireguard/timers.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use log::info;
|
||||
|
||||
use hjul::{Runner, Timer};
|
||||
|
||||
use super::constants::*;
|
||||
use super::router::Callbacks;
|
||||
use super::types::{bind, tun};
|
||||
use super::wireguard::{Peer, PeerInner};
|
||||
|
||||
pub struct Timers {
|
||||
handshake_pending: AtomicBool,
|
||||
handshake_attempts: AtomicUsize,
|
||||
|
||||
retransmit_handshake: Timer,
|
||||
send_keepalive: Timer,
|
||||
send_persistent_keepalive: Timer,
|
||||
sent_lastminute_handshake: AtomicBool,
|
||||
zero_key_material: Timer,
|
||||
new_handshake: Timer,
|
||||
need_another_keepalive: AtomicBool,
|
||||
}
|
||||
|
||||
impl Timers {
|
||||
#[inline(always)]
|
||||
fn need_another_keepalive(&self) -> bool {
|
||||
self.need_another_keepalive.swap(false, Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl <T: tun::Tun, B: bind::Bind>Peer<T, B> {
|
||||
/* should be called after an authenticated data packet is sent */
|
||||
pub fn timers_data_sent(&self) {
|
||||
self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
|
||||
}
|
||||
|
||||
/* should be called after an authenticated data packet is received */
|
||||
pub fn timers_data_received(&self) {
|
||||
if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
|
||||
self.timers().need_another_keepalive.store(true, Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is sent, whether:
|
||||
* - keepalive
|
||||
* - data
|
||||
* - handshake
|
||||
*/
|
||||
pub fn timers_any_authenticated_packet_sent(&self) {
|
||||
self.timers().send_keepalive.stop()
|
||||
}
|
||||
|
||||
/* Should be called after any type of authenticated packet is received, whether:
|
||||
* - keepalive
|
||||
* - data
|
||||
* - handshake
|
||||
*/
|
||||
pub fn timers_any_authenticated_packet_received(&self) {
|
||||
self.timers().new_handshake.stop();
|
||||
}
|
||||
|
||||
/* Should be called after a handshake initiation message is sent. */
|
||||
pub fn timers_handshake_initiated(&self) {
|
||||
self.timers().send_keepalive.stop();
|
||||
self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
|
||||
}
|
||||
|
||||
/* Should be called after a handshake response message is received and processed
|
||||
* or when getting key confirmation via the first data message.
|
||||
*/
|
||||
pub fn timers_handshake_complete(&self) {
|
||||
self.timers().handshake_attempts.store(0, Ordering::SeqCst);
|
||||
self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst);
|
||||
// TODO: Store time in peer for config
|
||||
// self.walltime_last_handshake
|
||||
}
|
||||
|
||||
/* Should be called after an ephemeral key is created, which is before sending a
|
||||
* handshake response or after receiving a handshake response.
|
||||
*/
|
||||
pub fn timers_session_derived(&self) {
|
||||
self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3);
|
||||
}
|
||||
|
||||
/* Should be called before a packet with authentication, whether
|
||||
* keepalive, data, or handshake is sent, or after one is received.
|
||||
*/
|
||||
pub fn timers_any_authenticated_packet_traversal(&self) {
|
||||
let keepalive = self.state.keepalive.load(Ordering::Acquire);
|
||||
if keepalive > 0 {
|
||||
self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Timers {
|
||||
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
|
||||
where
|
||||
T: tun::Tun,
|
||||
B: bind::Bind,
|
||||
{
|
||||
// create a timer instance for the provided peer
|
||||
Timers {
|
||||
handshake_pending: AtomicBool::new(false),
|
||||
need_another_keepalive: AtomicBool::new(false),
|
||||
sent_lastminute_handshake: AtomicBool::new(false),
|
||||
handshake_attempts: AtomicUsize::new(0),
|
||||
retransmit_handshake: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
if peer.timers().handshake_retry() {
|
||||
info!("Retransmit handshake for {}", peer);
|
||||
peer.new_handshake();
|
||||
} else {
|
||||
info!("Failed to complete handshake for {}", peer);
|
||||
peer.router.purge_staged_packets();
|
||||
peer.timers().send_keepalive.stop();
|
||||
peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
|
||||
}
|
||||
})
|
||||
},
|
||||
send_keepalive: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
peer.router.send_keepalive();
|
||||
if peer.timers().need_another_keepalive() {
|
||||
peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT);
|
||||
}
|
||||
})
|
||||
},
|
||||
new_handshake: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
info!(
|
||||
"Retrying handshake with {}, because we stopped hearing back after {} seconds",
|
||||
peer,
|
||||
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
|
||||
);
|
||||
peer.new_handshake();
|
||||
peer.timers.read().handshake_begun();
|
||||
})
|
||||
},
|
||||
zero_key_material: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
peer.router.zero_keys();
|
||||
})
|
||||
},
|
||||
send_persistent_keepalive: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
let keepalive = peer.state.keepalive.load(Ordering::Acquire);
|
||||
if keepalive > 0 {
|
||||
peer.router.send_keepalive();
|
||||
peer.timers().send_keepalive.stop();
|
||||
peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_begun(&self) {
|
||||
self.handshake_pending.store(true, Ordering::SeqCst);
|
||||
self.handshake_attempts.store(0, Ordering::SeqCst);
|
||||
self.retransmit_handshake.reset(REKEY_TIMEOUT);
|
||||
}
|
||||
|
||||
fn handshake_retry(&self) -> bool {
|
||||
if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES {
|
||||
self.retransmit_handshake.reset(REKEY_TIMEOUT);
|
||||
true
|
||||
} else {
|
||||
self.handshake_pending.store(false, Ordering::SeqCst);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn updated_persistent_keepalive(&self, keepalive: usize) {
|
||||
if keepalive > 0 {
|
||||
self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dummy(runner: &Runner) -> Timers {
|
||||
Timers {
|
||||
handshake_pending: AtomicBool::new(false),
|
||||
need_another_keepalive: AtomicBool::new(false),
|
||||
sent_lastminute_handshake: AtomicBool::new(false),
|
||||
handshake_attempts: AtomicUsize::new(0),
|
||||
retransmit_handshake: runner.timer(|| {}),
|
||||
new_handshake: runner.timer(|| {}),
|
||||
send_keepalive: runner.timer(|| {}),
|
||||
send_persistent_keepalive: runner.timer(|| {}),
|
||||
zero_key_material: runner.timer(|| {})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handshake_sent(&self) {
|
||||
self.send_keepalive.stop();
|
||||
}
|
||||
}
|
||||
|
||||
/* Instance of the router callbacks */
|
||||
|
||||
pub struct Events<T, B>(PhantomData<(T, B)>);
|
||||
|
||||
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
|
||||
type Opaque = Arc<PeerInner<B>>;
|
||||
|
||||
fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
|
||||
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
|
||||
peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn need_key(peer: &Self::Opaque) {
|
||||
let timers = peer.timers();
|
||||
if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
|
||||
timers.handshake_attempts.store(0, Ordering::SeqCst);
|
||||
timers.new_handshake.fire();
|
||||
}
|
||||
}
|
||||
|
||||
fn key_confirmed(peer: &Self::Opaque) {
|
||||
peer.timers().retransmit_handshake.stop();
|
||||
}
|
||||
}
|
||||
23
src/wireguard/types/bind.rs
Normal file
23
src/wireguard/types/bind.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use super::Endpoint;
|
||||
use std::error::Error;
|
||||
|
||||
pub trait Reader<E: Endpoint>: Send + Sync {
|
||||
type Error: Error;
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
|
||||
type Error: Error;
|
||||
|
||||
fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Bind: Send + Sync + 'static {
|
||||
type Error: Error;
|
||||
type Endpoint: Endpoint;
|
||||
|
||||
/* Until Rust gets type equality constraints these have to be generic */
|
||||
type Writer: Writer<Self::Endpoint>;
|
||||
type Reader: Reader<Self::Endpoint>;
|
||||
}
|
||||
323
src/wireguard/types/dummy.rs
Normal file
323
src/wireguard/types/dummy.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::marker;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Instant;
|
||||
use std::sync::atomic::{Ordering, AtomicUsize};
|
||||
|
||||
use super::*;
|
||||
|
||||
/* This submodule provides pure/dummy implementations of the IO interfaces
|
||||
* for use in unit tests thoughout the project.
|
||||
*/
|
||||
|
||||
/* Error implementation */
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BindError {
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl Error for BindError {
|
||||
fn description(&self) -> &str {
|
||||
"Generic Bind Error"
|
||||
}
|
||||
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BindError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
BindError::Disconnected => write!(f, "PairBind disconnected"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* TUN implementation */
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TunError {
|
||||
Disconnected
|
||||
}
|
||||
|
||||
impl Error for TunError {
|
||||
fn description(&self) -> &str {
|
||||
"Generic Tun Error"
|
||||
}
|
||||
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TunError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Not Possible")
|
||||
}
|
||||
}
|
||||
|
||||
/* Endpoint implementation */
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct UnitEndpoint {}
|
||||
|
||||
impl Endpoint for UnitEndpoint {
|
||||
fn from_address(_: SocketAddr) -> UnitEndpoint {
|
||||
UnitEndpoint {}
|
||||
}
|
||||
|
||||
fn into_address(&self) -> SocketAddr {
|
||||
"127.0.0.1:8080".parse().unwrap()
|
||||
}
|
||||
|
||||
fn clear_src(&self) {}
|
||||
}
|
||||
|
||||
impl UnitEndpoint {
|
||||
pub fn new() -> UnitEndpoint {
|
||||
UnitEndpoint {}
|
||||
}
|
||||
}
|
||||
|
||||
/* */
|
||||
|
||||
pub struct TunTest {}
|
||||
|
||||
pub struct TunFakeIO {
|
||||
store: bool,
|
||||
tx: SyncSender<Vec<u8>>,
|
||||
rx: Receiver<Vec<u8>>
|
||||
}
|
||||
|
||||
pub struct TunReader {
|
||||
rx: Receiver<Vec<u8>>
|
||||
}
|
||||
|
||||
pub struct TunWriter {
|
||||
store: bool,
|
||||
tx: Mutex<SyncSender<Vec<u8>>>
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TunMTU {
|
||||
mtu: Arc<AtomicUsize>
|
||||
}
|
||||
|
||||
impl tun::Reader for TunReader {
|
||||
type Error = TunError;
|
||||
|
||||
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
|
||||
match self.rx.recv() {
|
||||
Ok(m) => {
|
||||
buf[offset..].copy_from_slice(&m[..]);
|
||||
Ok(m.len())
|
||||
}
|
||||
Err(_) => Err(TunError::Disconnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Writer for TunWriter {
|
||||
type Error = TunError;
|
||||
|
||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
|
||||
if self.store {
|
||||
let m = src.to_owned();
|
||||
match self.tx.lock().unwrap().send(m) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(TunError::Disconnected)
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::MTU for TunMTU {
|
||||
fn mtu(&self) -> usize {
|
||||
self.mtu.load(Ordering::Acquire)
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Tun for TunTest {
|
||||
type Writer = TunWriter;
|
||||
type Reader = TunReader;
|
||||
type MTU = TunMTU;
|
||||
type Error = TunError;
|
||||
}
|
||||
|
||||
impl TunFakeIO {
|
||||
pub fn write(&self, msg : Vec<u8>) {
|
||||
if self.store {
|
||||
self.tx.send(msg).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read(&self) -> Vec<u8> {
|
||||
self.rx.recv().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunTest {
|
||||
pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
|
||||
|
||||
let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
|
||||
let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
|
||||
|
||||
let fake = TunFakeIO{tx: tx1, rx: rx2, store};
|
||||
let reader = TunReader{rx : rx1};
|
||||
let writer = TunWriter{tx : Mutex::new(tx2), store};
|
||||
let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
|
||||
|
||||
(fake, reader, writer, mtu)
|
||||
}
|
||||
}
|
||||
|
||||
/* Void Bind */
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct VoidBind {}
|
||||
|
||||
impl bind::Reader<UnitEndpoint> for VoidBind {
|
||||
type Error = BindError;
|
||||
|
||||
fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
|
||||
Ok((0, UnitEndpoint {}))
|
||||
}
|
||||
}
|
||||
|
||||
impl bind::Writer<UnitEndpoint> for VoidBind {
|
||||
type Error = BindError;
|
||||
|
||||
fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl bind::Bind for VoidBind {
|
||||
type Error = BindError;
|
||||
type Endpoint = UnitEndpoint;
|
||||
|
||||
type Reader = VoidBind;
|
||||
type Writer = VoidBind;
|
||||
}
|
||||
|
||||
impl VoidBind {
|
||||
pub fn new() -> VoidBind {
|
||||
VoidBind {}
|
||||
}
|
||||
}
|
||||
|
||||
/* Pair Bind */
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairReader<E> {
|
||||
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
|
||||
type Error = BindError;
|
||||
fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
|
||||
let vec = self
|
||||
.recv
|
||||
.lock()
|
||||
.unwrap()
|
||||
.recv()
|
||||
.map_err(|_| BindError::Disconnected)?;
|
||||
let len = vec.len();
|
||||
buf[..len].copy_from_slice(&vec[..]);
|
||||
Ok((vec.len(), UnitEndpoint {}))
|
||||
}
|
||||
}
|
||||
|
||||
impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
|
||||
type Error = BindError;
|
||||
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
|
||||
let owned = buf.to_owned();
|
||||
match self.send.lock().unwrap().send(owned) {
|
||||
Err(_) => Err(BindError::Disconnected),
|
||||
Ok(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairWriter<E> {
|
||||
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairBind {}
|
||||
|
||||
impl PairBind {
|
||||
pub fn pair<E>() -> (
|
||||
(PairReader<E>, PairWriter<E>),
|
||||
(PairReader<E>, PairWriter<E>),
|
||||
) {
|
||||
let (tx1, rx1) = sync_channel(128);
|
||||
let (tx2, rx2) = sync_channel(128);
|
||||
(
|
||||
(
|
||||
PairReader {
|
||||
recv: Arc::new(Mutex::new(rx1)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
PairWriter {
|
||||
send: Arc::new(Mutex::new(tx2)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
),
|
||||
(
|
||||
PairReader {
|
||||
recv: Arc::new(Mutex::new(rx2)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
PairWriter {
|
||||
send: Arc::new(Mutex::new(tx1)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl bind::Bind for PairBind {
|
||||
type Error = BindError;
|
||||
type Endpoint = UnitEndpoint;
|
||||
type Reader = PairReader<Self::Endpoint>;
|
||||
type Writer = PairWriter<Self::Endpoint>;
|
||||
}
|
||||
|
||||
pub fn keypair(initiator: bool) -> KeyPair {
|
||||
let k1 = Key {
|
||||
key: [0x53u8; 32],
|
||||
id: 0x646e6573,
|
||||
};
|
||||
let k2 = Key {
|
||||
key: [0x52u8; 32],
|
||||
id: 0x76636572,
|
||||
};
|
||||
if initiator {
|
||||
KeyPair {
|
||||
birth: Instant::now(),
|
||||
initiator: true,
|
||||
send: k1,
|
||||
recv: k2,
|
||||
}
|
||||
} else {
|
||||
KeyPair {
|
||||
birth: Instant::now(),
|
||||
initiator: false,
|
||||
send: k2,
|
||||
recv: k1,
|
||||
}
|
||||
}
|
||||
}
|
||||
7
src/wireguard/types/endpoint.rs
Normal file
7
src/wireguard/types/endpoint.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub trait Endpoint: Send + 'static {
|
||||
fn from_address(addr: SocketAddr) -> Self;
|
||||
fn into_address(&self) -> SocketAddr;
|
||||
fn clear_src(&self);
|
||||
}
|
||||
36
src/wireguard/types/keys.rs
Normal file
36
src/wireguard/types/keys.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use clear_on_drop::clear::Clear;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Key {
|
||||
pub key: [u8; 32],
|
||||
pub id: u32,
|
||||
}
|
||||
|
||||
// zero key on drop
|
||||
impl Drop for Key {
|
||||
fn drop(&mut self) {
|
||||
self.key.clear()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl PartialEq for Key {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id && self.key[..] == other.key[..]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyPair {
|
||||
pub birth: Instant, // when was the key-pair created
|
||||
pub initiator: bool, // has the key-pair been confirmed?
|
||||
pub send: Key, // key for outbound messages
|
||||
pub recv: Key, // key for inbound messages
|
||||
}
|
||||
|
||||
impl KeyPair {
|
||||
pub fn local_id(&self) -> u32 {
|
||||
self.recv.id
|
||||
}
|
||||
}
|
||||
10
src/wireguard/types/mod.rs
Normal file
10
src/wireguard/types/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
mod endpoint;
|
||||
mod keys;
|
||||
pub mod tun;
|
||||
pub mod bind;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod dummy;
|
||||
|
||||
pub use endpoint::Endpoint;
|
||||
pub use keys::{Key, KeyPair};
|
||||
56
src/wireguard/types/tun.rs
Normal file
56
src/wireguard/types/tun.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::error::Error;
|
||||
|
||||
pub trait Writer: Send + Sync + 'static {
|
||||
type Error: Error;
|
||||
|
||||
/// Receive a cryptkey routed IP packet
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Buffer containing the IP packet to be written
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Unit type or an error
|
||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Reader: Send + 'static {
|
||||
type Error: Error;
|
||||
|
||||
/// Reads an IP packet into dst[offset:] from the tunnel device
|
||||
///
|
||||
/// The reason for providing space for a prefix
|
||||
/// is to efficiently accommodate platforms on which the packet is prefaced by a header.
|
||||
/// This space is later used to construct the transport message inplace.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - buf: Destination buffer (enough space for MTU bytes + header)
|
||||
/// - offset: Offset for the beginning of the IP packet
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
|
||||
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait MTU: Send + Sync + Clone + 'static {
|
||||
/// Returns the MTU of the device
|
||||
///
|
||||
/// This function needs to be efficient (called for every read).
|
||||
/// The goto implementation strategy is to .load an atomic variable,
|
||||
/// then use e.g. netlink to update the variable in a separate thread.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The MTU of the interface in bytes
|
||||
fn mtu(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait Tun: Send + Sync + 'static {
|
||||
type Writer: Writer;
|
||||
type Reader: Reader;
|
||||
type MTU: MTU;
|
||||
type Error: Error;
|
||||
}
|
||||
407
src/wireguard/wireguard.rs
Normal file
407
src/wireguard/wireguard.rs
Normal file
@@ -0,0 +1,407 @@
|
||||
use super::constants::*;
|
||||
use super::handshake;
|
||||
use super::router;
|
||||
use super::timers::{Events, Timers};
|
||||
|
||||
use super::types::bind::Reader as BindReader;
|
||||
use super::types::bind::{Bind, Writer};
|
||||
use super::types::tun::{Reader, Tun, MTU};
|
||||
use super::types::Endpoint;
|
||||
|
||||
use hjul::Runner;
|
||||
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use log::debug;
|
||||
use rand::rngs::OsRng;
|
||||
use spin::{Mutex, RwLock, RwLockReadGuard};
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use crossbeam_channel::{bounded, Sender};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
||||
|
||||
pub struct Peer<T: Tun, B: Bind> {
|
||||
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
||||
pub state: Arc<PeerInner<B>>,
|
||||
}
|
||||
|
||||
impl<T: Tun, B: Bind> Clone for Peer<T, B> {
|
||||
fn clone(&self) -> Peer<T, B> {
|
||||
Peer {
|
||||
router: self.router.clone(),
|
||||
state: self.state.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerInner<B: Bind> {
|
||||
pub keepalive: AtomicUsize, // keepalive interval
|
||||
pub rx_bytes: AtomicU64,
|
||||
pub tx_bytes: AtomicU64,
|
||||
pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
|
||||
pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
|
||||
pub timers: RwLock<Timers>, //
|
||||
}
|
||||
|
||||
impl<B: Bind> PeerInner<B> {
|
||||
#[inline(always)]
|
||||
pub fn timers(&self) -> RwLockReadGuard<Timers> {
|
||||
self.timers.read()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "peer()")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: Bind> Deref for Peer<T, B> {
|
||||
type Target = PeerInner<B>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Bind> PeerInner<B> {
|
||||
pub fn new_handshake(&self) {
|
||||
// TODO: clear endpoint source address ("unsticky")
|
||||
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
struct Handshake {
|
||||
device: handshake::Device,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
pub enum HandshakeJob<E> {
|
||||
Message(Vec<u8>, E),
|
||||
New(PublicKey),
|
||||
}
|
||||
|
||||
struct WireguardInner<T: Tun, B: Bind> {
|
||||
// provides access to the MTU value of the tun device
|
||||
// (otherwise owned solely by the router and a dedicated read IO thread)
|
||||
mtu: T::MTU,
|
||||
send: RwLock<Option<B::Writer>>,
|
||||
|
||||
// identify and configuration map
|
||||
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||
|
||||
// cryptkey router
|
||||
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
|
||||
|
||||
// handshake related state
|
||||
handshake: RwLock<Handshake>,
|
||||
under_load: AtomicBool,
|
||||
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
||||
}
|
||||
|
||||
pub struct Wireguard<T: Tun, B: Bind> {
|
||||
runner: Runner,
|
||||
state: Arc<WireguardInner<T, B>>,
|
||||
}
|
||||
|
||||
/* Returns the padded length of a message:
|
||||
*
|
||||
* # Arguments
|
||||
*
|
||||
* - `size` : Size of unpadded message
|
||||
* - `mtu` : Maximum transmission unit of the device
|
||||
*
|
||||
* # Returns
|
||||
*
|
||||
* The padded length (always less than or equal to the MTU)
|
||||
*/
|
||||
#[inline(always)]
|
||||
const fn padding(size: usize, mtu: usize) -> usize {
|
||||
#[inline(always)]
|
||||
const fn min(a: usize, b: usize) -> usize {
|
||||
let m = (a > b) as usize;
|
||||
a * m + (1 - m) * b
|
||||
}
|
||||
let pad = MESSAGE_PADDING_MULTIPLE;
|
||||
min(mtu, size + (pad - size % pad) % pad)
|
||||
}
|
||||
|
||||
impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
pub fn set_key(&self, sk: Option<StaticSecret>) {
|
||||
let mut handshake = self.state.handshake.write();
|
||||
match sk {
|
||||
None => {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
handshake.device.set_sk(StaticSecret::new(&mut rng));
|
||||
handshake.active = false;
|
||||
}
|
||||
Some(sk) => {
|
||||
handshake.device.set_sk(sk);
|
||||
handshake.active = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_sk(&self) -> Option<StaticSecret> {
|
||||
let handshake = self.state.handshake.read();
|
||||
if handshake.active {
|
||||
Some(handshake.device.get_sk())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
|
||||
let state = Arc::new(PeerInner {
|
||||
pk,
|
||||
queue: Mutex::new(self.state.queue.lock().clone()),
|
||||
keepalive: AtomicUsize::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
timers: RwLock::new(Timers::dummy(&self.runner)),
|
||||
});
|
||||
|
||||
let router = Arc::new(self.state.router.new_peer(state.clone()));
|
||||
|
||||
let peer = Peer { router, state };
|
||||
|
||||
/* The need for dummy timers arises from the chicken-egg
|
||||
* problem of the timer callbacks being able to set timers themselves.
|
||||
*
|
||||
* This is in fact the only place where the write lock is ever taken.
|
||||
*/
|
||||
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
|
||||
peer
|
||||
}
|
||||
|
||||
/* Begin consuming messages from the reader.
|
||||
*
|
||||
* Any previous reader thread is stopped by closing the previous reader,
|
||||
* which unblocks the thread and causes an error on reader.read
|
||||
*/
|
||||
pub fn add_reader(&self, reader: B::Reader) {
|
||||
let wg = self.state.clone();
|
||||
thread::spawn(move || {
|
||||
let mut last_under_load =
|
||||
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
||||
|
||||
loop {
|
||||
// create vector big enough for any message given current MTU
|
||||
let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
||||
msg.resize(size, 0);
|
||||
|
||||
// read UDP packet into vector
|
||||
let (size, src) = match reader.read(&mut msg) {
|
||||
Err(e) => {
|
||||
debug!("Bind reader closed with {}", e);
|
||||
return;
|
||||
}
|
||||
Ok(v) => v,
|
||||
};
|
||||
msg.truncate(size);
|
||||
|
||||
// message type de-multiplexer
|
||||
if msg.len() < std::mem::size_of::<u32>() {
|
||||
continue;
|
||||
}
|
||||
match LittleEndian::read_u32(&msg[..]) {
|
||||
handshake::TYPE_COOKIE_REPLY
|
||||
| handshake::TYPE_INITIATION
|
||||
| handshake::TYPE_RESPONSE => {
|
||||
// update under_load flag
|
||||
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
|
||||
last_under_load = Instant::now();
|
||||
wg.under_load.store(true, Ordering::SeqCst);
|
||||
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
|
||||
wg.under_load.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
wg.queue
|
||||
.lock()
|
||||
.send(HandshakeJob::Message(msg, src))
|
||||
.unwrap();
|
||||
}
|
||||
router::TYPE_TRANSPORT => {
|
||||
// transport message
|
||||
let _ = wg.router.recv(src, msg).map_err(|e| {
|
||||
debug!("Failed to handle incoming transport message: {}", e);
|
||||
});
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn set_writer(&self, writer: B::Writer) {
|
||||
// TODO: Consider unifying these and avoid Clone requirement on writer
|
||||
*self.state.send.write() = Some(writer.clone());
|
||||
self.state.router.set_outbound_writer(writer);
|
||||
}
|
||||
|
||||
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
|
||||
// create device state
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let wg = Arc::new(WireguardInner {
|
||||
mtu: mtu.clone(),
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
send: RwLock::new(None),
|
||||
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
||||
pending: AtomicUsize::new(0),
|
||||
handshake: RwLock::new(Handshake {
|
||||
device: handshake::Device::new(StaticSecret::new(&mut rng)),
|
||||
active: false,
|
||||
}),
|
||||
under_load: AtomicBool::new(false),
|
||||
queue: Mutex::new(tx),
|
||||
});
|
||||
|
||||
// start handshake workers
|
||||
for _ in 0..num_cpus::get() {
|
||||
let wg = wg.clone();
|
||||
let rx = rx.clone();
|
||||
thread::spawn(move || {
|
||||
// prepare OsRng instance for this thread
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
|
||||
// process elements from the handshake queue
|
||||
for job in rx {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let state = wg.handshake.read();
|
||||
if !state.active {
|
||||
continue;
|
||||
}
|
||||
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
|
||||
// process message
|
||||
match state.device.process(
|
||||
&mut rng,
|
||||
&msg[..],
|
||||
if wg.under_load.load(Ordering::Relaxed) {
|
||||
Some(&src_validate)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
) {
|
||||
Ok((pk, resp, keypair)) => {
|
||||
// send response
|
||||
let mut resp_len: u64 = 0;
|
||||
if let Some(msg) = resp {
|
||||
resp_len = msg.len() as u64;
|
||||
let send: &Option<B::Writer> = &*wg.send.read();
|
||||
if let Some(writer) = send.as_ref() {
|
||||
let _ = writer.write(&msg[..], &src).map_err(|e| {
|
||||
debug!(
|
||||
"handshake worker, failed to send response, error = {}",
|
||||
e
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// update timers
|
||||
if let Some(pk) = pk {
|
||||
// authenticated handshake packet received
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
// add to rx_bytes and tx_bytes
|
||||
let req_len = msg.len() as u64;
|
||||
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
|
||||
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
|
||||
|
||||
// update endpoint
|
||||
peer.router.set_endpoint(src);
|
||||
|
||||
// add keypair to peer
|
||||
keypair.map(|kp| {
|
||||
// free any unused ids
|
||||
for id in peer.router.add_keypair(kp) {
|
||||
state.device.release(id);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||
}
|
||||
}
|
||||
HandshakeJob::New(pk) => {
|
||||
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
let _ = peer.router.send(&msg[..]).map_err(|e| {
|
||||
debug!("handshake worker, failed to send handshake initiation, error = {}", e)
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// start TUN read IO threads (multiple threads to support multi-queue interfaces)
|
||||
debug_assert!(
|
||||
readers.len() > 0,
|
||||
"attempted to create WG device without TUN readers"
|
||||
);
|
||||
while let Some(reader) = readers.pop() {
|
||||
let wg = wg.clone();
|
||||
let mtu = mtu.clone();
|
||||
thread::spawn(move || loop {
|
||||
// create vector big enough for any transport message (based on MTU)
|
||||
let mtu = mtu.mtu();
|
||||
let size = mtu + router::SIZE_MESSAGE_PREFIX;
|
||||
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
||||
msg.resize(size, 0);
|
||||
|
||||
// read a new IP packet
|
||||
let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
|
||||
Ok(payload) => payload,
|
||||
Err(e) => {
|
||||
debug!("TUN worker, failed to read from tun device: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
|
||||
|
||||
// truncate padding
|
||||
let payload = padding(payload, mtu);
|
||||
msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
|
||||
debug_assert!(payload <= mtu);
|
||||
debug_assert_eq!(
|
||||
if payload < mtu {
|
||||
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
|
||||
} else {
|
||||
0
|
||||
},
|
||||
0
|
||||
);
|
||||
|
||||
// crypt-key route
|
||||
let e = wg.router.send(msg);
|
||||
debug!("TUN worker, router returned {:?}", e);
|
||||
});
|
||||
}
|
||||
|
||||
Wireguard {
|
||||
state: wg,
|
||||
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user