Added key_confirmed callback

This commit is contained in:
Mathias Hall-Andersen
2019-09-28 18:01:55 +02:00
parent 794933d6dd
commit edfd2f235a
8 changed files with 218 additions and 115 deletions

View File

@@ -11,3 +11,8 @@ pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
pub const MAX_TIMER_HANDSHAKES: usize = 18; pub const MAX_TIMER_HANDSHAKES: usize = 18;
pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
pub const TIMERS_CAPACITY: usize = 1024;

View File

@@ -12,4 +12,7 @@ mod timers;
mod types; mod types;
mod wireguard; mod wireguard;
#[test]
fn test_pure_wireguard() {}
fn main() {} fn main() {}

View File

@@ -60,6 +60,8 @@ pub struct Device<C: Callbacks, T: Tun, B: Bind> {
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
fn drop(&mut self) { fn drop(&mut self) {
debug!("router: dropping device");
// drop all queues // drop all queues
{ {
let mut queues = self.state.queues.lock(); let mut queues = self.state.queues.lock();
@@ -76,7 +78,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
_ => false, _ => false,
} {} } {}
debug!("device dropped"); debug!("router: device dropped");
} }
} }
@@ -175,7 +177,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
// schedule for encryption and transmission to peer // schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg) { if let Some(job) = peer.send_job(msg, true) {
debug_assert_eq!(job.1.op, Operation::Encryption); debug_assert_eq!(job.1.op, Operation::Encryption);
// add job to worker queue // add job to worker queue

View File

@@ -217,6 +217,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
fn send_staged(&self) -> bool { fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false; let mut sent = false;
let mut staged = self.staged_packets.lock(); let mut staged = self.staged_packets.lock();
loop { loop {
@@ -230,8 +231,11 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
} }
} }
// Treat the msg as the payload of a transport message
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
fn send_raw(&self, msg: Vec<u8>) -> bool { fn send_raw(&self, msg: Vec<u8>) -> bool {
match self.send_job(msg) { debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => { Some(job) => {
debug!("send_raw: got obtained send_job"); debug!("send_raw: got obtained send_job");
let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
@@ -246,30 +250,36 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
} }
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
// take lock and check keypair = keys.next debug!("peer.confirm_key");
let mut keys = self.keys.lock(); {
let next = match keys.next.as_ref() { // take lock and check keypair = keys.next
Some(next) => next, let mut keys = self.keys.lock();
None => { let next = match keys.next.as_ref() {
Some(next) => next,
None => {
return;
}
};
if !Arc::ptr_eq(&next, keypair) {
return; return;
} }
};
if !Arc::ptr_eq(&next, keypair) { // allocate new encryption state
return; let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel
let mut swap = None;
mem::swap(&mut keys.next, &mut swap);
mem::swap(&mut keys.current, &mut swap);
mem::swap(&mut keys.previous, &mut swap);
// tell the world outside the router that a key was confirmed
C::key_confirmed(&self.opaque);
// set new key for encryption
*self.ekey.lock() = ekey;
} }
// allocate new encryption state
let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel
let mut swap = None;
mem::swap(&mut keys.next, &mut swap);
mem::swap(&mut keys.current, &mut swap);
mem::swap(&mut keys.previous, &mut swap);
// set new encryption key
*self.ekey.lock() = ekey;
// start transmission of staged packets // start transmission of staged packets
self.send_staged(); self.send_staged();
} }
@@ -296,7 +306,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
} }
} }
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> { pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
debug!("peer.send_job");
debug_assert!( debug_assert!(
msg.len() >= mem::size_of::<TransportHeader>(), msg.len() >= mem::size_of::<TransportHeader>(),
"received message with size: {:}", "received message with size: {:}",
@@ -319,7 +330,6 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
None None
} else { } else {
// there should be no stacked packets lingering around // there should be no stacked packets lingering around
debug_assert_eq!(self.staged_packets.lock().len(), 0);
debug!("encryption state available, nonce = {}", state.nonce); debug!("encryption state available, nonce = {}", state.nonce);
// set transport message fields // set transport message fields
@@ -334,7 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
// If not suitable key was found: // If not suitable key was found:
// 1. Stage packet for later transmission // 1. Stage packet for later transmission
// 2. Request new key // 2. Request new key
if key.is_none() { if key.is_none() && stage {
self.staged_packets.lock().push_back(msg); self.staged_packets.lock().push_back(msg);
C::need_key(&self.opaque); C::need_key(&self.opaque);
return None; return None;
@@ -372,6 +382,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// This API still permits support for the "sticky socket" behavior, /// This API still permits support for the "sticky socket" behavior,
/// as sockets should be "unsticked" when manually updating the endpoint /// as sockets should be "unsticked" when manually updating the endpoint
pub fn set_endpoint(&self, address: SocketAddr) { pub fn set_endpoint(&self, address: SocketAddr) {
debug!("peer.set_endpoint");
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
} }
@@ -381,6 +392,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// ///
/// Does not convey potential "sticky socket" information /// Does not convey potential "sticky socket" information
pub fn get_endpoint(&self) -> Option<SocketAddr> { pub fn get_endpoint(&self) -> Option<SocketAddr> {
debug!("peer.get_endpoint");
self.state self.state
.endpoint .endpoint
.lock() .lock()
@@ -390,6 +402,8 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// Zero all key-material related to the peer /// Zero all key-material related to the peer
pub fn zero_keys(&self) { pub fn zero_keys(&self) {
debug!("peer.zero_keys");
let mut release: Vec<u32> = Vec::with_capacity(3); let mut release: Vec<u32> = Vec::with_capacity(3);
let mut keys = self.state.keys.lock(); let mut keys = self.state.keys.lock();
@@ -429,57 +443,74 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// since the only way to add additional keys to the peer is by using this method /// since the only way to add additional keys to the peer is by using this method
/// and a peer can have at most 3 keys allocated in the router at any time. /// and a peer can have at most 3 keys allocated in the router at any time.
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> { pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
let new = Arc::new(new); debug!("peer.add_keypair");
let mut keys = self.state.keys.lock();
let mut release = mem::replace(&mut keys.retired, vec![]);
// update key-wheel let initiator = new.initiator;
if new.initiator { let release = {
// start using key for encryption let new = Arc::new(new);
*self.state.ekey.lock() = Some(EncryptionState::new(&new)); let mut keys = self.state.keys.lock();
let mut release = mem::replace(&mut keys.retired, vec![]);
// move current into previous // update key-wheel
keys.previous = keys.current.as_ref().map(|v| v.clone()); if new.initiator {
keys.current = Some(new.clone()); // start using key for encryption
} else { *self.state.ekey.lock() = Some(EncryptionState::new(&new));
// store the key and await confirmation
keys.previous = keys.next.as_ref().map(|v| v.clone()); // move current into previous
keys.next = Some(new.clone()); keys.previous = keys.current.as_ref().map(|v| v.clone());
keys.current = Some(new.clone());
} else {
// store the key and await confirmation
keys.previous = keys.next.as_ref().map(|v| v.clone());
keys.next = Some(new.clone());
};
// update incoming packet id map
{
debug!("peer.add_keypair: updating inbound id map");
let mut recv = self.state.device.recv.write();
// purge recv map of previous id
keys.previous.as_ref().map(|k| {
recv.remove(&k.local_id());
release.push(k.local_id());
});
// map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
recv.insert(
new.recv.id,
Arc::new(DecryptionState::new(&self.state, &new)),
);
}
release
}; };
// update incoming packet id map
{
let mut recv = self.state.device.recv.write();
// purge recv map of previous id
keys.previous.as_ref().map(|k| {
recv.remove(&k.local_id());
release.push(k.local_id());
});
// map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
recv.insert(
new.recv.id,
Arc::new(DecryptionState::new(&self.state, &new)),
);
}
// schedule confirmation // schedule confirmation
if new.initiator { if initiator {
// fall back to keepalive packet debug_assert!(self.state.ekey.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
if !self.state.send_staged() { if !self.state.send_staged() {
let ok = self.keepalive(); // fall back to keepalive packet
debug!("keepalive for confirmation, sent = {}", ok); let ok = self.send_keepalive();
debug!(
"peer.add_keypair: keepalive for confirmation, sent = {}",
ok
);
} }
debug!("peer.add_keypair: key attempted confirmed");
} }
debug_assert!(release.len() <= 3); debug_assert!(
release.len() <= 3,
"since the key-wheel contains at most 3 keys"
);
release release
} }
pub fn keepalive(&self) -> bool { pub fn send_keepalive(&self) -> bool {
debug!("send keepalive"); debug!("peer.send_keepalive");
self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
} }
@@ -498,6 +529,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// If an identical value already exists as part of a prior peer, /// If an identical value already exists as part of a prior peer,
/// the allowed IP entry will be removed from that peer and added to this peer. /// the allowed IP entry will be removed from that peer and added to this peer.
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
debug!("peer.add_subnet");
match ip { match ip {
IpAddr::V4(v4) => { IpAddr::V4(v4) => {
self.state self.state
@@ -522,6 +554,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// ///
/// A vector of subnets, represented by as mask/size /// A vector of subnets, represented by as mask/size
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
debug!("peer.list_subnets");
let mut res = Vec::new(); let mut res = Vec::new();
res.append(&mut treebit_list( res.append(&mut treebit_list(
&self.state, &self.state,
@@ -540,6 +573,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// After the call, no subnets will be cryptkey routed to the peer. /// After the call, no subnets will be cryptkey routed to the peer.
/// Used for the UAPI command "replace_allowed_ips=true" /// Used for the UAPI command "replace_allowed_ips=true"
pub fn remove_subnets(&self) { pub fn remove_subnets(&self) {
debug!("peer.remove_subnets");
treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv4);
treebit_remove(self, &self.state.device.ipv6); treebit_remove(self, &self.state.device.ipv6);
} }
@@ -554,6 +588,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
/// ///
/// Unit if packet was sent, or an error indicating why sending failed /// Unit if packet was sent, or an error indicating why sending failed
pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
let inner = &self.state; let inner = &self.state;
match inner.endpoint.lock().as_ref() { match inner.endpoint.lock().as_ref() {
Some(endpoint) => inner Some(endpoint) => inner

View File

@@ -1,7 +1,7 @@
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::Ordering;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
@@ -228,6 +228,7 @@ mod tests {
send: Mutex<Vec<(usize, bool, bool)>>, send: Mutex<Vec<(usize, bool, bool)>>,
recv: Mutex<Vec<(usize, bool, bool)>>, recv: Mutex<Vec<(usize, bool, bool)>>,
need_key: Mutex<Vec<()>>, need_key: Mutex<Vec<()>>,
key_confirmed: Mutex<Vec<()>>,
} }
#[derive(Clone)] #[derive(Clone)]
@@ -241,6 +242,7 @@ mod tests {
send: Mutex::new(vec![]), send: Mutex::new(vec![]),
recv: Mutex::new(vec![]), recv: Mutex::new(vec![]),
need_key: Mutex::new(vec![]), need_key: Mutex::new(vec![]),
key_confirmed: Mutex::new(vec![]),
})) }))
} }
@@ -248,6 +250,7 @@ mod tests {
self.0.send.lock().unwrap().clear(); self.0.send.lock().unwrap().clear();
self.0.recv.lock().unwrap().clear(); self.0.recv.lock().unwrap().clear();
self.0.need_key.lock().unwrap().clear(); self.0.need_key.lock().unwrap().clear();
self.0.key_confirmed.lock().unwrap().clear();
} }
fn send(&self) -> Option<(usize, bool, bool)> { fn send(&self) -> Option<(usize, bool, bool)> {
@@ -262,11 +265,17 @@ mod tests {
self.0.need_key.lock().unwrap().pop() self.0.need_key.lock().unwrap().pop()
} }
fn key_confirmed(&self) -> Option<()> {
self.0.key_confirmed.lock().unwrap().pop()
}
// has all events been accounted for by assertions?
fn is_empty(&self) -> bool { fn is_empty(&self) -> bool {
let send = self.0.send.lock().unwrap(); let send = self.0.send.lock().unwrap();
let recv = self.0.recv.lock().unwrap(); let recv = self.0.recv.lock().unwrap();
let need_key = self.0.need_key.lock().unwrap(); let need_key = self.0.need_key.lock().unwrap();
send.is_empty() && recv.is_empty() && need_key.is_empty() let key_confirmed = self.0.key_confirmed.lock().unwrap();
send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty()
} }
} }
@@ -284,6 +293,15 @@ mod tests {
fn need_key(t: &Self::Opaque) { fn need_key(t: &Self::Opaque) {
t.0.need_key.lock().unwrap().push(()); t.0.need_key.lock().unwrap().push(());
} }
fn key_confirmed(t: &Self::Opaque) {
t.0.key_confirmed.lock().unwrap().push(());
}
}
// wait for scheduling
fn wait() {
thread::sleep(Duration::from_millis(50));
} }
fn init() { fn init() {
@@ -319,6 +337,7 @@ mod tests {
} }
fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
fn need_key(_: &Self::Opaque) {} fn need_key(_: &Self::Opaque) {}
fn key_confirmed(_: &Self::Opaque) {}
} }
// create device // create device
@@ -336,7 +355,7 @@ mod tests {
let ip1: IpAddr = ip.parse().unwrap(); let ip1: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len); peer.add_subnet(mask, len);
// every iteration sends 50 GB // every iteration sends 10 GB
b.iter(|| { b.iter(|| {
opaque.store(0, Ordering::SeqCst); opaque.store(0, Ordering::SeqCst);
let msg = make_packet(1024, ip1); let msg = make_packet(1024, ip1);
@@ -400,7 +419,7 @@ mod tests {
let res = router.send(msg); let res = router.send(msg);
// allow some scheduling // allow some scheduling
thread::sleep(Duration::from_millis(20)); wait();
if *okay { if *okay {
// cryptkey routing succeeded // cryptkey routing succeeded
@@ -444,12 +463,8 @@ mod tests {
} }
} }
fn wait() {
thread::sleep(Duration::from_millis(20));
}
#[test] #[test]
fn test_outbound_inbound() { fn test_bidirectional() {
init(); init();
let tests = [ let tests = [
@@ -463,15 +478,42 @@ mod tests {
("192.168.1.0", 24, "192.168.1.20", true), ("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true), ("172.133.133.133", 32, "172.133.133.133", true),
), ),
(
false, // confirm with keepalive
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
(
"2001:db8::ff40:42:8000",
113,
"2001:db8::ff40:42:ffff",
true,
),
),
(
false, // confirm with staged packet
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
(
"2001:db8::ff40:42:8000",
113,
"2001:db8::ff40:42:ffff",
true,
),
),
]; ];
for (stage, p1, p2) in tests.iter() { for (stage, p1, p2) in tests.iter() {
let (bind1, bind2) = bind_pair();
// create matching devices // create matching devices
let (bind1, bind2) = bind_pair();
let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone()); let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone());
let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone()); let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone());
// prepare opaque values for tracing callbacks // prepare opaque values for tracing callbacks
@@ -519,9 +561,7 @@ mod tests {
wait(); wait();
assert!(opaq2.send().is_some()); assert!(opaq2.send().is_some());
assert!(opaq2.recv().is_none()); assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
assert!(opaq2.need_key().is_none());
assert!(opaq2.is_empty());
assert!(opaq1.is_empty(), "nothing should happened on peer1"); assert!(opaq1.is_empty(), "nothing should happened on peer1");
// read confirming message received by the other end ("across the internet") // read confirming message received by the other end ("across the internet")
@@ -531,14 +571,16 @@ mod tests {
router1.recv(from, buf).unwrap(); router1.recv(from, buf).unwrap();
wait(); wait();
assert!(opaq1.send().is_none());
assert!(opaq1.recv().is_some()); assert!(opaq1.recv().is_some());
assert!(opaq1.need_key().is_none()); assert!(opaq1.key_confirmed().is_some());
assert!(opaq1.is_empty()); assert!(
opaq1.is_empty(),
"events on peer1 should be 'recv' and 'key_confirmed'"
);
assert!(peer1.get_endpoint().is_some()); assert!(peer1.get_endpoint().is_some());
assert!(opaq2.is_empty(), "nothing should happened on peer2"); assert!(opaq2.is_empty(), "nothing should happened on peer2");
// how that peer1 has an endpoint // now that peer1 has an endpoint
// route packets : peer1 -> peer2 // route packets : peer1 -> peer2
for _ in 0..10 { for _ in 0..10 {
@@ -572,8 +614,6 @@ mod tests {
assert!(opaq2.recv().is_some()); assert!(opaq2.recv().is_some());
assert!(opaq2.need_key().is_none()); assert!(opaq2.need_key().is_none());
} }
// route packets : peer2 -> peer1
} }
} }
} }

View File

@@ -23,9 +23,10 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
pub trait Callbacks: Send + Sync + 'static { pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque; type Opaque: Opaque;
fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
fn need_key(_opaque: &Self::Opaque) {} fn need_key(opaque: &Self::Opaque);
fn key_confirmed(opaque: &Self::Opaque);
} }
#[derive(Debug)] #[derive(Debug)]

View File

@@ -1,5 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use hjul::{Runner, Timer}; use hjul::{Runner, Timer};
@@ -7,7 +8,7 @@ use hjul::{Runner, Timer};
use crate::constants::*; use crate::constants::*;
use crate::router::Callbacks; use crate::router::Callbacks;
use crate::types::{Bind, Tun}; use crate::types::{Bind, Tun};
use crate::wireguard::Peer; use crate::wireguard::{Peer, PeerInner};
pub struct Timers { pub struct Timers {
handshake_pending: AtomicBool, handshake_pending: AtomicBool,
@@ -47,7 +48,7 @@ impl Timers {
send_keepalive: { send_keepalive: {
let peer = peer.clone(); let peer = peer.clone();
runner.timer(move || { runner.timer(move || {
peer.router.keepalive(); peer.router.send_keepalive();
let keepalive = peer.keepalive.load(Ordering::Acquire); let keepalive = peer.keepalive.load(Ordering::Acquire);
if keepalive > 0 { if keepalive > 0 {
peer.timers peer.timers
@@ -103,21 +104,26 @@ impl Timers {
pub struct Events<T, B>(PhantomData<(T, B)>); pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: Tun, B: Bind> Callbacks for Events<T, B> { impl<T: Tun, B: Bind> Callbacks for Events<T, B> {
type Opaque = Peer<T, B>; type Opaque = Arc<PeerInner<B>>;
fn send(peer: &Peer<T, B>, size: usize, data: bool, sent: bool) { fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
} }
fn recv(peer: &Peer<T, B>, size: usize, data: bool, sent: bool) { fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
} }
fn need_key(peer: &Peer<T, B>) { fn need_key(peer: &Self::Opaque) {
let timers = peer.timers.read(); let timers = peer.timers.read();
if !timers.handshake_pending.swap(true, Ordering::SeqCst) { if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
timers.handshake_attempts.store(0, Ordering::SeqCst); timers.handshake_attempts.store(0, Ordering::SeqCst);
timers.new_handshake.fire(); timers.new_handshake.fire();
} }
} }
fn key_confirmed(peer: &Self::Opaque) {
let timers = peer.timers.read();
timers.retransmit_handshake.stop();
}
} }

View File

@@ -1,8 +1,12 @@
use crate::constants::*;
use crate::handshake; use crate::handshake;
use crate::router; use crate::router;
use crate::timers::{Events, Timers}; use crate::timers::{Events, Timers};
use crate::types::{Bind, Endpoint, Tun}; use crate::types::{Bind, Endpoint, Tun};
use hjul::Runner;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
@@ -22,28 +26,32 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>; #[derive(Clone)]
pub struct Peer<T: Tun, B: Bind> {
pub router: Arc<router::Peer<Events<T, B>, T, B>>,
pub state: Arc<PeerInner<B>>,
}
pub struct PeerInner<T: Tun, B: Bind> { pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64, pub rx_bytes: AtomicU64,
pub tx_bytes: AtomicU64, pub tx_bytes: AtomicU64,
pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
pub router: router::Peer<Events<T, B>, T, B>, // router peer pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
pub timers: RwLock<Timers>, // pub timers: RwLock<Timers>, //
} }
impl<T: Tun, B: Bind> PeerInner<T, B> { impl<T: Tun, B: Bind> Deref for Peer<T, B> {
pub fn new_handshake(&self) { type Target = PeerInner<B>;
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); fn deref(&self) -> &Self::Target {
&self.state
} }
} }
macro_rules! timers { impl<B: Bind> PeerInner<B> {
($peer:expr) => { pub fn new_handshake(&self) {
$peer.timers.read() self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}; }
} }
struct Handshake { struct Handshake {
@@ -74,6 +82,7 @@ struct WireguardInner<T: Tun, B: Bind> {
} }
pub struct Wireguard<T: Tun, B: Bind> { pub struct Wireguard<T: Tun, B: Bind> {
runner: Runner,
state: Arc<WireguardInner<T, B>>, state: Arc<WireguardInner<T, B>>,
} }
@@ -93,19 +102,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
} }
/*
fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let router = self.state.router.new_peer(); let state = Arc::new(PeerInner {
Arc::new(PeerInner {
pk, pk,
queue: Mutex::new(self.state.queue.lock().clone()), queue: Mutex::new(self.state.queue.lock().clone()),
keepalive: AtomicUsize::new(0), keepalive: AtomicUsize::new(0),
rx_bytes: AtomicU64::new(0), rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0),
}) timers: RwLock::new(Timers::dummy(&self.runner)),
});
let router = Arc::new(self.state.router.new_peer(state.clone()));
Peer { router, state }
} }
*/
fn new(tun: T, bind: B) -> Wireguard<T, B> { fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state // create device state
@@ -189,7 +197,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
peer.router.send(&msg[..]); peer.router.send(&msg[..]);
timers!(peer).handshake_sent(); peer.timers.read().handshake_sent();
} }
} }
} }
@@ -270,6 +278,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}); });
} }
Wireguard { state: wg } Wireguard {
state: wg,
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
}
} }
} }