Refactoring timer code:

- Remove the Events struct
- Implement Callbacks on the PeerInner, elimiting an Arc.
This commit is contained in:
Mathias Hall-Andersen
2020-05-10 21:23:34 +02:00
parent 985fd088f8
commit 6c386146a7
9 changed files with 186 additions and 175 deletions

View File

@@ -310,25 +310,25 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.set_endpoint(B::Endpoint::from_address(addr));
peer.set_endpoint(B::Endpoint::from_address(addr));
}
}
fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.set_persistent_keepalive_interval(secs);
peer.opaque().set_persistent_keepalive_interval(secs);
}
}
fn replace_allowed_ips(&self, peer: &PublicKey) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.remove_allowed_ips();
peer.remove_allowed_ips();
}
}
fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.add_allowed_ip(ip, masklen);
peer.add_allowed_ip(ip, masklen);
}
}
@@ -337,26 +337,26 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
let peers = cfg.wireguard.list_peers();
let mut state = Vec::with_capacity(peers.len());
for p in peers {
for (pk, p) in peers {
// convert the system time to (secs, nano) since epoch
let last_handshake_time = (*p.walltime_last_handshake.lock()).and_then(|t| {
let last_handshake_time = (*p.opaque().walltime_last_handshake.lock()).and_then(|t| {
let duration = t
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0));
Some((duration.as_secs(), duration.subsec_nanos() as u64))
});
if let Some(psk) = cfg.wireguard.get_psk(&p.pk) {
if let Some(psk) = cfg.wireguard.get_psk(&pk) {
// extract state into PeerState
state.push(PeerState {
preshared_key: psk,
endpoint: p.router.get_endpoint(),
rx_bytes: p.rx_bytes.load(Ordering::Relaxed),
tx_bytes: p.tx_bytes.load(Ordering::Relaxed),
persistent_keepalive_interval: p.get_keepalive_interval(),
allowed_ips: p.router.list_allowed_ips(),
endpoint: p.get_endpoint(),
rx_bytes: p.opaque().rx_bytes.load(Ordering::Relaxed),
tx_bytes: p.opaque().tx_bytes.load(Ordering::Relaxed),
persistent_keepalive_interval: p.opaque().get_keepalive_interval(),
allowed_ips: p.list_allowed_ips(),
last_handshake_time,
public_key: p.pk,
public_key: pk,
})
}
}

View File

@@ -100,7 +100,7 @@ fn main() {
// daemonize
if !foreground {
let daemonize = Daemonize::new()
.pid_file(format!("/tmp/wgrs-{}.pid", name))
.pid_file(format!("/tmp/wireguard-rs-{}.pid", name))
.chown_pid_file(true)
.working_directory("/tmp")
.user("nobody")
@@ -170,7 +170,7 @@ fn main() {
Err(err) => {
log::info!("UAPI connection error: {}", err);
profiler_stop();
exit(0);
exit(-1);
}
}
});

View File

@@ -20,9 +20,6 @@ mod workers;
#[cfg(test)]
mod tests;
// represents a peer
pub use peer::Peer;
// represents a WireGuard interface
pub use wireguard::WireGuard;

View File

@@ -1,5 +1,4 @@
use super::router;
use super::timers::{Events, Timers};
use super::timers::Timers;
use super::tun::Tun;
use super::udp::UDP;
@@ -9,9 +8,7 @@ use super::wireguard::WireGuard;
use super::workers::HandshakeJob;
use std::fmt;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Instant, SystemTime};
use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
@@ -31,7 +28,7 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer?
// stats and configuration
pub pk: PublicKey, // public key
pub pk: PublicKey, // public key (TODO: there has to be a way to remove this)
pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes
@@ -39,20 +36,6 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub timers: RwLock<Timers>,
}
pub struct Peer<T: Tun, B: UDP> {
pub router: Arc<router::PeerHandle<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<T, B>>,
}
impl<T: Tun, B: UDP> Clone for Peer<T, B> {
fn clone(&self) -> Peer<T, B> {
Peer {
router: self.router.clone(),
state: self.state.clone(),
}
}
}
impl<T: Tun, B: UDP> PeerInner<T, B> {
/* Queue a handshake request for the parallel workers
* (if one does not already exist)
@@ -104,33 +87,3 @@ impl<T: Tun, B: UDP> fmt::Display for PeerInner<T, B> {
write!(f, "peer(id = {})", self.id)
}
}
impl<T: Tun, B: UDP> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer(id = {})", self.id)
}
}
impl<T: Tun, B: UDP> Deref for Peer<T, B> {
type Target = PeerInner<T, B>;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl<T: Tun, B: UDP> Peer<T, B> {
/// Bring the peer down. Causing:
///
/// - Timers to be stopped and disabled.
/// - All keystate to be zeroed
pub fn down(&self) {
self.stop_timers();
self.router.down();
}
/// Bring the peer up.
pub fn up(&self) {
self.router.up();
self.start_timers();
}
}

View File

@@ -22,6 +22,7 @@ use core::sync::atomic::AtomicBool;
use alloc::sync::Arc;
// TODO: consider no_std alternatives
use std::fmt;
use std::net::{IpAddr, SocketAddr};
use arraydeque::{ArrayDeque, Wrapping};
@@ -46,6 +47,14 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
pub endpoint: Mutex<Option<E>>,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for PeerInner<E, C, T, B> {
type Target = C::Opaque;
fn deref(&self) -> &Self::Target {
&self.opaque
}
}
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
inner: Arc<PeerInner<E, C, T, B>>,
}
@@ -87,6 +96,16 @@ pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<
peer: Peer<E, C, T, B>,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone
for PeerHandle<E, C, T, B>
{
fn clone(&self) -> Self {
PeerHandle {
peer: self.peer.clone(),
}
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
for PeerHandle<E, C, T, B>
{
@@ -96,6 +115,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> fmt::Display
for PeerHandle<E, C, T, B>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PeerHandle(format: TODO)")
}
}
impl EncryptionState {
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
EncryptionState {
@@ -338,6 +365,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
*self.peer.endpoint.lock() = Some(endpoint);
}
pub fn opaque(&self) -> &C::Opaque {
&self.opaque
}
/// Returns the current endpoint of the peer (for configuration)
///
/// # Note

View File

@@ -123,17 +123,13 @@ fn test_pure_wireguard() {
let peer2 = wg1.lookup_peer(&pk2).unwrap();
let peer1 = wg2.lookup_peer(&pk1).unwrap();
peer1
.router
.add_allowed_ip("192.168.1.0".parse().unwrap(), 24);
peer1.add_allowed_ip("192.168.1.0".parse().unwrap(), 24);
peer2
.router
.add_allowed_ip("192.168.2.0".parse().unwrap(), 24);
peer2.add_allowed_ip("192.168.2.0".parse().unwrap(), 24);
// set endpoint (the other should be learned dynamically)
peer2.router.set_endpoint(dummy::UnitEndpoint::new());
peer2.set_endpoint(dummy::UnitEndpoint::new());
let num_packets = 20;

View File

@@ -1,17 +1,19 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use hjul::{Runner, Timer};
use log::debug;
use hjul::Timer;
use x25519_dalek::PublicKey;
use super::constants::*;
use super::peer::{Peer, PeerInner};
use super::peer::PeerInner;
use super::router::{message_data_len, Callbacks};
use super::tun::Tun;
use super::types::KeyPair;
use super::udp::UDP;
use super::WireGuard;
pub struct Timers {
// only updated during configuration
@@ -229,7 +231,35 @@ impl<T: Tun, B: UDP> PeerInner<T, B> {
}
impl Timers {
pub fn new<T: Tun, B: UDP>(runner: &Runner, running: bool, peer: Peer<T, B>) -> Timers {
pub fn new<T: Tun, B: UDP>(
wg: WireGuard<T, B>, // WireGuard device
pk: PublicKey, // public key of peer
running: bool, // timers started
) -> Timers {
macro_rules! fetch_peer {
( $wg:expr, $pk:expr ) => {
match $wg.lookup_peer(&$pk) {
None => {
return;
}
Some(peer) => peer,
}
};
}
macro_rules! fetch_timer {
( $peer:expr ) => {{
let timers = $peer.timers();
if timers.enabled {
timers
} else {
return;
}
}};
}
let runner = wg.runner.lock();
// create a timer instance for the provided peer
Timers {
enabled: running,
@@ -238,21 +268,16 @@ impl Timers {
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: {
let peer = peer.clone();
let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (retransmit_handshake)", peer);
// ignore if timers are disabled
let timers = peer.timers();
if !timers.enabled {
return;
}
// check if handshake attempts remaining
let attempts = peer
.timers()
.handshake_attempts
.fetch_add(1, Ordering::SeqCst);
let attempts = timers.handshake_attempts.fetch_add(1, Ordering::SeqCst);
if attempts > MAX_TIMER_HANDSHAKES {
debug!(
"Handshake for peer {} did not complete after {} attempts, giving up",
@@ -261,7 +286,7 @@ impl Timers {
);
timers.send_keepalive.stop();
timers.zero_key_material.start(REJECT_AFTER_TIME * 3);
peer.router.purge_staged_packets();
peer.purge_staged_packets();
} else {
debug!(
"Handshake for {} did not complete after {} seconds, retrying (try {})",
@@ -270,56 +295,72 @@ impl Timers {
attempts
);
timers.retransmit_handshake.reset(REKEY_TIMEOUT);
peer.router.clear_src();
peer.clear_src();
peer.packet_send_queued_handshake_initiation(true);
}
})
},
send_keepalive: {
let peer = peer.clone();
let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (send_keepalive)", peer);
// ignore if timers are disabled
let timers = peer.timers();
if !timers.enabled {
return;
}
peer.router.send_keepalive();
// send keepalive and schedule next keepalive
peer.send_keepalive();
if timers.need_another_keepalive() {
timers.send_keepalive.start(KEEPALIVE_TIMEOUT);
}
})
},
new_handshake: {
let peer = peer.clone();
let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let _timers = fetch_timer!(peer);
log::trace!("{} : timer fired (new_handshake)", peer);
// clear source and retry
log::debug!(
"Retrying handshake with {} because we stopped hearing back after {} seconds",
peer,
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
);
peer.router.clear_src();
peer.clear_src();
peer.packet_send_queued_handshake_initiation(false);
})
},
zero_key_material: {
let peer = peer.clone();
let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let _timers = fetch_timer!(peer);
log::trace!("{} : timer fired (zero_key_material)", peer);
peer.router.zero_keys();
// null all key-material
peer.zero_keys();
})
},
send_persistent_keepalive: {
let peer = peer.clone();
let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (send_persistent_keepalive)", peer);
let timers = peer.timers();
if timers.enabled && timers.keepalive_interval > 0 {
// send and schedule persistent keepalive
if timers.keepalive_interval > 0 {
timers.send_keepalive.stop();
peer.router.send_keepalive();
peer.send_keepalive();
log::trace!("{} : keepalive queued", peer);
timers
.send_persistent_keepalive
@@ -329,28 +370,10 @@ impl Timers {
},
}
}
pub fn dummy(runner: &Runner) -> Timers {
Timers {
enabled: false,
keepalive_interval: 0,
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: runner.timer(|| {}),
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
send_persistent_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {}),
}
}
}
/* instance of the router callbacks */
pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: Tun, B: UDP> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<T, B>>;
impl<T: Tun, B: UDP> Callbacks for PeerInner<T, B> {
type Opaque = Self;
/* Called after the router encrypts a transport message destined for the peer.
* This method is called, even if the encrypted payload is empty (keepalive)

View File

@@ -1,8 +1,8 @@
use super::constants::*;
use super::handshake;
use super::peer::{Peer, PeerInner};
use super::peer::PeerInner;
use super::router;
use super::timers::{Events, Timers};
use super::timers::Timers;
use super::queue::ParallelQueue;
use super::workers::HandshakeJob;
@@ -45,10 +45,12 @@ pub struct WireguardInner<T: Tun, B: UDP> {
pub mtu: AtomicUsize,
// peer map
pub peers: RwLock<handshake::Device<Peer<T, B>>>,
pub peers: RwLock<
handshake::Device<router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>>,
>,
// cryptokey router
pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
pub router: router::Device<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>,
// handshake related state
pub last_under_load: Mutex<Instant>,
@@ -136,6 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// set all peers down (stops timers)
for (_, peer) in self.peers.write().iter() {
peer.stop_timers();
peer.down();
}
@@ -162,6 +165,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// set all peers up (restarts timers)
for (_, peer) in self.peers.write().iter() {
peer.up();
peer.start_timers();
}
*enabled = true;
@@ -175,16 +179,24 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
let _ = self.peers.write().remove(pk);
}
pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> {
self.peers.read().get(pk).map(|p| p.clone())
pub fn lookup_peer(
&self,
pk: &PublicKey,
) -> Option<router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>> {
self.peers.read().get(pk).map(|handle| handle.clone())
}
pub fn list_peers(&self) -> Vec<Peer<T, B>> {
pub fn list_peers(
&self,
) -> Vec<(
PublicKey,
router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>,
)> {
let peers = self.peers.read();
let mut list = Vec::with_capacity(peers.len());
for (k, v) in peers.iter() {
debug_assert!(k.as_bytes() == v.pk.as_bytes());
list.push(v.clone());
debug_assert!(k.as_bytes() == v.opaque().pk.as_bytes());
list.push((k.clone(), v.clone()));
}
list
}
@@ -215,7 +227,15 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
return false;
}
let state = Arc::new(PeerInner {
// prevent up/down while inserting
let enabled = *self.enabled.read();
// create timers (lookup by public key)
let timers = Timers::new::<T, B>(self.clone(), pk.clone(), enabled);
// create new router peer
let peer: router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer> =
self.router.new_peer(PeerInner {
id: OsRng.gen(),
pk,
wg: self.clone(),
@@ -224,27 +244,10 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
handshake_queued: AtomicBool::new(false),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&*self.runner.lock())),
timers: RwLock::new(timers),
});
// create a router peer
let router = Arc::new(self.router.new_peer(state.clone()));
// form WireGuard peer
let peer = Peer { router, state };
// prevent up/down while inserting
let enabled = self.enabled.read();
/* The need for dummy timers arises from the chicken-egg
* problem of the timer callbacks being able to set timers themselves.
*
* This is in fact the only place where the write lock is ever taken.
* TODO: Consider the ease of using atomic pointers instead.
*/
*peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone());
// finally, add the peer to the wireguard device
// finally, add the peer to the handshake device
peers.add(pk, peer).is_ok()
}
@@ -288,6 +291,10 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// create handshake queue
let (tx, mut rxs) = ParallelQueue::new(cpus, 128);
// create router
let router: router::Device<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer> =
router::Device::new(num_cpus::get(), writer);
// create arc to state
let wg = WireGuard {
inner: Arc::new(WireguardInner {
@@ -296,7 +303,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
id: OsRng.gen(),
mtu: AtomicUsize::new(0),
last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
router: router::Device::new(num_cpus::get(), writer),
router,
pending: AtomicUsize::new(0),
peers: RwLock::new(handshake::Device::new()),
runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)),

View File

@@ -209,23 +209,25 @@ pub fn handshake_worker<T: Tun, B: UDP>(
// add to rx_bytes and tx_bytes
let req_len = msg.len() as u64;
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
peer.opaque().rx_bytes.fetch_add(req_len, Ordering::Relaxed);
peer.opaque()
.tx_bytes
.fetch_add(resp_len, Ordering::Relaxed);
// update endpoint
peer.router.set_endpoint(src);
peer.set_endpoint(src);
if resp_len > 0 {
// update timers after sending handshake response
debug!("{} : handshake worker, handshake response sent", wg);
peer.state.sent_handshake_response();
peer.opaque().sent_handshake_response();
} else {
// update timers after receiving handshake response
debug!(
"{} : handshake worker, handshake response was received",
wg
);
peer.state.timers_handshake_complete();
peer.opaque().timers_handshake_complete();
}
// add any new keypair to peer
@@ -233,10 +235,10 @@ pub fn handshake_worker<T: Tun, B: UDP>(
debug!("{} : handshake worker, new keypair for {}", wg, peer);
// this means that a handshake response was processed or sent
peer.timers_session_derived();
peer.opaque().timers_session_derived();
// free any unused ids
for id in peer.router.add_keypair(kp) {
for id in peer.add_keypair(kp) {
device.release(id);
}
});
@@ -252,13 +254,15 @@ pub fn handshake_worker<T: Tun, B: UDP>(
wg, peer
);
let device = wg.peers.read();
let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| {
let _ = peer.router.send_raw(&msg[..]).map_err(|e| {
let _ = device.begin(&mut OsRng, &pk).map(|msg| {
let _ = peer.send_raw(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});
peer.state.sent_handshake_initiation();
peer.opaque().sent_handshake_initiation();
});
peer.handshake_queued.store(false, Ordering::SeqCst);
peer.opaque()
.handshake_queued
.store(false, Ordering::SeqCst);
}
}
}