First full test of pure WireGuard

This commit is contained in:
Mathias Hall-Andersen
2019-10-28 14:48:24 +01:00
parent 3e829c04d1
commit 4ff328b7da
12 changed files with 242 additions and 62 deletions

2
Cargo.lock generated
View File

@@ -1608,6 +1608,8 @@ dependencies = [
"pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)", "pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)",
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)", "proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
"ring 0.16.7 (registry+https://github.com/rust-lang/crates.io-index)", "ring 0.16.7 (registry+https://github.com/rust-lang/crates.io-index)",
"spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@@ -46,3 +46,5 @@ features = ["nightly"]
[dev-dependencies] [dev-dependencies]
proptest = "0.9.4" proptest = "0.9.4"
pnet = "^0.22" pnet = "^0.22"
rand_chacha = "0.2.1"
rand_core = "0.5"

View File

@@ -1,7 +1,12 @@
use hex;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::marker; use std::marker;
use log::debug;
use rand::rngs::OsRng;
use rand::Rng;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
@@ -95,6 +100,7 @@ impl VoidBind {
#[derive(Clone)] #[derive(Clone)]
pub struct PairReader<E> { pub struct PairReader<E> {
id: u32,
recv: Arc<Mutex<Receiver<Vec<u8>>>>, recv: Arc<Mutex<Receiver<Vec<u8>>>>,
_marker: marker::PhantomData<E>, _marker: marker::PhantomData<E>,
} }
@@ -110,13 +116,25 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
.map_err(|_| BindError::Disconnected)?; .map_err(|_| BindError::Disconnected)?;
let len = vec.len(); let len = vec.len();
buf[..len].copy_from_slice(&vec[..]); buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {})) debug!(
"dummy({}): read ({}, {})",
self.id,
len,
hex::encode(&buf[..len])
);
Ok((len, UnitEndpoint {}))
} }
} }
impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> { impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError; type Error = BindError;
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
debug!(
"dummy({}): write ({}, {})",
self.id,
buf.len(),
hex::encode(buf)
);
let owned = buf.to_owned(); let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) { match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected), Err(_) => Err(BindError::Disconnected),
@@ -127,6 +145,7 @@ impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
#[derive(Clone)] #[derive(Clone)]
pub struct PairWriter<E> { pub struct PairWriter<E> {
id: u32,
send: Arc<Mutex<SyncSender<Vec<u8>>>>, send: Arc<Mutex<SyncSender<Vec<u8>>>>,
_marker: marker::PhantomData<E>, _marker: marker::PhantomData<E>,
} }
@@ -139,25 +158,33 @@ impl PairBind {
(PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>),
(PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>),
) { ) {
let mut rng = OsRng::new().unwrap();
let id1: u32 = rng.gen();
let id2: u32 = rng.gen();
let (tx1, rx1) = sync_channel(128); let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128); let (tx2, rx2) = sync_channel(128);
( (
( (
PairReader { PairReader {
id: id1,
recv: Arc::new(Mutex::new(rx1)), recv: Arc::new(Mutex::new(rx1)),
_marker: marker::PhantomData, _marker: marker::PhantomData,
}, },
PairWriter { PairWriter {
id: id1,
send: Arc::new(Mutex::new(tx2)), send: Arc::new(Mutex::new(tx2)),
_marker: marker::PhantomData, _marker: marker::PhantomData,
}, },
), ),
( (
PairReader { PairReader {
id: id2,
recv: Arc::new(Mutex::new(rx2)), recv: Arc::new(Mutex::new(rx2)),
_marker: marker::PhantomData, _marker: marker::PhantomData,
}, },
PairWriter { PairWriter {
id: id2,
send: Arc::new(Mutex::new(tx1)), send: Arc::new(Mutex::new(tx1)),
_marker: marker::PhantomData, _marker: marker::PhantomData,
}, },

View File

@@ -1,3 +1,8 @@
use hex;
use log::debug;
use rand::rngs::OsRng;
use rand::Rng;
use std::cmp::min; use std::cmp::min;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@@ -61,16 +66,19 @@ impl fmt::Display for TunError {
pub struct TunTest {} pub struct TunTest {}
pub struct TunFakeIO { pub struct TunFakeIO {
id: u32,
store: bool, store: bool,
tx: SyncSender<Vec<u8>>, tx: SyncSender<Vec<u8>>,
rx: Receiver<Vec<u8>>, rx: Receiver<Vec<u8>>,
} }
pub struct TunReader { pub struct TunReader {
id: u32,
rx: Receiver<Vec<u8>>, rx: Receiver<Vec<u8>>,
} }
pub struct TunWriter { pub struct TunWriter {
id: u32,
store: bool, store: bool,
tx: Mutex<SyncSender<Vec<u8>>>, tx: Mutex<SyncSender<Vec<u8>>>,
} }
@@ -88,6 +96,12 @@ impl Reader for TunReader {
Ok(msg) => { Ok(msg) => {
let n = min(buf.len() - offset, msg.len()); let n = min(buf.len() - offset, msg.len());
buf[offset..offset + n].copy_from_slice(&msg[..n]); buf[offset..offset + n].copy_from_slice(&msg[..n]);
debug!(
"dummy::TUN({}) : read ({}, {})",
self.id,
n,
hex::encode(&buf[offset..offset + n])
);
Ok(n) Ok(n)
} }
Err(_) => Err(TunError::Disconnected), Err(_) => Err(TunError::Disconnected),
@@ -99,6 +113,12 @@ impl Writer for TunWriter {
type Error = TunError; type Error = TunError;
fn write(&self, src: &[u8]) -> Result<(), Self::Error> { fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
debug!(
"dummy::TUN({}) : write ({}, {})",
self.id,
src.len(),
hex::encode(src)
);
if self.store { if self.store {
let m = src.to_owned(); let m = src.to_owned();
match self.tx.lock().unwrap().send(m) { match self.tx.lock().unwrap().send(m) {
@@ -149,13 +169,18 @@ impl TunTest {
sync_channel(1) sync_channel(1)
}; };
let mut rng = OsRng::new().unwrap();
let id: u32 = rng.gen();
let fake = TunFakeIO { let fake = TunFakeIO {
id,
tx: tx1, tx: tx1,
rx: rx2, rx: rx2,
store, store,
}; };
let reader = TunReader { rx: rx1 }; let reader = TunReader { id, rx: rx1 };
let writer = TunWriter { let writer = TunWriter {
id,
tx: Mutex::new(tx2), tx: Mutex::new(tx2),
store, store,
}; };

View File

@@ -12,6 +12,8 @@ use chacha20poly1305::ChaCha20Poly1305;
use rand::{CryptoRng, RngCore}; use rand::{CryptoRng, RngCore};
use log::debug;
use generic_array::typenum::*; use generic_array::typenum::*;
use generic_array::*; use generic_array::*;
@@ -27,7 +29,7 @@ use super::peer::{Peer, State};
use super::timestamp; use super::timestamp;
use super::types::*; use super::types::*;
use super::super::types::{KeyPair, Key}; use super::super::types::{Key, KeyPair};
use std::time::Instant; use std::time::Instant;
@@ -222,6 +224,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
sender: u32, sender: u32,
msg: &mut NoiseInitiation, msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> { ) -> Result<(), HandshakeError> {
debug!("create initation");
clear_stack_on_return(CLEAR_PAGES, || { clear_stack_on_return(CLEAR_PAGES, || {
// initialize state // initialize state
@@ -300,6 +303,7 @@ pub fn consume_initiation<'a>(
device: &'a Device, device: &'a Device,
msg: &NoiseInitiation, msg: &NoiseInitiation,
) -> Result<(&'a Peer, TemporaryState), HandshakeError> { ) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
debug!("consume initation");
clear_stack_on_return(CLEAR_PAGES, || { clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state // initialize new state
@@ -377,6 +381,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
state: TemporaryState, // state from "consume_initiation" state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response msg: &mut NoiseResponse, // resulting response
) -> Result<KeyPair, HandshakeError> { ) -> Result<KeyPair, HandshakeError> {
debug!("create response");
clear_stack_on_return(CLEAR_PAGES, || { clear_stack_on_return(CLEAR_PAGES, || {
// unpack state // unpack state
@@ -457,6 +462,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
* in order to better mitigate DoS from malformed response messages. * in order to better mitigate DoS from malformed response messages.
*/ */
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> { pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
debug!("consume response");
clear_stack_on_return(CLEAR_PAGES, || { clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state // retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?; let peer = device.lookup_id(msg.f_receiver.get())?;

View File

@@ -89,13 +89,7 @@ fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>, device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8], packet: &[u8],
) -> Option<Arc<PeerInner<E, C, T, B>>> { ) -> Option<Arc<PeerInner<E, C, T, B>>> {
// ensure version access within bounds match packet.get(0)? >> 4 {
if packet.len() < 1 {
return None;
};
// cast to correct IP header
match packet[0] >> 4 {
VERSION_IP4 => { VERSION_IP4 => {
// check length and cast to IPv4 header // check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
@@ -176,7 +170,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
let packet = &msg[SIZE_MESSAGE_PREFIX..]; let packet = &msg[SIZE_MESSAGE_PREFIX..];
// lookup peer based on IP packet destination address // lookup peer based on IP packet destination address
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?;
// schedule for encryption and transmission to peer // schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) { if let Some(job) = peer.send_job(msg, true) {

View File

@@ -531,8 +531,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
/// ///
/// If an identical value already exists as part of a prior peer, /// If an identical value already exists as part of a prior peer,
/// the allowed IP entry will be removed from that peer and added to this peer. /// the allowed IP entry will be removed from that peer and added to this peer.
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { pub fn add_allowed_ips(&self, ip: IpAddr, masklen: u32) {
debug!("peer.add_subnet"); debug!("peer.add_allowed_ips");
match ip { match ip {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {
self.state self.state
@@ -556,8 +556,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
/// # Returns /// # Returns
/// ///
/// A vector of subnets, represented by as mask/size /// A vector of subnets, represented by as mask/size
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> {
debug!("peer.list_subnets"); debug!("peer.list_allowed_ips");
let mut res = Vec::new(); let mut res = Vec::new();
res.append(&mut treebit_list( res.append(&mut treebit_list(
&self.state, &self.state,
@@ -575,8 +575,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
/// Clear subnets mapped to the peer. /// Clear subnets mapped to the peer.
/// After the call, no subnets will be cryptkey routed to the peer. /// After the call, no subnets will be cryptkey routed to the peer.
/// Used for the UAPI command "replace_allowed_ips=true" /// Used for the UAPI command "replace_allowed_ips=true"
pub fn remove_subnets(&self) { pub fn remove_allowed_ips(&self) {
debug!("peer.remove_subnets"); debug!("peer.remove_allowed_ips");
treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv4);
treebit_remove(self, &self.state.device.ipv6); treebit_remove(self, &self.state.device.ipv6);
} }

View File

@@ -157,7 +157,7 @@ mod tests {
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap(); let mask: IpAddr = mask.parse().unwrap();
let ip1: IpAddr = ip.parse().unwrap(); let ip1: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len); peer.add_allowed_ips(mask, len);
// every iteration sends 10 GB // every iteration sends 10 GB
b.iter(|| { b.iter(|| {
@@ -215,7 +215,7 @@ mod tests {
} }
// map subnet to peer // map subnet to peer
peer.add_subnet(mask, *len); peer.add_allowed_ips(mask, *len);
// create "IP packet" // create "IP packet"
let msg = make_packet(1024, ip.parse().unwrap()); let msg = make_packet(1024, ip.parse().unwrap());
@@ -339,13 +339,13 @@ mod tests {
let (mask, len, _ip, _okay) = p1; let (mask, len, _ip, _okay) = p1;
let peer1 = router1.new_peer(opaq1.clone()); let peer1 = router1.new_peer(opaq1.clone());
let mask: IpAddr = mask.parse().unwrap(); let mask: IpAddr = mask.parse().unwrap();
peer1.add_subnet(mask, *len); peer1.add_allowed_ips(mask, *len);
peer1.add_keypair(dummy_keypair(false)); peer1.add_keypair(dummy_keypair(false));
let (mask, len, _ip, _okay) = p2; let (mask, len, _ip, _okay) = p2;
let peer2 = router2.new_peer(opaq2.clone()); let peer2 = router2.new_peer(opaq2.clone());
let mask: IpAddr = mask.parse().unwrap(); let mask: IpAddr = mask.parse().unwrap();
peer2.add_subnet(mask, *len); peer2.add_allowed_ips(mask, *len);
peer2.set_endpoint(dummy::UnitEndpoint::new()); peer2.set_endpoint(dummy::UnitEndpoint::new());
if *stage { if *stage {

View File

@@ -31,7 +31,7 @@ pub trait Callbacks: Send + Sync + 'static {
#[derive(Debug)] #[derive(Debug)]
pub enum RouterError { pub enum RouterError {
NoCryptKeyRoute, NoCryptoKeyRoute,
MalformedIPHeader, MalformedIPHeader,
MalformedTransportMessage, MalformedTransportMessage,
UnknownReceiverId, UnknownReceiverId,
@@ -42,7 +42,7 @@ pub enum RouterError {
impl fmt::Display for RouterError { impl fmt::Display for RouterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
RouterError::UnknownReceiverId => { RouterError::UnknownReceiverId => {

View File

@@ -5,13 +5,23 @@ use std::net::IpAddr;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use rand::rngs::OsRng; use hex;
use rand_chacha::ChaCha8Rng;
use rand_core::{RngCore, SeedableRng};
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet; use pnet::packet::ipv6::MutableIpv6Packet;
fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> { fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
// expand pseudo random payload
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
let mut p: Vec<u8> = vec![];
for _ in 0..size {
p.push(rng.next_u32() as u8);
}
// create "IP packet" // create "IP packet"
let mut msg = Vec::with_capacity(size); let mut msg = Vec::with_capacity(size);
msg.resize(size, 0); msg.resize(size, 0);
@@ -19,21 +29,25 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> {
IpAddr::V4(dst) => { IpAddr::V4(dst) => {
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst); packet.set_destination(dst);
packet.set_total_length(size as u16);
packet.set_source(if let IpAddr::V4(src) = src { packet.set_source(if let IpAddr::V4(src) = src {
src src
} else { } else {
panic!("src.version != dst.version") panic!("src.version != dst.version")
}); });
packet.set_payload(&p[..]);
packet.set_version(4); packet.set_version(4);
} }
IpAddr::V6(dst) => { IpAddr::V6(dst) => {
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst); packet.set_destination(dst);
packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16);
packet.set_source(if let IpAddr::V6(src) = src { packet.set_source(if let IpAddr::V6(src) = src {
src src
} else { } else {
panic!("src.version != dst.version") panic!("src.version != dst.version")
}); });
packet.set_payload(&p[..]);
packet.set_version(6); packet.set_version(6);
} }
} }
@@ -55,7 +69,7 @@ fn wait() {
fn test_pure_wireguard() { fn test_pure_wireguard() {
init(); init();
// create WG instances for fake TUN devices // create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
@@ -77,10 +91,20 @@ fn test_pure_wireguard() {
// generate (public, pivate) key pairs // generate (public, pivate) key pairs
let mut rng = OsRng::new().unwrap(); let sk1 = StaticSecret::from([
let sk1 = StaticSecret::new(&mut rng); 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c,
let sk2 = StaticSecret::new(&mut rng); 0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36,
0x78, 0x89,
]);
let sk2 = StaticSecret::from([
0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8,
0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2,
0x66, 0xe2,
]);
let pk1 = PublicKey::from(&sk1); let pk1 = PublicKey::from(&sk1);
let pk2 = PublicKey::from(&sk2); let pk2 = PublicKey::from(&sk2);
wg1.new_peer(pk2); wg1.new_peer(pk2);
@@ -94,21 +118,79 @@ fn test_pure_wireguard() {
let peer2 = wg1.lookup_peer(&pk2).unwrap(); let peer2 = wg1.lookup_peer(&pk2).unwrap();
let peer1 = wg2.lookup_peer(&pk1).unwrap(); let peer1 = wg2.lookup_peer(&pk1).unwrap();
peer1.router.add_subnet("192.168.2.0".parse().unwrap(), 24); peer1
peer2.router.add_subnet("192.168.1.0".parse().unwrap(), 24); .router
.add_allowed_ips("192.168.1.0".parse().unwrap(), 24);
// set endpoints peer2
.router
.add_allowed_ips("192.168.2.0".parse().unwrap(), 24);
// set endpoint (the other should be learned dynamically)
peer1.router.set_endpoint(dummy::UnitEndpoint::new());
peer2.router.set_endpoint(dummy::UnitEndpoint::new()); peer2.router.set_endpoint(dummy::UnitEndpoint::new());
// create IP packets (causing a new handshake) let num_packets = 20;
let packet_p1_to_p2 = make_packet( // send IP packets (causing a new handshake)
1000,
"192.168.2.20".parse().unwrap(), // src
"192.168.1.10".parse().unwrap(), // dst
);
fake1.write(packet_p1_to_p2); {
let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
for id in 0..num_packets {
packets.push(make_packet(
50 + 50 * id as usize, // size
"192.168.1.20".parse().unwrap(), // src
"192.168.2.10".parse().unwrap(), // dst
id as u64, // prng seed
));
}
let mut backup = packets.clone();
while let Some(p) = packets.pop() {
fake1.write(p);
}
wait();
while let Some(p) = backup.pop() {
assert_eq!(
hex::encode(fake2.read()),
hex::encode(p),
"Failed to receive valid IPv4 packet unmodified and in-order"
);
}
}
// send IP packets (other direction)
{
let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
for id in 0..num_packets {
packets.push(make_packet(
50 + 50 * id as usize, // size
"192.168.2.10".parse().unwrap(), // src
"192.168.1.20".parse().unwrap(), // dst
(id + 100) as u64, // prng seed
));
}
let mut backup = packets.clone();
while let Some(p) = packets.pop() {
fake2.write(p);
}
wait();
while let Some(p) = backup.pop() {
assert_eq!(
hex::encode(fake1.read()),
hex::encode(p),
"Failed to receive valid IPv4 packet unmodified and in-order"
);
}
}
} }

View File

@@ -7,10 +7,10 @@ use log::info;
use hjul::{Runner, Timer}; use hjul::{Runner, Timer};
use super::{bind, tun};
use super::constants::*; use super::constants::*;
use super::router::{Callbacks, message_data_len}; use super::router::{message_data_len, Callbacks};
use super::wireguard::{Peer, PeerInner}; use super::wireguard::{Peer, PeerInner};
use super::{bind, tun};
pub struct Timers { pub struct Timers {
handshake_pending: AtomicBool, handshake_pending: AtomicBool,
@@ -32,16 +32,20 @@ impl Timers {
} }
} }
impl <B: bind::Bind>PeerInner<B> { impl<B: bind::Bind> PeerInner<B> {
/* should be called after an authenticated data packet is sent */ /* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) { pub fn timers_data_sent(&self) {
self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); self.timers()
.new_handshake
.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
} }
/* should be called after an authenticated data packet is received */ /* should be called after an authenticated data packet is received */
pub fn timers_data_received(&self) { pub fn timers_data_received(&self) {
if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
self.timers().need_another_keepalive.store(true, Ordering::SeqCst) self.timers()
.need_another_keepalive
.store(true, Ordering::SeqCst)
} }
} }
@@ -74,7 +78,9 @@ impl <B: bind::Bind>PeerInner<B> {
*/ */
pub fn timers_handshake_complete(&self) { pub fn timers_handshake_complete(&self) {
self.timers().handshake_attempts.store(0, Ordering::SeqCst); self.timers().handshake_attempts.store(0, Ordering::SeqCst);
self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst); self.timers()
.sent_lastminute_handshake
.store(false, Ordering::SeqCst);
// TODO: Store time in peer for config // TODO: Store time in peer for config
// self.walltime_last_handshake // self.walltime_last_handshake
} }
@@ -92,7 +98,9 @@ impl <B: bind::Bind>PeerInner<B> {
pub fn timers_any_authenticated_packet_traversal(&self) { pub fn timers_any_authenticated_packet_traversal(&self) {
let keepalive = self.keepalive.load(Ordering::Acquire); let keepalive = self.keepalive.load(Ordering::Acquire);
if keepalive > 0 { if keepalive > 0 {
self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); self.timers()
.send_persistent_keepalive
.reset(Duration::from_secs(keepalive as u64));
} }
} }
@@ -149,11 +157,7 @@ impl Timers {
new_handshake: { new_handshake: {
let peer = peer.clone(); let peer = peer.clone();
runner.timer(move || { runner.timer(move || {
info!( info!("Initiate new handshake with {}", peer);
"Retrying handshake with {}, because we stopped hearing back after {} seconds",
peer,
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
);
peer.new_handshake(); peer.new_handshake();
peer.timers.read().handshake_begun(); peer.timers.read().handshake_begun();
}) })
@@ -171,10 +175,12 @@ impl Timers {
if keepalive > 0 { if keepalive > 0 {
peer.router.send_keepalive(); peer.router.send_keepalive();
peer.timers().send_keepalive.stop(); peer.timers().send_keepalive.stop();
peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64)); peer.timers()
.send_persistent_keepalive
.start(Duration::from_secs(keepalive as u64));
} }
}) })
} },
} }
} }
@@ -196,7 +202,8 @@ impl Timers {
pub fn updated_persistent_keepalive(&self, keepalive: usize) { pub fn updated_persistent_keepalive(&self, keepalive: usize) {
if keepalive > 0 { if keepalive > 0 {
self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); self.send_persistent_keepalive
.reset(Duration::from_secs(keepalive as u64));
} }
} }
@@ -210,7 +217,7 @@ impl Timers {
new_handshake: runner.timer(|| {}), new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}), send_keepalive: runner.timer(|| {}),
send_persistent_keepalive: runner.timer(|| {}), send_persistent_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {}) zero_key_material: runner.timer(|| {}),
} }
} }

View File

@@ -21,6 +21,7 @@ use std::collections::HashMap;
use log::debug; use log::debug;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::Rng;
use spin::{Mutex, RwLock, RwLockReadGuard}; use spin::{Mutex, RwLock, RwLockReadGuard};
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
@@ -37,6 +38,8 @@ pub struct Peer<T: Tun, B: Bind> {
} }
pub struct PeerInner<B: Bind> { pub struct PeerInner<B: Bind> {
pub id: u64,
pub keepalive: AtomicUsize, // keepalive interval pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64, pub rx_bytes: AtomicU64,
pub tx_bytes: AtomicU64, pub tx_bytes: AtomicU64,
@@ -50,6 +53,9 @@ pub struct PeerInner<B: Bind> {
} }
pub struct WireguardInner<T: Tun, B: Bind> { pub struct WireguardInner<T: Tun, B: Bind> {
// identifier (for logging)
id: u32,
// provides access to the MTU value of the tun device // provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread) // (otherwise owned solely by the router and a dedicated read IO thread)
mtu: T::MTU, mtu: T::MTU,
@@ -96,7 +102,13 @@ impl<B: Bind> PeerInner<B> {
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer()") write!(f, "peer(id = {})", self.id)
}
}
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "wireguard({:x})", self.id)
} }
} }
@@ -209,7 +221,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
pub fn new_peer(&self, pk: PublicKey) { pub fn new_peer(&self, pk: PublicKey) {
let mut rng = OsRng::new().unwrap();
let state = Arc::new(PeerInner { let state = Arc::new(PeerInner {
id: rng.gen(),
pk, pk,
last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
handshake_queued: AtomicBool::new(false), handshake_queued: AtomicBool::new(false),
@@ -277,11 +291,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
handshake::TYPE_COOKIE_REPLY handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION | handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => { | handshake::TYPE_RESPONSE => {
debug!("{} : reader, received handshake message", wg);
let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
// update under_load flag // update under_load flag
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { if pending > THRESHOLD_UNDER_LOAD {
debug!("{} : reader, set under load (pending = {})", wg, pending);
last_under_load = Instant::now(); last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst); wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD { } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
debug!("{} : reader, clear under load", wg);
wg.under_load.store(false, Ordering::SeqCst); wg.under_load.store(false, Ordering::SeqCst);
} }
@@ -291,6 +311,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
.unwrap(); .unwrap();
} }
router::TYPE_TRANSPORT => { router::TYPE_TRANSPORT => {
debug!("{} : reader, received transport message", wg);
// transport message // transport message
let _ = wg.router.recv(src, msg).map_err(|e| { let _ = wg.router.recv(src, msg).map_err(|e| {
debug!("Failed to handle incoming transport message: {}", e); debug!("Failed to handle incoming transport message: {}", e);
@@ -313,6 +335,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner { let wg = Arc::new(WireguardInner {
id: rng.gen(),
mtu: mtu.clone(), mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()), peers: RwLock::new(HashMap::new()),
send: RwLock::new(None), send: RwLock::new(None),
@@ -331,12 +354,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let wg = wg.clone(); let wg = wg.clone();
let rx = rx.clone(); let rx = rx.clone();
thread::spawn(move || { thread::spawn(move || {
debug!("{} : handshake worker, started", wg);
// prepare OsRng instance for this thread // prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue // process elements from the handshake queue
for job in rx { for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
let state = wg.handshake.read(); let state = wg.handshake.read();
if !state.active { if !state.active {
continue; continue;
@@ -344,6 +368,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
match job { match job {
HandshakeJob::Message(msg, src) => { HandshakeJob::Message(msg, src) => {
wg.pending.fetch_sub(1, Ordering::SeqCst);
// feed message to handshake device // feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid let src_validate = (&src).into_address(); // TODO avoid
@@ -352,6 +378,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
&mut rng, &mut rng,
&msg[..], &msg[..],
if wg.under_load.load(Ordering::Relaxed) { if wg.under_load.load(Ordering::Relaxed) {
debug!("{} : handshake worker, under load", wg);
Some(&src_validate) Some(&src_validate)
} else { } else {
None None
@@ -364,9 +391,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
resp_len = msg.len() as u64; resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read(); let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() { if let Some(writer) = send.as_ref() {
debug!(
"{} : handshake worker, send response ({} bytes)",
wg, resp_len
);
let _ = writer.write(&msg[..], &src).map_err(|e| { let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!( debug!(
"handshake worker, failed to send response, error = {}", "{} : handshake worker, failed to send response, error = {}",
wg,
e e
) )
}); });
@@ -387,11 +419,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// update timers after sending handshake response // update timers after sending handshake response
if resp_len > 0 { if resp_len > 0 {
debug!("{} : handshake worker, handshake response sent", wg);
peer.state.sent_handshake_response(); peer.state.sent_handshake_response();
} }
// add resulting keypair to peer // add resulting keypair to peer
keypair.map(|kp| { keypair.map(|kp| {
debug!("{} : handshake worker, new keypair", wg);
// free any unused ids // free any unused ids
for id in peer.router.add_keypair(kp) { for id in peer.router.add_keypair(kp) {
state.device.release(id); state.device.release(id);
@@ -400,14 +434,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
} }
} }
Err(e) => debug!("handshake worker, error = {:?}", e), Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e),
} }
} }
HandshakeJob::New(pk) => { HandshakeJob::New(pk) => {
debug!("{} : handshake worker, new handshake requested", wg);
let _ = state.device.begin(&mut rng, &pk).map(|msg| { let _ = state.device.begin(&mut rng, &pk).map(|msg| {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
let _ = peer.router.send(&msg[..]).map_err(|e| { let _ = peer.router.send(&msg[..]).map_err(|e| {
debug!("handshake worker, failed to send handshake initiation, error = {}", e) debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
}); });
peer.state.sent_handshake_initiation(); peer.state.sent_handshake_initiation();
} }