Refactoring timer code:

- Remove the Events struct
- Implement Callbacks on the PeerInner, elimiting an Arc.
This commit is contained in:
Mathias Hall-Andersen
2020-05-10 21:23:34 +02:00
parent 985fd088f8
commit 6c386146a7
9 changed files with 186 additions and 175 deletions

View File

@@ -310,25 +310,25 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.set_endpoint(B::Endpoint::from_address(addr)); peer.set_endpoint(B::Endpoint::from_address(addr));
} }
} }
fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.set_persistent_keepalive_interval(secs); peer.opaque().set_persistent_keepalive_interval(secs);
} }
} }
fn replace_allowed_ips(&self, peer: &PublicKey) { fn replace_allowed_ips(&self, peer: &PublicKey) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.remove_allowed_ips(); peer.remove_allowed_ips();
} }
} }
fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) {
if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.add_allowed_ip(ip, masklen); peer.add_allowed_ip(ip, masklen);
} }
} }
@@ -337,26 +337,26 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
let peers = cfg.wireguard.list_peers(); let peers = cfg.wireguard.list_peers();
let mut state = Vec::with_capacity(peers.len()); let mut state = Vec::with_capacity(peers.len());
for p in peers { for (pk, p) in peers {
// convert the system time to (secs, nano) since epoch // convert the system time to (secs, nano) since epoch
let last_handshake_time = (*p.walltime_last_handshake.lock()).and_then(|t| { let last_handshake_time = (*p.opaque().walltime_last_handshake.lock()).and_then(|t| {
let duration = t let duration = t
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0)); .unwrap_or(Duration::from_secs(0));
Some((duration.as_secs(), duration.subsec_nanos() as u64)) Some((duration.as_secs(), duration.subsec_nanos() as u64))
}); });
if let Some(psk) = cfg.wireguard.get_psk(&p.pk) { if let Some(psk) = cfg.wireguard.get_psk(&pk) {
// extract state into PeerState // extract state into PeerState
state.push(PeerState { state.push(PeerState {
preshared_key: psk, preshared_key: psk,
endpoint: p.router.get_endpoint(), endpoint: p.get_endpoint(),
rx_bytes: p.rx_bytes.load(Ordering::Relaxed), rx_bytes: p.opaque().rx_bytes.load(Ordering::Relaxed),
tx_bytes: p.tx_bytes.load(Ordering::Relaxed), tx_bytes: p.opaque().tx_bytes.load(Ordering::Relaxed),
persistent_keepalive_interval: p.get_keepalive_interval(), persistent_keepalive_interval: p.opaque().get_keepalive_interval(),
allowed_ips: p.router.list_allowed_ips(), allowed_ips: p.list_allowed_ips(),
last_handshake_time, last_handshake_time,
public_key: p.pk, public_key: pk,
}) })
} }
} }

View File

@@ -100,7 +100,7 @@ fn main() {
// daemonize // daemonize
if !foreground { if !foreground {
let daemonize = Daemonize::new() let daemonize = Daemonize::new()
.pid_file(format!("/tmp/wgrs-{}.pid", name)) .pid_file(format!("/tmp/wireguard-rs-{}.pid", name))
.chown_pid_file(true) .chown_pid_file(true)
.working_directory("/tmp") .working_directory("/tmp")
.user("nobody") .user("nobody")
@@ -170,7 +170,7 @@ fn main() {
Err(err) => { Err(err) => {
log::info!("UAPI connection error: {}", err); log::info!("UAPI connection error: {}", err);
profiler_stop(); profiler_stop();
exit(0); exit(-1);
} }
} }
}); });

View File

@@ -20,9 +20,6 @@ mod workers;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
// represents a peer
pub use peer::Peer;
// represents a WireGuard interface // represents a WireGuard interface
pub use wireguard::WireGuard; pub use wireguard::WireGuard;

View File

@@ -1,5 +1,4 @@
use super::router; use super::timers::Timers;
use super::timers::{Events, Timers};
use super::tun::Tun; use super::tun::Tun;
use super::udp::UDP; use super::udp::UDP;
@@ -9,9 +8,7 @@ use super::wireguard::WireGuard;
use super::workers::HandshakeJob; use super::workers::HandshakeJob;
use std::fmt; use std::fmt;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Instant, SystemTime}; use std::time::{Instant, SystemTime};
use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
@@ -31,7 +28,7 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer?
// stats and configuration // stats and configuration
pub pk: PublicKey, // public key pub pk: PublicKey, // public key (TODO: there has to be a way to remove this)
pub rx_bytes: AtomicU64, // received bytes pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes pub tx_bytes: AtomicU64, // transmitted bytes
@@ -39,20 +36,6 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub timers: RwLock<Timers>, pub timers: RwLock<Timers>,
} }
pub struct Peer<T: Tun, B: UDP> {
pub router: Arc<router::PeerHandle<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<T, B>>,
}
impl<T: Tun, B: UDP> Clone for Peer<T, B> {
fn clone(&self) -> Peer<T, B> {
Peer {
router: self.router.clone(),
state: self.state.clone(),
}
}
}
impl<T: Tun, B: UDP> PeerInner<T, B> { impl<T: Tun, B: UDP> PeerInner<T, B> {
/* Queue a handshake request for the parallel workers /* Queue a handshake request for the parallel workers
* (if one does not already exist) * (if one does not already exist)
@@ -104,33 +87,3 @@ impl<T: Tun, B: UDP> fmt::Display for PeerInner<T, B> {
write!(f, "peer(id = {})", self.id) write!(f, "peer(id = {})", self.id)
} }
} }
impl<T: Tun, B: UDP> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer(id = {})", self.id)
}
}
impl<T: Tun, B: UDP> Deref for Peer<T, B> {
type Target = PeerInner<T, B>;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl<T: Tun, B: UDP> Peer<T, B> {
/// Bring the peer down. Causing:
///
/// - Timers to be stopped and disabled.
/// - All keystate to be zeroed
pub fn down(&self) {
self.stop_timers();
self.router.down();
}
/// Bring the peer up.
pub fn up(&self) {
self.router.up();
self.start_timers();
}
}

View File

@@ -22,6 +22,7 @@ use core::sync::atomic::AtomicBool;
use alloc::sync::Arc; use alloc::sync::Arc;
// TODO: consider no_std alternatives // TODO: consider no_std alternatives
use std::fmt;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use arraydeque::{ArrayDeque, Wrapping}; use arraydeque::{ArrayDeque, Wrapping};
@@ -46,6 +47,14 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
pub endpoint: Mutex<Option<E>>, pub endpoint: Mutex<Option<E>>,
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for PeerInner<E, C, T, B> {
type Target = C::Opaque;
fn deref(&self) -> &Self::Target {
&self.opaque
}
}
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
inner: Arc<PeerInner<E, C, T, B>>, inner: Arc<PeerInner<E, C, T, B>>,
} }
@@ -87,6 +96,16 @@ pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<
peer: Peer<E, C, T, B>, peer: Peer<E, C, T, B>,
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone
for PeerHandle<E, C, T, B>
{
fn clone(&self) -> Self {
PeerHandle {
peer: self.peer.clone(),
}
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
for PeerHandle<E, C, T, B> for PeerHandle<E, C, T, B>
{ {
@@ -96,6 +115,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> fmt::Display
for PeerHandle<E, C, T, B>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PeerHandle(format: TODO)")
}
}
impl EncryptionState { impl EncryptionState {
fn new(keypair: &Arc<KeyPair>) -> EncryptionState { fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
EncryptionState { EncryptionState {
@@ -338,6 +365,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
*self.peer.endpoint.lock() = Some(endpoint); *self.peer.endpoint.lock() = Some(endpoint);
} }
pub fn opaque(&self) -> &C::Opaque {
&self.opaque
}
/// Returns the current endpoint of the peer (for configuration) /// Returns the current endpoint of the peer (for configuration)
/// ///
/// # Note /// # Note

View File

@@ -123,17 +123,13 @@ fn test_pure_wireguard() {
let peer2 = wg1.lookup_peer(&pk2).unwrap(); let peer2 = wg1.lookup_peer(&pk2).unwrap();
let peer1 = wg2.lookup_peer(&pk1).unwrap(); let peer1 = wg2.lookup_peer(&pk1).unwrap();
peer1 peer1.add_allowed_ip("192.168.1.0".parse().unwrap(), 24);
.router
.add_allowed_ip("192.168.1.0".parse().unwrap(), 24);
peer2 peer2.add_allowed_ip("192.168.2.0".parse().unwrap(), 24);
.router
.add_allowed_ip("192.168.2.0".parse().unwrap(), 24);
// set endpoint (the other should be learned dynamically) // set endpoint (the other should be learned dynamically)
peer2.router.set_endpoint(dummy::UnitEndpoint::new()); peer2.set_endpoint(dummy::UnitEndpoint::new());
let num_packets = 20; let num_packets = 20;

View File

@@ -1,17 +1,19 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use hjul::{Runner, Timer};
use log::debug; use log::debug;
use hjul::Timer;
use x25519_dalek::PublicKey;
use super::constants::*; use super::constants::*;
use super::peer::{Peer, PeerInner}; use super::peer::PeerInner;
use super::router::{message_data_len, Callbacks}; use super::router::{message_data_len, Callbacks};
use super::tun::Tun; use super::tun::Tun;
use super::types::KeyPair; use super::types::KeyPair;
use super::udp::UDP; use super::udp::UDP;
use super::WireGuard;
pub struct Timers { pub struct Timers {
// only updated during configuration // only updated during configuration
@@ -229,7 +231,35 @@ impl<T: Tun, B: UDP> PeerInner<T, B> {
} }
impl Timers { impl Timers {
pub fn new<T: Tun, B: UDP>(runner: &Runner, running: bool, peer: Peer<T, B>) -> Timers { pub fn new<T: Tun, B: UDP>(
wg: WireGuard<T, B>, // WireGuard device
pk: PublicKey, // public key of peer
running: bool, // timers started
) -> Timers {
macro_rules! fetch_peer {
( $wg:expr, $pk:expr ) => {
match $wg.lookup_peer(&$pk) {
None => {
return;
}
Some(peer) => peer,
}
};
}
macro_rules! fetch_timer {
( $peer:expr ) => {{
let timers = $peer.timers();
if timers.enabled {
timers
} else {
return;
}
}};
}
let runner = wg.runner.lock();
// create a timer instance for the provided peer // create a timer instance for the provided peer
Timers { Timers {
enabled: running, enabled: running,
@@ -238,21 +268,16 @@ impl Timers {
sent_lastminute_handshake: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0), handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: { retransmit_handshake: {
let peer = peer.clone(); let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || { runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (retransmit_handshake)", peer); log::trace!("{} : timer fired (retransmit_handshake)", peer);
// ignore if timers are disabled
let timers = peer.timers();
if !timers.enabled {
return;
}
// check if handshake attempts remaining // check if handshake attempts remaining
let attempts = peer let attempts = timers.handshake_attempts.fetch_add(1, Ordering::SeqCst);
.timers()
.handshake_attempts
.fetch_add(1, Ordering::SeqCst);
if attempts > MAX_TIMER_HANDSHAKES { if attempts > MAX_TIMER_HANDSHAKES {
debug!( debug!(
"Handshake for peer {} did not complete after {} attempts, giving up", "Handshake for peer {} did not complete after {} attempts, giving up",
@@ -261,7 +286,7 @@ impl Timers {
); );
timers.send_keepalive.stop(); timers.send_keepalive.stop();
timers.zero_key_material.start(REJECT_AFTER_TIME * 3); timers.zero_key_material.start(REJECT_AFTER_TIME * 3);
peer.router.purge_staged_packets(); peer.purge_staged_packets();
} else { } else {
debug!( debug!(
"Handshake for {} did not complete after {} seconds, retrying (try {})", "Handshake for {} did not complete after {} seconds, retrying (try {})",
@@ -270,56 +295,72 @@ impl Timers {
attempts attempts
); );
timers.retransmit_handshake.reset(REKEY_TIMEOUT); timers.retransmit_handshake.reset(REKEY_TIMEOUT);
peer.router.clear_src(); peer.clear_src();
peer.packet_send_queued_handshake_initiation(true); peer.packet_send_queued_handshake_initiation(true);
} }
}) })
}, },
send_keepalive: { send_keepalive: {
let peer = peer.clone(); let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || { runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (send_keepalive)", peer); log::trace!("{} : timer fired (send_keepalive)", peer);
// ignore if timers are disabled // send keepalive and schedule next keepalive
let timers = peer.timers(); peer.send_keepalive();
if !timers.enabled {
return;
}
peer.router.send_keepalive();
if timers.need_another_keepalive() { if timers.need_another_keepalive() {
timers.send_keepalive.start(KEEPALIVE_TIMEOUT); timers.send_keepalive.start(KEEPALIVE_TIMEOUT);
} }
}) })
}, },
new_handshake: { new_handshake: {
let peer = peer.clone(); let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || { runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let _timers = fetch_timer!(peer);
log::trace!("{} : timer fired (new_handshake)", peer); log::trace!("{} : timer fired (new_handshake)", peer);
// clear source and retry
log::debug!( log::debug!(
"Retrying handshake with {} because we stopped hearing back after {} seconds", "Retrying handshake with {} because we stopped hearing back after {} seconds",
peer, peer,
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
); );
peer.router.clear_src(); peer.clear_src();
peer.packet_send_queued_handshake_initiation(false); peer.packet_send_queued_handshake_initiation(false);
}) })
}, },
zero_key_material: { zero_key_material: {
let peer = peer.clone(); let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || { runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let _timers = fetch_timer!(peer);
log::trace!("{} : timer fired (zero_key_material)", peer); log::trace!("{} : timer fired (zero_key_material)", peer);
peer.router.zero_keys();
// null all key-material
peer.zero_keys();
}) })
}, },
send_persistent_keepalive: { send_persistent_keepalive: {
let peer = peer.clone(); let wg = wg.clone();
let pk = pk.clone();
runner.timer(move || { runner.timer(move || {
// fetch peer by public key
let peer = fetch_peer!(wg, pk);
let timers = fetch_timer!(peer);
log::trace!("{} : timer fired (send_persistent_keepalive)", peer); log::trace!("{} : timer fired (send_persistent_keepalive)", peer);
let timers = peer.timers();
if timers.enabled && timers.keepalive_interval > 0 { // send and schedule persistent keepalive
if timers.keepalive_interval > 0 {
timers.send_keepalive.stop(); timers.send_keepalive.stop();
peer.router.send_keepalive(); peer.send_keepalive();
log::trace!("{} : keepalive queued", peer); log::trace!("{} : keepalive queued", peer);
timers timers
.send_persistent_keepalive .send_persistent_keepalive
@@ -329,28 +370,10 @@ impl Timers {
}, },
} }
} }
pub fn dummy(runner: &Runner) -> Timers {
Timers {
enabled: false,
keepalive_interval: 0,
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: runner.timer(|| {}),
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
send_persistent_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {}),
}
}
} }
/* instance of the router callbacks */ impl<T: Tun, B: UDP> Callbacks for PeerInner<T, B> {
pub struct Events<T, B>(PhantomData<(T, B)>); type Opaque = Self;
impl<T: Tun, B: UDP> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<T, B>>;
/* Called after the router encrypts a transport message destined for the peer. /* Called after the router encrypts a transport message destined for the peer.
* This method is called, even if the encrypted payload is empty (keepalive) * This method is called, even if the encrypted payload is empty (keepalive)

View File

@@ -1,8 +1,8 @@
use super::constants::*; use super::constants::*;
use super::handshake; use super::handshake;
use super::peer::{Peer, PeerInner}; use super::peer::PeerInner;
use super::router; use super::router;
use super::timers::{Events, Timers}; use super::timers::Timers;
use super::queue::ParallelQueue; use super::queue::ParallelQueue;
use super::workers::HandshakeJob; use super::workers::HandshakeJob;
@@ -45,10 +45,12 @@ pub struct WireguardInner<T: Tun, B: UDP> {
pub mtu: AtomicUsize, pub mtu: AtomicUsize,
// peer map // peer map
pub peers: RwLock<handshake::Device<Peer<T, B>>>, pub peers: RwLock<
handshake::Device<router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>>,
>,
// cryptokey router // cryptokey router
pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, pub router: router::Device<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>,
// handshake related state // handshake related state
pub last_under_load: Mutex<Instant>, pub last_under_load: Mutex<Instant>,
@@ -136,6 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// set all peers down (stops timers) // set all peers down (stops timers)
for (_, peer) in self.peers.write().iter() { for (_, peer) in self.peers.write().iter() {
peer.stop_timers();
peer.down(); peer.down();
} }
@@ -162,6 +165,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// set all peers up (restarts timers) // set all peers up (restarts timers)
for (_, peer) in self.peers.write().iter() { for (_, peer) in self.peers.write().iter() {
peer.up(); peer.up();
peer.start_timers();
} }
*enabled = true; *enabled = true;
@@ -175,16 +179,24 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
let _ = self.peers.write().remove(pk); let _ = self.peers.write().remove(pk);
} }
pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> { pub fn lookup_peer(
self.peers.read().get(pk).map(|p| p.clone()) &self,
pk: &PublicKey,
) -> Option<router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>> {
self.peers.read().get(pk).map(|handle| handle.clone())
} }
pub fn list_peers(&self) -> Vec<Peer<T, B>> { pub fn list_peers(
&self,
) -> Vec<(
PublicKey,
router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>,
)> {
let peers = self.peers.read(); let peers = self.peers.read();
let mut list = Vec::with_capacity(peers.len()); let mut list = Vec::with_capacity(peers.len());
for (k, v) in peers.iter() { for (k, v) in peers.iter() {
debug_assert!(k.as_bytes() == v.pk.as_bytes()); debug_assert!(k.as_bytes() == v.opaque().pk.as_bytes());
list.push(v.clone()); list.push((k.clone(), v.clone()));
} }
list list
} }
@@ -215,36 +227,27 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
return false; return false;
} }
let state = Arc::new(PeerInner {
id: OsRng.gen(),
pk,
wg: self.clone(),
walltime_last_handshake: Mutex::new(None),
last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON),
handshake_queued: AtomicBool::new(false),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&*self.runner.lock())),
});
// create a router peer
let router = Arc::new(self.router.new_peer(state.clone()));
// form WireGuard peer
let peer = Peer { router, state };
// prevent up/down while inserting // prevent up/down while inserting
let enabled = self.enabled.read(); let enabled = *self.enabled.read();
/* The need for dummy timers arises from the chicken-egg // create timers (lookup by public key)
* problem of the timer callbacks being able to set timers themselves. let timers = Timers::new::<T, B>(self.clone(), pk.clone(), enabled);
*
* This is in fact the only place where the write lock is ever taken.
* TODO: Consider the ease of using atomic pointers instead.
*/
*peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone());
// finally, add the peer to the wireguard device // create new router peer
let peer: router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer> =
self.router.new_peer(PeerInner {
id: OsRng.gen(),
pk,
wg: self.clone(),
walltime_last_handshake: Mutex::new(None),
last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON),
handshake_queued: AtomicBool::new(false),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(timers),
});
// finally, add the peer to the handshake device
peers.add(pk, peer).is_ok() peers.add(pk, peer).is_ok()
} }
@@ -288,6 +291,10 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// create handshake queue // create handshake queue
let (tx, mut rxs) = ParallelQueue::new(cpus, 128); let (tx, mut rxs) = ParallelQueue::new(cpus, 128);
// create router
let router: router::Device<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer> =
router::Device::new(num_cpus::get(), writer);
// create arc to state // create arc to state
let wg = WireGuard { let wg = WireGuard {
inner: Arc::new(WireguardInner { inner: Arc::new(WireguardInner {
@@ -296,7 +303,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
id: OsRng.gen(), id: OsRng.gen(),
mtu: AtomicUsize::new(0), mtu: AtomicUsize::new(0),
last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
router: router::Device::new(num_cpus::get(), writer), router,
pending: AtomicUsize::new(0), pending: AtomicUsize::new(0),
peers: RwLock::new(handshake::Device::new()), peers: RwLock::new(handshake::Device::new()),
runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)),

View File

@@ -209,23 +209,25 @@ pub fn handshake_worker<T: Tun, B: UDP>(
// add to rx_bytes and tx_bytes // add to rx_bytes and tx_bytes
let req_len = msg.len() as u64; let req_len = msg.len() as u64;
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); peer.opaque().rx_bytes.fetch_add(req_len, Ordering::Relaxed);
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); peer.opaque()
.tx_bytes
.fetch_add(resp_len, Ordering::Relaxed);
// update endpoint // update endpoint
peer.router.set_endpoint(src); peer.set_endpoint(src);
if resp_len > 0 { if resp_len > 0 {
// update timers after sending handshake response // update timers after sending handshake response
debug!("{} : handshake worker, handshake response sent", wg); debug!("{} : handshake worker, handshake response sent", wg);
peer.state.sent_handshake_response(); peer.opaque().sent_handshake_response();
} else { } else {
// update timers after receiving handshake response // update timers after receiving handshake response
debug!( debug!(
"{} : handshake worker, handshake response was received", "{} : handshake worker, handshake response was received",
wg wg
); );
peer.state.timers_handshake_complete(); peer.opaque().timers_handshake_complete();
} }
// add any new keypair to peer // add any new keypair to peer
@@ -233,10 +235,10 @@ pub fn handshake_worker<T: Tun, B: UDP>(
debug!("{} : handshake worker, new keypair for {}", wg, peer); debug!("{} : handshake worker, new keypair for {}", wg, peer);
// this means that a handshake response was processed or sent // this means that a handshake response was processed or sent
peer.timers_session_derived(); peer.opaque().timers_session_derived();
// free any unused ids // free any unused ids
for id in peer.router.add_keypair(kp) { for id in peer.add_keypair(kp) {
device.release(id); device.release(id);
} }
}); });
@@ -252,13 +254,15 @@ pub fn handshake_worker<T: Tun, B: UDP>(
wg, peer wg, peer
); );
let device = wg.peers.read(); let device = wg.peers.read();
let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { let _ = device.begin(&mut OsRng, &pk).map(|msg| {
let _ = peer.router.send_raw(&msg[..]).map_err(|e| { let _ = peer.send_raw(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
}); });
peer.state.sent_handshake_initiation(); peer.opaque().sent_handshake_initiation();
}); });
peer.handshake_queued.store(false, Ordering::SeqCst); peer.opaque()
.handshake_queued
.store(false, Ordering::SeqCst);
} }
} }
} }