Restructure IO traits.
This commit is contained in:
183
src/config.rs
Normal file
183
src/config.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::error::Error;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use crate::wireguard::Wireguard;
|
||||
use crate::types::{Bind, Endpoint, Tun};
|
||||
|
||||
///
|
||||
/// The goal of the configuration interface is, among others,
|
||||
/// to hide the IO implementations (over which the WG device is generic),
|
||||
/// from the configuration and UAPI code.
|
||||
|
||||
/// Describes a snapshot of the state of a peer
|
||||
pub struct PeerState {
|
||||
rx_bytes: u64,
|
||||
tx_bytes: u64,
|
||||
last_handshake_time_sec: u64,
|
||||
last_handshake_time_nsec: u64,
|
||||
public_key: PublicKey,
|
||||
allowed_ips: Vec<(IpAddr, u32)>,
|
||||
}
|
||||
|
||||
pub enum ConfigError {
|
||||
NoSuchPeer
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
|
||||
fn errno(&self) -> i32 {
|
||||
match self {
|
||||
NoSuchPeer => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exposed configuration interface
|
||||
pub trait Configuration {
|
||||
/// Updates the private key of the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `sk`: The new private key (or None, if the private key should be cleared)
|
||||
fn set_private_key(&self, sk: Option<StaticSecret>);
|
||||
|
||||
/// Returns the private key of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The private if set, otherwise None.
|
||||
fn get_private_key(&self) -> Option<StaticSecret>;
|
||||
|
||||
/// Returns the protocol version of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer indicating the protocol version
|
||||
fn get_protocol_version(&self) -> usize;
|
||||
|
||||
fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
|
||||
|
||||
/// Set the firewall mark (or similar, depending on platform)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `mark`: The fwmark value
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if this operation is not supported by the underlying
|
||||
/// "bind" implementation.
|
||||
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
|
||||
|
||||
/// Removes all peers from the device
|
||||
fn replace_peers(&self);
|
||||
|
||||
/// Remove the peer from the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer to remove
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// If the peer does not exists this operation is a noop
|
||||
fn remove_peer(&self, peer: PublicKey);
|
||||
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer to add
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A bool indicating if the peer was added.
|
||||
///
|
||||
/// If the peer already exists this operation is a noop
|
||||
fn add_peer(&self, peer: PublicKey) -> bool;
|
||||
|
||||
/// Update the psk of a peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer
|
||||
/// - `psk`: The new psk or None if the psk should be unset
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if no such peer exists
|
||||
fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
|
||||
|
||||
/// Update the endpoint of the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
/// - `psk`
|
||||
fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
|
||||
|
||||
/// Update the endpoint of the
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
/// - `psk`
|
||||
fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
|
||||
|
||||
/// Remove all allowed IPs from the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer': The public key of the peer
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if no such peer exists
|
||||
fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
|
||||
|
||||
/// Add a new allowed subnet to the peer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `peer`: The public key of the peer
|
||||
/// - `ip`: Subnet mask
|
||||
/// - `masklen`:
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if the peer does not exist
|
||||
///
|
||||
/// # Note:
|
||||
///
|
||||
/// The API must itself sanitize the (ip, masklen) set:
|
||||
/// The ip should be masked to remove any set bits right of the first "masklen" bits.
|
||||
fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
|
||||
|
||||
/// Returns the state of all peers
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A list of structures describing the state of each peer
|
||||
fn get_peers(&self) -> Vec<PeerState>;
|
||||
}
|
||||
|
||||
impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
|
||||
|
||||
fn set_private_key(&self, sk : Option<StaticSecret>) {
|
||||
self.set_key(sk)
|
||||
}
|
||||
|
||||
fn get_private_key(&self) -> Option<StaticSecret> {
|
||||
self.get_sk()
|
||||
}
|
||||
|
||||
fn get_protocol_version(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -76,6 +76,15 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the secret key of the device
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A secret key (x25519 scalar)
|
||||
pub fn get_sk(&self) -> StaticSecret {
|
||||
StaticSecret::from(self.sk.to_bytes())
|
||||
}
|
||||
|
||||
/// Add a new public key to the state machine
|
||||
/// To remove public keys, you must create a new machine instance
|
||||
///
|
||||
|
||||
@@ -5,6 +5,7 @@ extern crate jemallocator;
|
||||
#[global_allocator]
|
||||
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
|
||||
|
||||
// mod config;
|
||||
mod constants;
|
||||
mod handshake;
|
||||
mod router;
|
||||
@@ -14,7 +15,8 @@ mod wireguard;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::types::{dummy, Bind};
|
||||
use crate::types::tun::Tun;
|
||||
use crate::types::{bind, dummy, tun};
|
||||
use crate::wireguard::Wireguard;
|
||||
|
||||
use std::thread;
|
||||
@@ -27,7 +29,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_pure_wireguard() {
|
||||
init();
|
||||
let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
|
||||
let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
|
||||
let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,21 +17,23 @@ use super::constants::*;
|
||||
use super::ip::*;
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::{new_peer, Peer, PeerInner};
|
||||
use super::types::{Callbacks, Opaque, RouterError};
|
||||
use super::types::{Callbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel, Operation};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::super::types::{Bind, KeyPair, Tun};
|
||||
use super::super::types::{KeyPair, Endpoint, bind, tun};
|
||||
|
||||
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
// IO & timer callbacks
|
||||
pub tun: T,
|
||||
pub bind: B,
|
||||
pub struct DeviceInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
// inbound writer (TUN)
|
||||
pub inbound: T,
|
||||
|
||||
// outbound writer (Bind)
|
||||
pub outbound: RwLock<Option<B>>,
|
||||
|
||||
// routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
|
||||
|
||||
// work queues
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
@@ -45,20 +47,20 @@ pub struct EncryptionState {
|
||||
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
|
||||
}
|
||||
|
||||
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub struct DecryptionState<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<C, T, B>>,
|
||||
pub peer: Arc<PeerInner<E, C, T, B>>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
|
||||
state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||
pub struct Device<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
debug!("router: dropping device");
|
||||
|
||||
@@ -83,10 +85,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: &Arc<DeviceInner<C, T, B>>,
|
||||
fn get_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: &Arc<DeviceInner<E, C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<Arc<PeerInner<C, T, B>>> {
|
||||
) -> Option<Arc<PeerInner<E, C, T, B>>> {
|
||||
// ensure version access within bounds
|
||||
if packet.len() < 1 {
|
||||
return None;
|
||||
@@ -122,12 +124,12 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
|
||||
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
|
||||
// allocate shared device state
|
||||
let mut inner = DeviceInner {
|
||||
tun,
|
||||
bind,
|
||||
inbound: tun,
|
||||
outbound: RwLock::new(None),
|
||||
queues: Mutex::new(Vec::with_capacity(num_workers)),
|
||||
queue_next: AtomicUsize::new(0),
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
@@ -159,7 +161,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
/// # Returns
|
||||
///
|
||||
/// A atomic ref. counted peer (with liftime matching the device)
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> {
|
||||
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
|
||||
new_peer(self.state.clone(), opaque)
|
||||
}
|
||||
|
||||
@@ -199,7 +201,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||
Some(v) => v,
|
||||
@@ -231,4 +233,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set outbound writer
|
||||
///
|
||||
///
|
||||
pub fn set_outbound_writer(&self, new : B) {
|
||||
*self.state.outbound.write() = Some(new);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::super::constants::*;
|
||||
use super::super::types::{Bind, Endpoint, KeyPair, Tun};
|
||||
use super::super::types::{Endpoint, KeyPair, bind, tun};
|
||||
|
||||
use super::anti_replay::AntiReplay;
|
||||
use super::device::DecryptionState;
|
||||
@@ -39,28 +39,28 @@ pub struct KeyWheel {
|
||||
retired: Vec<u32>, // retired ids
|
||||
}
|
||||
|
||||
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub device: Arc<DeviceInner<C, T, B>>,
|
||||
pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
pub device: Arc<DeviceInner<E, C, T, B>>,
|
||||
pub opaque: C::Opaque,
|
||||
pub outbound: Mutex<SyncSender<JobOutbound>>,
|
||||
pub inbound: Mutex<SyncSender<JobInbound<C, T, B>>>,
|
||||
pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
|
||||
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
|
||||
pub keys: Mutex<KeyWheel>,
|
||||
pub ekey: Mutex<Option<EncryptionState>>,
|
||||
pub endpoint: Mutex<Option<B::Endpoint>>,
|
||||
pub endpoint: Mutex<Option<E>>,
|
||||
}
|
||||
|
||||
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
||||
state: Arc<PeerInner<C, T, B>>,
|
||||
pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
|
||||
state: Arc<PeerInner<E, C, T, B>>,
|
||||
thread_outbound: Option<thread::JoinHandle<()>>,
|
||||
thread_inbound: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
|
||||
peer: &Arc<PeerInner<C, T, B>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||
callback: Box<dyn Fn(A, u32) -> E>,
|
||||
) -> Vec<E>
|
||||
fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
|
||||
callback: Box<dyn Fn(A, u32) -> R>,
|
||||
) -> Vec<R>
|
||||
where
|
||||
A: Address,
|
||||
{
|
||||
@@ -74,9 +74,9 @@ where
|
||||
res
|
||||
}
|
||||
|
||||
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
||||
peer: &Peer<C, T, B>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||
fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
peer: &Peer<E, C, T, B>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
|
||||
) {
|
||||
let mut m = table.write();
|
||||
|
||||
@@ -107,8 +107,8 @@ impl EncryptionState {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
|
||||
fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
|
||||
fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
|
||||
DecryptionState {
|
||||
confirmed: AtomicBool::new(keypair.initiator),
|
||||
keypair: keypair.clone(),
|
||||
@@ -119,7 +119,7 @@ impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
let peer = &self.state;
|
||||
|
||||
@@ -167,10 +167,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: Arc<DeviceInner<C, T, B>>,
|
||||
pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>,
|
||||
opaque: C::Opaque,
|
||||
) -> Peer<C, T, B> {
|
||||
) -> Peer<E, C, T, B> {
|
||||
let (out_tx, out_rx) = sync_channel(128);
|
||||
let (in_tx, in_rx) = sync_channel(128);
|
||||
|
||||
@@ -215,7 +215,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
|
||||
fn send_staged(&self) -> bool {
|
||||
debug!("peer.send_staged");
|
||||
let mut sent = false;
|
||||
@@ -286,8 +286,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||
|
||||
pub fn recv_job(
|
||||
&self,
|
||||
src: B::Endpoint,
|
||||
dec: Arc<DecryptionState<C, T, B>>,
|
||||
src: E,
|
||||
dec: Arc<DecryptionState<E, C, T, B>>,
|
||||
mut msg: Vec<u8>,
|
||||
) -> Option<JobParallel> {
|
||||
let (tx, rx) = oneshot();
|
||||
@@ -370,7 +370,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
|
||||
/// Set the endpoint of the peer
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -381,9 +381,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
///
|
||||
/// This API still permits support for the "sticky socket" behavior,
|
||||
/// as sockets should be "unsticked" when manually updating the endpoint
|
||||
pub fn set_endpoint(&self, address: SocketAddr) {
|
||||
pub fn set_endpoint(&self, endpoint: E) {
|
||||
debug!("peer.set_endpoint");
|
||||
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
|
||||
*self.state.endpoint.lock() = Some(endpoint);
|
||||
}
|
||||
|
||||
/// Returns the current endpoint of the peer (for configuration)
|
||||
@@ -591,11 +591,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
debug!("peer.send");
|
||||
let inner = &self.state;
|
||||
match inner.endpoint.lock().as_ref() {
|
||||
Some(endpoint) => inner
|
||||
.device
|
||||
.bind
|
||||
.send(msg, endpoint)
|
||||
.map_err(|_| RouterError::SendError),
|
||||
Some(endpoint) => inner.device
|
||||
.outbound
|
||||
.read()
|
||||
.as_ref()
|
||||
.ok_or(RouterError::SendError)
|
||||
.and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ),
|
||||
None => Err(RouterError::NoEndpoint),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::Duration;
|
||||
|
||||
use num_cpus;
|
||||
use pnet::packet::ipv4::MutableIpv4Packet;
|
||||
use pnet::packet::ipv6::MutableIpv6Packet;
|
||||
|
||||
use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
|
||||
use super::super::types::bind::*;
|
||||
use super::super::types::tun::*;
|
||||
use super::super::types::*;
|
||||
|
||||
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
|
||||
|
||||
extern crate test;
|
||||
@@ -145,8 +145,9 @@ mod tests {
|
||||
}
|
||||
|
||||
// create device
|
||||
let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
|
||||
Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
|
||||
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
|
||||
let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
|
||||
Device::new(num_cpus::get(), tun_writer);
|
||||
|
||||
// add new peer
|
||||
let opaque = Arc::new(AtomicUsize::new(0));
|
||||
@@ -174,8 +175,9 @@ mod tests {
|
||||
init();
|
||||
|
||||
// create device
|
||||
let router: Device<TestCallbacks, _, _> =
|
||||
Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
|
||||
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
|
||||
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
|
||||
router.set_outbound_writer(dummy::VoidBind::new());
|
||||
|
||||
let tests = vec![
|
||||
("192.168.1.0", 24, "192.168.1.20", true),
|
||||
@@ -315,12 +317,18 @@ mod tests {
|
||||
];
|
||||
|
||||
for (stage, p1, p2) in tests.iter() {
|
||||
// create matching devices
|
||||
let (bind1, bind2) = dummy::PairBind::pair();
|
||||
let router1: Device<TestCallbacks, _, _> =
|
||||
Device::new(1, dummy::TunTest::new(), bind1.clone());
|
||||
let router2: Device<TestCallbacks, _, _> =
|
||||
Device::new(1, dummy::TunTest::new(), bind2.clone());
|
||||
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
|
||||
dummy::PairBind::pair();
|
||||
|
||||
// create matching device
|
||||
let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap();
|
||||
let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap();
|
||||
|
||||
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
|
||||
router1.set_outbound_writer(bind_writer1);
|
||||
|
||||
let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
|
||||
router2.set_outbound_writer(bind_writer2);
|
||||
|
||||
// prepare opaque values for tracing callbacks
|
||||
|
||||
@@ -339,7 +347,7 @@ mod tests {
|
||||
let peer2 = router2.new_peer(opaq2.clone());
|
||||
let mask: IpAddr = mask.parse().unwrap();
|
||||
peer2.add_subnet(mask, *len);
|
||||
peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
|
||||
peer2.set_endpoint(dummy::UnitEndpoint::new());
|
||||
|
||||
if *stage {
|
||||
// stage a packet which can be used for confirmation (in place of a keepalive)
|
||||
@@ -372,7 +380,7 @@ mod tests {
|
||||
|
||||
// read confirming message received by the other end ("across the internet")
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let (len, from) = bind1.recv(&mut buf).unwrap();
|
||||
let (len, from) = bind_reader1.read(&mut buf).unwrap();
|
||||
buf.truncate(len);
|
||||
router1.recv(from, buf).unwrap();
|
||||
|
||||
@@ -411,7 +419,7 @@ mod tests {
|
||||
|
||||
// receive ("across the internet") on the other end
|
||||
let mut buf = vec![0u8; 2048];
|
||||
let (len, from) = bind2.recv(&mut buf).unwrap();
|
||||
let (len, from) = bind_reader2.read(&mut buf).unwrap();
|
||||
buf.truncate(len);
|
||||
router2.recv(from, buf).unwrap();
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use super::super::types::Endpoint;
|
||||
|
||||
pub trait Opaque: Send + Sync + 'static {}
|
||||
|
||||
impl<T> Opaque for T where T: Send + Sync + 'static {}
|
||||
|
||||
@@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::peer::PeerInner;
|
||||
use super::types::Callbacks;
|
||||
|
||||
use super::super::types::{Bind, Tun};
|
||||
use super::super::types::{Endpoint, tun, bind};
|
||||
use super::ip::*;
|
||||
|
||||
const SIZE_TAG: usize = 16;
|
||||
@@ -38,18 +38,18 @@ pub struct JobBuffer {
|
||||
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
|
||||
|
||||
#[allow(type_alias_bounds)]
|
||||
pub type JobInbound<C, T, B: Bind> = (
|
||||
Arc<DecryptionState<C, T, B>>,
|
||||
B::Endpoint,
|
||||
pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
|
||||
Arc<DecryptionState<E, C, T, B>>,
|
||||
E,
|
||||
oneshot::Receiver<JobBuffer>,
|
||||
);
|
||||
|
||||
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
|
||||
|
||||
#[inline(always)]
|
||||
fn check_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: &Arc<DeviceInner<C, T, B>>,
|
||||
peer: &Arc<PeerInner<C, T, B>>,
|
||||
fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: &Arc<DeviceInner<E, C, T, B>>,
|
||||
peer: &Arc<PeerInner<E, C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<usize> {
|
||||
match packet[0] >> 4 {
|
||||
@@ -93,10 +93,10 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: Arc<DeviceInner<C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<C, T, B>>, // related peer
|
||||
receiver: Receiver<JobInbound<C, T, B>>,
|
||||
pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<E, C, T, B>>, // related peer
|
||||
receiver: Receiver<JobInbound<E, C, T, B>>,
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
@@ -153,7 +153,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
|
||||
debug_assert!(inner_len <= length, "should be validated");
|
||||
if inner_len <= length {
|
||||
sent = match device.tun.write(&packet[..inner_len]) {
|
||||
sent = match device.inbound.write(&packet[..inner_len]) {
|
||||
Err(e) => {
|
||||
debug!("failed to write inbound packet to TUN: {:?}", e);
|
||||
false
|
||||
@@ -176,9 +176,9 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: Arc<DeviceInner<C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<C, T, B>>, // related peer
|
||||
pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
|
||||
device: Arc<DeviceInner<E, C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<E, C, T, B>>, // related peer
|
||||
receiver: Receiver<JobOutbound>,
|
||||
) {
|
||||
loop {
|
||||
@@ -198,12 +198,17 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
if buf.okay {
|
||||
// write to UDP bind
|
||||
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
|
||||
match device.bind.send(&buf.msg[..], dst) {
|
||||
Err(e) => {
|
||||
debug!("failed to send outbound packet: {:?}", e);
|
||||
false
|
||||
let send : &Option<B> = &*device.outbound.read();
|
||||
if let Some(writer) = send.as_ref() {
|
||||
match writer.write(&buf.msg[..], dst) {
|
||||
Err(e) => {
|
||||
debug!("failed to send outbound packet: {:?}", e);
|
||||
false
|
||||
}
|
||||
Ok(_) => true,
|
||||
}
|
||||
Ok(_) => true,
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
|
||||
@@ -7,7 +7,7 @@ use hjul::{Runner, Timer};
|
||||
|
||||
use crate::constants::*;
|
||||
use crate::router::Callbacks;
|
||||
use crate::types::{Bind, Tun};
|
||||
use crate::types::{tun, bind};
|
||||
use crate::wireguard::{Peer, PeerInner};
|
||||
|
||||
pub struct Timers {
|
||||
@@ -23,8 +23,8 @@ pub struct Timers {
|
||||
impl Timers {
|
||||
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
|
||||
where
|
||||
T: Tun,
|
||||
B: Bind,
|
||||
T: tun::Tun,
|
||||
B: bind::Bind,
|
||||
{
|
||||
// create a timer instance for the provided peer
|
||||
Timers {
|
||||
@@ -103,7 +103,7 @@ impl Timers {
|
||||
|
||||
pub struct Events<T, B>(PhantomData<(T, B)>);
|
||||
|
||||
impl<T: Tun, B: Bind> Callbacks for Events<T, B> {
|
||||
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
|
||||
type Opaque = Arc<PeerInner<B>>;
|
||||
|
||||
fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
|
||||
|
||||
@@ -1,73 +1,28 @@
|
||||
use super::Endpoint;
|
||||
use std::error;
|
||||
use std::error::Error;
|
||||
|
||||
/// Traits representing the "internet facing" end of the VPN.
|
||||
///
|
||||
/// In practice this is a UDP socket (but the router interface is agnostic).
|
||||
/// Often these traits will be implemented on the same type.
|
||||
pub trait Reader<E: Endpoint>: Send + Sync {
|
||||
type Error: Error;
|
||||
|
||||
/// Bind interface provided to the router code
|
||||
pub trait RouterBind: Send + Sync {
|
||||
type Error: error::Error;
|
||||
fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
|
||||
type Error: Error;
|
||||
|
||||
fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Bind: Send + Sync + 'static {
|
||||
type Error: Error;
|
||||
type Endpoint: Endpoint;
|
||||
|
||||
/// Receive a buffer on the bind
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `buf`, buffer for storing the packet. If the buffer is too short, the packet should just be truncated.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The size of the buffer is derieved from the MTU of the Tun device.
|
||||
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
|
||||
/* Until Rust gets type equality constraints these have to be generic */
|
||||
type Writer: Writer<Self::Endpoint>;
|
||||
type Reader: Reader<Self::Endpoint>;
|
||||
|
||||
/// Send a buffer to the endpoint
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `buf`, packet src buffer (in practice the body of a UDP datagram)
|
||||
/// - `dst`, destination endpoint (in practice, src: (ip, port) + dst: (ip, port) for sticky sockets)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The unit type or an error if transmission failed
|
||||
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
/// Bind interface provided for configuration (setting / getting the port)
|
||||
pub trait ConfigBind {
|
||||
type Error: error::Error;
|
||||
|
||||
/// Return a new (unbound) instance of a configuration bind
|
||||
fn new() -> Self;
|
||||
|
||||
/// Updates the port of the bind
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `port`, the new port to bind to. 0 means any available port.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The unit type or an error, if binding fails
|
||||
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns the current port of the bind
|
||||
fn get_port(&self) -> Option<u16>;
|
||||
|
||||
/// Set the mark (e.g. on Linus this is the fwmark) on the bind
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `mark`, the mark to set
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The mark should be retained accross calls to `set_port`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The unit type or an error, if the operation fails due to permission errors
|
||||
fn set_mark(&self, mark: u16) -> Result<(), Self::Error>;
|
||||
/* Used to close the reader/writer when binding to a new port */
|
||||
type Closer;
|
||||
|
||||
fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Instant;
|
||||
use std::marker;
|
||||
|
||||
use super::{Bind, Endpoint, Key, KeyPair, Tun};
|
||||
use super::*;
|
||||
|
||||
/* This submodule provides pure/dummy implementations of the IO interfaces
|
||||
* for use in unit tests thoughout the project.
|
||||
@@ -72,104 +73,103 @@ impl Endpoint for UnitEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnitEndpoint {
|
||||
pub fn new() -> UnitEndpoint {
|
||||
UnitEndpoint{}
|
||||
}
|
||||
}
|
||||
|
||||
/* */
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct TunTest {}
|
||||
|
||||
impl Tun for TunTest {
|
||||
impl tun::Reader for TunTest {
|
||||
type Error = TunError;
|
||||
|
||||
fn mtu(&self) -> usize {
|
||||
1500
|
||||
}
|
||||
|
||||
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::MTU for TunTest {
|
||||
fn mtu(&self) -> usize {
|
||||
1500
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Writer for TunTest {
|
||||
type Error = TunError;
|
||||
|
||||
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Tun for TunTest {
|
||||
type Writer = TunTest;
|
||||
type Reader = TunTest;
|
||||
type MTU = TunTest;
|
||||
type Error = TunError;
|
||||
}
|
||||
|
||||
impl TunTest {
|
||||
pub fn new() -> TunTest {
|
||||
TunTest {}
|
||||
pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> {
|
||||
Ok((TunTest {},TunTest {}, TunTest{}))
|
||||
}
|
||||
}
|
||||
|
||||
/* Bind implemenentations */
|
||||
/* Void Bind */
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct VoidBind {}
|
||||
|
||||
impl Bind for VoidBind {
|
||||
impl bind::Reader<UnitEndpoint> for VoidBind {
|
||||
type Error = BindError;
|
||||
type Endpoint = UnitEndpoint;
|
||||
|
||||
fn new() -> VoidBind {
|
||||
VoidBind {}
|
||||
}
|
||||
|
||||
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_port(&self) -> Option<u16> {
|
||||
None
|
||||
}
|
||||
|
||||
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
||||
fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
|
||||
Ok((0, UnitEndpoint {}))
|
||||
}
|
||||
}
|
||||
|
||||
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
||||
impl bind::Writer<UnitEndpoint> for VoidBind {
|
||||
type Error = BindError;
|
||||
|
||||
fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairBind {
|
||||
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
||||
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl PairBind {
|
||||
pub fn pair() -> (PairBind, PairBind) {
|
||||
let (tx1, rx1) = sync_channel(128);
|
||||
let (tx2, rx2) = sync_channel(128);
|
||||
(
|
||||
PairBind {
|
||||
send: Arc::new(Mutex::new(tx1)),
|
||||
recv: Arc::new(Mutex::new(rx2)),
|
||||
},
|
||||
PairBind {
|
||||
send: Arc::new(Mutex::new(tx2)),
|
||||
recv: Arc::new(Mutex::new(rx1)),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for PairBind {
|
||||
impl bind::Bind for VoidBind {
|
||||
type Error = BindError;
|
||||
type Endpoint = UnitEndpoint;
|
||||
|
||||
fn new() -> PairBind {
|
||||
PairBind {
|
||||
send: Arc::new(Mutex::new(sync_channel(0).0)),
|
||||
recv: Arc::new(Mutex::new(sync_channel(0).1)),
|
||||
}
|
||||
}
|
||||
type Reader = VoidBind;
|
||||
type Writer = VoidBind;
|
||||
type Closer = ();
|
||||
|
||||
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
|
||||
Ok((VoidBind{}, VoidBind{}, (), 2600))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_port(&self) -> Option<u16> {
|
||||
None
|
||||
impl VoidBind {
|
||||
pub fn new() -> VoidBind {
|
||||
VoidBind{}
|
||||
}
|
||||
}
|
||||
|
||||
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
||||
/* Pair Bind */
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairReader<E> {
|
||||
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
|
||||
type Error = BindError;
|
||||
fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
|
||||
let vec = self
|
||||
.recv
|
||||
.lock()
|
||||
@@ -180,8 +180,11 @@ impl Bind for PairBind {
|
||||
buf[..len].copy_from_slice(&vec[..]);
|
||||
Ok((vec.len(), UnitEndpoint {}))
|
||||
}
|
||||
}
|
||||
|
||||
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
||||
impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
|
||||
type Error = BindError;
|
||||
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
|
||||
let owned = buf.to_owned();
|
||||
match self.send.lock().unwrap().send(owned) {
|
||||
Err(_) => Err(BindError::Disconnected),
|
||||
@@ -190,6 +193,57 @@ impl Bind for PairBind {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairWriter<E> {
|
||||
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
||||
_marker: marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PairBind {}
|
||||
|
||||
impl PairBind {
|
||||
pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) {
|
||||
let (tx1, rx1) = sync_channel(128);
|
||||
let (tx2, rx2) = sync_channel(128);
|
||||
(
|
||||
(
|
||||
PairReader{
|
||||
|
||||
recv: Arc::new(Mutex::new(rx1)),
|
||||
_marker: marker::PhantomData
|
||||
},
|
||||
PairWriter{
|
||||
send: Arc::new(Mutex::new(tx2)),
|
||||
_marker: marker::PhantomData
|
||||
}
|
||||
),
|
||||
(
|
||||
PairReader{
|
||||
recv: Arc::new(Mutex::new(rx2)),
|
||||
_marker: marker::PhantomData
|
||||
},
|
||||
PairWriter{
|
||||
send: Arc::new(Mutex::new(tx1)),
|
||||
_marker: marker::PhantomData
|
||||
}
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl bind::Bind for PairBind {
|
||||
type Closer = ();
|
||||
type Error = BindError;
|
||||
type Endpoint = UnitEndpoint;
|
||||
type Reader = PairReader<Self::Endpoint>;
|
||||
type Writer = PairWriter<Self::Endpoint>;
|
||||
|
||||
fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
|
||||
Err(BindError::Disconnected)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keypair(initiator: bool) -> KeyPair {
|
||||
let k1 = Key {
|
||||
key: [0x53u8; 32],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub trait Endpoint: Send {
|
||||
pub trait Endpoint: Send + 'static {
|
||||
fn from_address(addr: SocketAddr) -> Self;
|
||||
fn into_address(&self) -> SocketAddr;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
mod endpoint;
|
||||
mod keys;
|
||||
mod tun;
|
||||
mod udp;
|
||||
pub mod tun;
|
||||
pub mod bind;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod dummy;
|
||||
|
||||
pub use endpoint::Endpoint;
|
||||
pub use keys::{Key, KeyPair};
|
||||
pub use tun::Tun;
|
||||
pub use udp::Bind;
|
||||
pub use keys::{Key, KeyPair};
|
||||
@@ -1,18 +1,22 @@
|
||||
use std::error;
|
||||
use std::error::Error;
|
||||
|
||||
pub trait Tun: Send + Sync + Clone + 'static {
|
||||
type Error: error::Error;
|
||||
pub trait Writer: Send + Sync + 'static {
|
||||
type Error: Error;
|
||||
|
||||
/// Returns the MTU of the device
|
||||
/// Receive a cryptkey routed IP packet
|
||||
///
|
||||
/// This function needs to be efficient (called for every read).
|
||||
/// The goto implementation strategy is to .load an atomic variable,
|
||||
/// then use e.g. netlink to update the variable in a separate thread.
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Buffer containing the IP packet to be written
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The MTU of the interface in bytes
|
||||
fn mtu(&self) -> usize;
|
||||
/// Unit type or an error
|
||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait Reader: Send + 'static {
|
||||
type Error: Error;
|
||||
|
||||
/// Reads an IP packet into dst[offset:] from the tunnel device
|
||||
///
|
||||
@@ -29,15 +33,24 @@ pub trait Tun: Send + Sync + Clone + 'static {
|
||||
///
|
||||
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
|
||||
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
/// Writes an IP packet to the tunnel device
|
||||
pub trait MTU: Send + Sync + Clone + 'static {
|
||||
/// Returns the MTU of the device
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Buffer containing the IP packet to be written
|
||||
/// This function needs to be efficient (called for every read).
|
||||
/// The goto implementation strategy is to .load an atomic variable,
|
||||
/// then use e.g. netlink to update the variable in a separate thread.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Unit type or an error
|
||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
|
||||
/// The MTU of the interface in bytes
|
||||
fn mtu(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait Tun: Send + Sync + 'static {
|
||||
type Writer: Writer;
|
||||
type Reader: Reader;
|
||||
type MTU: MTU;
|
||||
type Error: Error;
|
||||
}
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
use super::Endpoint;
|
||||
use std::error;
|
||||
|
||||
/* Often times an a file descriptor in an atomic might suffice.
|
||||
*/
|
||||
pub trait Bind: Send + Sync + Clone + 'static {
|
||||
type Error: error::Error + Send;
|
||||
type Endpoint: Endpoint;
|
||||
|
||||
fn new() -> Self;
|
||||
|
||||
/// Updates the port of the Bind
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - port, The new port to bind to. 0 means any available port.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The unit type or an error, if binding fails
|
||||
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns the current port of the bind
|
||||
fn get_port(&self) -> Option<u16>;
|
||||
|
||||
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
|
||||
|
||||
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
|
||||
}
|
||||
254
src/wireguard.rs
254
src/wireguard.rs
@@ -2,11 +2,13 @@ use crate::constants::*;
|
||||
use crate::handshake;
|
||||
use crate::router;
|
||||
use crate::timers::{Events, Timers};
|
||||
use crate::types::{Bind, Endpoint, Tun};
|
||||
|
||||
use crate::types::Endpoint;
|
||||
use crate::types::tun::{Tun, Reader, MTU};
|
||||
use crate::types::bind::{Bind, Writer};
|
||||
|
||||
use hjul::Runner;
|
||||
|
||||
use std::cmp;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
@@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Peer<T: Tun, B: Bind> {
|
||||
pub router: Arc<router::Peer<Events<T, B>, T, B>>,
|
||||
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
||||
pub state: Arc<PeerInner<B>>,
|
||||
}
|
||||
|
||||
impl <T : Tun, B : Bind> Clone for Peer<T, B > {
|
||||
fn clone(&self) -> Peer<T, B> {
|
||||
Peer{
|
||||
router: self.router.clone(),
|
||||
state: self.state.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerInner<B: Bind> {
|
||||
pub keepalive: AtomicUsize, // keepalive interval
|
||||
pub rx_bytes: AtomicU64,
|
||||
@@ -66,20 +76,22 @@ pub enum HandshakeJob<E> {
|
||||
}
|
||||
|
||||
struct WireguardInner<T: Tun, B: Bind> {
|
||||
// provides access to the MTU value of the tun device
|
||||
// (otherwise owned solely by the router and a dedicated read IO thread)
|
||||
mtu: T::MTU,
|
||||
send: RwLock<Option<B::Writer>>,
|
||||
|
||||
// identify and configuration map
|
||||
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||
|
||||
// cryptkey router
|
||||
router: router::Device<Events<T, B>, T, B>,
|
||||
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
|
||||
|
||||
// handshake related state
|
||||
handshake: RwLock<Handshake>,
|
||||
under_load: AtomicBool,
|
||||
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
||||
|
||||
// IO
|
||||
bind: B,
|
||||
}
|
||||
|
||||
pub struct Wireguard<T: Tun, B: Bind> {
|
||||
@@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> {
|
||||
state: Arc<WireguardInner<T, B>>,
|
||||
}
|
||||
|
||||
/* Returns the padded length of a message:
|
||||
*
|
||||
* # Arguments
|
||||
*
|
||||
* - `size` : Size of unpadded message
|
||||
* - `mtu` : Maximum transmission unit of the device
|
||||
*
|
||||
* # Returns
|
||||
*
|
||||
* The padded length (always less than or equal to the MTU)
|
||||
*/
|
||||
#[inline(always)]
|
||||
const fn padding(size: usize, mtu: usize) -> usize {
|
||||
#[inline(always)]
|
||||
@@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_sk(&self) -> Option<StaticSecret> {
|
||||
let mut handshake = self.state.handshake.read();
|
||||
if handshake.active {
|
||||
Some(handshake.device.get_sk())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
|
||||
let state = Arc::new(PeerInner {
|
||||
pk,
|
||||
@@ -137,113 +169,36 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
peer
|
||||
}
|
||||
|
||||
pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
||||
// create device state
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let wg = Arc::new(WireguardInner {
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
|
||||
pending: AtomicUsize::new(0),
|
||||
handshake: RwLock::new(Handshake {
|
||||
device: handshake::Device::new(StaticSecret::new(&mut rng)),
|
||||
active: false,
|
||||
}),
|
||||
under_load: AtomicBool::new(false),
|
||||
bind: bind.clone(),
|
||||
queue: Mutex::new(tx),
|
||||
});
|
||||
pub fn new_bind(
|
||||
reader: B::Reader,
|
||||
writer: B::Writer,
|
||||
closer: B::Closer
|
||||
) {
|
||||
|
||||
// start handshake workers
|
||||
for _ in 0..num_cpus::get() {
|
||||
let wg = wg.clone();
|
||||
let rx = rx.clone();
|
||||
let bind = bind.clone();
|
||||
thread::spawn(move || {
|
||||
// prepare OsRng instance for this thread
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
// drop existing closer
|
||||
|
||||
// process elements from the handshake queue
|
||||
for job in rx {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let state = wg.handshake.read();
|
||||
if !state.active {
|
||||
continue;
|
||||
}
|
||||
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
// swap IO thread for new reader
|
||||
|
||||
// process message
|
||||
match state.device.process(
|
||||
&mut rng,
|
||||
&msg[..],
|
||||
if wg.under_load.load(Ordering::Relaxed) {
|
||||
Some(&src_validate)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
) {
|
||||
Ok((pk, msg, keypair)) => {
|
||||
// send response
|
||||
if let Some(msg) = msg {
|
||||
let _ = bind.send(&msg[..], &src).map_err(|e| {
|
||||
debug!(
|
||||
"handshake worker, failed to send response, error = {:?}",
|
||||
e
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
// update timers
|
||||
if let Some(pk) = pk {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
// update endpoint (DISCUSS: right semantics?)
|
||||
peer.router.set_endpoint(src_validate);
|
||||
|
||||
// add keypair to peer and free any unused ids
|
||||
if let Some(keypair) = keypair {
|
||||
for id in peer.router.add_keypair(keypair) {
|
||||
state.device.release(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||
}
|
||||
}
|
||||
HandshakeJob::New(pk) => {
|
||||
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
peer.router.send(&msg[..]);
|
||||
peer.timers.read().handshake_sent();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// start UDP read IO thread
|
||||
|
||||
/*
|
||||
{
|
||||
let wg = wg.clone();
|
||||
let tun = tun.clone();
|
||||
let bind = bind.clone();
|
||||
let mtu = mtu.clone();
|
||||
thread::spawn(move || {
|
||||
let mut last_under_load =
|
||||
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
||||
|
||||
loop {
|
||||
// create vector big enough for any message given current MTU
|
||||
let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||
let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
||||
msg.resize(size, 0);
|
||||
|
||||
// read UDP packet into vector
|
||||
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
||||
let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
|
||||
msg.truncate(size);
|
||||
|
||||
// message type de-multiplexer
|
||||
@@ -276,19 +231,120 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
}
|
||||
});
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
reader: T::Reader,
|
||||
writer: T::Writer,
|
||||
mtu: T::MTU,
|
||||
) -> Wireguard<T, B> {
|
||||
// create device state
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let wg = Arc::new(WireguardInner {
|
||||
mtu: mtu.clone(),
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
send: RwLock::new(None),
|
||||
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
||||
pending: AtomicUsize::new(0),
|
||||
handshake: RwLock::new(Handshake {
|
||||
device: handshake::Device::new(StaticSecret::new(&mut rng)),
|
||||
active: false,
|
||||
}),
|
||||
under_load: AtomicBool::new(false),
|
||||
queue: Mutex::new(tx),
|
||||
});
|
||||
|
||||
// start handshake workers
|
||||
for _ in 0..num_cpus::get() {
|
||||
let wg = wg.clone();
|
||||
let rx = rx.clone();
|
||||
thread::spawn(move || {
|
||||
// prepare OsRng instance for this thread
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
|
||||
// process elements from the handshake queue
|
||||
for job in rx {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
let state = wg.handshake.read();
|
||||
if !state.active {
|
||||
continue;
|
||||
}
|
||||
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
|
||||
// process message
|
||||
match state.device.process(
|
||||
&mut rng,
|
||||
&msg[..],
|
||||
if wg.under_load.load(Ordering::Relaxed) {
|
||||
Some(&src_validate)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
) {
|
||||
Ok((pk, msg, keypair)) => {
|
||||
// send response
|
||||
if let Some(msg) = msg {
|
||||
let send : &Option<B::Writer> = &*wg.send.read();
|
||||
if let Some(writer) = send.as_ref() {
|
||||
let _ = writer.write(&msg[..], &src).map_err(|e| {
|
||||
debug!(
|
||||
"handshake worker, failed to send response, error = {:?}",
|
||||
e
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// update timers
|
||||
if let Some(pk) = pk {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
// update endpoint
|
||||
peer.router.set_endpoint(src);
|
||||
|
||||
// add keypair to peer and free any unused ids
|
||||
if let Some(keypair) = keypair {
|
||||
for id in peer.router.add_keypair(keypair) {
|
||||
state.device.release(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||
}
|
||||
}
|
||||
HandshakeJob::New(pk) => {
|
||||
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
peer.router.send(&msg[..]);
|
||||
peer.timers.read().handshake_sent();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// start TUN read IO thread
|
||||
{
|
||||
let wg = wg.clone();
|
||||
thread::spawn(move || loop {
|
||||
// create vector big enough for any transport message (based on MTU)
|
||||
let mtu = tun.mtu();
|
||||
let mtu = mtu.mtu();
|
||||
let size = mtu + router::SIZE_MESSAGE_PREFIX;
|
||||
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
||||
msg.resize(size, 0);
|
||||
|
||||
// read a new IP packet
|
||||
let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||
let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
|
||||
|
||||
// truncate padding
|
||||
|
||||
Reference in New Issue
Block a user