Work on Linux platform code

This commit is contained in:
Mathias Hall-Andersen
2019-10-13 22:26:12 +02:00
parent 6000cbf7e4
commit a08fd4002b
36 changed files with 293 additions and 52 deletions

186
src/wireguard/config.rs Normal file
View File

@@ -0,0 +1,186 @@
use std::net::{IpAddr, SocketAddr};
use x25519_dalek::{PublicKey, StaticSecret};
use super::wireguard::Wireguard;
use super::types::bind::Bind;
use super::types::tun::Tun;
/// The goal of the configuration interface is, among others,
/// to hide the IO implementations (over which the WG device is generic),
/// from the configuration and UAPI code.
/// Describes a snapshot of the state of a peer
pub struct PeerState {
rx_bytes: u64,
tx_bytes: u64,
last_handshake_time_sec: u64,
last_handshake_time_nsec: u64,
public_key: PublicKey,
allowed_ips: Vec<(IpAddr, u32)>,
}
pub enum ConfigError {
NoSuchPeer
}
impl ConfigError {
fn errno(&self) -> i32 {
match self {
NoSuchPeer => 1,
}
}
}
/// Exposed configuration interface
pub trait Configuration {
/// Updates the private key of the device
///
/// # Arguments
///
/// - `sk`: The new private key (or None, if the private key should be cleared)
fn set_private_key(&self, sk: Option<StaticSecret>);
/// Returns the private key of the device
///
/// # Returns
///
/// The private if set, otherwise None.
fn get_private_key(&self) -> Option<StaticSecret>;
/// Returns the protocol version of the device
///
/// # Returns
///
/// An integer indicating the protocol version
fn get_protocol_version(&self) -> usize;
fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
/// Set the firewall mark (or similar, depending on platform)
///
/// # Arguments
///
/// - `mark`: The fwmark value
///
/// # Returns
///
/// An error if this operation is not supported by the underlying
/// "bind" implementation.
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
/// Removes all peers from the device
fn replace_peers(&self);
/// Remove the peer from the
///
/// # Arguments
///
/// - `peer`: The public key of the peer to remove
///
/// # Returns
///
/// If the peer does not exists this operation is a noop
fn remove_peer(&self, peer: PublicKey);
/// Adds a new peer to the device
///
/// # Arguments
///
/// - `peer`: The public key of the peer to add
///
/// # Returns
///
/// A bool indicating if the peer was added.
///
/// If the peer already exists this operation is a noop
fn add_peer(&self, peer: PublicKey) -> bool;
/// Update the psk of a peer
///
/// # Arguments
///
/// - `peer`: The public key of the peer
/// - `psk`: The new psk or None if the psk should be unset
///
/// # Returns
///
/// An error if no such peer exists
fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
/// Update the endpoint of the
///
/// # Arguments
///
/// - `peer': The public key of the peer
/// - `psk`
fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
/// Update the endpoint of the
///
/// # Arguments
///
/// - `peer': The public key of the peer
/// - `psk`
fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
/// Remove all allowed IPs from the peer
///
/// # Arguments
///
/// - `peer': The public key of the peer
///
/// # Returns
///
/// An error if no such peer exists
fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
/// Add a new allowed subnet to the peer
///
/// # Arguments
///
/// - `peer`: The public key of the peer
/// - `ip`: Subnet mask
/// - `masklen`:
///
/// # Returns
///
/// An error if the peer does not exist
///
/// # Note:
///
/// The API must itself sanitize the (ip, masklen) set:
/// The ip should be masked to remove any set bits right of the first "masklen" bits.
fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
/// Returns the state of all peers
///
/// # Returns
///
/// A list of structures describing the state of each peer
fn get_peers(&self) -> Vec<PeerState>;
}
impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
fn set_private_key(&self, sk : Option<StaticSecret>) {
self.set_key(sk)
}
fn get_private_key(&self) -> Option<StaticSecret> {
self.get_sk()
}
fn get_protocol_version(&self) -> usize {
1
}
fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
None
}
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> {
None
}
}

View File

@@ -0,0 +1,20 @@
use std::time::Duration;
use std::u64;
pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16);
pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
pub const MAX_TIMER_HANDSHAKES: usize = 18;
pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
pub const TIMERS_CAPACITY: usize = 1024;
pub const MESSAGE_PADDING_MULTIPLE: usize = 16;

View File

@@ -0,0 +1,574 @@
use spin::RwLock;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Mutex;
use zerocopy::AsBytes;
use byteorder::{ByteOrder, LittleEndian};
use rand::prelude::*;
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
use super::macs;
use super::messages::{CookieReply, Initiation, Response};
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
use super::noise;
use super::peer::Peer;
use super::ratelimiter::RateLimiter;
use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
pub struct Device {
pub sk: StaticSecret, // static secret key
pub pk: PublicKey, // static public key
macs: macs::Validator, // validator for the mac fields
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
limiter: Mutex<RateLimiter>,
}
/* A mutable reference to the device needs to be held during configuration.
* Wrapping the device in a RwLock enables peer config after "configuration time"
*/
impl Device {
/// Initialize a new handshake state machine
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
pub fn new(sk: StaticSecret) -> Device {
let pk = PublicKey::from(&sk);
Device {
pk,
sk,
macs: macs::Validator::new(pk),
pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
limiter: Mutex::new(RateLimiter::new()),
}
}
/// Update the secret key of the device
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
pub fn set_sk(&mut self, sk: StaticSecret) {
// update secret and public key
let pk = PublicKey::from(&sk);
self.sk = sk;
self.pk = pk;
self.macs = macs::Validator::new(pk);
// recalculate the shared secrets for every peer
let mut ids = vec![];
for mut peer in self.pk_map.values_mut() {
peer.reset_state().map(|id| ids.push(id));
peer.ss = self.sk.diffie_hellman(&peer.pk)
}
// release ids from aborted handshakes
for id in ids {
self.release(id)
}
}
/// Return the secret key of the device
///
/// # Returns
///
/// A secret key (x25519 scalar)
pub fn get_sk(&self) -> StaticSecret {
StaticSecret::from(self.sk.to_bytes())
}
/// Add a new public key to the state machine
/// To remove public keys, you must create a new machine instance
///
/// # Arguments
///
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
// check that the pk is not added twice
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
return Err(ConfigError::new("Duplicate public key"));
};
// check that the pk is not that of the device
if *self.pk.as_bytes() == *pk.as_bytes() {
return Err(ConfigError::new(
"Public key corresponds to secret key of interface",
));
}
// ensure less than 2^20 peers
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
return Err(ConfigError::new("Too many peers for device"));
}
// map the public key to the peer state
self.pk_map
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
Ok(())
}
/// Remove a peer by public key
/// To remove public keys, you must create a new machine instance
///
/// # Arguments
///
/// * `pk` - The public key of the peer to remove
///
/// # Returns
///
/// The call might fail if the public key is not found
pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
// take write-lock on receive id table
let mut id_map = self.id_map.write();
// remove the peer
self.pk_map
.remove(pk.as_bytes())
.ok_or(ConfigError::new("Public key not in device"))?;
// pruge the id map (linear scan)
id_map.retain(|_, v| v != pk.as_bytes());
Ok(())
}
/// Add a psk to the peer
///
/// # Arguments
///
/// * `pk` - The public key of the peer
/// * `psk` - The psk to set / unset
///
/// # Returns
///
/// The call might fail if the public key is not found
pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
match self.pk_map.get_mut(pk.as_bytes()) {
Some(mut peer) => {
peer.psk = match psk {
Some(v) => v,
None => [0u8; 32],
};
Ok(())
}
_ => Err(ConfigError::new("No such public key")),
}
}
/// Return the psk for the peer
///
/// # Arguments
///
/// * `pk` - The public key of the peer
///
/// # Returns
///
/// A 32 byte array holding the PSK
///
/// The call might fail if the public key is not found
pub fn get_psk(&self, pk: PublicKey) -> Result<Psk, ConfigError> {
match self.pk_map.get(pk.as_bytes()) {
Some(peer) => Ok(peer.psk),
_ => Err(ConfigError::new("No such public key")),
}
}
/// Release an id back to the pool
///
/// # Arguments
///
/// * `id` - The (sender) id to release
pub fn release(&self, id: u32) {
let mut m = self.id_map.write();
debug_assert!(m.contains_key(&id), "Releasing id not allocated");
m.remove(&id);
}
/// Begin a new handshake
///
/// # Arguments
///
/// * `pk` - Public key of peer to initiate handshake for
pub fn begin<R: RngCore + CryptoRng>(
&self,
rng: &mut R,
pk: &PublicKey,
) -> Result<Vec<u8>, HandshakeError> {
match self.pk_map.get(pk.as_bytes()) {
None => Err(HandshakeError::UnknownPublicKey),
Some(peer) => {
let sender = self.allocate(rng, peer);
let mut msg = Initiation::default();
noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
// add macs to initation
peer.macs
.lock()
.generate(msg.noise.as_bytes(), &mut msg.macs);
Ok(msg.as_bytes().to_owned())
}
}
}
/// Process a handshake message.
///
/// # Arguments
///
/// * `msg` - Byte slice containing the message (untrusted input)
pub fn process<'a, R: RngCore + CryptoRng, S>(
&self,
rng: &mut R, // rng instance to sample randomness from
msg: &[u8], // message buffer
src: Option<&'a S>, // optional source endpoint, set when "under load"
) -> Result<Output, HandshakeError>
where
&'a S: Into<&'a SocketAddr>,
{
// ensure type read in-range
if msg.len() < 4 {
return Err(HandshakeError::InvalidMessageFormat);
}
// de-multiplex the message type field
match LittleEndian::read_u32(msg) {
TYPE_INITIATION => {
// parse message
let msg = Initiation::parse(msg)?;
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation
if let Some(src) = src {
// obtain ref to socket addr
let src = src.into();
// check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
src,
&msg.macs,
&mut reply,
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
// check ratelimiter
if !self.limiter.lock().unwrap().allow(&src.ip()) {
return Err(HandshakeError::RateLimited);
}
}
// consume the initiation
let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
// allocate new index for response
let sender = self.allocate(rng, peer);
// prepare memory for response, TODO: take slice for zero allocation
let mut resp = Response::default();
// create response (release id on error)
let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err(
|e| {
self.release(sender);
e
},
)?;
// add macs to response
peer.macs
.lock()
.generate(resp.noise.as_bytes(), &mut resp.macs);
// return unconfirmed keypair and the response as vector
Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
}
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation
if let Some(src) = src {
// obtain ref to socket addr
let src = src.into();
// check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
src,
&msg.macs,
&mut reply,
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
// check ratelimiter
if !self.limiter.lock().unwrap().allow(&src.ip()) {
return Err(HandshakeError::RateLimited);
}
}
// consume inner playload
noise::consume_response(self, &msg.noise)
}
TYPE_COOKIE_REPLY => {
let msg = CookieReply::parse(msg)?;
// lookup peer
let peer = self.lookup_id(msg.f_receiver.get())?;
// validate cookie reply
peer.macs.lock().process(&msg)?;
// this prompts no new message and
// DOES NOT cryptographically verify the peer
Ok((None, None, None))
}
_ => Err(HandshakeError::InvalidMessageFormat),
}
}
// Internal function
//
// Return the peer associated with the public key
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
self.pk_map
.get(pk.as_bytes())
.ok_or(HandshakeError::UnknownPublicKey)
}
// Internal function
//
// Return the peer currently associated with the receiver identifier
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
let im = self.id_map.read();
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
match self.pk_map.get(pk) {
Some(peer) => Ok(peer),
_ => unreachable!(), // if the id-lookup succeeded, the peer should exist
}
}
// Internal function
//
// Allocated a new receiver identifier for the peer
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
loop {
let id = rng.gen();
// check membership with read lock
if self.id_map.read().contains_key(&id) {
continue;
}
// take write lock and add index
let mut m = self.id_map.write();
if !m.contains_key(&id) {
m.insert(id, *peer.pk.as_bytes());
return id;
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::messages::*;
use super::*;
use hex;
use rand::rngs::OsRng;
use std::net::SocketAddr;
use std::thread;
use std::time::Duration;
fn setup_devices<R: RngCore + CryptoRng>(
rng: &mut R,
) -> (PublicKey, Device, PublicKey, Device) {
// generate new keypairs
let sk1 = StaticSecret::new(rng);
let pk1 = PublicKey::from(&sk1);
let sk2 = StaticSecret::new(rng);
let pk2 = PublicKey::from(&sk2);
// pick random psk
let mut psk = [0u8; 32];
rng.fill_bytes(&mut psk[..]);
// intialize devices on both ends
let mut dev1 = Device::new(sk1);
let mut dev2 = Device::new(sk2);
dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap();
dev1.set_psk(pk2, Some(psk)).unwrap();
dev2.set_psk(pk1, Some(psk)).unwrap();
(pk1, dev1, pk2, dev2)
}
/* Test longest possible handshake interaction (7 messages):
*
* 1. I -> R (initation)
* 2. I <- R (cookie reply)
* 3. I -> R (initation)
* 4. I <- R (response)
* 5. I -> R (cookie reply)
* 6. I -> R (initation)
* 7. I <- R (response)
*/
#[test]
fn handshake_under_load() {
let mut rng = OsRng::new().unwrap();
let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
// 1. device-1 : create first initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 2. device-2 : responds with CookieReply
let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-1 : processes CookieReply (no response)
match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
// avoid initation flood
thread::sleep(Duration::from_millis(20));
// 3. device-1 : create second initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 4. device-2 : responds with noise response
let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
msg
}
_ => panic!("unexpected response"),
};
// 5. device-1 : responds with CookieReply
let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-2 : processes CookieReply (no response)
match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
// avoid initation flood
thread::sleep(Duration::from_millis(20));
// 6. device-1 : create third initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 7. device-2 : responds with noise response
let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
(msg, kp)
}
_ => panic!("unexpected response"),
};
// device-1 : process noise response
let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
(Some(_), None, Some(kp)) => {
assert_eq!(kp.initiator, true);
kp
}
_ => panic!("unexpected response"),
};
assert_eq!(kp1.send, kp2.recv);
assert_eq!(kp1.recv, kp2.send);
}
#[test]
fn handshake_no_load() {
let mut rng = OsRng::new().unwrap();
let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
// do a few handshakes (every handshake should succeed)
for i in 0..10 {
println!("handshake : {}", i);
// create initiation
let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
// process initiation and create response
let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap();
let ks_r = ks_r.unwrap();
let msg2 = msg2.unwrap();
println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
assert!(!ks_r.initiator, "Responders key-pair is confirmed");
// process response and obtain confirmed key-pair
let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap();
let ks_i = ks_i.unwrap();
assert!(msg3.is_none(), "Returned message after response");
assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
dev1.release(ks_i.send.id);
dev2.release(ks_r.send.id);
// to avoid flood detection
thread::sleep(Duration::from_millis(20));
}
dev1.remove(pk2).unwrap();
dev2.remove(pk1).unwrap();
}
}

View File

@@ -0,0 +1,327 @@
use generic_array::GenericArray;
use rand::{CryptoRng, RngCore};
use spin::RwLock;
use std::time::{Duration, Instant};
// types to coalesce into bytes
use std::net::SocketAddr;
use x25519_dalek::PublicKey;
// AEAD
use aead::{Aead, NewAead, Payload};
use chacha20poly1305::XChaCha20Poly1305;
// MAC
use blake2::Blake2s;
use subtle::ConstantTimeEq;
use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY};
use super::types::HandshakeError;
const LABEL_MAC1: &[u8] = b"mac1----";
const LABEL_COOKIE: &[u8] = b"cookie--";
const SIZE_COOKIE: usize = 16;
const SIZE_SECRET: usize = 32;
const SIZE_MAC: usize = 16; // blake2s-mac128
const SIZE_TAG: usize = 16; // xchacha20poly1305 tag
const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120);
macro_rules! HASH {
( $($input:expr),* ) => {{
use blake2::Digest;
let mut hsh = Blake2s::new();
$(
hsh.input($input);
)*
hsh.result()
}};
}
macro_rules! MAC {
( $key:expr, $($input:expr),* ) => {{
use blake2::VarBlake2s;
use digest::Input;
use digest::VariableOutput;
let mut tag = [0u8; SIZE_MAC];
let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC);
$(
mac.input($input);
)*
mac.variable_result(|buf| tag.copy_from_slice(buf));
tag
}};
}
macro_rules! XSEAL {
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key))
.encrypt(
GenericArray::from_slice($nonce),
Payload { msg: $pt, aad: $ad },
)
.unwrap();
debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG);
$ct.copy_from_slice(&ct);
}};
}
macro_rules! XOPEN {
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG);
XChaCha20Poly1305::new(*GenericArray::from_slice($key))
.decrypt(
GenericArray::from_slice($nonce),
Payload { msg: $ct, aad: $ad },
)
.map_err(|_| HandshakeError::DecryptionFailure)
.map(|pt| $pt.copy_from_slice(&pt))
}};
}
struct Cookie {
value: [u8; 16],
birth: Instant,
}
pub struct Generator {
mac1_key: [u8; 32],
cookie_key: [u8; 32], // xchacha20poly key for opening cookie response
last_mac1: Option<[u8; 16]>,
cookie: Option<Cookie>,
}
fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec<u8> {
match addr {
SocketAddr::V4(addr) => {
let mut res = Vec::with_capacity(4 + 2);
res.extend(&addr.ip().octets());
res.extend(&addr.port().to_le_bytes());
res
}
SocketAddr::V6(addr) => {
let mut res = Vec::with_capacity(16 + 2);
res.extend(&addr.ip().octets());
res.extend(&addr.port().to_le_bytes());
res
}
}
}
impl Generator {
/// Initalize a new mac field generator
///
/// # Arguments
///
/// - pk: The public key of the peer to which the generator is associated
///
/// # Returns
///
/// A freshly initated generator
pub fn new(pk: PublicKey) -> Generator {
Generator {
mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
last_mac1: None,
cookie: None,
}
}
/// Process a CookieReply message
///
/// # Arguments
///
/// - reply: CookieReply to process
///
/// # Returns
///
/// Can fail if the cookie reply fails to validate
/// (either indicating that it is outdated or malformed)
pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> {
let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?;
let mut tau = [0u8; SIZE_COOKIE];
XOPEN!(
&self.cookie_key, // key
&reply.f_nonce, // nonce
&mac1, // ad
&mut tau, // pt
&reply.f_cookie // ct || tag
)?;
self.cookie = Some(Cookie {
birth: Instant::now(),
value: tau,
});
Ok(())
}
/// Generate both mac fields for an inner message
///
/// # Arguments
///
/// - inner: A byteslice representing the inner message to be covered
/// - macs: The destination mac footer for the resulting macs
pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) {
macs.f_mac1 = MAC!(&self.mac1_key, inner);
macs.f_mac2 = match &self.cookie {
Some(cookie) => {
if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL {
self.cookie = None;
[0u8; SIZE_MAC]
} else {
MAC!(&cookie.value, inner, macs.f_mac1)
}
}
None => [0u8; SIZE_MAC],
};
self.last_mac1 = Some(macs.f_mac1);
}
}
struct Secret {
value: [u8; 32],
birth: Instant,
}
pub struct Validator {
mac1_key: [u8; 32], // mac1 key, derived from device public key
cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response
secret: RwLock<Secret>,
}
impl Validator {
pub fn new(pk: PublicKey) -> Validator {
Validator {
mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
secret: RwLock::new(Secret {
value: [0u8; SIZE_SECRET],
birth: Instant::now() - Duration::new(86400, 0),
}),
}
}
fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
let secret = self.secret.read();
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
Some(MAC!(&secret.value, src))
} else {
None
}
}
fn get_set_tau<R: RngCore + CryptoRng>(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] {
// check if current value is still valid
{
let secret = self.secret.read();
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
return MAC!(&secret.value, src);
};
}
// take write lock, check again
{
let mut secret = self.secret.write();
if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
return MAC!(&secret.value, src);
};
// set new random cookie secret
rng.fill_bytes(&mut secret.value);
secret.birth = Instant::now();
MAC!(&secret.value, src)
}
}
pub fn create_cookie_reply<R: RngCore + CryptoRng>(
&self,
rng: &mut R,
receiver: u32, // receiver id of incoming message
src: &SocketAddr, // source address of incoming message
macs: &MacsFooter, // footer of incoming message
msg: &mut CookieReply, // resulting cookie reply
) {
let src = addr_to_mac_bytes(src);
msg.f_type.set(TYPE_COOKIE_REPLY as u32);
msg.f_receiver.set(receiver);
rng.fill_bytes(&mut msg.f_nonce);
XSEAL!(
&self.cookie_key, // key
&msg.f_nonce, // nonce
&macs.f_mac1, // ad
&self.get_set_tau(rng, &src), // pt
&mut msg.f_cookie // ct || tag
);
}
/// Check the mac1 field against the inner message
///
/// # Arguments
///
/// - inner: The inner message covered by the mac1 field
/// - macs: The mac footer
pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> {
let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into();
if !valid_mac1 {
Err(HandshakeError::InvalidMac1)
} else {
Ok(())
}
}
pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool {
let src = addr_to_mac_bytes(src);
match self.get_tau(&src) {
Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(),
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use rand::rngs::OsRng;
use x25519_dalek::StaticSecret;
fn new_validator_generator() -> (Validator, Generator) {
let mut rng = OsRng::new().unwrap();
let sk = StaticSecret::new(&mut rng);
let pk = PublicKey::from(&sk);
(Validator::new(pk), Generator::new(pk))
}
proptest! {
#[test]
fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) {
let mut msg = CookieReply::default();
let mut rng = OsRng::new().expect("failed to create rng");
let mut macs = MacsFooter::default();
let src = "192.0.2.16:8080".parse().unwrap();
let (validator, mut generator) = new_validator_generator();
// generate mac1 for first message
generator.generate(&inner1[..], &mut macs);
assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set");
// check validity of mac1
validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate");
assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate");
validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg);
// consume cookie reply
generator.process(&msg).expect("failed to process CookieReply");
// generate mac2 & mac2 for second message
generator.generate(&inner2[..], &mut macs);
assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set");
// check validity of mac1 and mac2
validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate");
assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate");
}
}
}

View File

@@ -0,0 +1,363 @@
#[cfg(test)]
use hex;
#[cfg(test)]
use std::fmt;
use std::mem;
use byteorder::LittleEndian;
use zerocopy::byteorder::U32;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
use super::types::*;
const SIZE_MAC: usize = 16;
const SIZE_TAG: usize = 16; // poly1305 tag
const SIZE_XNONCE: usize = 24; // xchacha20 nonce
const SIZE_COOKIE: usize = 16; //
const SIZE_X25519_POINT: usize = 32; // x25519 public key
const SIZE_TIMESTAMP: usize = 12;
pub const TYPE_INITIATION: u32 = 1;
pub const TYPE_RESPONSE: u32 = 2;
pub const TYPE_COOKIE_REPLY: u32 = 3;
const fn max(a: usize, b: usize) -> usize {
let m: usize = (a > b) as usize;
m * a + (1 - m) * b
}
pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
mem::size_of::<CookieReply>(),
);
/* Handshake messsages */
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct Response {
pub noise: NoiseResponse, // inner message covered by macs
pub macs: MacsFooter,
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct Initiation {
pub noise: NoiseInitiation, // inner message covered by macs
pub macs: MacsFooter,
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct CookieReply {
pub f_type: U32<LittleEndian>,
pub f_receiver: U32<LittleEndian>,
pub f_nonce: [u8; SIZE_XNONCE],
pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG],
}
/* Inner sub-messages */
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct MacsFooter {
pub f_mac1: [u8; SIZE_MAC],
pub f_mac2: [u8; SIZE_MAC],
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct NoiseInitiation {
pub f_type: U32<LittleEndian>,
pub f_sender: U32<LittleEndian>,
pub f_ephemeral: [u8; SIZE_X25519_POINT],
pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG],
pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG],
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct NoiseResponse {
pub f_type: U32<LittleEndian>,
pub f_sender: U32<LittleEndian>,
pub f_receiver: U32<LittleEndian>,
pub f_ephemeral: [u8; SIZE_X25519_POINT],
pub f_empty: [u8; SIZE_TAG],
}
/* Zero copy parsing of handshake messages */
impl Initiation {
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
let msg: LayoutVerified<B, Self> =
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
if msg.noise.f_type.get() != (TYPE_INITIATION as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
Ok(msg)
}
}
impl Response {
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
let msg: LayoutVerified<B, Self> =
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
Ok(msg)
}
}
impl CookieReply {
pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
let msg: LayoutVerified<B, Self> =
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
Ok(msg)
}
}
/* Default values */
impl Default for Response {
fn default() -> Self {
Self {
noise: Default::default(),
macs: Default::default(),
}
}
}
impl Default for Initiation {
fn default() -> Self {
Self {
noise: Default::default(),
macs: Default::default(),
}
}
}
impl Default for CookieReply {
fn default() -> Self {
Self {
f_type: <U32<LittleEndian>>::new(TYPE_COOKIE_REPLY as u32),
f_receiver: <U32<LittleEndian>>::ZERO,
f_nonce: [0u8; SIZE_XNONCE],
f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG],
}
}
}
impl Default for MacsFooter {
fn default() -> Self {
Self {
f_mac1: [0u8; SIZE_MAC],
f_mac2: [0u8; SIZE_MAC],
}
}
}
impl Default for NoiseInitiation {
fn default() -> Self {
Self {
f_type: <U32<LittleEndian>>::new(TYPE_INITIATION as u32),
f_sender: <U32<LittleEndian>>::ZERO,
f_ephemeral: [0u8; SIZE_X25519_POINT],
f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG],
f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG],
}
}
}
impl Default for NoiseResponse {
fn default() -> Self {
Self {
f_type: <U32<LittleEndian>>::new(TYPE_RESPONSE as u32),
f_sender: <U32<LittleEndian>>::ZERO,
f_receiver: <U32<LittleEndian>>::ZERO,
f_ephemeral: [0u8; SIZE_X25519_POINT],
f_empty: [0u8; SIZE_TAG],
}
}
}
/* Debug formatting (for testing purposes) */
#[cfg(test)]
impl fmt::Debug for Initiation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs)
}
}
#[cfg(test)]
impl fmt::Debug for Response {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs)
}
}
#[cfg(test)]
impl fmt::Debug for CookieReply {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}",
self.f_type,
self.f_receiver,
hex::encode(&self.f_nonce[..]),
hex::encode(&self.f_cookie[..]),
)
}
}
#[cfg(test)]
impl fmt::Debug for NoiseInitiation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f,
"NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}",
self.f_type.get(),
self.f_sender.get(),
hex::encode(&self.f_ephemeral[..]),
hex::encode(&self.f_static[..]),
hex::encode(&self.f_timestamp[..]),
)
}
}
#[cfg(test)]
impl fmt::Debug for NoiseResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f,
"NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}",
self.f_type,
self.f_sender,
self.f_receiver,
hex::encode(&self.f_ephemeral[..]),
hex::encode(&self.f_empty[..])
)
}
}
#[cfg(test)]
impl fmt::Debug for MacsFooter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Macs {{ mac1 = {}, mac2 = {} }}",
hex::encode(&self.f_mac1[..]),
hex::encode(&self.f_mac2[..])
)
}
}
/* Equality (for testing purposes) */
#[cfg(test)]
macro_rules! eq_as_bytes {
($type:path) => {
impl PartialEq for $type {
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl Eq for $type {}
};
}
#[cfg(test)]
eq_as_bytes!(Initiation);
#[cfg(test)]
eq_as_bytes!(Response);
#[cfg(test)]
eq_as_bytes!(CookieReply);
#[cfg(test)]
eq_as_bytes!(MacsFooter);
#[cfg(test)]
eq_as_bytes!(NoiseInitiation);
#[cfg(test)]
eq_as_bytes!(NoiseResponse);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_response_identity() {
let mut msg: Response = Default::default();
msg.noise.f_sender.set(146252);
msg.noise.f_receiver.set(554442);
msg.noise.f_ephemeral = [
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
0x13, 0x56, 0x52, 0x1f,
];
msg.noise.f_empty = [
0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05,
0x2f, 0xde,
];
msg.macs.f_mac1 = [
0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
0x69, 0x29,
];
msg.macs.f_mac2 = [
0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
0xdf, 0xbb,
];
let buf: Vec<u8> = msg.as_bytes().to_vec();
let msg_p = Response::parse(&buf[..]).unwrap();
assert_eq!(msg, *msg_p.into_ref());
}
#[test]
fn message_initiate_identity() {
let mut msg: Initiation = Default::default();
msg.noise.f_sender.set(575757);
msg.noise.f_ephemeral = [
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
0x13, 0x56, 0x52, 0x1f,
];
msg.noise.f_static = [
0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c,
0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3,
0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d,
0x86, 0x37, 0xee, 0x83, 0x6b, 0x42,
];
msg.noise.f_timestamp = [
0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e,
0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde,
];
msg.macs.f_mac1 = [
0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
0x69, 0x29,
];
msg.macs.f_mac2 = [
0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
0xdf, 0xbb,
];
let buf: Vec<u8> = msg.as_bytes().to_vec();
let msg_p = Initiation::parse(&buf[..]).unwrap();
assert_eq!(msg, *msg_p.into_ref());
}
}

View File

@@ -0,0 +1,21 @@
/* Implementation of the:
*
* Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s
*
* Protocol pattern, see: http://www.noiseprotocol.org/noise.html.
* For documentation.
*/
mod device;
mod macs;
mod messages;
mod noise;
mod peer;
mod ratelimiter;
mod timestamp;
mod types;
// publicly exposed interface
pub use device::Device;
pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};

View File

@@ -0,0 +1,549 @@
// DH
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
// HASH & MAC
use blake2::Blake2s;
use hmac::Hmac;
// AEAD
use aead::{Aead, NewAead, Payload};
use chacha20poly1305::ChaCha20Poly1305;
use rand::{CryptoRng, RngCore};
use generic_array::typenum::*;
use generic_array::*;
use clear_on_drop::clear::Clear;
use clear_on_drop::clear_stack_on_return;
use subtle::ConstantTimeEq;
use super::device::Device;
use super::messages::{NoiseInitiation, NoiseResponse};
use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
use super::peer::{Peer, State};
use super::timestamp;
use super::types::*;
use super::super::types::{KeyPair, Key};
use std::time::Instant;
// HMAC hasher (generic construction)
type HMACBlake2s = Hmac<Blake2s>;
// convenient alias to pass state temporarily into device.rs and back
type TemporaryState = (u32, PublicKey, GenericArray<u8, U32>, GenericArray<u8, U32>);
const SIZE_CK: usize = 32;
const SIZE_HS: usize = 32;
const SIZE_NONCE: usize = 8;
const SIZE_TAG: usize = 16;
// number of pages to clear after sensitive call
const CLEAR_PAGES: usize = 1;
// C := Hash(Construction)
const INITIAL_CK: [u8; SIZE_CK] = [
0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0,
0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36,
];
// H := Hash(C || Identifier)
const INITIAL_HS: [u8; SIZE_HS] = [
0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32,
0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3,
];
const ZERO_NONCE: [u8; 12] = [0u8; 12];
macro_rules! HASH {
( $($input:expr),* ) => {{
use blake2::Digest;
let mut hsh = Blake2s::new();
$(
hsh.input($input);
)*
hsh.result()
}};
}
macro_rules! HMAC {
($key:expr, $($input:expr),*) => {{
use hmac::Mac;
let mut mac = HMACBlake2s::new_varkey($key).unwrap();
$(
mac.input($input);
)*
mac.result().code()
}};
}
macro_rules! KDF1 {
($ck:expr, $input:expr) => {{
let mut t0 = HMAC!($ck, $input);
let t1 = HMAC!(&t0, &[0x1]);
t0.clear();
t1
}};
}
macro_rules! KDF2 {
($ck:expr, $input:expr) => {{
let mut t0 = HMAC!($ck, $input);
let t1 = HMAC!(&t0, &[0x1]);
let t2 = HMAC!(&t0, &t1, &[0x2]);
t0.clear();
(t1, t2)
}};
}
macro_rules! KDF3 {
($ck:expr, $input:expr) => {{
let mut t0 = HMAC!($ck, $input);
let t1 = HMAC!(&t0, &[0x1]);
let t2 = HMAC!(&t0, &t1, &[0x2]);
let t3 = HMAC!(&t0, &t2, &[0x3]);
t0.clear();
(t1, t2, t3)
}};
}
macro_rules! SEAL {
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
ChaCha20Poly1305::new(*GenericArray::from_slice($key))
.encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad })
.map(|ct| $ct.copy_from_slice(&ct))
.unwrap()
};
}
macro_rules! OPEN {
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
ChaCha20Poly1305::new(*GenericArray::from_slice($key))
.decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad })
.map_err(|_| HandshakeError::DecryptionFailure)
.map(|pt| $pt.copy_from_slice(&pt))
};
}
#[cfg(test)]
mod tests {
use super::*;
const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
/* Sanity check precomputed initial chain key
*/
#[test]
fn precomputed_chain_key() {
assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]);
}
/* Sanity check precomputed initial hash transcript
*/
#[test]
fn precomputed_hash() {
assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]);
}
/* Sanity check the HKDF macro
*
* Test vectors generated using WireGuard-Go
*/
#[test]
fn hkdf() {
let tests: Vec<(Vec<u8>, Vec<u8>, [u8; 32], [u8; 32], [u8; 32])> = vec![
(
vec![],
vec![],
[
0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09,
0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23,
0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0,
],
[
0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05,
0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4,
0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e,
],
[
0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d,
0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17,
0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e,
],
),
(
vec![0xde, 0xad, 0xbe, 0xef],
vec![],
[
0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08,
0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2,
0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6,
],
[
0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac,
0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0,
0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee,
],
[
0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a,
0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1,
0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4,
],
),
];
for (key, input, t0, t1, t2) in &tests {
let tt0 = KDF1!(key, input);
debug_assert_eq!(tt0[..], t0[..]);
let (tt0, tt1) = KDF2!(key, input);
debug_assert_eq!(tt0[..], t0[..]);
debug_assert_eq!(tt1[..], t1[..]);
let (tt0, tt1, tt2) = KDF3!(key, input);
debug_assert_eq!(tt0[..], t0[..]);
debug_assert_eq!(tt1[..], t1[..]);
debug_assert_eq!(tt2[..], t2[..]);
}
}
}
pub fn create_initiation<R: RngCore + CryptoRng>(
rng: &mut R,
device: &Device,
peer: &Peer,
sender: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// initialize state
let ck = INITIAL_CK;
let hs = INITIAL_HS;
let hs = HASH!(&hs, peer.pk.as_bytes());
msg.f_type.set(TYPE_INITIATION as u32);
msg.f_sender.set(sender);
// (E_priv, E_pub) := DH-Generate()
let eph_sk = StaticSecret::new(rng);
let eph_pk = PublicKey::from(&eph_sk);
// C := Kdf(C, E_pub)
let ck = KDF1!(&ck, eph_pk.as_bytes());
// msg.ephemeral := E_pub
msg.f_ephemeral = *eph_pk.as_bytes();
// H := HASH(H, msg.ephemeral)
let hs = HASH!(&hs, msg.f_ephemeral);
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
// msg.static := Aead(k, 0, S_pub, H)
SEAL!(
&key,
&hs, // ad
device.pk.as_bytes(), // pt
&mut msg.f_static // ct || tag
);
// H := Hash(H || msg.static)
let hs = HASH!(&hs, &msg.f_static[..]);
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
// msg.timestamp := Aead(k, 0, Timestamp(), H)
SEAL!(
&key,
&hs, // ad
&timestamp::now(), // pt
&mut msg.f_timestamp // ct || tag
);
// H := Hash(H || msg.timestamp)
let hs = HASH!(&hs, &msg.f_timestamp);
// update state of peer
*peer.state.lock() = State::InitiationSent {
hs,
ck,
eph_sk,
sender,
};
Ok(())
})
}
pub fn consume_initiation<'a>(
device: &'a Device,
msg: &NoiseInitiation,
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state
let ck = INITIAL_CK;
let hs = INITIAL_HS;
let hs = HASH!(&hs, device.pk.as_bytes());
// C := Kdf(C, E_pub)
let ck = KDF1!(&ck, &msg.f_ephemeral);
// H := HASH(H, msg.ephemeral)
let hs = HASH!(&hs, &msg.f_ephemeral);
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
// msg.static := Aead(k, 0, S_pub, H)
let mut pk = [0u8; 32];
OPEN!(
&key,
&hs, // ad
&mut pk, // pt
&msg.f_static // ct || tag
)?;
let peer = device.lookup_pk(&PublicKey::from(pk))?;
// reset initiation state
*peer.state.lock() = State::Reset;
// H := Hash(H || msg.static)
let hs = HASH!(&hs, &msg.f_static[..]);
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
// msg.timestamp := Aead(k, 0, Timestamp(), H)
let mut ts = timestamp::ZERO;
OPEN!(
&key,
&hs, // ad
&mut ts, // pt
&msg.f_timestamp // ct || tag
)?;
// check and update timestamp
peer.check_replay_flood(device, &ts)?;
// H := Hash(H || msg.timestamp)
let hs = HASH!(&hs, &msg.f_timestamp);
// return state (to create response)
Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
})
}
pub fn create_response<R: RngCore + CryptoRng>(
rng: &mut R,
peer: &Peer,
sender: u32, // sending identifier
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
) -> Result<KeyPair, HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// unpack state
let (receiver, eph_r_pk, hs, ck) = state;
msg.f_type.set(TYPE_RESPONSE as u32);
msg.f_sender.set(sender);
msg.f_receiver.set(receiver);
// (E_priv, E_pub) := DH-Generate()
let eph_sk = StaticSecret::new(rng);
let eph_pk = PublicKey::from(&eph_sk);
// C := Kdf1(C, E_pub)
let ck = KDF1!(&ck, eph_pk.as_bytes());
// msg.ephemeral := E_pub
msg.f_ephemeral = *eph_pk.as_bytes();
// H := Hash(H || msg.ephemeral)
let hs = HASH!(&hs, &msg.f_ephemeral);
// C := Kdf1(C, DH(E_priv, E_pub))
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
// C := Kdf1(C, DH(E_priv, S_pub))
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
// (C, tau, k) := Kdf3(C, Q)
let (ck, tau, key) = KDF3!(&ck, &peer.psk);
// H := Hash(H || tau)
let hs = HASH!(&hs, tau);
// msg.empty := Aead(k, 0, [], H)
SEAL!(
&key,
&hs, // ad
&[], // pt
&mut msg.f_empty // \epsilon || tag
);
// Not strictly needed
// let hs = HASH!(&hs, &msg.f_empty_tag);
// derive key-pair
let (key_recv, key_send) = KDF2!(&ck, &[]);
// return unconfirmed key-pair
Ok(KeyPair {
birth: Instant::now(),
initiator: false,
send: Key {
id: sender,
key: key_send.into(),
},
recv: Key {
id: receiver,
key: key_recv.into(),
},
})
})
}
/* The state lock is released while processing the message to
* allow concurrent processing of potential responses to the initiation,
* in order to better mitigate DoS from malformed response messages.
*/
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?;
let (hs, ck, sender, eph_sk) = match *peer.state.lock() {
State::InitiationSent {
hs,
ck,
sender,
ref eph_sk,
} => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))),
_ => Err(HandshakeError::InvalidState),
}?;
// C := Kdf1(C, E_pub)
let ck = KDF1!(&ck, &msg.f_ephemeral);
// H := Hash(H || msg.ephemeral)
let hs = HASH!(&hs, &msg.f_ephemeral);
// C := Kdf1(C, DH(E_priv, E_pub))
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
// C := Kdf1(C, DH(E_priv, S_pub))
let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
// (C, tau, k) := Kdf3(C, Q)
let (ck, tau, key) = KDF3!(&ck, &peer.psk);
// H := Hash(H || tau)
let hs = HASH!(&hs, tau);
// msg.empty := Aead(k, 0, [], H)
OPEN!(
&key,
&hs, // ad
&mut [], // pt
&msg.f_empty // \epsilon || tag
)?;
// derive key-pair
let birth = Instant::now();
let (key_send, key_recv) = KDF2!(&ck, &[]);
// check for new initiation sent while lock released
let mut state = peer.state.lock();
let update = match *state {
State::InitiationSent {
eph_sk: ref old, ..
} => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(),
_ => false,
};
if update {
// null the initiation state
// (to avoid replay of this response message)
*state = State::Reset;
// return confirmed key-pair
Ok((
Some(peer.pk),
None,
Some(KeyPair {
birth,
initiator: true,
send: Key {
id: sender,
key: key_send.into(),
},
recv: Key {
id: msg.f_sender.get(),
key: key_recv.into(),
},
}),
))
} else {
Err(HandshakeError::InvalidState)
}
})
}

View File

@@ -0,0 +1,142 @@
use spin::Mutex;
use std::mem;
use std::time::{Duration, Instant};
use generic_array::typenum::U32;
use generic_array::GenericArray;
use x25519_dalek::PublicKey;
use x25519_dalek::SharedSecret;
use x25519_dalek::StaticSecret;
use clear_on_drop::clear::Clear;
use super::device::Device;
use super::macs;
use super::timestamp;
use super::types::*;
const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
/* Represents the recomputation and state of a peer.
*
* This type is only for internal use and not exposed.
*/
pub struct Peer {
// mutable state
pub(crate) state: Mutex<State>,
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
pub(crate) last_initiation_consumption: Mutex<Option<Instant>>,
// state related to DoS mitigation fields
pub(crate) macs: Mutex<macs::Generator>,
// constant state
pub(crate) pk: PublicKey, // public key of peer
pub(crate) ss: SharedSecret, // precomputed DH(static, static)
pub(crate) psk: Psk, // psk of peer
}
pub enum State {
Reset,
InitiationSent {
sender: u32, // assigned sender id
eph_sk: StaticSecret,
hs: GenericArray<u8, U32>,
ck: GenericArray<u8, U32>,
},
}
impl Drop for State {
fn drop(&mut self) {
match self {
State::InitiationSent { hs, ck, .. } => {
// eph_sk already cleared by dalek-x25519
hs.clear();
ck.clear();
}
_ => (),
}
}
}
impl Peer {
pub fn new(
pk: PublicKey, // public key of peer
ss: SharedSecret, // precomputed DH(static, static)
) -> Self {
Self {
macs: Mutex::new(macs::Generator::new(pk)),
state: Mutex::new(State::Reset),
timestamp: Mutex::new(None),
last_initiation_consumption: Mutex::new(None),
pk: pk,
ss: ss,
psk: [0u8; 32],
}
}
/// Set the state of the peer unconditionally
///
/// # Arguments
///
pub fn set_state(&self, state_new: State) {
*self.state.lock() = state_new;
}
pub fn reset_state(&self) -> Option<u32> {
match mem::replace(&mut *self.state.lock(), State::Reset) {
State::InitiationSent { sender, .. } => Some(sender),
_ => None,
}
}
/// Set the mutable state of the peer conditioned on the timestamp being newer
///
/// # Arguments
///
/// * st_new - The updated state of the peer
/// * ts_new - The associated timestamp
pub fn check_replay_flood(
&self,
device: &Device,
timestamp_new: &timestamp::TAI64N,
) -> Result<(), HandshakeError> {
let mut state = self.state.lock();
let mut timestamp = self.timestamp.lock();
let mut last_initiation_consumption = self.last_initiation_consumption.lock();
// check replay attack
match *timestamp {
Some(timestamp_old) => {
if !timestamp::compare(&timestamp_old, &timestamp_new) {
return Err(HandshakeError::OldTimestamp);
}
}
_ => (),
};
// check flood attack
match *last_initiation_consumption {
Some(last) => {
if last.elapsed() < TIME_BETWEEN_INITIATIONS {
return Err(HandshakeError::InitiationFlood);
}
}
_ => (),
}
// reset state
match *state {
State::InitiationSent { sender, .. } => device.release(sender),
_ => (),
}
// update replay & flood protection
*state = State::Reset;
*timestamp = Some(*timestamp_new);
*last_initiation_consumption = Some(Instant::now());
Ok(())
}
}

View File

@@ -0,0 +1,199 @@
use spin;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};
const PACKETS_PER_SECOND: u64 = 20;
const PACKETS_BURSTABLE: u64 = 5;
const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
const GC_INTERVAL: Duration = Duration::from_secs(1);
struct Entry {
pub last_time: Instant,
pub tokens: u64,
}
pub struct RateLimiter(Arc<RateLimiterInner>);
struct RateLimiterInner {
gc_running: AtomicBool,
gc_dropped: (Mutex<bool>, Condvar),
table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
}
impl Drop for RateLimiter {
fn drop(&mut self) {
// wake up & terminate any lingering GC thread
let &(ref lock, ref cvar) = &self.0.gc_dropped;
let mut dropped = lock.lock().unwrap();
*dropped = true;
cvar.notify_all();
}
}
impl RateLimiter {
pub fn new() -> Self {
RateLimiter(Arc::new(RateLimiterInner {
gc_dropped: (Mutex::new(false), Condvar::new()),
gc_running: AtomicBool::from(false),
table: spin::RwLock::new(HashMap::new()),
}))
}
pub fn allow(&self, addr: &IpAddr) -> bool {
// check if allowed
let allowed = {
// check for existing entry (only requires read lock)
if let Some(entry) = self.0.table.read().get(addr) {
// update existing entry
let mut entry = entry.lock();
// add tokens earned since last time
entry.tokens = MAX_TOKENS
.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
entry.last_time = Instant::now();
// subtract cost of packet
if entry.tokens > PACKET_COST {
entry.tokens -= PACKET_COST;
return true;
} else {
return false;
}
}
// add new entry (write lock)
self.0.table.write().insert(
*addr,
spin::Mutex::new(Entry {
last_time: Instant::now(),
tokens: MAX_TOKENS - PACKET_COST,
}),
);
true
};
// check that GC thread is scheduled
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
let limiter = self.0.clone();
thread::spawn(move || {
let &(ref lock, ref cvar) = &limiter.gc_dropped;
let mut dropped = lock.lock().unwrap();
while !*dropped {
// garbage collect
{
let mut tw = limiter.table.write();
tw.retain(|_, ref mut entry| {
entry.lock().last_time.elapsed() <= GC_INTERVAL
});
if tw.len() == 0 {
limiter.gc_running.store(false, Ordering::Relaxed);
return;
}
}
// wait until stopped or new GC (~1 every sec)
let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap();
dropped = res.0;
}
});
}
allowed
}
}
#[cfg(test)]
mod tests {
use super::*;
use std;
struct Result {
allowed: bool,
text: &'static str,
wait: Duration,
}
#[test]
fn test_ratelimiter() {
let ratelimiter = RateLimiter::new();
let mut expected = vec![];
let ips = vec![
"127.0.0.1".parse().unwrap(),
"192.168.1.1".parse().unwrap(),
"172.167.2.3".parse().unwrap(),
"97.231.252.215".parse().unwrap(),
"248.97.91.167".parse().unwrap(),
"188.208.233.47".parse().unwrap(),
"104.2.183.179".parse().unwrap(),
"72.129.46.120".parse().unwrap(),
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
];
for _ in 0..PACKETS_BURSTABLE {
expected.push(Result {
allowed: true,
wait: Duration::new(0, 0),
text: "inital burst",
});
}
expected.push(Result {
allowed: false,
wait: Duration::new(0, 0),
text: "after burst",
});
expected.push(Result {
allowed: true,
wait: Duration::new(0, PACKET_COST as u32),
text: "filling tokens for single packet",
});
expected.push(Result {
allowed: false,
wait: Duration::new(0, 0),
text: "not having refilled enough",
});
expected.push(Result {
allowed: true,
wait: Duration::new(0, 2 * PACKET_COST as u32),
text: "filling tokens for 2 * packet burst",
});
expected.push(Result {
allowed: true,
wait: Duration::new(0, 0),
text: "second packet in 2 packet burst",
});
expected.push(Result {
allowed: false,
wait: Duration::new(0, 0),
text: "packet following 2 packet burst",
});
for item in expected {
std::thread::sleep(item.wait);
for ip in ips.iter() {
if ratelimiter.allow(&ip) != item.allowed {
panic!(
"test failed for {} on {}. expected: {}, got: {}",
ip, item.text, item.allowed, !item.allowed
)
}
}
}
}
}

View File

@@ -0,0 +1,32 @@
use std::time::{SystemTime, UNIX_EPOCH};
pub type TAI64N = [u8; 12];
const TAI64_EPOCH: u64 = 0x400000000000000a;
pub const ZERO: TAI64N = [0u8; 12];
pub fn now() -> TAI64N {
// get system time as duration
let sysnow = SystemTime::now();
let delta = sysnow.duration_since(UNIX_EPOCH).unwrap();
// convert to tai64n
let tai64_secs = delta.as_secs() + TAI64_EPOCH;
let tai64_nano = delta.subsec_nanos();
// serialize
let mut res = [0u8; 12];
res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]);
res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]);
res
}
pub fn compare(old: &TAI64N, new: &TAI64N) -> bool {
for i in 0..12 {
if new[i] > old[i] {
return true;
}
}
return false;
}

View File

@@ -0,0 +1,90 @@
use std::error::Error;
use std::fmt;
use x25519_dalek::PublicKey;
use super::super::types::KeyPair;
/* Internal types for the noise IKpsk2 implementation */
// config error
#[derive(Debug)]
pub struct ConfigError(String);
impl ConfigError {
pub fn new(s: &str) -> Self {
ConfigError(s.to_string())
}
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ConfigError({})", self.0)
}
}
impl Error for ConfigError {
fn description(&self) -> &str {
&self.0
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
// handshake error
#[derive(Debug)]
pub enum HandshakeError {
DecryptionFailure,
UnknownPublicKey,
UnknownReceiverId,
InvalidMessageFormat,
OldTimestamp,
InvalidState,
InvalidMac1,
RateLimited,
InitiationFlood,
}
impl fmt::Display for HandshakeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"),
HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"),
HandshakeError::UnknownReceiverId => {
write!(f, "Receiver id not allocated to any handshake")
}
HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"),
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
HandshakeError::InitiationFlood => {
write!(f, "Message was dropped because of initiation flood")
}
}
}
}
impl Error for HandshakeError {
fn description(&self) -> &str {
"Generic Handshake Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
pub type Output = (
Option<PublicKey>, // external identifier associated with peer
Option<Vec<u8>>, // message to send
Option<KeyPair>, // resulting key-pair of successful handshake
);
// preshared key
pub type Psk = [u8; 32];

23
src/wireguard/mod.rs Normal file
View File

@@ -0,0 +1,23 @@
mod wireguard;
// mod config;
mod constants;
mod timers;
mod handshake;
mod router;
mod types;
#[cfg(test)]
mod tests;
/// The WireGuard sub-module contains a pure, configurable implementation of WireGuard.
/// The implementation is generic over:
///
/// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface.
/// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint"
pub use wireguard::{Wireguard, Peer};
pub use types::bind;
pub use types::tun;
pub use types::Endpoint;

View File

@@ -0,0 +1,157 @@
use std::mem;
// Implementation of RFC 6479.
// https://tools.ietf.org/html/rfc6479
#[cfg(target_pointer_width = "64")]
type Word = u64;
#[cfg(target_pointer_width = "64")]
const REDUNDANT_BIT_SHIFTS: usize = 6;
#[cfg(target_pointer_width = "32")]
type Word = u32;
#[cfg(target_pointer_width = "32")]
const REDUNDANT_BIT_SHIFTS: usize = 5;
const SIZE_OF_WORD: usize = mem::size_of::<Word>() * 8;
const BITMAP_BITLEN: usize = 2048;
const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD);
const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1;
const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64;
const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64;
pub struct AntiReplay {
bitmap: [Word; BITMAP_LEN],
last: u64,
}
impl Default for AntiReplay {
fn default() -> Self {
AntiReplay::new()
}
}
impl AntiReplay {
pub fn new() -> Self {
debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD);
debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0);
AntiReplay {
last: 0,
bitmap: [0; BITMAP_LEN],
}
}
// Returns true if check is passed, i.e., not a replay or too old.
//
// Unlike RFC 6479, zero is allowed.
fn check(&self, seq: u64) -> bool {
// Larger is always good.
if seq > self.last {
return true;
}
if self.last - seq > WINDOW_SIZE {
return false;
}
let bit_location = seq & BITMAP_LOC_MASK;
let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK;
self.bitmap[index as usize] & (1 << bit_location) == 0
}
// Should only be called if check returns true.
fn update_store(&mut self, seq: u64) {
debug_assert!(self.check(seq));
let index = seq >> REDUNDANT_BIT_SHIFTS;
if seq > self.last {
let index_cur = self.last >> REDUNDANT_BIT_SHIFTS;
let diff = index - index_cur;
if diff >= BITMAP_LEN as u64 {
self.bitmap = [0; BITMAP_LEN];
} else {
for i in 0..diff {
let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK;
self.bitmap[real_index as usize] = 0;
}
}
self.last = seq;
}
let index = index & BITMAP_INDEX_MASK;
let bit_location = seq & BITMAP_LOC_MASK;
self.bitmap[index as usize] |= 1 << bit_location;
}
/// Checks and marks a sequence number in the replay filter
///
/// # Arguments
///
/// - seq: Sequence number check for replay and add to filter
///
/// # Returns
///
/// Ok(()) if sequence number is valid (not marked and not behind the moving window).
/// Err if the sequence number is invalid (already marked or "too old").
pub fn update(&mut self, seq: u64) -> bool {
if self.check(seq) {
self.update_store(seq);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn anti_replay() {
let mut ar = AntiReplay::new();
for i in 0..20000 {
assert!(ar.update(i));
}
for i in (0..20000).rev() {
assert!(!ar.check(i));
}
assert!(ar.update(65536));
for i in (65536 - WINDOW_SIZE)..65535 {
assert!(ar.update(i));
}
for i in (65536 - 10 * WINDOW_SIZE)..65535 {
assert!(!ar.check(i));
}
assert!(ar.update(66000));
for i in 65537..66000 {
assert!(ar.update(i));
}
for i in 65537..66000 {
assert_eq!(ar.update(i), false);
}
// Test max u64.
let next = u64::max_value();
assert!(ar.update(next));
assert!(!ar.check(next));
for i in (next - WINDOW_SIZE)..next {
assert!(ar.update(i));
}
for i in (next - 20 * WINDOW_SIZE)..next {
assert!(!ar.check(i));
}
}
}

View File

@@ -0,0 +1,7 @@
// WireGuard semantics constants
pub const MAX_STAGED_PACKETS: usize = 128;
// performance constants
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;

View File

@@ -0,0 +1,243 @@
use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel;
use std::sync::mpsc::SyncSender;
use std::sync::Arc;
use std::thread;
use std::time::Instant;
use log::debug;
use spin::{Mutex, RwLock};
use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
use super::super::types::{bind, tun, Endpoint, KeyPair};
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
// inbound writer (TUN)
pub inbound: T,
// outbound writer (Bind)
pub outbound: RwLock<Option<B>>,
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
// work queues
pub queue_next: AtomicUsize, // next round-robin index
pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
}
pub struct EncryptionState {
pub key: [u8; 32], // encryption key
pub id: u32, // receiver id
pub nonce: u64, // next available nonce
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
}
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>,
pub peer: Arc<PeerInner<E, C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
fn drop(&mut self) {
debug!("router: dropping device");
// drop all queues
{
let mut queues = self.state.queues.lock();
while queues.pop().is_some() {}
}
// join all worker threads
while match self.handles.pop() {
Some(handle) => {
handle.thread().unpark();
handle.join().unwrap();
true
}
_ => false,
} {}
debug!("router: device dropped");
}
}
#[inline(always)]
fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<E, C, T, B>>> {
// ensure version access within bounds
if packet.len() < 1 {
return None;
};
// cast to correct IP header
match packet[0] >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// lookup destination address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// lookup destination address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
_ => None,
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state
let inner = DeviceInner {
inbound: tun,
outbound: RwLock::new(None),
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
recv: RwLock::new(HashMap::new()),
ipv4: RwLock::new(IpLookupTable::new()),
ipv6: RwLock::new(IpLookupTable::new()),
};
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
inner.queues.lock().push(tx);
threads.push(thread::spawn(move || worker_parallel(rx)));
}
// return exported device handle
Device {
state: Arc::new(inner),
handles: threads,
}
}
/// A new secret key has been set for the device.
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
pub fn new_sk(&self) {}
/// Adds a new peer to the device
///
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
new_peer(self.state.clone(), opaque)
}
/// Cryptkey routes and sends a plaintext message (IP packet)
///
/// # Arguments
///
/// - msg: IP packet to crypt-key route
///
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
// ignore header prefix (for in-place transport message construction)
let packet = &msg[SIZE_MESSAGE_PREFIX..];
// lookup peer based on IP packet destination address
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
debug_assert_eq!(job.1.op, Operation::Encryption);
// add job to worker queue
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
let queues = self.state.queues.lock();
queues[idx % queues.len()].send(job).unwrap();
}
Ok(())
}
/// Receive an encrypted transport message
///
/// # Arguments
///
/// - src: Source address of the packet
/// - msg: Encrypted transport message
///
/// # Returns
///
///
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
// parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
None => {
return Err(RouterError::MalformedTransportMessage);
}
};
let header: LayoutVerified<&[u8], TransportHeader> = header;
debug_assert!(
header.f_type.get() == TYPE_TRANSPORT as u32,
"this should be checked by the message type multiplexer"
);
// lookup peer based on receiver id
let dec = self.state.recv.read();
let dec = dec
.get(&header.f_receiver.get())
.ok_or(RouterError::UnknownReceiverId)?;
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
debug_assert_eq!(job.1.op, Operation::Decryption);
// add job to worker queue
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
let queues = self.state.queues.lock();
queues[idx % queues.len()].send(job).unwrap();
}
Ok(())
}
/// Set outbound writer
///
///
pub fn set_outbound_writer(&self, new: B) {
*self.state.outbound.write() = Some(new);
}
}

View File

@@ -0,0 +1,26 @@
use byteorder::BigEndian;
use zerocopy::byteorder::U16;
use zerocopy::{AsBytes, FromBytes};
pub const VERSION_IP4: u8 = 4;
pub const VERSION_IP6: u8 = 6;
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct IPv4Header {
_f_space1: [u8; 2],
pub f_total_len: U16<BigEndian>,
_f_space2: [u8; 8],
pub f_source: [u8; 4],
pub f_destination: [u8; 4],
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct IPv6Header {
_f_space1: [u8; 4],
pub f_len: U16<BigEndian>,
_f_space2: [u8; 2],
pub f_source: [u8; 16],
pub f_destination: [u8; 16],
}

View File

@@ -0,0 +1,13 @@
use byteorder::LittleEndian;
use zerocopy::byteorder::{U32, U64};
use zerocopy::{AsBytes, FromBytes};
pub const TYPE_TRANSPORT: u32 = 4;
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct TransportHeader {
pub f_type: U32<LittleEndian>,
pub f_receiver: U32<LittleEndian>,
pub f_counter: U64<LittleEndian>,
}

View File

@@ -0,0 +1,22 @@
mod anti_replay;
mod constants;
mod device;
mod ip;
mod messages;
mod peer;
mod types;
mod workers;
#[cfg(test)]
mod tests;
use messages::TransportHeader;
use std::mem;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
pub use messages::TYPE_TRANSPORT;
pub use device::Device;
pub use peer::Peer;
pub use types::Callbacks;

View File

@@ -0,0 +1,611 @@
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::mpsc::{sync_channel, SyncSender};
use std::sync::Arc;
use std::thread;
use arraydeque::{ArrayDeque, Wrapping};
use log::debug;
use spin::Mutex;
use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::super::constants::*;
use super::super::types::{bind, tun, Endpoint, KeyPair};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::DeviceInner;
use super::device::EncryptionState;
use super::messages::TransportHeader;
use futures::*;
use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
use super::SIZE_MESSAGE_PREFIX;
use super::constants::*;
use super::types::{Callbacks, RouterError};
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
retired: Vec<u32>, // retired ids
}
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub device: Arc<DeviceInner<E, C, T, B>>,
pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>,
pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
pub ekey: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<E>>,
}
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
state: Arc<PeerInner<E, C, T, B>>,
thread_outbound: Option<thread::JoinHandle<()>>,
thread_inbound: Option<thread::JoinHandle<()>>,
}
fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Arc<PeerInner<E, C, T, B>>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> R>,
) -> Vec<R>
where
A: Address,
{
let mut res = Vec::new();
for subnet in table.read().iter() {
let (ip, masklen, p) = subnet;
if Arc::ptr_eq(&p, &peer) {
res.push(callback(ip, masklen))
}
}
res
}
fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Peer<E, C, T, B>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
) {
let mut m = table.write();
// collect keys for value
let mut subnets = vec![];
for subnet in m.iter() {
let (ip, masklen, p) = subnet;
if Arc::ptr_eq(&p, &peer.state) {
subnets.push((ip, masklen))
}
}
// remove all key mappings
for (ip, masklen) in subnets {
let r = m.remove(ip, masklen);
debug_assert!(r.is_some());
}
}
impl EncryptionState {
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
EncryptionState {
id: keypair.send.id,
key: keypair.send.key,
nonce: 0,
death: keypair.birth + REJECT_AFTER_TIME,
}
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
fn new(
peer: &Arc<PeerInner<E, C, T, B>>,
keypair: &Arc<KeyPair>,
) -> DecryptionState<E, C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
protector: spin::Mutex::new(AntiReplay::new()),
peer: peer.clone(),
death: keypair.birth + REJECT_AFTER_TIME,
}
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
// remove from cryptkey router
treebit_remove(self, &peer.device.ipv4);
treebit_remove(self, &peer.device.ipv6);
// drop channels
mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
// join with workers
mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
// release ids from the receiver map
let mut keys = peer.keys.lock();
let mut release = Vec::with_capacity(3);
keys.next.as_ref().map(|k| release.push(k.recv.id));
keys.current.as_ref().map(|k| release.push(k.recv.id));
keys.previous.as_ref().map(|k| release.push(k.recv.id));
if release.len() > 0 {
let mut recv = peer.device.recv.write();
for id in &release {
recv.remove(id);
}
}
// null key-material
keys.next = None;
keys.current = None;
keys.previous = None;
*peer.ekey.lock() = None;
*peer.endpoint.lock() = None;
debug!("peer dropped & removed from device");
}
}
pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>,
opaque: C::Opaque,
) -> Peer<E, C, T, B> {
let (out_tx, out_rx) = sync_channel(128);
let (in_tx, in_rx) = sync_channel(128);
// allocate peer object
let peer = {
let device = device.clone();
Arc::new(PeerInner {
opaque,
device,
inbound: Mutex::new(in_tx),
outbound: Mutex::new(out_tx),
ekey: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
keys: spin::Mutex::new(KeyWheel {
next: None,
current: None,
previous: None,
retired: vec![],
}),
staged_packets: spin::Mutex::new(ArrayDeque::new()),
})
};
// spawn outbound thread
let thread_inbound = {
let peer = peer.clone();
let device = device.clone();
thread::spawn(move || worker_outbound(device, peer, out_rx))
};
// spawn inbound thread
let thread_outbound = {
let peer = peer.clone();
let device = device.clone();
thread::spawn(move || worker_inbound(device, peer, in_rx))
};
Peer {
state: peer,
thread_inbound: Some(thread_inbound),
thread_outbound: Some(thread_outbound),
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false;
let mut staged = self.staged_packets.lock();
loop {
match staged.pop_front() {
Some(msg) => {
sent = true;
self.send_raw(msg);
}
None => break sent,
}
}
}
// Treat the msg as the payload of a transport message
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
fn send_raw(&self, msg: Vec<u8>) -> bool {
debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => {
debug!("send_raw: got obtained send_job");
let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
let queues = self.device.queues.lock();
match queues[index % queues.len()].send(job) {
Ok(_) => true,
Err(_) => false,
}
}
None => false,
}
}
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
debug!("peer.confirm_key");
{
// take lock and check keypair = keys.next
let mut keys = self.keys.lock();
let next = match keys.next.as_ref() {
Some(next) => next,
None => {
return;
}
};
if !Arc::ptr_eq(&next, keypair) {
return;
}
// allocate new encryption state
let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel
let mut swap = None;
mem::swap(&mut keys.next, &mut swap);
mem::swap(&mut keys.current, &mut swap);
mem::swap(&mut keys.previous, &mut swap);
// tell the world outside the router that a key was confirmed
C::key_confirmed(&self.opaque);
// set new key for encryption
*self.ekey.lock() = ekey;
}
// start transmission of staged packets
self.send_staged();
}
pub fn recv_job(
&self,
src: E,
dec: Arc<DecryptionState<E, C, T, B>>,
msg: Vec<u8>,
) -> Option<JobParallel> {
let (tx, rx) = oneshot();
let key = dec.keypair.recv.key;
match self.inbound.lock().try_send((dec, src, rx)) {
Ok(_) => Some((
tx,
JobBuffer {
msg,
key: key,
okay: false,
op: Operation::Decryption,
},
)),
Err(_) => None,
}
}
pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
debug!("peer.send_job");
debug_assert!(
msg.len() >= mem::size_of::<TransportHeader>(),
"received message with size: {:}",
msg.len()
);
// parse / cast
let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
let mut header: LayoutVerified<&mut [u8], TransportHeader> = header;
// check if has key
let key = {
let mut ekey = self.ekey.lock();
let key = match ekey.as_mut() {
None => None,
Some(mut state) => {
// avoid integer overflow in nonce
if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
*ekey = None;
None
} else {
// there should be no stacked packets lingering around
debug!("encryption state available, nonce = {}", state.nonce);
// set transport message fields
header.f_counter.set(state.nonce);
header.f_receiver.set(state.id);
state.nonce += 1;
Some(state.key)
}
}
};
// If not suitable key was found:
// 1. Stage packet for later transmission
// 2. Request new key
if key.is_none() && stage {
self.staged_packets.lock().push_back(msg);
C::need_key(&self.opaque);
return None;
};
key
}?;
// add job to in-order queue and return sendeer to device for inclusion in worker pool
let (tx, rx) = oneshot();
match self.outbound.lock().try_send(rx) {
Ok(_) => Some((
tx,
JobBuffer {
msg,
key,
okay: false,
op: Operation::Encryption,
},
)),
Err(_) => None,
}
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
/// Set the endpoint of the peer
///
/// # Arguments
///
/// - `endpoint`, socket address converted to bind endpoint
///
/// # Note
///
/// This API still permits support for the "sticky socket" behavior,
/// as sockets should be "unsticked" when manually updating the endpoint
pub fn set_endpoint(&self, endpoint: E) {
debug!("peer.set_endpoint");
*self.state.endpoint.lock() = Some(endpoint);
}
/// Returns the current endpoint of the peer (for configuration)
///
/// # Note
///
/// Does not convey potential "sticky socket" information
pub fn get_endpoint(&self) -> Option<SocketAddr> {
debug!("peer.get_endpoint");
self.state
.endpoint
.lock()
.as_ref()
.map(|e| e.into_address())
}
/// Zero all key-material related to the peer
pub fn zero_keys(&self) {
debug!("peer.zero_keys");
let mut release: Vec<u32> = Vec::with_capacity(3);
let mut keys = self.state.keys.lock();
// update key-wheel
mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
keys.retired.extend(&release[..]);
// update inbound "recv" map
{
let mut recv = self.state.device.recv.write();
for id in release {
recv.remove(&id);
}
}
// clear encryption state
*self.state.ekey.lock() = None;
}
/// Add a new keypair
///
/// # Arguments
///
/// - new: The new confirmed/unconfirmed key pair
///
/// # Returns
///
/// A vector of ids which has been released.
/// These should be released in the handshake module.
///
/// # Note
///
/// The number of ids to be released can be at most 3,
/// since the only way to add additional keys to the peer is by using this method
/// and a peer can have at most 3 keys allocated in the router at any time.
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
debug!("peer.add_keypair");
let initiator = new.initiator;
let release = {
let new = Arc::new(new);
let mut keys = self.state.keys.lock();
let mut release = mem::replace(&mut keys.retired, vec![]);
// update key-wheel
if new.initiator {
// start using key for encryption
*self.state.ekey.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
keys.current = Some(new.clone());
} else {
// store the key and await confirmation
keys.previous = keys.next.as_ref().map(|v| v.clone());
keys.next = Some(new.clone());
};
// update incoming packet id map
{
debug!("peer.add_keypair: updating inbound id map");
let mut recv = self.state.device.recv.write();
// purge recv map of previous id
keys.previous.as_ref().map(|k| {
recv.remove(&k.local_id());
release.push(k.local_id());
});
// map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
recv.insert(
new.recv.id,
Arc::new(DecryptionState::new(&self.state, &new)),
);
}
release
};
// schedule confirmation
if initiator {
debug_assert!(self.state.ekey.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
if !self.state.send_staged() {
// fall back to keepalive packet
let ok = self.send_keepalive();
debug!(
"peer.add_keypair: keepalive for confirmation, sent = {}",
ok
);
}
debug!("peer.add_keypair: key attempted confirmed");
}
debug_assert!(
release.len() <= 3,
"since the key-wheel contains at most 3 keys"
);
release
}
pub fn send_keepalive(&self) -> bool {
debug!("peer.send_keepalive");
self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
}
/// Map a subnet to the peer
///
/// # Arguments
///
/// - `ip`, the mask of the subnet
/// - `masklen`, the length of the mask
///
/// # Note
///
/// The `ip` must not have any bits set right of `masklen`.
/// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not.
///
/// If an identical value already exists as part of a prior peer,
/// the allowed IP entry will be removed from that peer and added to this peer.
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
debug!("peer.add_subnet");
match ip {
IpAddr::V4(v4) => {
self.state
.device
.ipv4
.write()
.insert(v4, masklen, self.state.clone())
}
IpAddr::V6(v6) => {
self.state
.device
.ipv6
.write()
.insert(v6, masklen, self.state.clone())
}
};
}
/// List subnets mapped to the peer
///
/// # Returns
///
/// A vector of subnets, represented by as mask/size
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
debug!("peer.list_subnets");
let mut res = Vec::new();
res.append(&mut treebit_list(
&self.state,
&self.state.device.ipv4,
Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
));
res.append(&mut treebit_list(
&self.state,
&self.state.device.ipv6,
Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
));
res
}
/// Clear subnets mapped to the peer.
/// After the call, no subnets will be cryptkey routed to the peer.
/// Used for the UAPI command "replace_allowed_ips=true"
pub fn remove_subnets(&self) {
debug!("peer.remove_subnets");
treebit_remove(self, &self.state.device.ipv4);
treebit_remove(self, &self.state.device.ipv6);
}
/// Send a raw message to the peer (used for handshake messages)
///
/// # Arguments
///
/// - `msg`, message body to send to peer
///
/// # Returns
///
/// Unit if packet was sent, or an error indicating why sending failed
pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
let inner = &self.state;
match inner.endpoint.lock().as_ref() {
Some(endpoint) => inner
.device
.outbound
.read()
.as_ref()
.ok_or(RouterError::SendError)
.and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
None => Err(RouterError::NoEndpoint),
}
}
pub fn purge_staged_packets(&self) {
self.state.staged_packets.lock().clear();
}
}

View File

@@ -0,0 +1,432 @@
use std::net::IpAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::types::bind::*;
use super::super::types::*;
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
const SIZE_KEEPALIVE: usize = 32;
#[cfg(test)]
mod tests {
use super::*;
use env_logger;
use log::debug;
use std::sync::atomic::AtomicUsize;
use test::Bencher;
// type for tracking events inside the router module
struct Flags {
send: Mutex<Vec<(usize, bool, bool)>>,
recv: Mutex<Vec<(usize, bool, bool)>>,
need_key: Mutex<Vec<()>>,
key_confirmed: Mutex<Vec<()>>,
}
#[derive(Clone)]
struct Opaque(Arc<Flags>);
struct TestCallbacks();
impl Opaque {
fn new() -> Opaque {
Opaque(Arc::new(Flags {
send: Mutex::new(vec![]),
recv: Mutex::new(vec![]),
need_key: Mutex::new(vec![]),
key_confirmed: Mutex::new(vec![]),
}))
}
fn reset(&self) {
self.0.send.lock().unwrap().clear();
self.0.recv.lock().unwrap().clear();
self.0.need_key.lock().unwrap().clear();
self.0.key_confirmed.lock().unwrap().clear();
}
fn send(&self) -> Option<(usize, bool, bool)> {
self.0.send.lock().unwrap().pop()
}
fn recv(&self) -> Option<(usize, bool, bool)> {
self.0.recv.lock().unwrap().pop()
}
fn need_key(&self) -> Option<()> {
self.0.need_key.lock().unwrap().pop()
}
fn key_confirmed(&self) -> Option<()> {
self.0.key_confirmed.lock().unwrap().pop()
}
// has all events been accounted for by assertions?
fn is_empty(&self) -> bool {
let send = self.0.send.lock().unwrap();
let recv = self.0.recv.lock().unwrap();
let need_key = self.0.need_key.lock().unwrap();
let key_confirmed = self.0.key_confirmed.lock().unwrap();
send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty()
}
}
impl Callbacks for TestCallbacks {
type Opaque = Opaque;
fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
t.0.send.lock().unwrap().push((size, data, sent))
}
fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
t.0.recv.lock().unwrap().push((size, data, sent))
}
fn need_key(t: &Self::Opaque) {
t.0.need_key.lock().unwrap().push(());
}
fn key_confirmed(t: &Self::Opaque) {
t.0.key_confirmed.lock().unwrap().push(());
}
}
// wait for scheduling
fn wait() {
thread::sleep(Duration::from_millis(50));
}
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> {
// create "IP packet"
let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16);
msg.resize(SIZE_MESSAGE_PREFIX + size, 0);
match ip {
IpAddr::V4(ip) => {
let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
packet.set_destination(ip);
packet.set_version(4);
}
IpAddr::V6(ip) => {
let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
packet.set_destination(ip);
packet.set_version(6);
}
}
msg
}
#[bench]
fn bench_outbound(b: &mut Bencher) {
struct BencherCallbacks {}
impl Callbacks for BencherCallbacks {
type Opaque = Arc<AtomicUsize>;
fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
t.fetch_add(size, Ordering::SeqCst);
}
fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
fn need_key(_: &Self::Opaque) {}
fn key_confirmed(_: &Self::Opaque) {}
}
// create device
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer);
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
let peer = router.new_peer(opaque.clone());
peer.add_keypair(dummy::keypair(true));
// add subnet to peer
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap();
let ip1: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len);
// every iteration sends 10 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
let msg = make_packet(1024, ip1);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
router.send(msg.to_vec()).unwrap();
}
});
}
#[test]
fn test_outbound() {
init();
// create device
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
("172.133.133.133", 32, "172.133.133.132", false),
(
"2001:db8::ff00:42:0000",
112,
"2001:db8::ff00:42:3242",
true,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:0660",
false,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
];
for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
for set_key in vec![true, false] {
debug!("index = {}, set_key = {}", num, set_key);
// add new peer
let opaque = Opaque::new();
let peer = router.new_peer(opaque.clone());
let mask: IpAddr = mask.parse().unwrap();
if set_key {
peer.add_keypair(dummy::keypair(true));
}
// map subnet to peer
peer.add_subnet(mask, *len);
// create "IP packet"
let msg = make_packet(1024, ip.parse().unwrap());
// cryptkey route the IP packet
let res = router.send(msg);
// allow some scheduling
wait();
if *okay {
// cryptkey routing succeeded
assert!(res.is_ok(), "crypt-key routing should succeed");
assert_eq!(
opaque.need_key().is_some(),
!set_key,
"should have requested a new key, if no encryption state was set"
);
assert_eq!(
opaque.send().is_some(),
set_key,
"transmission should have been attempted"
);
assert!(
opaque.recv().is_none(),
"no messages should have been marked as received"
);
} else {
// no such cryptkey route
assert!(res.is_err(), "crypt-key routing should fail");
assert!(
opaque.need_key().is_none(),
"should not request a new-key if crypt-key routing failed"
);
assert_eq!(
opaque.send(),
if set_key {
Some((SIZE_KEEPALIVE, false, false))
} else {
None
},
"transmission should only happen if key was set (keepalive)",
);
assert!(
opaque.recv().is_none(),
"no messages should have been marked as received",
);
}
}
}
}
#[test]
fn test_bidirectional() {
init();
let tests = [
(
false, // confirm with keepalive
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
),
(
true, // confirm with staged packet
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
),
(
false, // confirm with keepalive
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
(
"2001:db8::ff40:42:8000",
113,
"2001:db8::ff40:42:ffff",
true,
),
),
(
false, // confirm with staged packet
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
(
"2001:db8::ff40:42:8000",
113,
"2001:db8::ff40:42:ffff",
true,
),
),
];
for (stage, p1, p2) in tests.iter() {
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
dummy::PairBind::pair();
// create matching device
let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1);
let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
router2.set_outbound_writer(bind_writer2);
// prepare opaque values for tracing callbacks
let opaq1 = Opaque::new();
let opaq2 = Opaque::new();
// create peers with matching keypairs and assign subnets
let (mask, len, _ip, _okay) = p1;
let peer1 = router1.new_peer(opaq1.clone());
let mask: IpAddr = mask.parse().unwrap();
peer1.add_subnet(mask, *len);
peer1.add_keypair(dummy::keypair(false));
let (mask, len, _ip, _okay) = p2;
let peer2 = router2.new_peer(opaq2.clone());
let mask: IpAddr = mask.parse().unwrap();
peer2.add_subnet(mask, *len);
peer2.set_endpoint(dummy::UnitEndpoint::new());
if *stage {
// stage a packet which can be used for confirmation (in place of a keepalive)
let (_mask, _len, ip, _okay) = p2;
let msg = make_packet(1024, ip.parse().unwrap());
router2.send(msg).expect("failed to sent staged packet");
wait();
assert!(opaq2.recv().is_none());
assert!(
opaq2.send().is_none(),
"sending should fail as not key is set"
);
assert!(
opaq2.need_key().is_some(),
"a new key should be requested since a packet was attempted transmitted"
);
assert!(opaq2.is_empty(), "callbacks should only run once");
}
// this should cause a key-confirmation packet (keepalive or staged packet)
// this also causes peer1 to learn the "endpoint" for peer2
assert!(peer1.get_endpoint().is_none());
peer2.add_keypair(dummy::keypair(true));
wait();
assert!(opaq2.send().is_some());
assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
assert!(opaq1.is_empty(), "nothing should happened on peer1");
// read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
let (len, from) = bind_reader1.read(&mut buf).unwrap();
buf.truncate(len);
router1.recv(from, buf).unwrap();
wait();
assert!(opaq1.recv().is_some());
assert!(opaq1.key_confirmed().is_some());
assert!(
opaq1.is_empty(),
"events on peer1 should be 'recv' and 'key_confirmed'"
);
assert!(peer1.get_endpoint().is_some());
assert!(opaq2.is_empty(), "nothing should happened on peer2");
// now that peer1 has an endpoint
// route packets : peer1 -> peer2
for _ in 0..10 {
assert!(
opaq1.is_empty(),
"we should have asserted a value for every callback on peer1"
);
assert!(
opaq2.is_empty(),
"we should have asserted a value for every callback on peer2"
);
// pass IP packet to router
let (_mask, _len, ip, _okay) = p1;
let msg = make_packet(1024, ip.parse().unwrap());
router1.send(msg).unwrap();
wait();
assert!(opaq1.send().is_some());
assert!(opaq1.recv().is_none());
assert!(opaq1.need_key().is_none());
// receive ("across the internet") on the other end
let mut buf = vec![0u8; 2048];
let (len, from) = bind_reader2.read(&mut buf).unwrap();
buf.truncate(len);
router2.recv(from, buf).unwrap();
wait();
assert!(opaq2.send().is_none());
assert!(opaq2.recv().is_some());
assert!(opaq2.need_key().is_none());
}
}
}
}

View File

@@ -0,0 +1,65 @@
use std::error::Error;
use std::fmt;
pub trait Opaque: Send + Sync + 'static {}
impl<T> Opaque for T where T: Send + Sync + 'static {}
/// A send/recv callback takes 3 arguments:
///
/// * `0`, a reference to the opaque value assigned to the peer
/// * `1`, a bool indicating whether the message contained data (not just keepalive)
/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
/// A key callback takes 1 argument
///
/// * `0`, a reference to the opaque value assigned to the peer
pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque;
fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
fn need_key(opaque: &Self::Opaque);
fn key_confirmed(opaque: &Self::Opaque);
}
#[derive(Debug)]
pub enum RouterError {
NoCryptKeyRoute,
MalformedIPHeader,
MalformedTransportMessage,
UnknownReceiverId,
NoEndpoint,
SendError,
}
impl fmt::Display for RouterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
RouterError::UnknownReceiverId => {
write!(f, "No decryption state associated with receiver id")
}
RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
RouterError::SendError => write!(f, "Failed to send packet on bind"),
}
}
}
impl Error for RouterError {
fn description(&self) -> &str {
"Generic Handshake Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}

View File

@@ -0,0 +1,305 @@
use std::mem;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use futures::sync::oneshot;
use futures::*;
use log::debug;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified};
use super::device::{DecryptionState, DeviceInner};
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
use super::super::types::{Endpoint, tun, bind};
use super::ip::*;
const SIZE_TAG: usize = 16;
#[derive(PartialEq, Debug)]
pub enum Operation {
Encryption,
Decryption,
}
pub struct JobBuffer {
pub msg: Vec<u8>, // message buffer (nonce and receiver id set)
pub key: [u8; 32], // chacha20poly1305 key
pub okay: bool, // state of the job
pub op: Operation, // should be buffer be encrypted / decrypted?
}
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
#[allow(type_alias_bounds)]
pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
Arc<DecryptionState<E, C, T, B>>,
E,
oneshot::Receiver<JobBuffer>,
);
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet[0] >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_total_len.get() as usize)
} else {
None
}
})
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None
}
})
}
_ => None,
}
}
pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>,
) {
loop {
// fetch job
let (state, endpoint, rx) = match receiver.recv() {
Ok(v) => v,
_ => {
return;
}
};
debug!("inbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
debug!("inbound worker: job complete");
if buf.okay {
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
match LayoutVerified::new_from_prefix(&buf.msg[..]) {
Some(v) => v,
None => {
debug!("inbound worker: failed to parse message");
return;
}
};
debug_assert!(
packet.len() >= CHACHA20_POLY1305.tag_len(),
"this should be checked earlier in the pipeline (decryption should fail)"
);
// check for replay
if !state.protector.lock().update(header.f_counter.get()) {
debug!("inbound worker: replay detected");
return;
}
// check for confirms key
if !state.confirmed.swap(true, Ordering::SeqCst) {
debug!("inbound worker: message confirms key");
peer.confirm_key(&state.keypair);
}
// update endpoint
*peer.endpoint.lock() = Some(endpoint);
// calculate length of IP packet + padding
let length = packet.len() - SIZE_TAG;
debug!("inbound worker: plaintext length = {}", length);
// check if should be written to TUN
let mut sent = false;
if length > 0 {
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
debug_assert!(inner_len <= length, "should be validated");
if inner_len <= length {
sent = match device.inbound.write(&packet[..inner_len]) {
Err(e) => {
debug!("failed to write inbound packet to TUN: {:?}", e);
false
}
Ok(_) => true,
}
}
}
} else {
debug!("inbound worker: received keepalive")
}
// trigger callback
C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
} else {
debug!("inbound worker: authentication failure")
}
})
.wait();
}
}
pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>,
) {
loop {
// fetch job
let rx = match receiver.recv() {
Ok(v) => v,
_ => {
return;
}
};
debug!("outbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
debug!("outbound worker: job complete");
if buf.okay {
// write to UDP bind
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
let send : &Option<B> = &*device.outbound.read();
if let Some(writer) = send.as_ref() {
match writer.write(&buf.msg[..], dst) {
Err(e) => {
debug!("failed to send outbound packet: {:?}", e);
false
}
Ok(_) => true,
}
} else {
false
}
} else {
false
};
// trigger callback
C::send(
&peer.opaque,
buf.msg.len(),
buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
xmit,
);
}
})
.wait();
}
}
pub fn worker_parallel(receiver: Receiver<JobParallel>) {
loop {
// fetch next job
let (tx, mut buf) = match receiver.recv() {
Err(_) => {
return;
}
Ok(val) => val,
};
debug!("parallel worker: obtained job");
// make space for tag (TODO: consider moving this out)
if buf.op == Operation::Encryption {
buf.msg.extend([0u8; SIZE_TAG].iter());
}
// cast and check size of packet
let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut buf.msg[..]) {
Some(v) => v,
None => {
debug_assert!(
false,
"parallel worker: failed to parse message (insufficient size)"
);
continue;
}
};
debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
// do the weird ring AEAD dance
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
match buf.op {
Operation::Encryption => {
debug!("parallel worker: process encryption");
// set the type field
header.f_type.set(TYPE_TRANSPORT);
// encrypt content of transport message in-place
let end = packet.len() - SIZE_TAG;
let tag = key
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
.unwrap();
// append tag
packet[end..].copy_from_slice(tag.as_ref());
buf.okay = true;
}
Operation::Decryption => {
debug!("parallel worker: process decryption");
// opening failure is signaled by fault state
buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) {
Ok(_) => true,
Err(_) => false,
};
}
}
// pass ownership to consumer
let okay = tx.send(buf);
debug!(
"parallel worker: passing ownership to sequential worker: {}",
okay.is_ok()
);
}
}

46
src/wireguard/tests.rs Normal file
View File

@@ -0,0 +1,46 @@
use super::types::tun::Tun;
use super::types::{bind, dummy, tun};
use super::wireguard::Wireguard;
use std::thread;
use std::time::Duration;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
/* Create and configure two matching pure instances of WireGuard
*
*/
#[test]
fn test_pure_wireguard() {
init();
// create WG instances for fake TUN devices
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
// create pair bind to connect the interfaces "over the internet"
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
wg1.set_writer(bind_writer1);
wg2.set_writer(bind_writer2);
wg1.add_reader(bind_reader1);
wg2.add_reader(bind_reader2);
// generate (public, pivate) key pairs
// configure cryptkey router
// create IP packets
thread::sleep(Duration::from_millis(500));
}

234
src/wireguard/timers.rs Normal file
View File

@@ -0,0 +1,234 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use log::info;
use hjul::{Runner, Timer};
use super::constants::*;
use super::router::Callbacks;
use super::types::{bind, tun};
use super::wireguard::{Peer, PeerInner};
pub struct Timers {
handshake_pending: AtomicBool,
handshake_attempts: AtomicUsize,
retransmit_handshake: Timer,
send_keepalive: Timer,
send_persistent_keepalive: Timer,
sent_lastminute_handshake: AtomicBool,
zero_key_material: Timer,
new_handshake: Timer,
need_another_keepalive: AtomicBool,
}
impl Timers {
#[inline(always)]
fn need_another_keepalive(&self) -> bool {
self.need_another_keepalive.swap(false, Ordering::SeqCst)
}
}
impl <T: tun::Tun, B: bind::Bind>Peer<T, B> {
/* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) {
self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
}
/* should be called after an authenticated data packet is received */
pub fn timers_data_received(&self) {
if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
self.timers().need_another_keepalive.store(true, Ordering::SeqCst)
}
}
/* Should be called after any type of authenticated packet is sent, whether:
* - keepalive
* - data
* - handshake
*/
pub fn timers_any_authenticated_packet_sent(&self) {
self.timers().send_keepalive.stop()
}
/* Should be called after any type of authenticated packet is received, whether:
* - keepalive
* - data
* - handshake
*/
pub fn timers_any_authenticated_packet_received(&self) {
self.timers().new_handshake.stop();
}
/* Should be called after a handshake initiation message is sent. */
pub fn timers_handshake_initiated(&self) {
self.timers().send_keepalive.stop();
self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
}
/* Should be called after a handshake response message is received and processed
* or when getting key confirmation via the first data message.
*/
pub fn timers_handshake_complete(&self) {
self.timers().handshake_attempts.store(0, Ordering::SeqCst);
self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst);
// TODO: Store time in peer for config
// self.walltime_last_handshake
}
/* Should be called after an ephemeral key is created, which is before sending a
* handshake response or after receiving a handshake response.
*/
pub fn timers_session_derived(&self) {
self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3);
}
/* Should be called before a packet with authentication, whether
* keepalive, data, or handshake is sent, or after one is received.
*/
pub fn timers_any_authenticated_packet_traversal(&self) {
let keepalive = self.state.keepalive.load(Ordering::Acquire);
if keepalive > 0 {
self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
}
}
}
impl Timers {
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
where
T: tun::Tun,
B: bind::Bind,
{
// create a timer instance for the provided peer
Timers {
handshake_pending: AtomicBool::new(false),
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: {
let peer = peer.clone();
runner.timer(move || {
if peer.timers().handshake_retry() {
info!("Retransmit handshake for {}", peer);
peer.new_handshake();
} else {
info!("Failed to complete handshake for {}", peer);
peer.router.purge_staged_packets();
peer.timers().send_keepalive.stop();
peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
}
})
},
send_keepalive: {
let peer = peer.clone();
runner.timer(move || {
peer.router.send_keepalive();
if peer.timers().need_another_keepalive() {
peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT);
}
})
},
new_handshake: {
let peer = peer.clone();
runner.timer(move || {
info!(
"Retrying handshake with {}, because we stopped hearing back after {} seconds",
peer,
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
);
peer.new_handshake();
peer.timers.read().handshake_begun();
})
},
zero_key_material: {
let peer = peer.clone();
runner.timer(move || {
peer.router.zero_keys();
})
},
send_persistent_keepalive: {
let peer = peer.clone();
runner.timer(move || {
let keepalive = peer.state.keepalive.load(Ordering::Acquire);
if keepalive > 0 {
peer.router.send_keepalive();
peer.timers().send_keepalive.stop();
peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
}
})
}
}
}
fn handshake_begun(&self) {
self.handshake_pending.store(true, Ordering::SeqCst);
self.handshake_attempts.store(0, Ordering::SeqCst);
self.retransmit_handshake.reset(REKEY_TIMEOUT);
}
fn handshake_retry(&self) -> bool {
if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES {
self.retransmit_handshake.reset(REKEY_TIMEOUT);
true
} else {
self.handshake_pending.store(false, Ordering::SeqCst);
false
}
}
pub fn updated_persistent_keepalive(&self, keepalive: usize) {
if keepalive > 0 {
self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
}
}
pub fn dummy(runner: &Runner) -> Timers {
Timers {
handshake_pending: AtomicBool::new(false),
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: runner.timer(|| {}),
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
send_persistent_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {})
}
}
pub fn handshake_sent(&self) {
self.send_keepalive.stop();
}
}
/* Instance of the router callbacks */
pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<B>>;
fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
}
fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
}
fn need_key(peer: &Self::Opaque) {
let timers = peer.timers();
if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
timers.handshake_attempts.store(0, Ordering::SeqCst);
timers.new_handshake.fire();
}
}
fn key_confirmed(peer: &Self::Opaque) {
peer.timers().retransmit_handshake.stop();
}
}

View File

@@ -0,0 +1,23 @@
use super::Endpoint;
use std::error::Error;
pub trait Reader<E: Endpoint>: Send + Sync {
type Error: Error;
fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
}
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
type Error: Error;
fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
}
pub trait Bind: Send + Sync + 'static {
type Error: Error;
type Endpoint: Endpoint;
/* Until Rust gets type equality constraints these have to be generic */
type Writer: Writer<Self::Endpoint>;
type Reader: Reader<Self::Endpoint>;
}

View File

@@ -0,0 +1,323 @@
use std::error::Error;
use std::fmt;
use std::marker;
use std::net::SocketAddr;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Instant;
use std::sync::atomic::{Ordering, AtomicUsize};
use super::*;
/* This submodule provides pure/dummy implementations of the IO interfaces
* for use in unit tests thoughout the project.
*/
/* Error implementation */
#[derive(Debug)]
pub enum BindError {
Disconnected,
}
impl Error for BindError {
fn description(&self) -> &str {
"Generic Bind Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BindError::Disconnected => write!(f, "PairBind disconnected"),
}
}
}
/* TUN implementation */
#[derive(Debug)]
pub enum TunError {
Disconnected
}
impl Error for TunError {
fn description(&self) -> &str {
"Generic Tun Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for TunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Not Possible")
}
}
/* Endpoint implementation */
#[derive(Clone, Copy)]
pub struct UnitEndpoint {}
impl Endpoint for UnitEndpoint {
fn from_address(_: SocketAddr) -> UnitEndpoint {
UnitEndpoint {}
}
fn into_address(&self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap()
}
fn clear_src(&self) {}
}
impl UnitEndpoint {
pub fn new() -> UnitEndpoint {
UnitEndpoint {}
}
}
/* */
pub struct TunTest {}
pub struct TunFakeIO {
store: bool,
tx: SyncSender<Vec<u8>>,
rx: Receiver<Vec<u8>>
}
pub struct TunReader {
rx: Receiver<Vec<u8>>
}
pub struct TunWriter {
store: bool,
tx: Mutex<SyncSender<Vec<u8>>>
}
#[derive(Clone)]
pub struct TunMTU {
mtu: Arc<AtomicUsize>
}
impl tun::Reader for TunReader {
type Error = TunError;
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
match self.rx.recv() {
Ok(m) => {
buf[offset..].copy_from_slice(&m[..]);
Ok(m.len())
}
Err(_) => Err(TunError::Disconnected)
}
}
}
impl tun::Writer for TunWriter {
type Error = TunError;
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
if self.store {
let m = src.to_owned();
match self.tx.lock().unwrap().send(m) {
Ok(_) => Ok(()),
Err(_) => Err(TunError::Disconnected)
}
} else {
Ok(())
}
}
}
impl tun::MTU for TunMTU {
fn mtu(&self) -> usize {
self.mtu.load(Ordering::Acquire)
}
}
impl tun::Tun for TunTest {
type Writer = TunWriter;
type Reader = TunReader;
type MTU = TunMTU;
type Error = TunError;
}
impl TunFakeIO {
pub fn write(&self, msg : Vec<u8>) {
if self.store {
self.tx.send(msg).unwrap();
}
}
pub fn read(&self) -> Vec<u8> {
self.rx.recv().unwrap()
}
}
impl TunTest {
pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
let fake = TunFakeIO{tx: tx1, rx: rx2, store};
let reader = TunReader{rx : rx1};
let writer = TunWriter{tx : Mutex::new(tx2), store};
let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
(fake, reader, writer, mtu)
}
}
/* Void Bind */
#[derive(Clone, Copy)]
pub struct VoidBind {}
impl bind::Reader<UnitEndpoint> for VoidBind {
type Error = BindError;
fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
}
impl bind::Writer<UnitEndpoint> for VoidBind {
type Error = BindError;
fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
Ok(())
}
}
impl bind::Bind for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
type Reader = VoidBind;
type Writer = VoidBind;
}
impl VoidBind {
pub fn new() -> VoidBind {
VoidBind {}
}
}
/* Pair Bind */
#[derive(Clone)]
pub struct PairReader<E> {
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
type Error = BindError;
fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
let vec = self
.recv
.lock()
.unwrap()
.recv()
.map_err(|_| BindError::Disconnected)?;
let len = vec.len();
buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
}
impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError;
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
Ok(_) => Ok(()),
}
}
}
#[derive(Clone)]
pub struct PairWriter<E> {
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
#[derive(Clone)]
pub struct PairBind {}
impl PairBind {
pub fn pair<E>() -> (
(PairReader<E>, PairWriter<E>),
(PairReader<E>, PairWriter<E>),
) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
(
PairReader {
recv: Arc::new(Mutex::new(rx1)),
_marker: marker::PhantomData,
},
PairWriter {
send: Arc::new(Mutex::new(tx2)),
_marker: marker::PhantomData,
},
),
(
PairReader {
recv: Arc::new(Mutex::new(rx2)),
_marker: marker::PhantomData,
},
PairWriter {
send: Arc::new(Mutex::new(tx1)),
_marker: marker::PhantomData,
},
),
)
}
}
impl bind::Bind for PairBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
type Reader = PairReader<Self::Endpoint>;
type Writer = PairWriter<Self::Endpoint>;
}
pub fn keypair(initiator: bool) -> KeyPair {
let k1 = Key {
key: [0x53u8; 32],
id: 0x646e6573,
};
let k2 = Key {
key: [0x52u8; 32],
id: 0x76636572,
};
if initiator {
KeyPair {
birth: Instant::now(),
initiator: true,
send: k1,
recv: k2,
}
} else {
KeyPair {
birth: Instant::now(),
initiator: false,
send: k2,
recv: k1,
}
}
}

View File

@@ -0,0 +1,7 @@
use std::net::SocketAddr;
pub trait Endpoint: Send + 'static {
fn from_address(addr: SocketAddr) -> Self;
fn into_address(&self) -> SocketAddr;
fn clear_src(&self);
}

View File

@@ -0,0 +1,36 @@
use clear_on_drop::clear::Clear;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct Key {
pub key: [u8; 32],
pub id: u32,
}
// zero key on drop
impl Drop for Key {
fn drop(&mut self) {
self.key.clear()
}
}
#[cfg(test)]
impl PartialEq for Key {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.key[..] == other.key[..]
}
}
#[derive(Debug, Clone)]
pub struct KeyPair {
pub birth: Instant, // when was the key-pair created
pub initiator: bool, // has the key-pair been confirmed?
pub send: Key, // key for outbound messages
pub recv: Key, // key for inbound messages
}
impl KeyPair {
pub fn local_id(&self) -> u32 {
self.recv.id
}
}

View File

@@ -0,0 +1,10 @@
mod endpoint;
mod keys;
pub mod tun;
pub mod bind;
#[cfg(test)]
pub mod dummy;
pub use endpoint::Endpoint;
pub use keys::{Key, KeyPair};

View File

@@ -0,0 +1,56 @@
use std::error::Error;
pub trait Writer: Send + Sync + 'static {
type Error: Error;
/// Receive a cryptkey routed IP packet
///
/// # Arguments
///
/// - src: Buffer containing the IP packet to be written
///
/// # Returns
///
/// Unit type or an error
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
}
pub trait Reader: Send + 'static {
type Error: Error;
/// Reads an IP packet into dst[offset:] from the tunnel device
///
/// The reason for providing space for a prefix
/// is to efficiently accommodate platforms on which the packet is prefaced by a header.
/// This space is later used to construct the transport message inplace.
///
/// # Arguments
///
/// - buf: Destination buffer (enough space for MTU bytes + header)
/// - offset: Offset for the beginning of the IP packet
///
/// # Returns
///
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
}
pub trait MTU: Send + Sync + Clone + 'static {
/// Returns the MTU of the device
///
/// This function needs to be efficient (called for every read).
/// The goto implementation strategy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///
/// The MTU of the interface in bytes
fn mtu(&self) -> usize;
}
pub trait Tun: Send + Sync + 'static {
type Writer: Writer;
type Reader: Reader;
type MTU: MTU;
type Error: Error;
}

407
src/wireguard/wireguard.rs Normal file
View File

@@ -0,0 +1,407 @@
use super::constants::*;
use super::handshake;
use super::router;
use super::timers::{Events, Timers};
use super::types::bind::Reader as BindReader;
use super::types::bind::{Bind, Writer};
use super::types::tun::{Reader, Tun, MTU};
use super::types::Endpoint;
use hjul::Runner;
use std::fmt;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use std::collections::HashMap;
use log::debug;
use rand::rngs::OsRng;
use spin::{Mutex, RwLock, RwLockReadGuard};
use byteorder::{ByteOrder, LittleEndian};
use crossbeam_channel::{bounded, Sender};
use x25519_dalek::{PublicKey, StaticSecret};
const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub struct Peer<T: Tun, B: Bind> {
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<B>>,
}
impl<T: Tun, B: Bind> Clone for Peer<T, B> {
fn clone(&self) -> Peer<T, B> {
Peer {
router: self.router.clone(),
state: self.state.clone(),
}
}
}
pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
pub tx_bytes: AtomicU64,
pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
pub timers: RwLock<Timers>, //
}
impl<B: Bind> PeerInner<B> {
#[inline(always)]
pub fn timers(&self) -> RwLockReadGuard<Timers> {
self.timers.read()
}
}
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer()")
}
}
impl<T: Tun, B: Bind> Deref for Peer<T, B> {
type Target = PeerInner<B>;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl<B: Bind> PeerInner<B> {
pub fn new_handshake(&self) {
// TODO: clear endpoint source address ("unsticky")
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}
}
struct Handshake {
device: handshake::Device,
active: bool,
}
pub enum HandshakeJob<E> {
Message(Vec<u8>, E),
New(PublicKey),
}
struct WireguardInner<T: Tun, B: Bind> {
// provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread)
mtu: T::MTU,
send: RwLock<Option<B::Writer>>,
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptkey router
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
}
pub struct Wireguard<T: Tun, B: Bind> {
runner: Runner,
state: Arc<WireguardInner<T, B>>,
}
/* Returns the padded length of a message:
*
* # Arguments
*
* - `size` : Size of unpadded message
* - `mtu` : Maximum transmission unit of the device
*
* # Returns
*
* The padded length (always less than or equal to the MTU)
*/
#[inline(always)]
const fn padding(size: usize, mtu: usize) -> usize {
#[inline(always)]
const fn min(a: usize, b: usize) -> usize {
let m = (a > b) as usize;
a * m + (1 - m) * b
}
let pad = MESSAGE_PADDING_MULTIPLE;
min(mtu, size + (pad - size % pad) % pad)
}
impl<T: Tun, B: Bind> Wireguard<T, B> {
pub fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write();
match sk {
None => {
let mut rng = OsRng::new().unwrap();
handshake.device.set_sk(StaticSecret::new(&mut rng));
handshake.active = false;
}
Some(sk) => {
handshake.device.set_sk(sk);
handshake.active = true;
}
}
}
pub fn get_sk(&self) -> Option<StaticSecret> {
let handshake = self.state.handshake.read();
if handshake.active {
Some(handshake.device.get_sk())
} else {
None
}
}
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
queue: Mutex::new(self.state.queue.lock().clone()),
keepalive: AtomicUsize::new(0),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&self.runner)),
});
let router = Arc::new(self.state.router.new_peer(state.clone()));
let peer = Peer { router, state };
/* The need for dummy timers arises from the chicken-egg
* problem of the timer callbacks being able to set timers themselves.
*
* This is in fact the only place where the write lock is ever taken.
*/
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
peer
}
/* Begin consuming messages from the reader.
*
* Any previous reader thread is stopped by closing the previous reader,
* which unblocks the thread and causes an error on reader.read
*/
pub fn add_reader(&self, reader: B::Reader) {
let wg = self.state.clone();
thread::spawn(move || {
let mut last_under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
// create vector big enough for any message given current MTU
let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
// read UDP packet into vector
let (size, src) = match reader.read(&mut msg) {
Err(e) => {
debug!("Bind reader closed with {}", e);
return;
}
Ok(v) => v,
};
msg.truncate(size);
// message type de-multiplexer
if msg.len() < std::mem::size_of::<u32>() {
continue;
}
match LittleEndian::read_u32(&msg[..]) {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => {
// update under_load flag
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
wg.under_load.store(false, Ordering::SeqCst);
}
wg.queue
.lock()
.send(HandshakeJob::Message(msg, src))
.unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
let _ = wg.router.recv(src, msg).map_err(|e| {
debug!("Failed to handle incoming transport message: {}", e);
});
}
_ => (),
}
}
});
}
pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone());
self.state.router.set_outbound_writer(writer);
}
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),
send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
queue: Mutex::new(tx),
});
// start handshake workers
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
let state = wg.handshake.read();
if !state.active {
continue;
}
match job {
HandshakeJob::Message(msg, src) => {
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
// process message
match state.device.process(
&mut rng,
&msg[..],
if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
) {
Ok((pk, resp, keypair)) => {
// send response
let mut resp_len: u64 = 0;
if let Some(msg) = resp {
resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() {
let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {}",
e
)
});
}
}
// update timers
if let Some(pk) = pk {
// authenticated handshake packet received
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
// add to rx_bytes and tx_bytes
let req_len = msg.len() as u64;
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
// update endpoint
peer.router.set_endpoint(src);
// add keypair to peer
keypair.map(|kp| {
// free any unused ids
for id in peer.router.add_keypair(kp) {
state.device.release(id);
}
});
}
}
}
Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
HandshakeJob::New(pk) => {
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
let _ = peer.router.send(&msg[..]).map_err(|e| {
debug!("handshake worker, failed to send handshake initiation, error = {}", e)
});
}
});
}
}
}
});
}
// start TUN read IO threads (multiple threads to support multi-queue interfaces)
debug_assert!(
readers.len() > 0,
"attempted to create WG device without TUN readers"
);
while let Some(reader) = readers.pop() {
let wg = wg.clone();
let mtu = mtu.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
let mtu = mtu.mtu();
let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
// read a new IP packet
let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
Ok(payload) => payload,
Err(e) => {
debug!("TUN worker, failed to read from tun device: {}", e);
return;
}
};
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// truncate padding
let payload = padding(payload, mtu);
msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
debug_assert!(payload <= mtu);
debug_assert_eq!(
if payload < mtu {
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
} else {
0
},
0
);
// crypt-key route
let e = wg.router.send(msg);
debug!("TUN worker, router returned {:?}", e);
});
}
Wireguard {
state: wg,
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
}
}
}