Make under_load global for WireGuard device
This commit is contained in:
@@ -27,24 +27,24 @@ pub struct PeerState {
|
||||
pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk)
|
||||
}
|
||||
|
||||
pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>);
|
||||
pub struct WireGuardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>);
|
||||
|
||||
struct Inner<T: tun::Tun, B: udp::PlatformUDP> {
|
||||
wireguard: Wireguard<T, B>,
|
||||
wireguard: WireGuard<T, B>,
|
||||
port: u16,
|
||||
bind: Option<B::Owner>,
|
||||
fwmark: Option<u32>,
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> WireGuardConfig<T, B> {
|
||||
fn lock(&self) -> MutexGuard<Inner<T, B>> {
|
||||
self.0.lock().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
|
||||
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
|
||||
WireguardConfig(Arc::new(Mutex::new(Inner {
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> WireGuardConfig<T, B> {
|
||||
pub fn new(wg: WireGuard<T, B>) -> WireGuardConfig<T, B> {
|
||||
WireGuardConfig(Arc::new(Mutex::new(Inner {
|
||||
wireguard: wg,
|
||||
port: 0,
|
||||
bind: None,
|
||||
@@ -53,9 +53,9 @@ impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireguardConfig<T, B> {
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireGuardConfig<T, B> {
|
||||
fn clone(&self) -> Self {
|
||||
WireguardConfig(self.0.clone())
|
||||
WireGuardConfig(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ pub trait Configuration {
|
||||
fn get_fwmark(&self) -> Option<u32>;
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
|
||||
impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
|
||||
fn up(&self, mtu: usize) {
|
||||
self.lock().wireguard.up(mtu);
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ pub mod uapi;
|
||||
|
||||
use super::platform::Endpoint;
|
||||
use super::platform::{tun, udp};
|
||||
use super::wireguard::Wireguard;
|
||||
use super::wireguard::WireGuard;
|
||||
|
||||
pub use error::ConfigError;
|
||||
|
||||
pub use config::Configuration;
|
||||
pub use config::WireguardConfig;
|
||||
pub use config::WireGuardConfig;
|
||||
|
||||
@@ -25,6 +25,8 @@ use platform::tun::{PlatformTun, Status};
|
||||
use platform::uapi::{BindUAPI, PlatformUAPI};
|
||||
use platform::*;
|
||||
|
||||
use wireguard::WireGuard;
|
||||
|
||||
#[cfg(feature = "profiler")]
|
||||
fn profiler_stop() {
|
||||
println!("Stopping profiler");
|
||||
@@ -118,7 +120,7 @@ fn main() {
|
||||
profiler_start(name.as_str());
|
||||
|
||||
// create WireGuard device
|
||||
let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(writer);
|
||||
let wg: WireGuard<plt::Tun, plt::UDP> = WireGuard::new(writer);
|
||||
|
||||
// add all Tun readers
|
||||
while let Some(reader) = readers.pop() {
|
||||
@@ -126,7 +128,7 @@ fn main() {
|
||||
}
|
||||
|
||||
// wrap in configuration interface
|
||||
let cfg = configuration::WireguardConfig::new(wg.clone());
|
||||
let cfg = configuration::WireGuardConfig::new(wg.clone());
|
||||
|
||||
// start Tun event thread
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod udp;
|
||||
mod endpoint;
|
||||
mod tun;
|
||||
mod udp;
|
||||
|
||||
/* A pure dummy platform available during "test-time"
|
||||
*
|
||||
|
||||
@@ -252,15 +252,12 @@ impl Device {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `msg` - Byte slice containing the message (untrusted input)
|
||||
pub fn process<'a, R: RngCore + CryptoRng, S>(
|
||||
pub fn process<'a, R: RngCore + CryptoRng>(
|
||||
&self,
|
||||
rng: &mut R, // rng instance to sample randomness from
|
||||
msg: &[u8], // message buffer
|
||||
src: Option<&'a S>, // optional source endpoint, set when "under load"
|
||||
) -> Result<Output, HandshakeError>
|
||||
where
|
||||
&'a S: Into<&'a SocketAddr>,
|
||||
{
|
||||
src: Option<SocketAddr>, // optional source endpoint, set when "under load"
|
||||
) -> Result<Output, HandshakeError> {
|
||||
// ensure type read in-range
|
||||
if msg.len() < 4 {
|
||||
return Err(HandshakeError::InvalidMessageFormat);
|
||||
@@ -286,16 +283,13 @@ impl Device {
|
||||
|
||||
// address validation & DoS mitigation
|
||||
if let Some(src) = src {
|
||||
// obtain ref to socket addr
|
||||
let src = src.into();
|
||||
|
||||
// check mac2 field
|
||||
if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||
if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) {
|
||||
let mut reply = Default::default();
|
||||
keyst.macs.create_cookie_reply(
|
||||
rng,
|
||||
msg.noise.f_sender.get(),
|
||||
src,
|
||||
&src,
|
||||
&msg.macs,
|
||||
&mut reply,
|
||||
);
|
||||
@@ -344,12 +338,12 @@ impl Device {
|
||||
let src = src.into();
|
||||
|
||||
// check mac2 field
|
||||
if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||
if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) {
|
||||
let mut reply = Default::default();
|
||||
keyst.macs.create_cookie_reply(
|
||||
rng,
|
||||
msg.noise.f_sender.get(),
|
||||
src,
|
||||
&src,
|
||||
&msg.macs,
|
||||
&mut reply,
|
||||
);
|
||||
|
||||
@@ -69,13 +69,13 @@ fn handshake_under_load() {
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 2. device-2 : responds with CookieReply
|
||||
let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
|
||||
(None, Some(msg), None) => msg,
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// device-1 : processes CookieReply (no response)
|
||||
match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
|
||||
match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() {
|
||||
(None, None, None) => (),
|
||||
_ => panic!("unexpected response"),
|
||||
}
|
||||
@@ -87,7 +87,7 @@ fn handshake_under_load() {
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 4. device-2 : responds with noise response
|
||||
let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
|
||||
(Some(_), Some(msg), Some(kp)) => {
|
||||
assert_eq!(kp.initiator, false);
|
||||
msg
|
||||
@@ -96,13 +96,13 @@ fn handshake_under_load() {
|
||||
};
|
||||
|
||||
// 5. device-1 : responds with CookieReply
|
||||
let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||
let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
|
||||
(None, Some(msg), None) => msg,
|
||||
_ => panic!("unexpected response"),
|
||||
};
|
||||
|
||||
// device-2 : processes CookieReply (no response)
|
||||
match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
|
||||
match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() {
|
||||
(None, None, None) => (),
|
||||
_ => panic!("unexpected response"),
|
||||
}
|
||||
@@ -114,7 +114,7 @@ fn handshake_under_load() {
|
||||
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||
|
||||
// 7. device-2 : responds with noise response
|
||||
let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||
let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
|
||||
(Some(_), Some(msg), Some(kp)) => {
|
||||
assert_eq!(kp.initiator, false);
|
||||
(msg, kp)
|
||||
@@ -123,7 +123,7 @@ fn handshake_under_load() {
|
||||
};
|
||||
|
||||
// device-1 : process noise response
|
||||
let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||
let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
|
||||
(Some(_), None, Some(kp)) => {
|
||||
assert_eq!(kp.initiator, true);
|
||||
kp
|
||||
|
||||
@@ -24,7 +24,7 @@ mod tests;
|
||||
pub use peer::Peer;
|
||||
|
||||
// represents a WireGuard interface
|
||||
pub use wireguard::Wireguard;
|
||||
pub use wireguard::WireGuard;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use types::dummy_keypair;
|
||||
|
||||
@@ -3,8 +3,8 @@ use super::timers::{Events, Timers};
|
||||
|
||||
use super::tun::Tun;
|
||||
use super::udp::UDP;
|
||||
use super::Wireguard;
|
||||
|
||||
use super::wireguard::WireGuard;
|
||||
use super::constants::REKEY_TIMEOUT;
|
||||
use super::workers::HandshakeJob;
|
||||
|
||||
@@ -23,7 +23,7 @@ pub struct PeerInner<T: Tun, B: UDP> {
|
||||
pub id: u64,
|
||||
|
||||
// wireguard device state
|
||||
pub wg: Wireguard<T, B>,
|
||||
pub wg: WireGuard<T, B>,
|
||||
|
||||
// handshake state
|
||||
pub walltime_last_handshake: Mutex<Option<SystemTime>>, // walltime for last handshake (for UAPI status)
|
||||
|
||||
@@ -50,6 +50,7 @@ mod tests {
|
||||
}))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn reset(&self) {
|
||||
self.0.send.lock().unwrap().clear();
|
||||
self.0.recv.lock().unwrap().clear();
|
||||
@@ -103,7 +104,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// wait for scheduling
|
||||
// wait for scheduling (VERY conservative)
|
||||
fn wait() {
|
||||
thread::sleep(Duration::from_millis(30));
|
||||
}
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::{debug, trace};
|
||||
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
|
||||
use crossbeam_channel::Receiver;
|
||||
use std::sync::atomic::Ordering;
|
||||
use zerocopy::{AsBytes, LayoutVerified};
|
||||
|
||||
use super::device::{DecryptionState, DeviceInner};
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::PeerInner;
|
||||
use super::types::Callbacks;
|
||||
|
||||
use super::REJECT_AFTER_MESSAGES;
|
||||
|
||||
use super::super::types::KeyPair;
|
||||
use super::super::{tun, udp, Endpoint};
|
||||
|
||||
pub const SIZE_TAG: usize = 16;
|
||||
|
||||
pub struct JobEncryption {
|
||||
pub msg: Vec<u8>,
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub counter: u64,
|
||||
}
|
||||
|
||||
pub struct JobDecryption {
|
||||
pub msg: Vec<u8>,
|
||||
pub keypair: Arc<KeyPair>,
|
||||
}
|
||||
|
||||
pub enum JobParallel {
|
||||
Encryption(oneshot::Sender<JobEncryption>, JobEncryption),
|
||||
Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption),
|
||||
}
|
||||
|
||||
#[allow(type_alias_bounds)]
|
||||
pub type JobInbound<E, C, T, B: udp::Writer<E>> = (
|
||||
Arc<DecryptionState<E, C, T, B>>,
|
||||
E,
|
||||
oneshot::Receiver<Option<JobDecryption>>,
|
||||
);
|
||||
|
||||
pub type JobOutbound = oneshot::Receiver<JobEncryption>;
|
||||
|
||||
/* TODO: Replace with run-queue
|
||||
*/
|
||||
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<E, C, T, B>>, // related peer
|
||||
receiver: Receiver<JobInbound<E, C, T, B>>,
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
let (state, endpoint, rx) = match receiver.recv() {
|
||||
Ok(v) => v,
|
||||
_ => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("inbound worker: obtained job");
|
||||
|
||||
// wait for job to complete
|
||||
let _ = rx
|
||||
.map(|buf| {
|
||||
debug!("inbound worker: job complete");
|
||||
if let Some(buf) = buf {
|
||||
// cast transport header
|
||||
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
|
||||
match LayoutVerified::new_from_prefix(&buf.msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug!("inbound worker: failed to parse message");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert!(
|
||||
packet.len() >= CHACHA20_POLY1305.tag_len(),
|
||||
"this should be checked earlier in the pipeline (decryption should fail)"
|
||||
);
|
||||
|
||||
// check for replay
|
||||
if !state.protector.lock().update(header.f_counter.get()) {
|
||||
debug!("inbound worker: replay detected");
|
||||
return;
|
||||
}
|
||||
|
||||
// check for confirms key
|
||||
if !state.confirmed.swap(true, Ordering::SeqCst) {
|
||||
debug!("inbound worker: message confirms key");
|
||||
peer.confirm_key(&state.keypair);
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
*peer.endpoint.lock() = Some(endpoint);
|
||||
|
||||
// calculate length of IP packet + padding
|
||||
let length = packet.len() - SIZE_TAG;
|
||||
debug!("inbound worker: plaintext length = {}", length);
|
||||
|
||||
// check if should be written to TUN
|
||||
let mut sent = false;
|
||||
if length > 0 {
|
||||
if let Some(inner_len) = device.table.check_route(&peer, &packet[..length])
|
||||
{
|
||||
// TODO: Consider moving the cryptkey route check to parallel decryption worker
|
||||
debug_assert!(inner_len <= length, "should be validated earlier");
|
||||
if inner_len <= length {
|
||||
sent = match device.inbound.write(&packet[..inner_len]) {
|
||||
Err(e) => {
|
||||
debug!("failed to write inbound packet to TUN: {:?}", e);
|
||||
false
|
||||
}
|
||||
Ok(_) => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("inbound worker: received keepalive")
|
||||
}
|
||||
|
||||
// trigger callback
|
||||
C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair);
|
||||
} else {
|
||||
debug!("inbound worker: authentication failure")
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO: Replace with run-queue
|
||||
*/
|
||||
pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: Arc<PeerInner<E, C, T, B>>,
|
||||
receiver: Receiver<JobOutbound>,
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
let rx = match receiver.recv() {
|
||||
Ok(v) => v,
|
||||
_ => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("outbound worker: obtained job");
|
||||
|
||||
// wait for job to complete
|
||||
let _ = rx
|
||||
.map(|buf| {
|
||||
debug!("outbound worker: job complete");
|
||||
|
||||
// send to peer
|
||||
let xmit = peer.send(&buf.msg[..]).is_ok();
|
||||
|
||||
// trigger callback
|
||||
C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter);
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_parallel(receiver: Receiver<JobParallel>) {
|
||||
loop {
|
||||
// fetch next job
|
||||
let job = match receiver.recv() {
|
||||
Err(_) => {
|
||||
return;
|
||||
}
|
||||
Ok(val) => val,
|
||||
};
|
||||
trace!("parallel worker: obtained job");
|
||||
|
||||
// handle job
|
||||
match job {
|
||||
JobParallel::Encryption(tx, mut job) => {
|
||||
job.msg.extend([0u8; SIZE_TAG].iter());
|
||||
|
||||
// cast to header (should never fail)
|
||||
let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
|
||||
LayoutVerified::new_from_prefix(&mut job.msg[..])
|
||||
.expect("earlier code should ensure that there is ample space");
|
||||
|
||||
// set header fields
|
||||
debug_assert!(
|
||||
job.counter < REJECT_AFTER_MESSAGES,
|
||||
"should be checked when assigning counters"
|
||||
);
|
||||
header.f_type.set(TYPE_TRANSPORT);
|
||||
header.f_receiver.set(job.keypair.send.id);
|
||||
header.f_counter.set(job.counter);
|
||||
|
||||
// create a nonce object
|
||||
let mut nonce = [0u8; 12];
|
||||
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
|
||||
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
|
||||
let nonce = Nonce::assume_unique_for_key(nonce);
|
||||
|
||||
// do the weird ring AEAD dance
|
||||
let key = LessSafeKey::new(
|
||||
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
|
||||
);
|
||||
|
||||
// encrypt content of transport message in-place
|
||||
let end = body.len() - SIZE_TAG;
|
||||
let tag = key
|
||||
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end])
|
||||
.unwrap();
|
||||
|
||||
// append tag
|
||||
body[end..].copy_from_slice(tag.as_ref());
|
||||
|
||||
// pass ownership
|
||||
let _ = tx.send(job);
|
||||
}
|
||||
JobParallel::Decryption(tx, mut job) => {
|
||||
// cast to header (could fail)
|
||||
let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> =
|
||||
LayoutVerified::new_from_prefix(&mut job.msg[..]);
|
||||
|
||||
let _ = tx.send(match layout {
|
||||
Some((header, body)) => {
|
||||
debug_assert_eq!(
|
||||
header.f_type.get(),
|
||||
TYPE_TRANSPORT,
|
||||
"type and reserved bits should be checked by message de-multiplexer"
|
||||
);
|
||||
if header.f_counter.get() < REJECT_AFTER_MESSAGES {
|
||||
// create a nonce object
|
||||
let mut nonce = [0u8; 12];
|
||||
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
|
||||
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
|
||||
let nonce = Nonce::assume_unique_for_key(nonce);
|
||||
|
||||
// do the weird ring AEAD dance
|
||||
let key = LessSafeKey::new(
|
||||
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..])
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// attempt to open (and authenticate) the body
|
||||
match key.open_in_place(nonce, Aad::empty(), body) {
|
||||
Ok(_) => Some(job),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,6 @@
|
||||
use super::dummy;
|
||||
use super::wireguard::Wireguard;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use hex;
|
||||
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use rand_core::{RngCore, SeedableRng};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
@@ -14,6 +8,9 @@ use x25519_dalek::{PublicKey, StaticSecret};
|
||||
use pnet::packet::ipv4::MutableIpv4Packet;
|
||||
use pnet::packet::ipv6::MutableIpv6Packet;
|
||||
|
||||
use super::dummy;
|
||||
use super::wireguard::WireGuard;
|
||||
|
||||
pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
|
||||
// expand pseudo random payload
|
||||
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
|
||||
@@ -58,10 +55,6 @@ fn init() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
}
|
||||
|
||||
fn wait() {
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
|
||||
/* Create and configure two matching pure instances of WireGuard
|
||||
*/
|
||||
#[test]
|
||||
@@ -71,12 +64,12 @@ fn test_pure_wireguard() {
|
||||
// create WG instances for dummy TUN devices
|
||||
|
||||
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
|
||||
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer1);
|
||||
let wg1: WireGuard<dummy::TunTest, dummy::PairBind> = WireGuard::new(tun_writer1);
|
||||
wg1.add_tun_reader(tun_reader1);
|
||||
wg1.up(1500);
|
||||
|
||||
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
|
||||
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer2);
|
||||
let wg2: WireGuard<dummy::TunTest, dummy::PairBind> = WireGuard::new(tun_writer2);
|
||||
wg2.add_tun_reader(tun_reader2);
|
||||
wg2.up(1500);
|
||||
|
||||
|
||||
@@ -58,33 +58,33 @@ pub struct WireguardInner<T: Tun, B: UDP> {
|
||||
|
||||
// handshake related state
|
||||
pub handshake: RwLock<handshake::Device>,
|
||||
pub last_under_load: AtomicUsize,
|
||||
pub pending: AtomicUsize, // num of pending handshake packets in queue
|
||||
pub last_under_load: Mutex<Instant>,
|
||||
pub pending: AtomicUsize, // number of pending handshake packets in queue
|
||||
pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>,
|
||||
}
|
||||
|
||||
pub struct Wireguard<T: Tun, B: UDP> {
|
||||
pub struct WireGuard<T: Tun, B: UDP> {
|
||||
inner: Arc<WireguardInner<T, B>>,
|
||||
}
|
||||
|
||||
pub struct WaitCounter(StdMutex<usize>, Condvar);
|
||||
|
||||
impl<T: Tun, B: UDP> fmt::Display for Wireguard<T, B> {
|
||||
impl<T: Tun, B: UDP> fmt::Display for WireGuard<T, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "wireguard({:x})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: UDP> Deref for Wireguard<T, B> {
|
||||
impl<T: Tun, B: UDP> Deref for WireGuard<T, B> {
|
||||
type Target = WireguardInner<T, B>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: UDP> Clone for Wireguard<T, B> {
|
||||
impl<T: Tun, B: UDP> Clone for WireGuard<T, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Wireguard {
|
||||
WireGuard {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ impl WaitCounter {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tun, B: UDP> Wireguard<T, B> {
|
||||
impl<T: Tun, B: UDP> WireGuard<T, B> {
|
||||
/// Brings the WireGuard device down.
|
||||
/// Usually called when the associated interface is brought down.
|
||||
///
|
||||
@@ -307,7 +307,7 @@ impl<T: Tun, B: UDP> Wireguard<T, B> {
|
||||
self.tun_readers.wait();
|
||||
}
|
||||
|
||||
pub fn new(writer: T::Writer) -> Wireguard<T, B> {
|
||||
pub fn new(writer: T::Writer) -> WireGuard<T, B> {
|
||||
// workers equal to number of physical cores
|
||||
let cpus = num_cpus::get();
|
||||
|
||||
@@ -318,14 +318,14 @@ impl<T: Tun, B: UDP> Wireguard<T, B> {
|
||||
let (tx, mut rxs) = ParallelQueue::new(cpus, 128);
|
||||
|
||||
// create arc to state
|
||||
let wg = Wireguard {
|
||||
let wg = WireGuard {
|
||||
inner: Arc::new(WireguardInner {
|
||||
enabled: RwLock::new(false),
|
||||
tun_readers: WaitCounter::new(),
|
||||
id: rng.gen(),
|
||||
mtu: AtomicUsize::new(0),
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
last_under_load: AtomicUsize::new(0), // TODO
|
||||
last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
|
||||
send: RwLock::new(None),
|
||||
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
||||
pending: AtomicUsize::new(0),
|
||||
|
||||
@@ -25,7 +25,7 @@ use super::handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||
use super::handshake::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
||||
use super::router::{CAPACITY_MESSAGE_POSTFIX, SIZE_MESSAGE_PREFIX, TYPE_TRANSPORT};
|
||||
|
||||
use super::Wireguard;
|
||||
use super::wireguard::WireGuard;
|
||||
|
||||
pub enum HandshakeJob<E> {
|
||||
Message(Vec<u8>, E),
|
||||
@@ -54,7 +54,7 @@ const fn padding(size: usize, mtu: usize) -> usize {
|
||||
min(mtu, size + (pad - size % pad) % pad)
|
||||
}
|
||||
|
||||
pub fn tun_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: T::Reader) {
|
||||
pub fn tun_worker<T: Tun, B: UDP>(wg: &WireGuard<T, B>, reader: T::Reader) {
|
||||
loop {
|
||||
// create vector big enough for any transport message (based on MTU)
|
||||
let mtu = wg.mtu.load(Ordering::Relaxed);
|
||||
@@ -100,7 +100,7 @@ pub fn tun_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: T::Reader) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn udp_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: B::Reader) {
|
||||
pub fn udp_worker<T: Tun, B: UDP>(wg: &WireGuard<T, B>, reader: B::Reader) {
|
||||
let mut last_under_load = Instant::now() - TIME_HORIZON;
|
||||
|
||||
loop {
|
||||
@@ -160,7 +160,7 @@ pub fn udp_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: B::Reader) {
|
||||
}
|
||||
|
||||
pub fn handshake_worker<T: Tun, B: UDP>(
|
||||
wg: &Wireguard<T, B>,
|
||||
wg: &WireGuard<T, B>,
|
||||
rx: Receiver<HandshakeJob<B::Endpoint>>,
|
||||
) {
|
||||
debug!("{} : handshake worker, started", wg);
|
||||
@@ -170,30 +170,38 @@ pub fn handshake_worker<T: Tun, B: UDP>(
|
||||
|
||||
// process elements from the handshake queue
|
||||
for job in rx {
|
||||
// decrement pending pakcets (under_load)
|
||||
// check if under load
|
||||
let job: HandshakeJob<B::Endpoint> = job;
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let pending = wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let mut under_load = false;
|
||||
|
||||
// demultiplex staged handshake jobs and handshake messages
|
||||
// immediate go under load if too many handshakes pending
|
||||
if pending > THRESHOLD_UNDER_LOAD {
|
||||
*wg.last_under_load.lock() = Instant::now();
|
||||
under_load = true;
|
||||
}
|
||||
|
||||
// remain under load for a while
|
||||
if !under_load {
|
||||
let elapsed = wg.last_under_load.lock().elapsed();
|
||||
if elapsed > DURATION_UNDER_LOAD {
|
||||
under_load = true;
|
||||
}
|
||||
}
|
||||
|
||||
// de-multiplex staged handshake jobs and handshake messages
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
|
||||
// process message
|
||||
let device = wg.handshake.read();
|
||||
match device.process(
|
||||
&mut rng,
|
||||
&msg[..],
|
||||
None,
|
||||
/*
|
||||
if wg.under_load.load(Ordering::Relaxed) {
|
||||
debug!("{} : handshake worker, under load", wg);
|
||||
Some(&src_validate)
|
||||
if under_load {
|
||||
Some(src.into_address())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
*/
|
||||
) {
|
||||
Ok((pk, resp, keypair)) => {
|
||||
// send response (might be cookie reply or handshake response)
|
||||
|
||||
Reference in New Issue
Block a user