Make IO traits suitable for Tun events (up/down)

This commit is contained in:
Mathias Hall-Andersen
2019-11-24 18:41:43 +01:00
parent dee23969f5
commit 3bff078e3f
20 changed files with 186 additions and 126 deletions

View File

@@ -4,8 +4,8 @@ use std::sync::atomic::Ordering;
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
use super::udp::Owner;
use super::*; use super::*;
use bind::Owner;
/// The goal of the configuration interface is, among others, /// The goal of the configuration interface is, among others,
/// to hide the IO implementations (over which the WG device is generic), /// to hide the IO implementations (over which the WG device is generic),
@@ -26,13 +26,13 @@ pub struct PeerState {
pub preshared_key: [u8; 32], // 0^32 is the "default value" pub preshared_key: [u8; 32], // 0^32 is the "default value"
} }
pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> { pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP> {
wireguard: Wireguard<T, B>, wireguard: Wireguard<T, B>,
fwmark: Mutex<Option<u32>>, fwmark: Mutex<Option<u32>>,
network: Mutex<Option<B::Owner>>, network: Mutex<Option<B::Owner>>,
} }
impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> { impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig { WireguardConfig {
wireguard: wg, wireguard: wg,
@@ -170,7 +170,7 @@ pub trait Configuration {
fn get_fwmark(&self) -> Option<u32>; fn get_fwmark(&self) -> Option<u32>;
} }
impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B> { impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
fn get_fwmark(&self) -> Option<u32> { fn get_fwmark(&self) -> Option<u32> {
self.network self.network
.lock() .lock()

View File

@@ -3,7 +3,7 @@ mod error;
pub mod uapi; pub mod uapi;
use super::platform::Endpoint; use super::platform::Endpoint;
use super::platform::{bind, tun}; use super::platform::{tun, udp};
use super::wireguard::Wireguard; use super::wireguard::Wireguard;
pub use error::ConfigError; pub use error::ConfigError;

View File

@@ -6,6 +6,7 @@ use log;
use daemonize::Daemonize; use daemonize::Daemonize;
use std::env; use std::env;
use std::process::exit; use std::process::exit;
use std::thread;
mod configuration; mod configuration;
mod platform; mod platform;
@@ -52,7 +53,7 @@ fn main() {
}); });
// create TUN device // create TUN device
let (readers, writer, mtu) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { let (readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| {
eprintln!("Failed to create TUN device: {}", e); eprintln!("Failed to create TUN device: {}", e);
exit(-3); exit(-3);
}); });
@@ -78,8 +79,26 @@ fn main() {
if drop_privileges {} if drop_privileges {}
// create WireGuard device // create WireGuard device
let wg: wireguard::Wireguard<plt::Tun, plt::Bind> = let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer);
wireguard::Wireguard::new(readers, writer, mtu);
wg.set_mtu(1420);
// start Tun event thread
/*
{
let wg = wg.clone();
let mut status = status;
thread::spawn(move || loop {
match status.event() {
Err(_) => break,
Ok(tun::TunEvent::Up(mtu)) => {
wg.mtu.store(mtu, Ordering::Relaxed);
}
Ok(tun::TunEvent::Down) => {}
}
});
}
*/
// handle TUN updates up/down // handle TUN updates up/down

View File

@@ -11,7 +11,7 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use super::super::bind::*; use super::super::udp::*;
use super::UnitEndpoint; use super::UnitEndpoint;
@@ -82,7 +82,7 @@ impl Writer<UnitEndpoint> for VoidBind {
} }
} }
impl Bind for VoidBind { impl UDP for VoidBind {
type Error = BindError; type Error = BindError;
type Endpoint = UnitEndpoint; type Endpoint = UnitEndpoint;
@@ -193,7 +193,7 @@ impl PairBind {
} }
} }
impl Bind for PairBind { impl UDP for PairBind {
type Error = BindError; type Error = BindError;
type Endpoint = UnitEndpoint; type Endpoint = UnitEndpoint;
type Reader = PairReader<Self::Endpoint>; type Reader = PairReader<Self::Endpoint>;
@@ -216,7 +216,7 @@ impl Owner for VoidOwner {
} }
} }
impl PlatformBind for PairBind { impl PlatformUDP for PairBind {
type Owner = VoidOwner; type Owner = VoidOwner;
fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
Err(BindError::Disconnected) Err(BindError::Disconnected)

View File

@@ -10,6 +10,8 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use std::thread;
use std::time::Duration;
use super::super::tun::*; use super::super::tun::*;
@@ -83,9 +85,8 @@ pub struct TunWriter {
tx: Mutex<SyncSender<Vec<u8>>>, tx: Mutex<SyncSender<Vec<u8>>>,
} }
#[derive(Clone)] pub struct TunStatus {
pub struct TunMTU { first: bool,
mtu: Arc<AtomicUsize>,
} }
impl Reader for TunReader { impl Reader for TunReader {
@@ -131,16 +132,25 @@ impl Writer for TunWriter {
} }
} }
impl MTU for TunMTU { impl Status for TunStatus {
fn mtu(&self) -> usize { type Error = TunError;
self.mtu.load(Ordering::Acquire)
fn event(&mut self) -> Result<TunEvent, Self::Error> {
if self.first {
self.first = false;
return Ok(TunEvent::Up(1420));
}
loop {
thread::sleep(Duration::from_secs(60 * 60));
}
} }
} }
impl Tun for TunTest { impl Tun for TunTest {
type Writer = TunWriter; type Writer = TunWriter;
type Reader = TunReader; type Reader = TunReader;
type MTU = TunMTU; type Status = TunStatus;
type Error = TunError; type Error = TunError;
} }
@@ -157,7 +167,7 @@ impl TunFakeIO {
} }
impl TunTest { impl TunTest {
pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) {
let (tx1, rx1) = if store { let (tx1, rx1) = if store {
sync_channel(32) sync_channel(32)
} else { } else {
@@ -184,16 +194,13 @@ impl TunTest {
tx: Mutex::new(tx2), tx: Mutex::new(tx2),
store, store,
}; };
let mtu = TunMTU { let status = TunStatus { first: true };
mtu: Arc::new(AtomicUsize::new(mtu)), (fake, reader, writer, status)
};
(fake, reader, writer, mtu)
} }
} }
impl PlatformTun for TunTest { impl PlatformTun for TunTest {
fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> {
Err(TunError::Disconnected) Err(TunError::Disconnected)
} }
} }

View File

@@ -4,4 +4,4 @@ mod udp;
pub use tun::LinuxTun as Tun; pub use tun::LinuxTun as Tun;
pub use uapi::LinuxUAPI as UAPI; pub use uapi::LinuxUAPI as UAPI;
pub use udp::LinuxBind as Bind; pub use udp::LinuxUDP as UDP;

View File

@@ -6,8 +6,8 @@ use std::error::Error;
use std::fmt; use std::fmt;
use std::os::raw::c_short; use std::os::raw::c_short;
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread;
use std::sync::Arc; use std::time::Duration;
const IFNAMSIZ: usize = 16; const IFNAMSIZ: usize = 16;
const TUNSETIFF: u64 = 0x4004_54ca; const TUNSETIFF: u64 = 0x4004_54ca;
@@ -30,7 +30,9 @@ struct Ifreq {
_pad: [u8; 64], _pad: [u8; 64],
} }
pub struct LinuxTun {} pub struct LinuxTun {
events: Vec<TunEvent>,
}
pub struct LinuxTunReader { pub struct LinuxTunReader {
fd: RawFd, fd: RawFd,
@@ -44,8 +46,8 @@ pub struct LinuxTunWriter {
* announcing an MTU update for the interface * announcing an MTU update for the interface
*/ */
#[derive(Clone)] #[derive(Clone)]
pub struct LinuxTunMTU { pub struct LinuxTunStatus {
value: Arc<AtomicUsize>, first: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -81,13 +83,6 @@ impl Error for LinuxTunError {
} }
} }
impl MTU for LinuxTunMTU {
#[inline(always)]
fn mtu(&self) -> usize {
self.value.load(Ordering::Relaxed)
}
}
impl Reader for LinuxTunReader { impl Reader for LinuxTunReader {
type Error = LinuxTunError; type Error = LinuxTunError;
@@ -118,15 +113,30 @@ impl Writer for LinuxTunWriter {
} }
} }
impl Status for LinuxTunStatus {
type Error = LinuxTunError;
fn event(&mut self) -> Result<TunEvent, Self::Error> {
if self.first {
self.first = false;
return Ok(TunEvent::Up(1420));
}
loop {
thread::sleep(Duration::from_secs(60 * 60));
}
}
}
impl Tun for LinuxTun { impl Tun for LinuxTun {
type Error = LinuxTunError; type Error = LinuxTunError;
type Reader = LinuxTunReader; type Reader = LinuxTunReader;
type Writer = LinuxTunWriter; type Writer = LinuxTunWriter;
type MTU = LinuxTunMTU; type Status = LinuxTunStatus;
} }
impl PlatformTun for LinuxTun { impl PlatformTun for LinuxTun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> {
// construct request struct // construct request struct
let mut req = Ifreq { let mut req = Ifreq {
name: [0u8; libc::IFNAMSIZ], name: [0u8; libc::IFNAMSIZ],
@@ -157,9 +167,7 @@ impl PlatformTun for LinuxTun {
Ok(( Ok((
vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
LinuxTunWriter { fd }, LinuxTunWriter { fd },
LinuxTunMTU { LinuxTunStatus { first: true },
value: Arc::new(AtomicUsize::new(1500)), // TODO: fetch and update
},
)) ))
} }
} }

View File

@@ -1,4 +1,4 @@
use super::super::bind::*; use super::super::udp::*;
use super::super::Endpoint; use super::super::Endpoint;
use std::io; use std::io;
@@ -6,7 +6,7 @@ use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc; use std::sync::Arc;
#[derive(Clone)] #[derive(Clone)]
pub struct LinuxBind(Arc<UdpSocket>); pub struct LinuxUDP(Arc<UdpSocket>);
pub struct LinuxOwner(Arc<UdpSocket>); pub struct LinuxOwner(Arc<UdpSocket>);
@@ -22,7 +22,7 @@ impl Endpoint for SocketAddr {
} }
} }
impl Reader<SocketAddr> for LinuxBind { impl Reader<SocketAddr> for LinuxUDP {
type Error = io::Error; type Error = io::Error;
fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
@@ -30,7 +30,7 @@ impl Reader<SocketAddr> for LinuxBind {
} }
} }
impl Writer<SocketAddr> for LinuxBind { impl Writer<SocketAddr> for LinuxUDP {
type Error = io::Error; type Error = io::Error;
fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> {
@@ -56,17 +56,19 @@ impl Owner for LinuxOwner {
} }
impl Drop for LinuxOwner { impl Drop for LinuxOwner {
fn drop(&mut self) {} fn drop(&mut self) {
// TODO: close udp bind
}
} }
impl Bind for LinuxBind { impl UDP for LinuxUDP {
type Error = io::Error; type Error = io::Error;
type Endpoint = SocketAddr; type Endpoint = SocketAddr;
type Reader = LinuxBind; type Reader = Self;
type Writer = LinuxBind; type Writer = Self;
} }
impl PlatformBind for LinuxBind { impl PlatformUDP for LinuxUDP {
type Owner = LinuxOwner; type Owner = LinuxOwner;
fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
@@ -74,8 +76,8 @@ impl PlatformBind for LinuxBind {
let socket = Arc::new(socket); let socket = Arc::new(socket);
Ok(( Ok((
vec![LinuxBind(socket.clone())], vec![LinuxUDP(socket.clone())],
LinuxBind(socket.clone()), LinuxUDP(socket.clone()),
LinuxOwner(socket), LinuxOwner(socket),
)) ))
} }

View File

@@ -1,8 +1,8 @@
mod endpoint; mod endpoint;
pub mod bind;
pub mod tun; pub mod tun;
pub mod uapi; pub mod uapi;
pub mod udp;
pub use endpoint::Endpoint; pub use endpoint::Endpoint;

View File

@@ -1,5 +1,18 @@
use std::error::Error; use std::error::Error;
pub enum TunEvent {
Up(usize), // interface is up (supply MTU)
Down, // interface is down
}
pub trait Status: Send + 'static {
type Error: Error;
/// Returns status updates for the interface
/// When the status is unchanged the method blocks
fn event(&mut self) -> Result<TunEvent, Self::Error>;
}
pub trait Writer: Send + Sync + 'static { pub trait Writer: Send + Sync + 'static {
type Error: Error; type Error: Error;
@@ -35,27 +48,14 @@ pub trait Reader: Send + 'static {
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>; fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
} }
pub trait MTU: Send + Sync + Clone + 'static {
/// Returns the MTU of the device
///
/// This function needs to be efficient (called for every read).
/// The goto implementation strategy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///
/// The MTU of the interface in bytes
fn mtu(&self) -> usize;
}
pub trait Tun: Send + Sync + 'static { pub trait Tun: Send + Sync + 'static {
type Writer: Writer; type Writer: Writer;
type Reader: Reader; type Reader: Reader;
type MTU: MTU; type Status: Status;
type Error: Error; type Error: Error;
} }
/// On some platforms the application can create the TUN device itself. /// On some platforms the application can create the TUN device itself.
pub trait PlatformTun: Tun { pub trait PlatformTun: Tun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>; fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error>;
} }

View File

@@ -13,7 +13,7 @@ pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
} }
pub trait Bind: Send + Sync + 'static { pub trait UDP: Send + Sync + 'static {
type Error: Error; type Error: Error;
type Endpoint: Endpoint; type Endpoint: Endpoint;
@@ -37,7 +37,7 @@ pub trait Owner: Send {
/// On some platforms the application can itself bind to a socket. /// On some platforms the application can itself bind to a socket.
/// This enables configuration using the UAPI interface. /// This enables configuration using the UAPI interface.
pub trait PlatformBind: Bind { pub trait PlatformUDP: UDP {
type Owner: Owner; type Owner: Owner;
/// Bind to a new port, returning the reader/writer and /// Bind to a new port, returning the reader/writer and

View File

@@ -20,7 +20,7 @@ pub use types::dummy_keypair;
#[cfg(test)] #[cfg(test)]
use super::platform::dummy; use super::platform::dummy;
use super::platform::{bind, tun, Endpoint}; use super::platform::{tun, udp, Endpoint};
use peer::PeerInner; use peer::PeerInner;
use types::KeyPair; use types::KeyPair;
use wireguard::HandshakeJob; use wireguard::HandshakeJob;

View File

@@ -2,8 +2,8 @@ use super::router;
use super::timers::{Events, Timers}; use super::timers::{Events, Timers};
use super::HandshakeJob; use super::HandshakeJob;
use super::bind::Bind;
use super::tun::Tun; use super::tun::Tun;
use super::udp::UDP;
use super::wireguard::WireguardInner; use super::wireguard::WireguardInner;
use std::fmt; use std::fmt;
@@ -17,12 +17,12 @@ use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
use crossbeam_channel::Sender; use crossbeam_channel::Sender;
use x25519_dalek::PublicKey; use x25519_dalek::PublicKey;
pub struct Peer<T: Tun, B: Bind> { pub struct Peer<T: Tun, B: UDP> {
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<T, B>>, pub state: Arc<PeerInner<T, B>>,
} }
pub struct PeerInner<T: Tun, B: Bind> { pub struct PeerInner<T: Tun, B: UDP> {
// internal id (for logging) // internal id (for logging)
pub id: u64, pub id: u64,
@@ -44,7 +44,7 @@ pub struct PeerInner<T: Tun, B: Bind> {
pub timers: RwLock<Timers>, pub timers: RwLock<Timers>,
} }
impl<T: Tun, B: Bind> Clone for Peer<T, B> { impl<T: Tun, B: UDP> Clone for Peer<T, B> {
fn clone(&self) -> Peer<T, B> { fn clone(&self) -> Peer<T, B> {
Peer { Peer {
router: self.router.clone(), router: self.router.clone(),
@@ -53,7 +53,7 @@ impl<T: Tun, B: Bind> Clone for Peer<T, B> {
} }
} }
impl<T: Tun, B: Bind> PeerInner<T, B> { impl<T: Tun, B: UDP> PeerInner<T, B> {
#[inline(always)] #[inline(always)]
pub fn timers(&self) -> RwLockReadGuard<Timers> { pub fn timers(&self) -> RwLockReadGuard<Timers> {
self.timers.read() self.timers.read()
@@ -65,20 +65,20 @@ impl<T: Tun, B: Bind> PeerInner<T, B> {
} }
} }
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { impl<T: Tun, B: UDP> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer(id = {})", self.id) write!(f, "peer(id = {})", self.id)
} }
} }
impl<T: Tun, B: Bind> Deref for Peer<T, B> { impl<T: Tun, B: UDP> Deref for Peer<T, B> {
type Target = PeerInner<T, B>; type Target = PeerInner<T, B>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.state &self.state
} }
} }
impl<T: Tun, B: Bind> Peer<T, B> { impl<T: Tun, B: UDP> Peer<T, B> {
/// Bring the peer down. Causing: /// Bring the peer down. Causing:
/// ///
/// - Timers to be stopped and disabled. /// - Timers to be stopped and disabled.

View File

@@ -21,9 +21,9 @@ use super::SIZE_MESSAGE_PREFIX;
use super::route::RoutingTable; use super::route::RoutingTable;
use super::super::{bind, tun, Endpoint, KeyPair}; use super::super::{tun, udp, Endpoint, KeyPair};
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
// inbound writer (TUN) // inbound writer (TUN)
pub inbound: T, pub inbound: T,
@@ -45,7 +45,7 @@ pub struct EncryptionState {
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
} }
pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
pub keypair: Arc<KeyPair>, pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool, pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>, pub protector: Mutex<AntiReplay>,
@@ -53,12 +53,12 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::W
pub death: Instant, // time when the key can no longer be used for decryption pub death: Instant, // time when the key can no longer be used for decryption
} }
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
state: Arc<DeviceInner<E, C, T, B>>, // reference to device state state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers handles: Vec<thread::JoinHandle<()>>, // join handles for workers
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> {
fn drop(&mut self) { fn drop(&mut self) {
debug!("router: dropping device"); debug!("router: dropping device");
@@ -82,7 +82,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Dev
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state // allocate shared device state
let inner = DeviceInner { let inner = DeviceInner {

View File

@@ -12,7 +12,7 @@ use log::debug;
use spin::Mutex; use spin::Mutex;
use super::super::constants::*; use super::super::constants::*;
use super::super::{bind, tun, Endpoint, KeyPair}; use super::super::{tun, udp, Endpoint, KeyPair};
use super::anti_replay::AntiReplay; use super::anti_replay::AntiReplay;
use super::device::DecryptionState; use super::device::DecryptionState;
@@ -36,7 +36,7 @@ pub struct KeyWheel {
retired: Vec<u32>, // retired ids retired: Vec<u32>, // retired ids
} }
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
pub device: Arc<DeviceInner<E, C, T, B>>, pub device: Arc<DeviceInner<E, C, T, B>>,
pub opaque: C::Opaque, pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>, pub outbound: Mutex<SyncSender<JobOutbound>>,
@@ -47,13 +47,13 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<
pub endpoint: Mutex<Option<E>>, pub endpoint: Mutex<Option<E>>,
} }
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
state: Arc<PeerInner<E, C, T, B>>, state: Arc<PeerInner<E, C, T, B>>,
thread_outbound: Option<thread::JoinHandle<()>>, thread_outbound: Option<thread::JoinHandle<()>>,
thread_inbound: Option<thread::JoinHandle<()>>, thread_inbound: Option<thread::JoinHandle<()>>,
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Deref for Peer<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
type Target = Arc<PeerInner<E, C, T, B>>; type Target = Arc<PeerInner<E, C, T, B>>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@@ -71,7 +71,7 @@ impl EncryptionState {
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> {
fn new( fn new(
peer: &Arc<PeerInner<E, C, T, B>>, peer: &Arc<PeerInner<E, C, T, B>>,
keypair: &Arc<KeyPair>, keypair: &Arc<KeyPair>,
@@ -86,7 +86,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionSt
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> {
fn drop(&mut self) { fn drop(&mut self) {
let peer = &self.state; let peer = &self.state;
@@ -133,7 +133,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Pee
} }
} }
pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, device: Arc<DeviceInner<E, C, T, B>>,
opaque: C::Opaque, opaque: C::Opaque,
) -> Peer<E, C, T, B> { ) -> Peer<E, C, T, B> {
@@ -180,7 +180,7 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> {
/// Send a raw message to the peer (used for handshake messages) /// Send a raw message to the peer (used for handshake messages)
/// ///
/// # Arguments /// # Arguments
@@ -352,7 +352,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E,
} }
} }
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
/// Set the endpoint of the peer /// Set the endpoint of the peer
/// ///
/// # Arguments /// # Arguments

View File

@@ -7,10 +7,10 @@ use std::time::Duration;
use num_cpus; use num_cpus;
use super::super::bind::*;
use super::super::dummy; use super::super::dummy;
use super::super::dummy_keypair; use super::super::dummy_keypair;
use super::super::tests::make_packet_dst; use super::super::tests::make_packet_dst;
use super::super::udp::*;
use super::KeyPair; use super::KeyPair;
use super::SIZE_MESSAGE_PREFIX; use super::SIZE_MESSAGE_PREFIX;
use super::{Callbacks, Device}; use super::{Callbacks, Device};

View File

@@ -19,7 +19,7 @@ use super::types::Callbacks;
use super::REJECT_AFTER_MESSAGES; use super::REJECT_AFTER_MESSAGES;
use super::super::types::KeyPair; use super::super::types::KeyPair;
use super::super::{bind, tun, Endpoint}; use super::super::{tun, udp, Endpoint};
pub const SIZE_TAG: usize = 16; pub const SIZE_TAG: usize = 16;
@@ -40,7 +40,7 @@ pub enum JobParallel {
} }
#[allow(type_alias_bounds)] #[allow(type_alias_bounds)]
pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( pub type JobInbound<E, C, T, B: udp::Writer<E>> = (
Arc<DecryptionState<E, C, T, B>>, Arc<DecryptionState<E, C, T, B>>,
E, E,
oneshot::Receiver<Option<JobDecryption>>, oneshot::Receiver<Option<JobDecryption>>,
@@ -50,7 +50,7 @@ pub type JobOutbound = oneshot::Receiver<JobEncryption>;
/* TODO: Replace with run-queue /* TODO: Replace with run-queue
*/ */
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>, receiver: Receiver<JobInbound<E, C, T, B>>,
@@ -137,7 +137,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer
/* TODO: Replace with run-queue /* TODO: Replace with run-queue
*/ */
pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: Arc<PeerInner<E, C, T, B>>, peer: Arc<PeerInner<E, C, T, B>>,
receiver: Receiver<JobOutbound>, receiver: Receiver<JobOutbound>,
) { ) {

View File

@@ -1,5 +1,5 @@
use super::wireguard::Wireguard; use super::wireguard::Wireguard;
use super::{bind, dummy, tun}; use super::{dummy, tun, udp};
use std::net::IpAddr; use std::net::IpAddr;
use std::thread; use std::thread;
@@ -84,13 +84,17 @@ fn test_pure_wireguard() {
// create WG instances for dummy TUN devices // create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader1], tun_writer1, mtu1); Wireguard::new(vec![tun_reader1], tun_writer1);
let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true); wg1.set_mtu(1500);
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true);
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader2], tun_writer2, mtu2); Wireguard::new(vec![tun_reader2], tun_writer2);
wg2.set_mtu(1500);
// create pair bind to connect the interfaces "over the internet" // create pair bind to connect the interfaces "over the internet"

View File

@@ -9,7 +9,7 @@ use hjul::{Runner, Timer};
use super::constants::*; use super::constants::*;
use super::router::{message_data_len, Callbacks}; use super::router::{message_data_len, Callbacks};
use super::{Peer, PeerInner}; use super::{Peer, PeerInner};
use super::{bind, tun}; use super::{udp, tun};
use super::types::KeyPair; use super::types::KeyPair;
pub struct Timers { pub struct Timers {
@@ -35,7 +35,7 @@ impl Timers {
} }
} }
impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> { impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
pub fn get_keepalive_interval(&self) -> u64 { pub fn get_keepalive_interval(&self) -> u64 {
self.timers().keepalive_interval self.timers().keepalive_interval
@@ -224,7 +224,7 @@ impl Timers {
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
where where
T: tun::Tun, T: tun::Tun,
B: bind::Bind, B: udp::UDP,
{ {
// create a timer instance for the provided peer // create a timer instance for the provided peer
Timers { Timers {
@@ -335,7 +335,7 @@ impl Timers {
pub struct Events<T, B>(PhantomData<(T, B)>); pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> { impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<T, B>>; type Opaque = Arc<PeerInner<T, B>>;
/* Called after the router encrypts a transport message destined for the peer. /* Called after the router encrypts a transport message destined for the peer.

View File

@@ -4,9 +4,13 @@ use super::router;
use super::timers::{Events, Timers}; use super::timers::{Events, Timers};
use super::{Peer, PeerInner}; use super::{Peer, PeerInner};
use super::bind::Reader as BindReader; use super::tun;
use super::bind::{Bind, Writer}; use super::tun::Reader as TunReader;
use super::tun::{Reader, Tun, MTU};
use super::udp;
use super::udp::Reader as UDPReader;
use super::udp::Writer as UDPWriter;
use super::Endpoint; use super::Endpoint;
use hjul::Runner; use hjul::Runner;
@@ -34,13 +38,15 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub struct WireguardInner<T: Tun, B: Bind> { pub struct WireguardInner<T: tun::Tun, B: udp::UDP> {
// identifier (for logging) // identifier (for logging)
id: u32, id: u32,
start: Instant, start: Instant,
// current MTU
mtu: AtomicUsize,
// provides access to the MTU value of the tun device // provides access to the MTU value of the tun device
mtu: T::MTU,
send: RwLock<Option<B::Writer>>, send: RwLock<Option<B::Writer>>,
// identity and configuration map // identity and configuration map
@@ -56,7 +62,7 @@ pub struct WireguardInner<T: Tun, B: Bind> {
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
} }
impl<T: Tun, B: Bind> PeerInner<T, B> { impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
/* Queue a handshake request for the parallel workers /* Queue a handshake request for the parallel workers
* (if one does not already exist) * (if one does not already exist)
* *
@@ -87,20 +93,20 @@ pub enum HandshakeJob<E> {
New(PublicKey), New(PublicKey),
} }
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> { impl<T: tun::Tun, B: udp::UDP> fmt::Display for WireguardInner<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "wireguard({:x})", self.id) write!(f, "wireguard({:x})", self.id)
} }
} }
impl<T: Tun, B: Bind> Deref for Wireguard<T, B> { impl<T: tun::Tun, B: udp::UDP> Deref for Wireguard<T, B> {
type Target = Arc<WireguardInner<T, B>>; type Target = Arc<WireguardInner<T, B>>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.state &self.state
} }
} }
pub struct Wireguard<T: Tun, B: Bind> { pub struct Wireguard<T: tun::Tun, B: udp::UDP> {
runner: Runner, runner: Runner,
state: Arc<WireguardInner<T, B>>, state: Arc<WireguardInner<T, B>>,
} }
@@ -127,7 +133,7 @@ const fn padding(size: usize, mtu: usize) -> usize {
min(mtu, size + (pad - size % pad) % pad) min(mtu, size + (pad - size % pad) % pad)
} }
impl<T: Tun, B: Bind> Wireguard<T, B> { impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
/// Brings the WireGuard device down. /// Brings the WireGuard device down.
/// Usually called when the associated interface is brought down. /// Usually called when the associated interface is brought down.
/// ///
@@ -269,7 +275,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
loop { loop {
// create vector big enough for any message given current MTU // create vector big enough for any message given current MTU
let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; let mtu = wg.mtu.load(Ordering::Relaxed);
let size = mtu + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size); let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0); msg.resize(size, 0);
@@ -283,6 +290,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}; };
msg.truncate(size); msg.truncate(size);
// TODO: start device down
if mtu == 0 {
continue;
}
// message type de-multiplexer // message type de-multiplexer
if msg.len() < std::mem::size_of::<u32>() { if msg.len() < std::mem::size_of::<u32>() {
continue; continue;
@@ -326,13 +338,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}); });
} }
pub fn set_mtu(&self, mtu: usize) {
self.mtu.store(mtu, Ordering::Relaxed);
}
pub fn set_writer(&self, writer: B::Writer) { pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer // TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone()); *self.state.send.write() = Some(writer.clone());
self.state.router.set_outbound_writer(writer); self.state.router.set_outbound_writer(writer);
} }
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> {
// create device state // create device state
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
@@ -342,7 +358,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let wg = Arc::new(WireguardInner { let wg = Arc::new(WireguardInner {
start: Instant::now(), start: Instant::now(),
id: rng.gen(), id: rng.gen(),
mtu: mtu.clone(), mtu: AtomicUsize::new(0),
peers: RwLock::new(HashMap::new()), peers: RwLock::new(HashMap::new()),
send: RwLock::new(None), send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
@@ -475,10 +491,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
); );
while let Some(reader) = readers.pop() { while let Some(reader) = readers.pop() {
let wg = wg.clone(); let wg = wg.clone();
let mtu = mtu.clone();
thread::spawn(move || loop { thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU) // create vector big enough for any transport message (based on MTU)
let mtu = mtu.mtu(); let mtu = wg.mtu.load(Ordering::Relaxed);
let size = mtu + router::SIZE_MESSAGE_PREFIX; let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0); msg.resize(size, 0);
@@ -493,6 +508,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}; };
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// TODO: start device down
if mtu == 0 {
continue;
}
// truncate padding // truncate padding
let padded = padding(payload, mtu); let padded = padding(payload, mtu);
log::trace!( log::trace!(