Moving away from peer threads
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#![feature(test)]
|
||||
#![feature(weak_into_raw)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
use log;
|
||||
|
||||
@@ -6,9 +6,7 @@ use rand::Rng;
|
||||
use std::cmp::min;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -359,31 +359,9 @@ impl PlatformTun for LinuxTun {
|
||||
|
||||
// create PlatformTunMTU instance
|
||||
Ok((
|
||||
vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
|
||||
vec![LinuxTunReader { fd }], // TODO: use multi-queue for Linux
|
||||
LinuxTunWriter { fd },
|
||||
LinuxTunStatus::new(req.name)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
|
||||
fn is_root() -> bool {
|
||||
match env::var("USER") {
|
||||
Ok(val) => val == "root",
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tun_create() {
|
||||
if !is_root() {
|
||||
return;
|
||||
}
|
||||
let (readers, writers, mtu) = LinuxTun::create("test").unwrap();
|
||||
// TODO: test (any good idea how?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use crossbeam_channel::Sender;
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
pub struct Peer<T: Tun, B: UDP> {
|
||||
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
||||
pub router: Arc<router::PeerHandle<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
||||
pub state: Arc<PeerInner<T, B>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::sync_channel;
|
||||
use std::sync::mpsc::SyncSender;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use log::debug;
|
||||
use spin::{Mutex, RwLock};
|
||||
use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::constants::*;
|
||||
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::{new_peer, Peer, PeerInner};
|
||||
use super::types::{Callbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::route::get_route;
|
||||
|
||||
use super::super::{bind, tun, Endpoint, KeyPair};
|
||||
|
||||
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
// inbound writer (TUN)
|
||||
pub inbound: T,
|
||||
|
||||
// outbound writer (Bind)
|
||||
pub outbound: RwLock<(bool, Option<B>)>,
|
||||
|
||||
// routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
|
||||
|
||||
// work queues
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
}
|
||||
|
||||
pub struct EncryptionState {
|
||||
pub keypair: Arc<KeyPair>, // keypair
|
||||
pub nonce: u64, // next available nonce
|
||||
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
|
||||
}
|
||||
|
||||
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<E, C, T, B>>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
debug!("router: dropping device");
|
||||
|
||||
// drop all queues
|
||||
{
|
||||
let mut queues = self.state.queues.lock();
|
||||
while queues.pop().is_some() {}
|
||||
}
|
||||
|
||||
// join all worker threads
|
||||
while match self.handles.pop() {
|
||||
Some(handle) => {
|
||||
handle.thread().unpark();
|
||||
handle.join().unwrap();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
} {}
|
||||
|
||||
debug!("router: device dropped");
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
|
||||
// allocate shared device state
|
||||
let inner = DeviceInner {
|
||||
inbound: tun,
|
||||
outbound: RwLock::new((true, None)),
|
||||
queues: Mutex::new(Vec::with_capacity(num_workers)),
|
||||
queue_next: AtomicUsize::new(0),
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
ipv4: RwLock::new(IpLookupTable::new()),
|
||||
ipv6: RwLock::new(IpLookupTable::new()),
|
||||
};
|
||||
|
||||
// start worker threads
|
||||
let mut threads = Vec::with_capacity(num_workers);
|
||||
for _ in 0..num_workers {
|
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||
inner.queues.lock().push(tx);
|
||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||
}
|
||||
|
||||
// return exported device handle
|
||||
Device {
|
||||
state: Arc::new(inner),
|
||||
handles: threads,
|
||||
}
|
||||
}
|
||||
|
||||
/// Brings the router down.
|
||||
/// When the router is brought down it:
|
||||
/// - Prevents transmission of outbound messages.
|
||||
pub fn down(&self) {
|
||||
self.state.outbound.write().0 = false;
|
||||
}
|
||||
|
||||
/// Brints the router up
|
||||
/// When the router is brought up it enables the transmission of outbound messages.
|
||||
pub fn up(&self) {
|
||||
self.state.outbound.write().0 = true;
|
||||
}
|
||||
|
||||
/// A new secret key has been set for the device.
|
||||
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
|
||||
pub fn new_sk(&self) {}
|
||||
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A atomic ref. counted peer (with liftime matching the device)
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
|
||||
new_peer(self.state.clone(), opaque)
|
||||
}
|
||||
|
||||
/// Cryptkey routes and sends a plaintext message (IP packet)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - msg: IP packet to crypt-key route
|
||||
///
|
||||
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX);
|
||||
log::trace!(
|
||||
"Router, outbound packet = {}",
|
||||
hex::encode(&msg[SIZE_MESSAGE_PREFIX..])
|
||||
);
|
||||
|
||||
// ignore header prefix (for in-place transport message construction)
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?;
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg, true) {
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive an encrypted transport message
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Source address of the packet
|
||||
/// - msg: Encrypted transport message
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err(RouterError::MalformedTransportMessage);
|
||||
}
|
||||
};
|
||||
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||
|
||||
debug_assert!(
|
||||
header.f_type.get() == TYPE_TRANSPORT as u32,
|
||||
"this should be checked by the message type multiplexer"
|
||||
);
|
||||
|
||||
log::trace!(
|
||||
"Router, handle transport message: (receiver = {}, counter = {})",
|
||||
header.f_receiver,
|
||||
header.f_counter
|
||||
);
|
||||
|
||||
// lookup peer based on receiver id
|
||||
let dec = self.state.recv.read();
|
||||
let dec = dec
|
||||
.get(&header.f_receiver.get())
|
||||
.ok_or(RouterError::UnknownReceiverId)?;
|
||||
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set outbound writer
|
||||
///
|
||||
///
|
||||
pub fn set_outbound_writer(&self, new: B) {
|
||||
self.state.outbound.write().1 = Some(new);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::sync_channel;
|
||||
use std::sync::mpsc::SyncSender;
|
||||
use std::sync::mpsc::{Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
@@ -11,18 +12,61 @@ use spin::{Mutex, RwLock};
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::constants::*;
|
||||
use super::pool::Job;
|
||||
|
||||
use super::inbound;
|
||||
use super::outbound;
|
||||
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::{new_peer, Peer, PeerInner};
|
||||
use super::peer::{new_peer, Peer, PeerHandle};
|
||||
use super::types::{Callbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::route::RoutingTable;
|
||||
|
||||
use super::super::{tun, udp, Endpoint, KeyPair};
|
||||
|
||||
pub struct ParallelQueue<T> {
|
||||
next: AtomicUsize, // next round-robin index
|
||||
queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread)
|
||||
}
|
||||
|
||||
impl<T> ParallelQueue<T> {
|
||||
fn new(queues: usize) -> (Vec<Receiver<T>>, Self) {
|
||||
let mut rxs = vec![];
|
||||
let mut txs = vec![];
|
||||
|
||||
for _ in 0..queues {
|
||||
let (tx, rx) = sync_channel(128);
|
||||
txs.push(Mutex::new(tx));
|
||||
rxs.push(rx);
|
||||
}
|
||||
|
||||
(
|
||||
rxs,
|
||||
ParallelQueue {
|
||||
next: AtomicUsize::new(0),
|
||||
queues: txs,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn send(&self, v: T) {
|
||||
let len = self.queues.len();
|
||||
let idx = self.next.fetch_add(1, Ordering::SeqCst);
|
||||
let que = self.queues[idx % len].lock();
|
||||
que.send(v).unwrap();
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
for i in 0..self.queues.len() {
|
||||
let (tx, _) = sync_channel(0);
|
||||
let queue = &self.queues[i];
|
||||
*queue.lock() = tx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
// inbound writer (TUN)
|
||||
pub inbound: T,
|
||||
@@ -32,11 +76,11 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
|
||||
|
||||
// routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
|
||||
pub table: RoutingTable<PeerInner<E, C, T, B>>,
|
||||
pub table: RoutingTable<Peer<E, C, T, B>>,
|
||||
|
||||
// work queues
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
|
||||
pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
|
||||
}
|
||||
|
||||
pub struct EncryptionState {
|
||||
@@ -49,24 +93,53 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Wr
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<E, C, T, B>>,
|
||||
pub peer: Peer<E, C, T, B>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
|
||||
inner: Arc<DeviceInner<E, C, T, B>>,
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Device<E, C, T, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Device {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq
|
||||
for Device<E, C, T, B>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::ptr_eq(&self.inner, &other.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Device<E, C, T, B> {}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Device<E, C, T, B> {
|
||||
type Target = DeviceInner<E, C, T, B>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DeviceHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
state: Device<E, C, T, B>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> {
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
|
||||
for DeviceHandle<E, C, T, B>
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
debug!("router: dropping device");
|
||||
|
||||
// drop all queues
|
||||
{
|
||||
let mut queues = self.state.queues.lock();
|
||||
while queues.pop().is_some() {}
|
||||
}
|
||||
// close worker queues
|
||||
self.state.outbound_queue.close();
|
||||
self.state.inbound_queue.close();
|
||||
|
||||
// join all worker threads
|
||||
while match self.handles.pop() {
|
||||
@@ -82,14 +155,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Devi
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
|
||||
// allocate shared device state
|
||||
let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers);
|
||||
let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers);
|
||||
let inner = DeviceInner {
|
||||
inbound: tun,
|
||||
inbound_queue,
|
||||
outbound: RwLock::new((true, None)),
|
||||
queues: Mutex::new(Vec::with_capacity(num_workers)),
|
||||
queue_next: AtomicUsize::new(0),
|
||||
outbound_queue,
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
table: RoutingTable::new(),
|
||||
};
|
||||
@@ -97,14 +172,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
|
||||
// start worker threads
|
||||
let mut threads = Vec::with_capacity(num_workers);
|
||||
for _ in 0..num_workers {
|
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||
inner.queues.lock().push(tx);
|
||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||
let rx = inrx.pop().unwrap();
|
||||
threads.push(thread::spawn(move || inbound::worker(rx)));
|
||||
}
|
||||
|
||||
for _ in 0..num_workers {
|
||||
let rx = outrx.pop().unwrap();
|
||||
threads.push(thread::spawn(move || outbound::worker(rx)));
|
||||
}
|
||||
|
||||
// return exported device handle
|
||||
Device {
|
||||
state: Arc::new(inner),
|
||||
DeviceHandle {
|
||||
state: Device {
|
||||
inner: Arc::new(inner),
|
||||
},
|
||||
handles: threads,
|
||||
}
|
||||
}
|
||||
@@ -131,7 +212,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
|
||||
/// # Returns
|
||||
///
|
||||
/// A atomic ref. counted peer (with liftime matching the device)
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle<E, C, T, B> {
|
||||
new_peer(self.state.clone(), opaque)
|
||||
}
|
||||
|
||||
@@ -160,10 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg, true) {
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
self.state.outbound_queue.send(job);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -209,10 +287,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
|
||||
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||
// add job to worker queue
|
||||
let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.state.queues.lock();
|
||||
queues[idx % queues.len()].send(job).unwrap();
|
||||
self.state.inbound_queue.send(job);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
172
src/wireguard/router/inbound.rs
Normal file
172
src/wireguard/router/inbound.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use super::device::DecryptionState;
|
||||
use super::messages::TransportHeader;
|
||||
use super::peer::Peer;
|
||||
use super::pool::*;
|
||||
use super::types::Callbacks;
|
||||
use super::{tun, udp, Endpoint};
|
||||
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use zerocopy::{AsBytes, LayoutVerified};
|
||||
|
||||
use std::mem;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const SIZE_TAG: usize = 16;
|
||||
|
||||
pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
msg: Vec<u8>,
|
||||
failed: bool,
|
||||
state: Arc<DecryptionState<E, C, T, B>>,
|
||||
endpoint: Option<E>,
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> {
|
||||
pub fn new(
|
||||
msg: Vec<u8>,
|
||||
state: Arc<DecryptionState<E, C, T, B>>,
|
||||
endpoint: E,
|
||||
) -> Inbound<E, C, T, B> {
|
||||
Inbound {
|
||||
msg,
|
||||
state,
|
||||
failed: false,
|
||||
endpoint: Some(endpoint),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
body: &mut Inbound<E, C, T, B>,
|
||||
) {
|
||||
// cast to header followed by payload
|
||||
let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
|
||||
match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
log::debug!("inbound worker: failed to parse message");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// authenticate and decrypt payload
|
||||
{
|
||||
// create nonce object
|
||||
let mut nonce = [0u8; 12];
|
||||
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
|
||||
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
|
||||
let nonce = Nonce::assume_unique_for_key(nonce);
|
||||
|
||||
// do the weird ring AEAD dance
|
||||
let key = LessSafeKey::new(
|
||||
UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
|
||||
);
|
||||
|
||||
// attempt to open (and authenticate) the body
|
||||
match key.open_in_place(nonce, Aad::empty(), packet) {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
// fault and return early
|
||||
body.failed = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cryptokey route and strip padding
|
||||
let inner_len = {
|
||||
let length = packet.len() - SIZE_TAG;
|
||||
if length > 0 {
|
||||
peer.device.table.check_route(&peer, &packet[..length])
|
||||
} else {
|
||||
Some(0)
|
||||
}
|
||||
};
|
||||
|
||||
// truncate to remove tag
|
||||
match inner_len {
|
||||
None => {
|
||||
body.failed = true;
|
||||
}
|
||||
Some(len) => {
|
||||
body.msg.truncate(mem::size_of::<TransportHeader>() + len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
body: &mut Inbound<E, C, T, B>,
|
||||
) {
|
||||
// decryption failed, return early
|
||||
if body.failed {
|
||||
return;
|
||||
}
|
||||
|
||||
// cast transport header
|
||||
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
|
||||
match LayoutVerified::new_from_prefix(&body.msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
log::debug!("inbound worker: failed to parse message");
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug_assert!(
|
||||
packet.len() >= CHACHA20_POLY1305.tag_len(),
|
||||
"this should be checked earlier in the pipeline (decryption should fail)"
|
||||
);
|
||||
|
||||
// check for replay
|
||||
if !body.state.protector.lock().update(header.f_counter.get()) {
|
||||
log::debug!("inbound worker: replay detected");
|
||||
return;
|
||||
}
|
||||
|
||||
// check for confirms key
|
||||
if !body.state.confirmed.swap(true, Ordering::SeqCst) {
|
||||
log::debug!("inbound worker: message confirms key");
|
||||
peer.confirm_key(&body.state.keypair);
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
*peer.endpoint.lock() = body.endpoint.take();
|
||||
|
||||
// calculate length of IP packet + padding
|
||||
let length = packet.len() - SIZE_TAG;
|
||||
log::debug!("inbound worker: plaintext length = {}", length);
|
||||
|
||||
// check if should be written to TUN
|
||||
let mut sent = false;
|
||||
if length > 0 {
|
||||
sent = match peer.device.inbound.write(&packet[..]) {
|
||||
Err(e) => {
|
||||
log::debug!("failed to write inbound packet to TUN: {:?}", e);
|
||||
false
|
||||
}
|
||||
Ok(_) => true,
|
||||
}
|
||||
} else {
|
||||
log::debug!("inbound worker: received keepalive")
|
||||
}
|
||||
|
||||
// trigger callback
|
||||
C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> {
|
||||
&peer.inbound
|
||||
}
|
||||
|
||||
pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
|
||||
) {
|
||||
worker_template(receiver, parallel, sequential, queue)
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
mod anti_replay;
|
||||
mod constants;
|
||||
mod device;
|
||||
mod inbound;
|
||||
mod ip;
|
||||
mod messages;
|
||||
mod outbound;
|
||||
mod peer;
|
||||
mod pool;
|
||||
mod route;
|
||||
mod types;
|
||||
mod workers;
|
||||
|
||||
// mod workers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -16,15 +20,17 @@ use std::mem;
|
||||
|
||||
use super::constants::REJECT_AFTER_MESSAGES;
|
||||
use super::types::*;
|
||||
use super::{tun, udp, Endpoint};
|
||||
|
||||
pub const SIZE_TAG: usize = 16;
|
||||
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
|
||||
pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG;
|
||||
pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG;
|
||||
|
||||
pub const fn message_data_len(payload: usize) -> usize {
|
||||
payload + mem::size_of::<TransportHeader>() + workers::SIZE_TAG
|
||||
payload + mem::size_of::<TransportHeader>() + SIZE_TAG
|
||||
}
|
||||
|
||||
pub use device::Device;
|
||||
pub use device::DeviceHandle as Device;
|
||||
pub use messages::TYPE_TRANSPORT;
|
||||
pub use peer::Peer;
|
||||
pub use peer::PeerHandle;
|
||||
pub use types::Callbacks;
|
||||
|
||||
104
src/wireguard/router/outbound.rs
Normal file
104
src/wireguard/router/outbound.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::Peer;
|
||||
use super::pool::*;
|
||||
use super::types::Callbacks;
|
||||
use super::KeyPair;
|
||||
use super::REJECT_AFTER_MESSAGES;
|
||||
use super::{tun, udp, Endpoint};
|
||||
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use zerocopy::{AsBytes, LayoutVerified};
|
||||
|
||||
pub const SIZE_TAG: usize = 16;
|
||||
|
||||
pub struct Outbound {
|
||||
msg: Vec<u8>,
|
||||
keypair: Arc<KeyPair>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl Outbound {
|
||||
pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound {
|
||||
Outbound {
|
||||
msg,
|
||||
keypair,
|
||||
counter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
_peer: &Peer<E, C, T, B>,
|
||||
body: &mut Outbound,
|
||||
) {
|
||||
// make space for the tag
|
||||
body.msg.extend([0u8; SIZE_TAG].iter());
|
||||
|
||||
// cast to header (should never fail)
|
||||
let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
|
||||
LayoutVerified::new_from_prefix(&mut body.msg[..])
|
||||
.expect("earlier code should ensure that there is ample space");
|
||||
|
||||
// set header fields
|
||||
debug_assert!(
|
||||
body.counter < REJECT_AFTER_MESSAGES,
|
||||
"should be checked when assigning counters"
|
||||
);
|
||||
header.f_type.set(TYPE_TRANSPORT);
|
||||
header.f_receiver.set(body.keypair.send.id);
|
||||
header.f_counter.set(body.counter);
|
||||
|
||||
// create a nonce object
|
||||
let mut nonce = [0u8; 12];
|
||||
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
|
||||
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
|
||||
let nonce = Nonce::assume_unique_for_key(nonce);
|
||||
|
||||
// do the weird ring AEAD dance
|
||||
let key =
|
||||
LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap());
|
||||
|
||||
// encrypt content of transport message in-place
|
||||
let end = packet.len() - SIZE_TAG;
|
||||
let tag = key
|
||||
.seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
|
||||
.unwrap();
|
||||
|
||||
// append tag
|
||||
packet[end..].copy_from_slice(tag.as_ref());
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
body: &mut Outbound,
|
||||
) {
|
||||
// send to peer
|
||||
let xmit = peer.send(&body.msg[..]).is_ok();
|
||||
|
||||
// trigger callback
|
||||
C::send(
|
||||
&peer.opaque,
|
||||
body.msg.len(),
|
||||
xmit,
|
||||
&body.keypair,
|
||||
body.counter,
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
) -> &InorderQueue<Peer<E, C, T, B>, Outbound> {
|
||||
&peer.outbound
|
||||
}
|
||||
|
||||
pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
|
||||
) {
|
||||
worker_template(receiver, parallel, sequential, queue)
|
||||
}
|
||||
@@ -2,10 +2,7 @@ use std::mem;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::{sync_channel, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use arraydeque::{ArrayDeque, Wrapping};
|
||||
use log::debug;
|
||||
@@ -16,18 +13,18 @@ use super::super::{tun, udp, Endpoint, KeyPair};
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::device::DecryptionState;
|
||||
use super::device::DeviceInner;
|
||||
use super::device::Device;
|
||||
use super::device::EncryptionState;
|
||||
use super::messages::TransportHeader;
|
||||
|
||||
use futures::*;
|
||||
|
||||
use super::workers::{worker_inbound, worker_outbound};
|
||||
use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::constants::*;
|
||||
use super::types::{Callbacks, RouterError};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
// worker pool related
|
||||
use super::inbound::Inbound;
|
||||
use super::outbound::Outbound;
|
||||
use super::pool::{InorderQueue, Job};
|
||||
|
||||
pub struct KeyWheel {
|
||||
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
|
||||
@@ -37,10 +34,10 @@ pub struct KeyWheel {
|
||||
}
|
||||
|
||||
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
pub device: Arc<DeviceInner<E, C, T, B>>,
|
||||
pub device: Device<E, C, T, B>,
|
||||
pub opaque: C::Opaque,
|
||||
pub outbound: Mutex<SyncSender<JobOutbound>>,
|
||||
pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
|
||||
pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>,
|
||||
pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>,
|
||||
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
|
||||
pub keys: Mutex<KeyWheel>,
|
||||
pub ekey: Mutex<Option<EncryptionState>>,
|
||||
@@ -48,16 +45,42 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
|
||||
}
|
||||
|
||||
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
state: Arc<PeerInner<E, C, T, B>>,
|
||||
thread_outbound: Option<thread::JoinHandle<()>>,
|
||||
thread_inbound: Option<thread::JoinHandle<()>>,
|
||||
inner: Arc<PeerInner<E, C, T, B>>,
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
|
||||
type Target = Arc<PeerInner<E, C, T, B>>;
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Peer<E, C, T, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Peer {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::ptr_eq(&self.inner, &other.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
|
||||
type Target = PeerInner<E, C, T, B>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
|
||||
peer: Peer<E, C, T, B>,
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
|
||||
for PeerHandle<E, C, T, B>
|
||||
{
|
||||
type Target = PeerInner<E, C, T, B>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.peer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,37 +95,24 @@ impl EncryptionState {
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> {
|
||||
fn new(
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
keypair: &Arc<KeyPair>,
|
||||
) -> DecryptionState<E, C, T, B> {
|
||||
fn new(peer: Peer<E, C, T, B>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
|
||||
DecryptionState {
|
||||
confirmed: AtomicBool::new(keypair.initiator),
|
||||
keypair: keypair.clone(),
|
||||
protector: spin::Mutex::new(AntiReplay::new()),
|
||||
peer: peer.clone(),
|
||||
death: keypair.birth + REJECT_AFTER_TIME,
|
||||
peer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> {
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for PeerHandle<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
let peer = &self.state;
|
||||
let peer = &self.peer;
|
||||
|
||||
// remove from cryptkey router
|
||||
|
||||
self.state.device.table.remove(peer);
|
||||
|
||||
// drop channels
|
||||
|
||||
mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
|
||||
mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
|
||||
|
||||
// join with workers
|
||||
|
||||
mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
|
||||
mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
|
||||
self.peer.device.table.remove(peer);
|
||||
|
||||
// release ids from the receiver map
|
||||
|
||||
@@ -134,20 +144,18 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer
|
||||
}
|
||||
|
||||
pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>,
|
||||
device: Device<E, C, T, B>,
|
||||
opaque: C::Opaque,
|
||||
) -> Peer<E, C, T, B> {
|
||||
let (out_tx, out_rx) = sync_channel(128);
|
||||
let (in_tx, in_rx) = sync_channel(128);
|
||||
|
||||
) -> PeerHandle<E, C, T, B> {
|
||||
// allocate peer object
|
||||
let peer = {
|
||||
let device = device.clone();
|
||||
Arc::new(PeerInner {
|
||||
Peer {
|
||||
inner: Arc::new(PeerInner {
|
||||
opaque,
|
||||
device,
|
||||
inbound: Mutex::new(in_tx),
|
||||
outbound: Mutex::new(out_tx),
|
||||
inbound: InorderQueue::new(),
|
||||
outbound: InorderQueue::new(),
|
||||
ekey: spin::Mutex::new(None),
|
||||
endpoint: spin::Mutex::new(None),
|
||||
keys: spin::Mutex::new(KeyWheel {
|
||||
@@ -157,27 +165,11 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
|
||||
retired: vec![],
|
||||
}),
|
||||
staged_packets: spin::Mutex::new(ArrayDeque::new()),
|
||||
})
|
||||
};
|
||||
|
||||
// spawn outbound thread
|
||||
let thread_inbound = {
|
||||
let peer = peer.clone();
|
||||
thread::spawn(move || worker_outbound(peer, out_rx))
|
||||
};
|
||||
|
||||
// spawn inbound thread
|
||||
let thread_outbound = {
|
||||
let peer = peer.clone();
|
||||
let device = device.clone();
|
||||
thread::spawn(move || worker_inbound(device, peer, in_rx))
|
||||
};
|
||||
|
||||
Peer {
|
||||
state: peer,
|
||||
thread_inbound: Some(thread_inbound),
|
||||
thread_outbound: Some(thread_outbound),
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
PeerHandle { peer }
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> {
|
||||
@@ -210,7 +202,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
|
||||
None => Err(RouterError::NoEndpoint),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
|
||||
// Transmit all staged packets
|
||||
fn send_staged(&self) -> bool {
|
||||
debug!("peer.send_staged");
|
||||
@@ -230,16 +224,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
|
||||
// Treat the msg as the payload of a transport message
|
||||
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
|
||||
fn send_raw(&self, msg: Vec<u8>) -> bool {
|
||||
debug!("peer.send_raw");
|
||||
log::debug!("peer.send_raw");
|
||||
match self.send_job(msg, false) {
|
||||
Some(job) => {
|
||||
self.device.outbound_queue.send(job);
|
||||
debug!("send_raw: got obtained send_job");
|
||||
let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
let queues = self.device.queues.lock();
|
||||
match queues[index % queues.len()].send(job) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
}
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
@@ -285,16 +275,11 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
|
||||
src: E,
|
||||
dec: Arc<DecryptionState<E, C, T, B>>,
|
||||
msg: Vec<u8>,
|
||||
) -> Option<JobParallel> {
|
||||
let (tx, rx) = oneshot();
|
||||
let keypair = dec.keypair.clone();
|
||||
match self.inbound.lock().try_send((dec, src, rx)) {
|
||||
Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })),
|
||||
Err(_) => None,
|
||||
}
|
||||
) -> Option<Job<Self, Inbound<E, C, T, B>>> {
|
||||
Some(Job::new(self.clone(), Inbound::new(msg, dec, src)))
|
||||
}
|
||||
|
||||
pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
|
||||
pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> {
|
||||
debug!("peer.send_job");
|
||||
debug_assert!(
|
||||
msg.len() >= mem::size_of::<TransportHeader>(),
|
||||
@@ -337,22 +322,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
|
||||
}?;
|
||||
|
||||
// add job to in-order queue and return sender to device for inclusion in worker pool
|
||||
let (tx, rx) = oneshot();
|
||||
match self.outbound.lock().try_send(rx) {
|
||||
Ok(_) => Some(JobParallel::Encryption(
|
||||
tx,
|
||||
JobEncryption {
|
||||
msg,
|
||||
counter,
|
||||
keypair,
|
||||
},
|
||||
)),
|
||||
Err(_) => None,
|
||||
}
|
||||
let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter));
|
||||
self.outbound.send(job.clone());
|
||||
Some(job)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
|
||||
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> {
|
||||
/// Set the endpoint of the peer
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -365,7 +341,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
/// as sockets should be "unsticked" when manually updating the endpoint
|
||||
pub fn set_endpoint(&self, endpoint: E) {
|
||||
debug!("peer.set_endpoint");
|
||||
*self.state.endpoint.lock() = Some(endpoint);
|
||||
*self.peer.endpoint.lock() = Some(endpoint);
|
||||
}
|
||||
|
||||
/// Returns the current endpoint of the peer (for configuration)
|
||||
@@ -375,11 +351,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
/// Does not convey potential "sticky socket" information
|
||||
pub fn get_endpoint(&self) -> Option<SocketAddr> {
|
||||
debug!("peer.get_endpoint");
|
||||
self.state
|
||||
.endpoint
|
||||
.lock()
|
||||
.as_ref()
|
||||
.map(|e| e.into_address())
|
||||
self.peer.endpoint.lock().as_ref().map(|e| e.into_address())
|
||||
}
|
||||
|
||||
/// Zero all key-material related to the peer
|
||||
@@ -387,7 +359,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
debug!("peer.zero_keys");
|
||||
|
||||
let mut release: Vec<u32> = Vec::with_capacity(3);
|
||||
let mut keys = self.state.keys.lock();
|
||||
let mut keys = self.peer.keys.lock();
|
||||
|
||||
// update key-wheel
|
||||
|
||||
@@ -398,14 +370,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
|
||||
// update inbound "recv" map
|
||||
{
|
||||
let mut recv = self.state.device.recv.write();
|
||||
let mut recv = self.peer.device.recv.write();
|
||||
for id in release {
|
||||
recv.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
// clear encryption state
|
||||
*self.state.ekey.lock() = None;
|
||||
*self.peer.ekey.lock() = None;
|
||||
}
|
||||
|
||||
pub fn down(&self) {
|
||||
@@ -436,13 +408,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
let initiator = new.initiator;
|
||||
let release = {
|
||||
let new = Arc::new(new);
|
||||
let mut keys = self.state.keys.lock();
|
||||
let mut keys = self.peer.keys.lock();
|
||||
let mut release = mem::replace(&mut keys.retired, vec![]);
|
||||
|
||||
// update key-wheel
|
||||
if new.initiator {
|
||||
// start using key for encryption
|
||||
*self.state.ekey.lock() = Some(EncryptionState::new(&new));
|
||||
*self.peer.ekey.lock() = Some(EncryptionState::new(&new));
|
||||
|
||||
// move current into previous
|
||||
keys.previous = keys.current.as_ref().map(|v| v.clone());
|
||||
@@ -456,7 +428,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
// update incoming packet id map
|
||||
{
|
||||
debug!("peer.add_keypair: updating inbound id map");
|
||||
let mut recv = self.state.device.recv.write();
|
||||
let mut recv = self.peer.device.recv.write();
|
||||
|
||||
// purge recv map of previous id
|
||||
keys.previous.as_ref().map(|k| {
|
||||
@@ -468,7 +440,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||
recv.insert(
|
||||
new.recv.id,
|
||||
Arc::new(DecryptionState::new(&self.state, &new)),
|
||||
Arc::new(DecryptionState::new(self.peer.clone(), &new)),
|
||||
);
|
||||
}
|
||||
release
|
||||
@@ -476,10 +448,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
|
||||
// schedule confirmation
|
||||
if initiator {
|
||||
debug_assert!(self.state.ekey.lock().is_some());
|
||||
debug_assert!(self.peer.ekey.lock().is_some());
|
||||
debug!("peer.add_keypair: is initiator, must confirm the key");
|
||||
// attempt to confirm using staged packets
|
||||
if !self.state.send_staged() {
|
||||
if !self.peer.send_staged() {
|
||||
// fall back to keepalive packet
|
||||
let ok = self.send_keepalive();
|
||||
debug!(
|
||||
@@ -499,7 +471,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
|
||||
pub fn send_keepalive(&self) -> bool {
|
||||
debug!("peer.send_keepalive");
|
||||
self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
|
||||
self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
|
||||
}
|
||||
|
||||
/// Map a subnet to the peer
|
||||
@@ -517,10 +489,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
/// If an identical value already exists as part of a prior peer,
|
||||
/// the allowed IP entry will be removed from that peer and added to this peer.
|
||||
pub fn add_allowed_ip(&self, ip: IpAddr, masklen: u32) {
|
||||
self.state
|
||||
self.peer
|
||||
.device
|
||||
.table
|
||||
.insert(ip, masklen, self.state.clone())
|
||||
.insert(ip, masklen, self.peer.clone())
|
||||
}
|
||||
|
||||
/// List subnets mapped to the peer
|
||||
@@ -529,23 +501,21 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
|
||||
///
|
||||
/// A vector of subnets, represented by as mask/size
|
||||
pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> {
|
||||
self.state.device.table.list(&self.state)
|
||||
self.peer.device.table.list(&self.peer)
|
||||
}
|
||||
|
||||
/// Clear subnets mapped to the peer.
|
||||
/// After the call, no subnets will be cryptkey routed to the peer.
|
||||
/// Used for the UAPI command "replace_allowed_ips=true"
|
||||
pub fn remove_allowed_ips(&self) {
|
||||
self.state.device.table.remove(&self.state)
|
||||
self.peer.device.table.remove(&self.peer)
|
||||
}
|
||||
|
||||
pub fn clear_src(&self) {
|
||||
(*self.state.endpoint.lock())
|
||||
.as_mut()
|
||||
.map(|e| e.clear_src());
|
||||
(*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src());
|
||||
}
|
||||
|
||||
pub fn purge_staged_packets(&self) {
|
||||
self.state.staged_packets.lock().clear();
|
||||
self.peer.staged_packets.lock().clear();
|
||||
}
|
||||
}
|
||||
|
||||
132
src/wireguard/router/pool.rs
Normal file
132
src/wireguard/router/pool.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use arraydeque::ArrayDeque;
|
||||
use spin::{Mutex, MutexGuard};
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::Arc;
|
||||
|
||||
const INORDER_QUEUE_SIZE: usize = 64;
|
||||
|
||||
pub struct InnerJob<P, B> {
|
||||
// peer (used by worker to schedule/handle inorder queue),
|
||||
// when the peer is None, the job is complete
|
||||
peer: Option<P>,
|
||||
pub body: B,
|
||||
}
|
||||
|
||||
pub struct Job<P, B> {
|
||||
inner: Arc<Mutex<InnerJob<P, B>>>,
|
||||
}
|
||||
|
||||
impl<P, B> Clone for Job<P, B> {
|
||||
fn clone(&self) -> Job<P, B> {
|
||||
Job {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, B> Job<P, B> {
|
||||
pub fn new(peer: P, body: B) -> Job<P, B> {
|
||||
Job {
|
||||
inner: Arc::new(Mutex::new(InnerJob {
|
||||
peer: Some(peer),
|
||||
body,
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, B> Job<P, B> {
|
||||
/// Returns a mutex guard to the inner job if complete
|
||||
pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> {
|
||||
self.inner
|
||||
.try_lock()
|
||||
.and_then(|m| if m.peer.is_none() { Some(m) } else { None })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InorderQueue<P, B> {
|
||||
queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>,
|
||||
}
|
||||
|
||||
impl<P, B> InorderQueue<P, B> {
|
||||
pub fn send(&self, job: Job<P, B>) -> bool {
|
||||
self.queue.lock().push_back(job).is_ok()
|
||||
}
|
||||
|
||||
pub fn new() -> InorderQueue<P, B> {
|
||||
InorderQueue {
|
||||
queue: Mutex::new(ArrayDeque::new()),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) {
|
||||
// take the mutex
|
||||
let mut queue = self.queue.lock();
|
||||
|
||||
// handle all complete messages
|
||||
while queue
|
||||
.pop_front()
|
||||
.and_then(|j| {
|
||||
// check if job is complete
|
||||
let ret = if let Some(mut guard) = j.complete() {
|
||||
f(&mut *guard);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
// return job to cyclic buffer if not complete
|
||||
if ret {
|
||||
let _res = queue.push_front(j);
|
||||
debug_assert!(_res.is_ok());
|
||||
None
|
||||
} else {
|
||||
// add job back to pool
|
||||
Some(())
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows easy construction of a semi-parallel worker.
|
||||
/// Applicable for both decryption and encryption workers.
|
||||
#[inline(always)]
|
||||
pub fn worker_template<
|
||||
P, // represents a peer (atomic reference counted pointer)
|
||||
B, // inner body type (message buffer, key material, ...)
|
||||
W: Fn(&P, &mut B),
|
||||
S: Fn(&P, &mut B),
|
||||
Q: Fn(&P) -> &InorderQueue<P, B>,
|
||||
>(
|
||||
receiver: Receiver<Job<P, B>>, // receiever for new jobs
|
||||
work_parallel: W, // perform parallel / out-of-order work on peer
|
||||
work_sequential: S, // perform sequential work on peer
|
||||
queue: Q, // resolve a peer to an inorder queue
|
||||
) {
|
||||
loop {
|
||||
// handle new job
|
||||
let peer = {
|
||||
// get next job
|
||||
let job = match receiver.recv() {
|
||||
Ok(job) => job,
|
||||
_ => return,
|
||||
};
|
||||
|
||||
// lock the job
|
||||
let mut job = job.inner.lock();
|
||||
|
||||
// take the peer from the job
|
||||
let peer = job.peer.take().unwrap();
|
||||
|
||||
// process job
|
||||
work_parallel(&peer, &mut job.body);
|
||||
peer
|
||||
};
|
||||
|
||||
// process inorder jobs for peer
|
||||
queue(&peer).handle(|j| work_sequential(&peer, &mut j.body));
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ use zerocopy::LayoutVerified;
|
||||
|
||||
use std::mem;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::Arc;
|
||||
|
||||
use spin::RwLock;
|
||||
use treebitmap::address::Address;
|
||||
@@ -12,12 +11,12 @@ use treebitmap::IpLookupTable;
|
||||
|
||||
/* Functions for obtaining and validating "cryptokey" routes */
|
||||
|
||||
pub struct RoutingTable<T> {
|
||||
ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>,
|
||||
ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>,
|
||||
pub struct RoutingTable<T: Eq + Clone> {
|
||||
ipv4: RwLock<IpLookupTable<Ipv4Addr, T>>,
|
||||
ipv6: RwLock<IpLookupTable<Ipv6Addr, T>>,
|
||||
}
|
||||
|
||||
impl<T> RoutingTable<T> {
|
||||
impl<T: Eq + Clone> RoutingTable<T> {
|
||||
pub fn new() -> Self {
|
||||
RoutingTable {
|
||||
ipv4: RwLock::new(IpLookupTable::new()),
|
||||
@@ -26,27 +25,27 @@ impl<T> RoutingTable<T> {
|
||||
}
|
||||
|
||||
// collect keys mapping to the given value
|
||||
fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)>
|
||||
fn collect<A>(table: &IpLookupTable<A, T>, value: &T) -> Vec<(A, u32)>
|
||||
where
|
||||
A: Address,
|
||||
{
|
||||
let mut res = Vec::new();
|
||||
for (ip, cidr, v) in table.iter() {
|
||||
if Arc::ptr_eq(v, value) {
|
||||
if v == value {
|
||||
res.push((ip, cidr))
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) {
|
||||
pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value),
|
||||
IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> {
|
||||
pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> {
|
||||
let mut res = vec![];
|
||||
res.extend(
|
||||
Self::collect(&*self.ipv4.read(), value)
|
||||
@@ -61,7 +60,7 @@ impl<T> RoutingTable<T> {
|
||||
res
|
||||
}
|
||||
|
||||
pub fn remove(&self, value: &Arc<T>) {
|
||||
pub fn remove(&self, value: &T) {
|
||||
let mut v4 = self.ipv4.write();
|
||||
for (ip, cidr) in Self::collect(&*v4, value) {
|
||||
v4.remove(ip, cidr);
|
||||
@@ -74,7 +73,7 @@ impl<T> RoutingTable<T> {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> {
|
||||
pub fn get_route(&self, packet: &[u8]) -> Option<T> {
|
||||
match packet.get(0)? >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
@@ -113,7 +112,7 @@ impl<T> RoutingTable<T> {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> {
|
||||
pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option<usize> {
|
||||
match packet.get(0)? >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
@@ -130,7 +129,7 @@ impl<T> RoutingTable<T> {
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, peer) {
|
||||
if p == peer {
|
||||
Some(header.f_total_len.get() as usize)
|
||||
} else {
|
||||
None
|
||||
@@ -152,7 +151,7 @@ impl<T> RoutingTable<T> {
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, peer) {
|
||||
if p == peer {
|
||||
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::dummy;
|
||||
use super::wireguard::Wireguard;
|
||||
use super::{dummy, tun, udp};
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::thread;
|
||||
|
||||
@@ -137,6 +137,7 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
|
||||
pub fn timers_handshake_complete(&self) {
|
||||
let timers = self.timers();
|
||||
if timers.enabled {
|
||||
timers.retransmit_handshake.stop();
|
||||
timers.handshake_attempts.store(0, Ordering::SeqCst);
|
||||
timers.sent_lastminute_handshake.store(false, Ordering::SeqCst);
|
||||
*self.walltime_last_handshake.lock() = Some(SystemTime::now());
|
||||
|
||||
Reference in New Issue
Block a user