Restructure IO traits.

This commit is contained in:
Mathias Hall-Andersen
2019-10-09 15:08:26 +02:00
parent c82d3e554b
commit 761c46064d
16 changed files with 643 additions and 376 deletions

183
src/config.rs Normal file
View File

@@ -0,0 +1,183 @@
use std::error::Error;
use std::net::{IpAddr, SocketAddr};
use x25519_dalek::{PublicKey, StaticSecret};
use crate::wireguard::Wireguard;
use crate::types::{Bind, Endpoint, Tun};
///
/// The goal of the configuration interface is, among others,
/// to hide the IO implementations (over which the WG device is generic),
/// from the configuration and UAPI code.
/// Describes a snapshot of the state of a peer
pub struct PeerState {
rx_bytes: u64,
tx_bytes: u64,
last_handshake_time_sec: u64,
last_handshake_time_nsec: u64,
public_key: PublicKey,
allowed_ips: Vec<(IpAddr, u32)>,
}
pub enum ConfigError {
NoSuchPeer
}
impl ConfigError {
fn errno(&self) -> i32 {
match self {
NoSuchPeer => 1,
}
}
}
/// Exposed configuration interface
pub trait Configuration {
/// Updates the private key of the device
///
/// # Arguments
///
/// - `sk`: The new private key (or None, if the private key should be cleared)
fn set_private_key(&self, sk: Option<StaticSecret>);
/// Returns the private key of the device
///
/// # Returns
///
/// The private if set, otherwise None.
fn get_private_key(&self) -> Option<StaticSecret>;
/// Returns the protocol version of the device
///
/// # Returns
///
/// An integer indicating the protocol version
fn get_protocol_version(&self) -> usize;
fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
/// Set the firewall mark (or similar, depending on platform)
///
/// # Arguments
///
/// - `mark`: The fwmark value
///
/// # Returns
///
/// An error if this operation is not supported by the underlying
/// "bind" implementation.
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
/// Removes all peers from the device
fn replace_peers(&self);
/// Remove the peer from the
///
/// # Arguments
///
/// - `peer`: The public key of the peer to remove
///
/// # Returns
///
/// If the peer does not exists this operation is a noop
fn remove_peer(&self, peer: PublicKey);
/// Adds a new peer to the device
///
/// # Arguments
///
/// - `peer`: The public key of the peer to add
///
/// # Returns
///
/// A bool indicating if the peer was added.
///
/// If the peer already exists this operation is a noop
fn add_peer(&self, peer: PublicKey) -> bool;
/// Update the psk of a peer
///
/// # Arguments
///
/// - `peer`: The public key of the peer
/// - `psk`: The new psk or None if the psk should be unset
///
/// # Returns
///
/// An error if no such peer exists
fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
/// Update the endpoint of the
///
/// # Arguments
///
/// - `peer': The public key of the peer
/// - `psk`
fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
/// Update the endpoint of the
///
/// # Arguments
///
/// - `peer': The public key of the peer
/// - `psk`
fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
/// Remove all allowed IPs from the peer
///
/// # Arguments
///
/// - `peer': The public key of the peer
///
/// # Returns
///
/// An error if no such peer exists
fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
/// Add a new allowed subnet to the peer
///
/// # Arguments
///
/// - `peer`: The public key of the peer
/// - `ip`: Subnet mask
/// - `masklen`:
///
/// # Returns
///
/// An error if the peer does not exist
///
/// # Note:
///
/// The API must itself sanitize the (ip, masklen) set:
/// The ip should be masked to remove any set bits right of the first "masklen" bits.
fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
/// Returns the state of all peers
///
/// # Returns
///
/// A list of structures describing the state of each peer
fn get_peers(&self) -> Vec<PeerState>;
}
impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
fn set_private_key(&self, sk : Option<StaticSecret>) {
self.set_key(sk)
}
fn get_private_key(&self) -> Option<StaticSecret> {
self.get_sk()
}
fn get_protocol_version(&self) -> usize {
1
}
fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
}
}

View File

@@ -76,6 +76,15 @@ impl Device {
}
}
/// Return the secret key of the device
///
/// # Returns
///
/// A secret key (x25519 scalar)
pub fn get_sk(&self) -> StaticSecret {
StaticSecret::from(self.sk.to_bytes())
}
/// Add a new public key to the state machine
/// To remove public keys, you must create a new machine instance
///

View File

@@ -5,6 +5,7 @@ extern crate jemallocator;
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
// mod config;
mod constants;
mod handshake;
mod router;
@@ -14,7 +15,8 @@ mod wireguard;
#[cfg(test)]
mod tests {
use crate::types::{dummy, Bind};
use crate::types::tun::Tun;
use crate::types::{bind, dummy, tun};
use crate::wireguard::Wireguard;
use std::thread;
@@ -27,7 +29,8 @@ mod tests {
#[test]
fn test_pure_wireguard() {
init();
let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
thread::sleep(Duration::from_millis(500));
}
}

View File

@@ -17,21 +17,23 @@ use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callbacks, Opaque, RouterError};
use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
use super::super::types::{Bind, KeyPair, Tun};
use super::super::types::{KeyPair, Endpoint, bind, tun};
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer callbacks
pub tun: T,
pub bind: B,
pub struct DeviceInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
// inbound writer (TUN)
pub inbound: T,
// outbound writer (Bind)
pub outbound: RwLock<Option<B>>,
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
// work queues
pub queue_next: AtomicUsize, // next round-robin index
@@ -45,20 +47,20 @@ pub struct EncryptionState {
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
}
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
pub struct DecryptionState<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>,
pub peer: Arc<PeerInner<C, T, B>>,
pub peer: Arc<PeerInner<E, C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
state: Arc<DeviceInner<C, T, B>>, // reference to device state
pub struct Device<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
}
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
fn drop(&mut self) {
debug!("router: dropping device");
@@ -83,10 +85,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
}
#[inline(always)]
fn get_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
fn get_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<C, T, B>>> {
) -> Option<Arc<PeerInner<E, C, T, B>>> {
// ensure version access within bounds
if packet.len() < 1 {
return None;
@@ -122,12 +124,12 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
tun,
bind,
inbound: tun,
outbound: RwLock::new(None),
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
recv: RwLock::new(HashMap::new()),
@@ -159,7 +161,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> {
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
new_peer(self.state.clone(), opaque)
}
@@ -199,7 +201,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// # Returns
///
///
pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
// parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
@@ -231,4 +233,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
Ok(())
}
/// Set outbound writer
///
///
pub fn set_outbound_writer(&self, new : B) {
*self.state.outbound.write() = Some(new);
}
}

View File

@@ -14,7 +14,7 @@ use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::super::constants::*;
use super::super::types::{Bind, Endpoint, KeyPair, Tun};
use super::super::types::{Endpoint, KeyPair, bind, tun};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
@@ -39,28 +39,28 @@ pub struct KeyWheel {
retired: Vec<u32>, // retired ids
}
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub device: Arc<DeviceInner<C, T, B>>,
pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub device: Arc<DeviceInner<E, C, T, B>>,
pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>,
pub inbound: Mutex<SyncSender<JobInbound<C, T, B>>>,
pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
pub ekey: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<B::Endpoint>>,
pub endpoint: Mutex<Option<E>>,
}
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
state: Arc<PeerInner<C, T, B>>,
pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
state: Arc<PeerInner<E, C, T, B>>,
thread_outbound: Option<thread::JoinHandle<()>>,
thread_inbound: Option<thread::JoinHandle<()>>,
}
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
peer: &Arc<PeerInner<C, T, B>>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> E>,
) -> Vec<E>
fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Arc<PeerInner<E, C, T, B>>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> R>,
) -> Vec<R>
where
A: Address,
{
@@ -74,9 +74,9 @@ where
res
}
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
peer: &Peer<C, T, B>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Peer<E, C, T, B>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
) {
let mut m = table.write();
@@ -107,8 +107,8 @@ impl EncryptionState {
}
}
impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
@@ -119,7 +119,7 @@ impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
}
}
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
@@ -167,10 +167,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
}
}
pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
device: Arc<DeviceInner<C, T, B>>,
pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>,
opaque: C::Opaque,
) -> Peer<C, T, B> {
) -> Peer<E, C, T, B> {
let (out_tx, out_rx) = sync_channel(128);
let (in_tx, in_rx) = sync_channel(128);
@@ -215,7 +215,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
}
}
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false;
@@ -286,8 +286,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
pub fn recv_job(
&self,
src: B::Endpoint,
dec: Arc<DecryptionState<C, T, B>>,
src: E,
dec: Arc<DecryptionState<E, C, T, B>>,
mut msg: Vec<u8>,
) -> Option<JobParallel> {
let (tx, rx) = oneshot();
@@ -370,7 +370,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
}
}
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
/// Set the endpoint of the peer
///
/// # Arguments
@@ -381,9 +381,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
///
/// This API still permits support for the "sticky socket" behavior,
/// as sockets should be "unsticked" when manually updating the endpoint
pub fn set_endpoint(&self, address: SocketAddr) {
pub fn set_endpoint(&self, endpoint: E) {
debug!("peer.set_endpoint");
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
*self.state.endpoint.lock() = Some(endpoint);
}
/// Returns the current endpoint of the peer (for configuration)
@@ -591,11 +591,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
debug!("peer.send");
let inner = &self.state;
match inner.endpoint.lock().as_ref() {
Some(endpoint) => inner
.device
.bind
.send(msg, endpoint)
.map_err(|_| RouterError::SendError),
Some(endpoint) => inner.device
.outbound
.read()
.as_ref()
.ok_or(RouterError::SendError)
.and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ),
None => Err(RouterError::NoEndpoint),
}
}

View File

@@ -1,18 +1,18 @@
use std::error::Error;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
use std::sync::atomic::Ordering;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::time::{Duration, Instant};
use std::time::Duration;
use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
use super::super::types::bind::*;
use super::super::types::tun::*;
use super::super::types::*;
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
@@ -145,8 +145,9 @@ mod tests {
}
// create device
let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer);
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
@@ -174,8 +175,9 @@ mod tests {
init();
// create device
let router: Device<TestCallbacks, _, _> =
Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
@@ -315,12 +317,18 @@ mod tests {
];
for (stage, p1, p2) in tests.iter() {
// create matching devices
let (bind1, bind2) = dummy::PairBind::pair();
let router1: Device<TestCallbacks, _, _> =
Device::new(1, dummy::TunTest::new(), bind1.clone());
let router2: Device<TestCallbacks, _, _> =
Device::new(1, dummy::TunTest::new(), bind2.clone());
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
dummy::PairBind::pair();
// create matching device
let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap();
let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap();
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1);
let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
router2.set_outbound_writer(bind_writer2);
// prepare opaque values for tracing callbacks
@@ -339,7 +347,7 @@ mod tests {
let peer2 = router2.new_peer(opaq2.clone());
let mask: IpAddr = mask.parse().unwrap();
peer2.add_subnet(mask, *len);
peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
peer2.set_endpoint(dummy::UnitEndpoint::new());
if *stage {
// stage a packet which can be used for confirmation (in place of a keepalive)
@@ -372,7 +380,7 @@ mod tests {
// read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
let (len, from) = bind1.recv(&mut buf).unwrap();
let (len, from) = bind_reader1.read(&mut buf).unwrap();
buf.truncate(len);
router1.recv(from, buf).unwrap();
@@ -411,7 +419,7 @@ mod tests {
// receive ("across the internet") on the other end
let mut buf = vec![0u8; 2048];
let (len, from) = bind2.recv(&mut buf).unwrap();
let (len, from) = bind_reader2.read(&mut buf).unwrap();
buf.truncate(len);
router2.recv(from, buf).unwrap();

View File

@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;
use super::super::types::Endpoint;
pub trait Opaque: Send + Sync + 'static {}
impl<T> Opaque for T where T: Send + Sync + 'static {}

View File

@@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
use super::super::types::{Bind, Tun};
use super::super::types::{Endpoint, tun, bind};
use super::ip::*;
const SIZE_TAG: usize = 16;
@@ -38,18 +38,18 @@ pub struct JobBuffer {
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
#[allow(type_alias_bounds)]
pub type JobInbound<C, T, B: Bind> = (
Arc<DecryptionState<C, T, B>>,
B::Endpoint,
pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
Arc<DecryptionState<E, C, T, B>>,
E,
oneshot::Receiver<JobBuffer>,
);
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
fn check_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
peer: &Arc<PeerInner<C, T, B>>,
fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet[0] >> 4 {
@@ -93,10 +93,10 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>(
}
}
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
device: Arc<DeviceInner<C, T, B>>, // related device
peer: Arc<PeerInner<C, T, B>>, // related peer
receiver: Receiver<JobInbound<C, T, B>>,
pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>,
) {
loop {
// fetch job
@@ -153,7 +153,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
debug_assert!(inner_len <= length, "should be validated");
if inner_len <= length {
sent = match device.tun.write(&packet[..inner_len]) {
sent = match device.inbound.write(&packet[..inner_len]) {
Err(e) => {
debug!("failed to write inbound packet to TUN: {:?}", e);
false
@@ -176,9 +176,9 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
}
pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
device: Arc<DeviceInner<C, T, B>>, // related device
peer: Arc<PeerInner<C, T, B>>, // related peer
pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>,
) {
loop {
@@ -198,7 +198,9 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
if buf.okay {
// write to UDP bind
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
match device.bind.send(&buf.msg[..], dst) {
let send : &Option<B> = &*device.outbound.read();
if let Some(writer) = send.as_ref() {
match writer.write(&buf.msg[..], dst) {
Err(e) => {
debug!("failed to send outbound packet: {:?}", e);
false
@@ -207,6 +209,9 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
}
} else {
false
}
} else {
false
};
// trigger callback

View File

@@ -7,7 +7,7 @@ use hjul::{Runner, Timer};
use crate::constants::*;
use crate::router::Callbacks;
use crate::types::{Bind, Tun};
use crate::types::{tun, bind};
use crate::wireguard::{Peer, PeerInner};
pub struct Timers {
@@ -23,8 +23,8 @@ pub struct Timers {
impl Timers {
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
where
T: Tun,
B: Bind,
T: tun::Tun,
B: bind::Bind,
{
// create a timer instance for the provided peer
Timers {
@@ -103,7 +103,7 @@ impl Timers {
pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: Tun, B: Bind> Callbacks for Events<T, B> {
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<B>>;
fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {

View File

@@ -1,73 +1,28 @@
use super::Endpoint;
use std::error;
use std::error::Error;
/// Traits representing the "internet facing" end of the VPN.
///
/// In practice this is a UDP socket (but the router interface is agnostic).
/// Often these traits will be implemented on the same type.
pub trait Reader<E: Endpoint>: Send + Sync {
type Error: Error;
/// Bind interface provided to the router code
pub trait RouterBind: Send + Sync {
type Error: error::Error;
fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
}
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
type Error: Error;
fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
}
pub trait Bind: Send + Sync + 'static {
type Error: Error;
type Endpoint: Endpoint;
/// Receive a buffer on the bind
///
/// # Arguments
///
/// - `buf`, buffer for storing the packet. If the buffer is too short, the packet should just be truncated.
///
/// # Note
///
/// The size of the buffer is derieved from the MTU of the Tun device.
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
/* Until Rust gets type equality constraints these have to be generic */
type Writer: Writer<Self::Endpoint>;
type Reader: Reader<Self::Endpoint>;
/// Send a buffer to the endpoint
///
/// # Arguments
///
/// - `buf`, packet src buffer (in practice the body of a UDP datagram)
/// - `dst`, destination endpoint (in practice, src: (ip, port) + dst: (ip, port) for sticky sockets)
///
/// # Returns
///
/// The unit type or an error if transmission failed
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
}
/// Bind interface provided for configuration (setting / getting the port)
pub trait ConfigBind {
type Error: error::Error;
/// Return a new (unbound) instance of a configuration bind
fn new() -> Self;
/// Updates the port of the bind
///
/// # Arguments
///
/// - `port`, the new port to bind to. 0 means any available port.
///
/// # Returns
///
/// The unit type or an error, if binding fails
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
/// Returns the current port of the bind
fn get_port(&self) -> Option<u16>;
/// Set the mark (e.g. on Linus this is the fwmark) on the bind
///
/// # Arguments
///
/// - `mark`, the mark to set
///
/// # Note
///
/// The mark should be retained accross calls to `set_port`.
///
/// # Returns
///
/// The unit type or an error, if the operation fails due to permission errors
fn set_mark(&self, mark: u16) -> Result<(), Self::Error>;
/* Used to close the reader/writer when binding to a new port */
type Closer;
fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
}

View File

@@ -5,8 +5,9 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Instant;
use std::marker;
use super::{Bind, Endpoint, Key, KeyPair, Tun};
use super::*;
/* This submodule provides pure/dummy implementations of the IO interfaces
* for use in unit tests thoughout the project.
@@ -72,104 +73,103 @@ impl Endpoint for UnitEndpoint {
}
}
impl UnitEndpoint {
pub fn new() -> UnitEndpoint {
UnitEndpoint{}
}
}
/* */
#[derive(Clone, Copy)]
pub struct TunTest {}
impl Tun for TunTest {
impl tun::Reader for TunTest {
type Error = TunError;
fn mtu(&self) -> usize {
1500
}
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
Ok(0)
}
}
impl tun::MTU for TunTest {
fn mtu(&self) -> usize {
1500
}
}
impl tun::Writer for TunTest {
type Error = TunError;
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
Ok(())
}
}
impl tun::Tun for TunTest {
type Writer = TunTest;
type Reader = TunTest;
type MTU = TunTest;
type Error = TunError;
}
impl TunTest {
pub fn new() -> TunTest {
TunTest {}
pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> {
Ok((TunTest {},TunTest {}, TunTest{}))
}
}
/* Bind implemenentations */
/* Void Bind */
#[derive(Clone, Copy)]
pub struct VoidBind {}
impl Bind for VoidBind {
impl bind::Reader<UnitEndpoint> for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> VoidBind {
VoidBind {}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None
}
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
}
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
impl bind::Writer<UnitEndpoint> for VoidBind {
type Error = BindError;
fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Clone)]
pub struct PairBind {
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl PairBind {
pub fn pair() -> (PairBind, PairBind) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
PairBind {
send: Arc::new(Mutex::new(tx1)),
recv: Arc::new(Mutex::new(rx2)),
},
PairBind {
send: Arc::new(Mutex::new(tx2)),
recv: Arc::new(Mutex::new(rx1)),
},
)
}
}
impl Bind for PairBind {
impl bind::Bind for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> PairBind {
PairBind {
send: Arc::new(Mutex::new(sync_channel(0).0)),
recv: Arc::new(Mutex::new(sync_channel(0).1)),
type Reader = VoidBind;
type Writer = VoidBind;
type Closer = ();
fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
Ok((VoidBind{}, VoidBind{}, (), 2600))
}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
impl VoidBind {
pub fn new() -> VoidBind {
VoidBind{}
}
}
fn get_port(&self) -> Option<u16> {
None
/* Pair Bind */
#[derive(Clone)]
pub struct PairReader<E> {
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
type Error = BindError;
fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
let vec = self
.recv
.lock()
@@ -180,8 +180,11 @@ impl Bind for PairBind {
buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
}
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError;
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
@@ -190,6 +193,57 @@ impl Bind for PairBind {
}
}
#[derive(Clone)]
pub struct PairWriter<E> {
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
#[derive(Clone)]
pub struct PairBind {}
impl PairBind {
pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
(
PairReader{
recv: Arc::new(Mutex::new(rx1)),
_marker: marker::PhantomData
},
PairWriter{
send: Arc::new(Mutex::new(tx2)),
_marker: marker::PhantomData
}
),
(
PairReader{
recv: Arc::new(Mutex::new(rx2)),
_marker: marker::PhantomData
},
PairWriter{
send: Arc::new(Mutex::new(tx1)),
_marker: marker::PhantomData
}
),
)
}
}
impl bind::Bind for PairBind {
type Closer = ();
type Error = BindError;
type Endpoint = UnitEndpoint;
type Reader = PairReader<Self::Endpoint>;
type Writer = PairWriter<Self::Endpoint>;
fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
Err(BindError::Disconnected)
}
}
pub fn keypair(initiator: bool) -> KeyPair {
let k1 = Key {
key: [0x53u8; 32],

View File

@@ -1,6 +1,6 @@
use std::net::SocketAddr;
pub trait Endpoint: Send {
pub trait Endpoint: Send + 'static {
fn from_address(addr: SocketAddr) -> Self;
fn into_address(&self) -> SocketAddr;
}

View File

@@ -1,12 +1,10 @@
mod endpoint;
mod keys;
mod tun;
mod udp;
pub mod tun;
pub mod bind;
#[cfg(test)]
pub mod dummy;
pub use endpoint::Endpoint;
pub use keys::{Key, KeyPair};
pub use tun::Tun;
pub use udp::Bind;

View File

@@ -1,18 +1,22 @@
use std::error;
use std::error::Error;
pub trait Tun: Send + Sync + Clone + 'static {
type Error: error::Error;
pub trait Writer: Send + Sync + 'static {
type Error: Error;
/// Returns the MTU of the device
/// Receive a cryptkey routed IP packet
///
/// This function needs to be efficient (called for every read).
/// The goto implementation strategy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a separate thread.
/// # Arguments
///
/// - src: Buffer containing the IP packet to be written
///
/// # Returns
///
/// The MTU of the interface in bytes
fn mtu(&self) -> usize;
/// Unit type or an error
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
}
pub trait Reader: Send + 'static {
type Error: Error;
/// Reads an IP packet into dst[offset:] from the tunnel device
///
@@ -29,15 +33,24 @@ pub trait Tun: Send + Sync + Clone + 'static {
///
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
}
/// Writes an IP packet to the tunnel device
pub trait MTU: Send + Sync + Clone + 'static {
/// Returns the MTU of the device
///
/// # Arguments
///
/// - src: Buffer containing the IP packet to be written
/// This function needs to be efficient (called for every read).
/// The goto implementation strategy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///
/// Unit type or an error
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
/// The MTU of the interface in bytes
fn mtu(&self) -> usize;
}
pub trait Tun: Send + Sync + 'static {
type Writer: Writer;
type Reader: Reader;
type MTU: MTU;
type Error: Error;
}

View File

@@ -1,29 +0,0 @@
use super::Endpoint;
use std::error;
/* Often times an a file descriptor in an atomic might suffice.
*/
pub trait Bind: Send + Sync + Clone + 'static {
type Error: error::Error + Send;
type Endpoint: Endpoint;
fn new() -> Self;
/// Updates the port of the Bind
///
/// # Arguments
///
/// - port, The new port to bind to. 0 means any available port.
///
/// # Returns
///
/// The unit type or an error, if binding fails
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
/// Returns the current port of the bind
fn get_port(&self) -> Option<u16>;
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
}

View File

@@ -2,11 +2,13 @@ use crate::constants::*;
use crate::handshake;
use crate::router;
use crate::timers::{Events, Timers};
use crate::types::{Bind, Endpoint, Tun};
use crate::types::Endpoint;
use crate::types::tun::{Tun, Reader, MTU};
use crate::types::bind::{Bind, Writer};
use hjul::Runner;
use std::cmp;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
#[derive(Clone)]
pub struct Peer<T: Tun, B: Bind> {
pub router: Arc<router::Peer<Events<T, B>, T, B>>,
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<B>>,
}
impl <T : Tun, B : Bind> Clone for Peer<T, B > {
fn clone(&self) -> Peer<T, B> {
Peer{
router: self.router.clone(),
state: self.state.clone()
}
}
}
pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
@@ -66,20 +76,22 @@ pub enum HandshakeJob<E> {
}
struct WireguardInner<T: Tun, B: Bind> {
// provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread)
mtu: T::MTU,
send: RwLock<Option<B::Writer>>,
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptkey router
router: router::Device<Events<T, B>, T, B>,
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
// IO
bind: B,
}
pub struct Wireguard<T: Tun, B: Bind> {
@@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>,
}
/* Returns the padded length of a message:
*
* # Arguments
*
* - `size` : Size of unpadded message
* - `mtu` : Maximum transmission unit of the device
*
* # Returns
*
* The padded length (always less than or equal to the MTU)
*/
#[inline(always)]
const fn padding(size: usize, mtu: usize) -> usize {
#[inline(always)]
@@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
pub fn get_sk(&self) -> Option<StaticSecret> {
let mut handshake = self.state.handshake.read();
if handshake.active {
Some(handshake.device.get_sk())
} else {
None
}
}
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
@@ -137,113 +169,36 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer
}
pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
peers: RwLock::new(HashMap::new()),
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
bind: bind.clone(),
queue: Mutex::new(tx),
});
// start handshake workers
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
let bind = bind.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
let state = wg.handshake.read();
if !state.active {
continue;
}
match job {
HandshakeJob::Message(msg, src) => {
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
// process message
match state.device.process(
&mut rng,
&msg[..],
if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
pub fn new_bind(
reader: B::Reader,
writer: B::Writer,
closer: B::Closer
) {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
});
}
// update timers
if let Some(pk) = pk {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
// update endpoint (DISCUSS: right semantics?)
peer.router.set_endpoint(src_validate);
// drop existing closer
// swap IO thread for new reader
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
for id in peer.router.add_keypair(keypair) {
state.device.release(id);
}
}
}
}
}
Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
HandshakeJob::New(pk) => {
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
peer.router.send(&msg[..]);
peer.timers.read().handshake_sent();
}
}
}
}
});
}
// start UDP read IO thread
/*
{
let wg = wg.clone();
let tun = tun.clone();
let bind = bind.clone();
let mtu = mtu.clone();
thread::spawn(move || {
let mut last_under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
// create vector big enough for any message given current MTU
let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
// read UDP packet into vector
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
// message type de-multiplexer
@@ -276,19 +231,120 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
});
}
*/
}
pub fn new(
reader: T::Reader,
writer: T::Writer,
mtu: T::MTU,
) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),
send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
queue: Mutex::new(tx),
});
// start handshake workers
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
let state = wg.handshake.read();
if !state.active {
continue;
}
match job {
HandshakeJob::Message(msg, src) => {
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
// process message
match state.device.process(
&mut rng,
&msg[..],
if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
) {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let send : &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() {
let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
});
}
}
// update timers
if let Some(pk) = pk {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
// update endpoint
peer.router.set_endpoint(src);
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
for id in peer.router.add_keypair(keypair) {
state.device.release(id);
}
}
}
}
}
Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
HandshakeJob::New(pk) => {
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
peer.router.send(&msg[..]);
peer.timers.read().handshake_sent();
}
}
}
}
});
}
// start TUN read IO thread
{
let wg = wg.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
let mtu = tun.mtu();
let mtu = mtu.mtu();
let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
// read a new IP packet
let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// truncate padding