First full test of pure WireGuard
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1608,6 +1608,8 @@ dependencies = [
|
||||
"pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ring 0.16.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
||||
@@ -46,3 +46,5 @@ features = ["nightly"]
|
||||
[dev-dependencies]
|
||||
proptest = "0.9.4"
|
||||
pnet = "^0.22"
|
||||
rand_chacha = "0.2.1"
|
||||
rand_core = "0.5"
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
use hex;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::marker;
|
||||
|
||||
use log::debug;
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
@@ -95,6 +100,7 @@ impl VoidBind {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairReader<E> {
|
||||
id: u32,
|
||||
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
@@ -110,13 +116,25 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
|
||||
.map_err(|_| BindError::Disconnected)?;
|
||||
let len = vec.len();
|
||||
buf[..len].copy_from_slice(&vec[..]);
|
||||
Ok((vec.len(), UnitEndpoint {}))
|
||||
debug!(
|
||||
"dummy({}): read ({}, {})",
|
||||
self.id,
|
||||
len,
|
||||
hex::encode(&buf[..len])
|
||||
);
|
||||
Ok((len, UnitEndpoint {}))
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
|
||||
type Error = BindError;
|
||||
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
|
||||
debug!(
|
||||
"dummy({}): write ({}, {})",
|
||||
self.id,
|
||||
buf.len(),
|
||||
hex::encode(buf)
|
||||
);
|
||||
let owned = buf.to_owned();
|
||||
match self.send.lock().unwrap().send(owned) {
|
||||
Err(_) => Err(BindError::Disconnected),
|
||||
@@ -127,6 +145,7 @@ impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairWriter<E> {
|
||||
id: u32,
|
||||
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
@@ -139,25 +158,33 @@ impl PairBind {
|
||||
(PairReader<E>, PairWriter<E>),
|
||||
(PairReader<E>, PairWriter<E>),
|
||||
) {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let id1: u32 = rng.gen();
|
||||
let id2: u32 = rng.gen();
|
||||
|
||||
let (tx1, rx1) = sync_channel(128);
|
||||
let (tx2, rx2) = sync_channel(128);
|
||||
(
|
||||
(
|
||||
PairReader {
|
||||
id: id1,
|
||||
recv: Arc::new(Mutex::new(rx1)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
PairWriter {
|
||||
id: id1,
|
||||
send: Arc::new(Mutex::new(tx2)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
),
|
||||
(
|
||||
PairReader {
|
||||
id: id2,
|
||||
recv: Arc::new(Mutex::new(rx2)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
PairWriter {
|
||||
id: id2,
|
||||
send: Arc::new(Mutex::new(tx1)),
|
||||
_marker: marker::PhantomData,
|
||||
},
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
use hex;
|
||||
use log::debug;
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use std::cmp::min;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
@@ -61,16 +66,19 @@ impl fmt::Display for TunError {
|
||||
pub struct TunTest {}
|
||||
|
||||
pub struct TunFakeIO {
|
||||
id: u32,
|
||||
store: bool,
|
||||
tx: SyncSender<Vec<u8>>,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub struct TunReader {
|
||||
id: u32,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub struct TunWriter {
|
||||
id: u32,
|
||||
store: bool,
|
||||
tx: Mutex<SyncSender<Vec<u8>>>,
|
||||
}
|
||||
@@ -88,6 +96,12 @@ impl Reader for TunReader {
|
||||
Ok(msg) => {
|
||||
let n = min(buf.len() - offset, msg.len());
|
||||
buf[offset..offset + n].copy_from_slice(&msg[..n]);
|
||||
debug!(
|
||||
"dummy::TUN({}) : read ({}, {})",
|
||||
self.id,
|
||||
n,
|
||||
hex::encode(&buf[offset..offset + n])
|
||||
);
|
||||
Ok(n)
|
||||
}
|
||||
Err(_) => Err(TunError::Disconnected),
|
||||
@@ -99,6 +113,12 @@ impl Writer for TunWriter {
|
||||
type Error = TunError;
|
||||
|
||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
|
||||
debug!(
|
||||
"dummy::TUN({}) : write ({}, {})",
|
||||
self.id,
|
||||
src.len(),
|
||||
hex::encode(src)
|
||||
);
|
||||
if self.store {
|
||||
let m = src.to_owned();
|
||||
match self.tx.lock().unwrap().send(m) {
|
||||
@@ -149,13 +169,18 @@ impl TunTest {
|
||||
sync_channel(1)
|
||||
};
|
||||
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let id: u32 = rng.gen();
|
||||
|
||||
let fake = TunFakeIO {
|
||||
id,
|
||||
tx: tx1,
|
||||
rx: rx2,
|
||||
store,
|
||||
};
|
||||
let reader = TunReader { rx: rx1 };
|
||||
let reader = TunReader { id, rx: rx1 };
|
||||
let writer = TunWriter {
|
||||
id,
|
||||
tx: Mutex::new(tx2),
|
||||
store,
|
||||
};
|
||||
|
||||
@@ -12,6 +12,8 @@ use chacha20poly1305::ChaCha20Poly1305;
|
||||
|
||||
use rand::{CryptoRng, RngCore};
|
||||
|
||||
use log::debug;
|
||||
|
||||
use generic_array::typenum::*;
|
||||
use generic_array::*;
|
||||
|
||||
@@ -27,7 +29,7 @@ use super::peer::{Peer, State};
|
||||
use super::timestamp;
|
||||
use super::types::*;
|
||||
|
||||
use super::super::types::{KeyPair, Key};
|
||||
use super::super::types::{Key, KeyPair};
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -222,6 +224,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
|
||||
sender: u32,
|
||||
msg: &mut NoiseInitiation,
|
||||
) -> Result<(), HandshakeError> {
|
||||
debug!("create initation");
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// initialize state
|
||||
|
||||
@@ -300,6 +303,7 @@ pub fn consume_initiation<'a>(
|
||||
device: &'a Device,
|
||||
msg: &NoiseInitiation,
|
||||
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
|
||||
debug!("consume initation");
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// initialize new state
|
||||
|
||||
@@ -377,6 +381,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
|
||||
state: TemporaryState, // state from "consume_initiation"
|
||||
msg: &mut NoiseResponse, // resulting response
|
||||
) -> Result<KeyPair, HandshakeError> {
|
||||
debug!("create response");
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// unpack state
|
||||
|
||||
@@ -457,6 +462,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
|
||||
* in order to better mitigate DoS from malformed response messages.
|
||||
*/
|
||||
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
|
||||
debug!("consume response");
|
||||
clear_stack_on_return(CLEAR_PAGES, || {
|
||||
// retrieve peer and copy initiation state
|
||||
let peer = device.lookup_id(msg.f_receiver.get())?;
|
||||
|
||||
@@ -89,13 +89,7 @@ fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: &Arc<DeviceInner<E, C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<Arc<PeerInner<E, C, T, B>>> {
|
||||
// ensure version access within bounds
|
||||
if packet.len() < 1 {
|
||||
return None;
|
||||
};
|
||||
|
||||
// cast to correct IP header
|
||||
match packet[0] >> 4 {
|
||||
match packet.get(0)? >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
|
||||
@@ -176,7 +170,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?;
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg, true) {
|
||||
|
||||
@@ -531,8 +531,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
|
||||
///
|
||||
/// If an identical value already exists as part of a prior peer,
|
||||
/// the allowed IP entry will be removed from that peer and added to this peer.
|
||||
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
|
||||
debug!("peer.add_subnet");
|
||||
pub fn add_allowed_ips(&self, ip: IpAddr, masklen: u32) {
|
||||
debug!("peer.add_allowed_ips");
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
self.state
|
||||
@@ -556,8 +556,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of subnets, represented by as mask/size
|
||||
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
|
||||
debug!("peer.list_subnets");
|
||||
pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> {
|
||||
debug!("peer.list_allowed_ips");
|
||||
let mut res = Vec::new();
|
||||
res.append(&mut treebit_list(
|
||||
&self.state,
|
||||
@@ -575,8 +575,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
|
||||
/// Clear subnets mapped to the peer.
|
||||
/// After the call, no subnets will be cryptkey routed to the peer.
|
||||
/// Used for the UAPI command "replace_allowed_ips=true"
|
||||
pub fn remove_subnets(&self) {
|
||||
debug!("peer.remove_subnets");
|
||||
pub fn remove_allowed_ips(&self) {
|
||||
debug!("peer.remove_allowed_ips");
|
||||
treebit_remove(self, &self.state.device.ipv4);
|
||||
treebit_remove(self, &self.state.device.ipv6);
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ mod tests {
|
||||
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
let ip1: IpAddr = ip.parse().unwrap();
|
||||
peer.add_subnet(mask, len);
|
||||
peer.add_allowed_ips(mask, len);
|
||||
|
||||
// every iteration sends 10 GB
|
||||
b.iter(|| {
|
||||
@@ -215,7 +215,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// map subnet to peer
|
||||
peer.add_subnet(mask, *len);
|
||||
peer.add_allowed_ips(mask, *len);
|
||||
|
||||
// create "IP packet"
|
||||
let msg = make_packet(1024, ip.parse().unwrap());
|
||||
@@ -339,13 +339,13 @@ mod tests {
|
||||
let (mask, len, _ip, _okay) = p1;
|
||||
let peer1 = router1.new_peer(opaq1.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
peer1.add_subnet(mask, *len);
|
||||
peer1.add_allowed_ips(mask, *len);
|
||||
peer1.add_keypair(dummy_keypair(false));
|
||||
|
||||
let (mask, len, _ip, _okay) = p2;
|
||||
let peer2 = router2.new_peer(opaq2.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
peer2.add_subnet(mask, *len);
|
||||
peer2.add_allowed_ips(mask, *len);
|
||||
peer2.set_endpoint(dummy::UnitEndpoint::new());
|
||||
|
||||
if *stage {
|
||||
|
||||
@@ -31,7 +31,7 @@ pub trait Callbacks: Send + Sync + 'static {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RouterError {
|
||||
NoCryptKeyRoute,
|
||||
NoCryptoKeyRoute,
|
||||
MalformedIPHeader,
|
||||
MalformedTransportMessage,
|
||||
UnknownReceiverId,
|
||||
@@ -42,7 +42,7 @@ pub enum RouterError {
|
||||
impl fmt::Display for RouterError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
|
||||
RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"),
|
||||
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
|
||||
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
|
||||
RouterError::UnknownReceiverId => {
|
||||
|
||||
@@ -5,13 +5,23 @@ use std::net::IpAddr;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use rand::rngs::OsRng;
|
||||
use hex;
|
||||
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use rand_core::{RngCore, SeedableRng};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use pnet::packet::ipv4::MutableIpv4Packet;
|
||||
use pnet::packet::ipv6::MutableIpv6Packet;
|
||||
|
||||
fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> {
|
||||
fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
|
||||
// expand pseudo random payload
|
||||
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
|
||||
let mut p: Vec<u8> = vec![];
|
||||
for _ in 0..size {
|
||||
p.push(rng.next_u32() as u8);
|
||||
}
|
||||
|
||||
// create "IP packet"
|
||||
let mut msg = Vec::with_capacity(size);
|
||||
msg.resize(size, 0);
|
||||
@@ -19,21 +29,25 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> {
|
||||
IpAddr::V4(dst) => {
|
||||
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
|
||||
packet.set_destination(dst);
|
||||
packet.set_total_length(size as u16);
|
||||
packet.set_source(if let IpAddr::V4(src) = src {
|
||||
src
|
||||
} else {
|
||||
panic!("src.version != dst.version")
|
||||
});
|
||||
packet.set_payload(&p[..]);
|
||||
packet.set_version(4);
|
||||
}
|
||||
IpAddr::V6(dst) => {
|
||||
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
|
||||
packet.set_destination(dst);
|
||||
packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16);
|
||||
packet.set_source(if let IpAddr::V6(src) = src {
|
||||
src
|
||||
} else {
|
||||
panic!("src.version != dst.version")
|
||||
});
|
||||
packet.set_payload(&p[..]);
|
||||
packet.set_version(6);
|
||||
}
|
||||
}
|
||||
@@ -55,7 +69,7 @@ fn wait() {
|
||||
fn test_pure_wireguard() {
|
||||
init();
|
||||
|
||||
// create WG instances for fake TUN devices
|
||||
// create WG instances for dummy TUN devices
|
||||
|
||||
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
|
||||
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
|
||||
@@ -77,10 +91,20 @@ fn test_pure_wireguard() {
|
||||
|
||||
// generate (public, pivate) key pairs
|
||||
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let sk1 = StaticSecret::new(&mut rng);
|
||||
let sk2 = StaticSecret::new(&mut rng);
|
||||
let sk1 = StaticSecret::from([
|
||||
0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c,
|
||||
0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36,
|
||||
0x78, 0x89,
|
||||
]);
|
||||
|
||||
let sk2 = StaticSecret::from([
|
||||
0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8,
|
||||
0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2,
|
||||
0x66, 0xe2,
|
||||
]);
|
||||
|
||||
let pk1 = PublicKey::from(&sk1);
|
||||
|
||||
let pk2 = PublicKey::from(&sk2);
|
||||
|
||||
wg1.new_peer(pk2);
|
||||
@@ -94,21 +118,79 @@ fn test_pure_wireguard() {
|
||||
let peer2 = wg1.lookup_peer(&pk2).unwrap();
|
||||
let peer1 = wg2.lookup_peer(&pk1).unwrap();
|
||||
|
||||
peer1.router.add_subnet("192.168.2.0".parse().unwrap(), 24);
|
||||
peer2.router.add_subnet("192.168.1.0".parse().unwrap(), 24);
|
||||
peer1
|
||||
.router
|
||||
.add_allowed_ips("192.168.1.0".parse().unwrap(), 24);
|
||||
|
||||
// set endpoints
|
||||
peer2
|
||||
.router
|
||||
.add_allowed_ips("192.168.2.0".parse().unwrap(), 24);
|
||||
|
||||
// set endpoint (the other should be learned dynamically)
|
||||
|
||||
peer1.router.set_endpoint(dummy::UnitEndpoint::new());
|
||||
peer2.router.set_endpoint(dummy::UnitEndpoint::new());
|
||||
|
||||
// create IP packets (causing a new handshake)
|
||||
let num_packets = 20;
|
||||
|
||||
let packet_p1_to_p2 = make_packet(
|
||||
1000,
|
||||
"192.168.2.20".parse().unwrap(), // src
|
||||
"192.168.1.10".parse().unwrap(), // dst
|
||||
// send IP packets (causing a new handshake)
|
||||
|
||||
{
|
||||
let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
|
||||
|
||||
for id in 0..num_packets {
|
||||
packets.push(make_packet(
|
||||
50 + 50 * id as usize, // size
|
||||
"192.168.1.20".parse().unwrap(), // src
|
||||
"192.168.2.10".parse().unwrap(), // dst
|
||||
id as u64, // prng seed
|
||||
));
|
||||
}
|
||||
|
||||
let mut backup = packets.clone();
|
||||
|
||||
while let Some(p) = packets.pop() {
|
||||
fake1.write(p);
|
||||
}
|
||||
|
||||
wait();
|
||||
|
||||
while let Some(p) = backup.pop() {
|
||||
assert_eq!(
|
||||
hex::encode(fake2.read()),
|
||||
hex::encode(p),
|
||||
"Failed to receive valid IPv4 packet unmodified and in-order"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fake1.write(packet_p1_to_p2);
|
||||
// send IP packets (other direction)
|
||||
|
||||
{
|
||||
let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
|
||||
|
||||
for id in 0..num_packets {
|
||||
packets.push(make_packet(
|
||||
50 + 50 * id as usize, // size
|
||||
"192.168.2.10".parse().unwrap(), // src
|
||||
"192.168.1.20".parse().unwrap(), // dst
|
||||
(id + 100) as u64, // prng seed
|
||||
));
|
||||
}
|
||||
|
||||
let mut backup = packets.clone();
|
||||
|
||||
while let Some(p) = packets.pop() {
|
||||
fake2.write(p);
|
||||
}
|
||||
|
||||
wait();
|
||||
|
||||
while let Some(p) = backup.pop() {
|
||||
assert_eq!(
|
||||
hex::encode(fake1.read()),
|
||||
hex::encode(p),
|
||||
"Failed to receive valid IPv4 packet unmodified and in-order"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ use log::info;
|
||||
|
||||
use hjul::{Runner, Timer};
|
||||
|
||||
use super::{bind, tun};
|
||||
use super::constants::*;
|
||||
use super::router::{Callbacks, message_data_len};
|
||||
use super::router::{message_data_len, Callbacks};
|
||||
use super::wireguard::{Peer, PeerInner};
|
||||
use super::{bind, tun};
|
||||
|
||||
pub struct Timers {
|
||||
handshake_pending: AtomicBool,
|
||||
@@ -32,16 +32,20 @@ impl Timers {
|
||||
}
|
||||
}
|
||||
|
||||
impl <B: bind::Bind>PeerInner<B> {
|
||||
impl<B: bind::Bind> PeerInner<B> {
|
||||
/* should be called after an authenticated data packet is sent */
|
||||
pub fn timers_data_sent(&self) {
|
||||
self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
|
||||
self.timers()
|
||||
.new_handshake
|
||||
.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
|
||||
}
|
||||
|
||||
/* should be called after an authenticated data packet is received */
|
||||
pub fn timers_data_received(&self) {
|
||||
if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
|
||||
self.timers().need_another_keepalive.store(true, Ordering::SeqCst)
|
||||
self.timers()
|
||||
.need_another_keepalive
|
||||
.store(true, Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +78,9 @@ impl <B: bind::Bind>PeerInner<B> {
|
||||
*/
|
||||
pub fn timers_handshake_complete(&self) {
|
||||
self.timers().handshake_attempts.store(0, Ordering::SeqCst);
|
||||
self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst);
|
||||
self.timers()
|
||||
.sent_lastminute_handshake
|
||||
.store(false, Ordering::SeqCst);
|
||||
// TODO: Store time in peer for config
|
||||
// self.walltime_last_handshake
|
||||
}
|
||||
@@ -92,7 +98,9 @@ impl <B: bind::Bind>PeerInner<B> {
|
||||
pub fn timers_any_authenticated_packet_traversal(&self) {
|
||||
let keepalive = self.keepalive.load(Ordering::Acquire);
|
||||
if keepalive > 0 {
|
||||
self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
|
||||
self.timers()
|
||||
.send_persistent_keepalive
|
||||
.reset(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,11 +157,7 @@ impl Timers {
|
||||
new_handshake: {
|
||||
let peer = peer.clone();
|
||||
runner.timer(move || {
|
||||
info!(
|
||||
"Retrying handshake with {}, because we stopped hearing back after {} seconds",
|
||||
peer,
|
||||
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
|
||||
);
|
||||
info!("Initiate new handshake with {}", peer);
|
||||
peer.new_handshake();
|
||||
peer.timers.read().handshake_begun();
|
||||
})
|
||||
@@ -171,10 +175,12 @@ impl Timers {
|
||||
if keepalive > 0 {
|
||||
peer.router.send_keepalive();
|
||||
peer.timers().send_keepalive.stop();
|
||||
peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
|
||||
peer.timers()
|
||||
.send_persistent_keepalive
|
||||
.start(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +202,8 @@ impl Timers {
|
||||
|
||||
pub fn updated_persistent_keepalive(&self, keepalive: usize) {
|
||||
if keepalive > 0 {
|
||||
self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
|
||||
self.send_persistent_keepalive
|
||||
.reset(Duration::from_secs(keepalive as u64));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,7 +217,7 @@ impl Timers {
|
||||
new_handshake: runner.timer(|| {}),
|
||||
send_keepalive: runner.timer(|| {}),
|
||||
send_persistent_keepalive: runner.timer(|| {}),
|
||||
zero_key_material: runner.timer(|| {})
|
||||
zero_key_material: runner.timer(|| {}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ use std::collections::HashMap;
|
||||
|
||||
use log::debug;
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
use spin::{Mutex, RwLock, RwLockReadGuard};
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
@@ -37,6 +38,8 @@ pub struct Peer<T: Tun, B: Bind> {
|
||||
}
|
||||
|
||||
pub struct PeerInner<B: Bind> {
|
||||
pub id: u64,
|
||||
|
||||
pub keepalive: AtomicUsize, // keepalive interval
|
||||
pub rx_bytes: AtomicU64,
|
||||
pub tx_bytes: AtomicU64,
|
||||
@@ -50,6 +53,9 @@ pub struct PeerInner<B: Bind> {
|
||||
}
|
||||
|
||||
pub struct WireguardInner<T: Tun, B: Bind> {
|
||||
// identifier (for logging)
|
||||
id: u32,
|
||||
|
||||
// provides access to the MTU value of the tun device
|
||||
// (otherwise owned solely by the router and a dedicated read IO thread)
|
||||
mtu: T::MTU,
|
||||
@@ -96,7 +102,13 @@ impl<B: Bind> PeerInner<B> {
|
||||
|
||||
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "peer()")
|
||||
write!(f, "peer(id = {})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "wireguard({:x})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +221,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
}
|
||||
|
||||
pub fn new_peer(&self, pk: PublicKey) {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let state = Arc::new(PeerInner {
|
||||
id: rng.gen(),
|
||||
pk,
|
||||
last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
|
||||
handshake_queued: AtomicBool::new(false),
|
||||
@@ -277,11 +291,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
handshake::TYPE_COOKIE_REPLY
|
||||
| handshake::TYPE_INITIATION
|
||||
| handshake::TYPE_RESPONSE => {
|
||||
debug!("{} : reader, received handshake message", wg);
|
||||
|
||||
let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
// update under_load flag
|
||||
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
|
||||
if pending > THRESHOLD_UNDER_LOAD {
|
||||
debug!("{} : reader, set under load (pending = {})", wg, pending);
|
||||
last_under_load = Instant::now();
|
||||
wg.under_load.store(true, Ordering::SeqCst);
|
||||
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
|
||||
debug!("{} : reader, clear under load", wg);
|
||||
wg.under_load.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
@@ -291,6 +311,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
.unwrap();
|
||||
}
|
||||
router::TYPE_TRANSPORT => {
|
||||
debug!("{} : reader, received transport message", wg);
|
||||
|
||||
// transport message
|
||||
let _ = wg.router.recv(src, msg).map_err(|e| {
|
||||
debug!("Failed to handle incoming transport message: {}", e);
|
||||
@@ -313,6 +335,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let wg = Arc::new(WireguardInner {
|
||||
id: rng.gen(),
|
||||
mtu: mtu.clone(),
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
send: RwLock::new(None),
|
||||
@@ -331,12 +354,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
let wg = wg.clone();
|
||||
let rx = rx.clone();
|
||||
thread::spawn(move || {
|
||||
debug!("{} : handshake worker, started", wg);
|
||||
|
||||
// prepare OsRng instance for this thread
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
|
||||
// process elements from the handshake queue
|
||||
for job in rx {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let state = wg.handshake.read();
|
||||
if !state.active {
|
||||
continue;
|
||||
@@ -344,6 +368,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
|
||||
@@ -352,6 +378,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
&mut rng,
|
||||
&msg[..],
|
||||
if wg.under_load.load(Ordering::Relaxed) {
|
||||
debug!("{} : handshake worker, under load", wg);
|
||||
Some(&src_validate)
|
||||
} else {
|
||||
None
|
||||
@@ -364,9 +391,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
resp_len = msg.len() as u64;
|
||||
let send: &Option<B::Writer> = &*wg.send.read();
|
||||
if let Some(writer) = send.as_ref() {
|
||||
debug!(
|
||||
"{} : handshake worker, send response ({} bytes)",
|
||||
wg, resp_len
|
||||
);
|
||||
let _ = writer.write(&msg[..], &src).map_err(|e| {
|
||||
debug!(
|
||||
"handshake worker, failed to send response, error = {}",
|
||||
"{} : handshake worker, failed to send response, error = {}",
|
||||
wg,
|
||||
e
|
||||
)
|
||||
});
|
||||
@@ -387,11 +419,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
|
||||
// update timers after sending handshake response
|
||||
if resp_len > 0 {
|
||||
debug!("{} : handshake worker, handshake response sent", wg);
|
||||
peer.state.sent_handshake_response();
|
||||
}
|
||||
|
||||
// add resulting keypair to peer
|
||||
keypair.map(|kp| {
|
||||
debug!("{} : handshake worker, new keypair", wg);
|
||||
// free any unused ids
|
||||
for id in peer.router.add_keypair(kp) {
|
||||
state.device.release(id);
|
||||
@@ -400,14 +434,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||
Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e),
|
||||
}
|
||||
}
|
||||
HandshakeJob::New(pk) => {
|
||||
debug!("{} : handshake worker, new handshake requested", wg);
|
||||
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
let _ = peer.router.send(&msg[..]).map_err(|e| {
|
||||
debug!("handshake worker, failed to send handshake initiation, error = {}", e)
|
||||
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
|
||||
});
|
||||
peer.state.sent_handshake_initiation();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user