Enable adding TUN reader to WG interface

This commit is contained in:
Mathias Hall-Andersen
2019-10-11 12:57:24 +02:00
parent 7ce5415169
commit 3d6e8f08a7
7 changed files with 247 additions and 147 deletions

View File

@@ -15,25 +15,6 @@ mod types;
mod wireguard; mod wireguard;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests;
use crate::types::tun::Tun;
use crate::types::{bind, dummy, tun};
use crate::wireguard::Wireguard;
use std::thread;
use std::time::Duration;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[test]
fn test_pure_wireguard() {
init();
let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
thread::sleep(Duration::from_millis(500));
}
}
fn main() {} fn main() {}

View File

@@ -145,8 +145,8 @@ mod tests {
} }
// create device // create device
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> = let router: Device< _, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer); Device::new(num_cpus::get(), tun_writer);
// add new peer // add new peer
@@ -175,7 +175,7 @@ mod tests {
init(); init();
// create device // create device
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new()); router.set_outbound_writer(dummy::VoidBind::new());
@@ -321,8 +321,8 @@ mod tests {
dummy::PairBind::pair(); dummy::PairBind::pair();
// create matching device // create matching device
let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap(); let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap(); let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1); router1.set_outbound_writer(bind_writer1);

46
src/tests.rs Normal file
View File

@@ -0,0 +1,46 @@
use crate::types::tun::Tun;
use crate::types::{bind, dummy, tun};
use crate::wireguard::Wireguard;
use std::thread;
use std::time::Duration;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
/* Create and configure two matching pure instances of WireGuard
*
*/
#[test]
fn test_pure_wireguard() {
init();
// create WG instances for fake TUN devices
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
// create pair bind to connect the interfaces "over the internet"
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
wg1.set_writer(bind_writer1);
wg2.set_writer(bind_writer2);
wg1.add_reader(bind_reader1);
wg2.add_reader(bind_reader2);
// generate (public, pivate) key pairs
// configure cryptkey router
// create IP packets
thread::sleep(Duration::from_millis(500));
}

View File

@@ -20,9 +20,4 @@ pub trait Bind: Send + Sync + 'static {
/* Until Rust gets type equality constraints these have to be generic */ /* Until Rust gets type equality constraints these have to be generic */
type Writer: Writer<Self::Endpoint>; type Writer: Writer<Self::Endpoint>;
type Reader: Reader<Self::Endpoint>; type Reader: Reader<Self::Endpoint>;
/* Used to close the reader/writer when binding to a new port */
type Closer;
fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
} }

View File

@@ -1,11 +1,12 @@
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::marker;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Instant; use std::time::Instant;
use std::marker; use std::sync::atomic::{Ordering, AtomicUsize};
use super::*; use super::*;
@@ -41,7 +42,9 @@ impl fmt::Display for BindError {
/* TUN implementation */ /* TUN implementation */
#[derive(Debug)] #[derive(Debug)]
pub enum TunError {} pub enum TunError {
Disconnected
}
impl Error for TunError { impl Error for TunError {
fn description(&self) -> &str { fn description(&self) -> &str {
@@ -68,54 +71,111 @@ impl Endpoint for UnitEndpoint {
fn from_address(_: SocketAddr) -> UnitEndpoint { fn from_address(_: SocketAddr) -> UnitEndpoint {
UnitEndpoint {} UnitEndpoint {}
} }
fn into_address(&self) -> SocketAddr { fn into_address(&self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap() "127.0.0.1:8080".parse().unwrap()
} }
fn clear_src(&self) {}
} }
impl UnitEndpoint { impl UnitEndpoint {
pub fn new() -> UnitEndpoint { pub fn new() -> UnitEndpoint {
UnitEndpoint{} UnitEndpoint {}
} }
} }
/* */ /* */
#[derive(Clone, Copy)]
pub struct TunTest {} pub struct TunTest {}
impl tun::Reader for TunTest { pub struct TunFakeIO {
store: bool,
tx: SyncSender<Vec<u8>>,
rx: Receiver<Vec<u8>>
}
pub struct TunReader {
rx: Receiver<Vec<u8>>
}
pub struct TunWriter {
store: bool,
tx: Mutex<SyncSender<Vec<u8>>>
}
#[derive(Clone)]
pub struct TunMTU {
mtu: Arc<AtomicUsize>
}
impl tun::Reader for TunReader {
type Error = TunError; type Error = TunError;
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> { fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
Ok(0) match self.rx.recv() {
Ok(m) => {
buf[offset..].copy_from_slice(&m[..]);
Ok(m.len())
}
Err(_) => Err(TunError::Disconnected)
}
} }
} }
impl tun::MTU for TunTest { impl tun::Writer for TunWriter {
type Error = TunError;
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
if self.store {
let m = src.to_owned();
match self.tx.lock().unwrap().send(m) {
Ok(_) => Ok(()),
Err(_) => Err(TunError::Disconnected)
}
} else {
Ok(())
}
}
}
impl tun::MTU for TunMTU {
fn mtu(&self) -> usize { fn mtu(&self) -> usize {
1500 self.mtu.load(Ordering::Acquire)
}
}
impl tun::Writer for TunTest {
type Error = TunError;
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
Ok(())
} }
} }
impl tun::Tun for TunTest { impl tun::Tun for TunTest {
type Writer = TunTest; type Writer = TunWriter;
type Reader = TunTest; type Reader = TunReader;
type MTU = TunTest; type MTU = TunMTU;
type Error = TunError; type Error = TunError;
} }
impl TunFakeIO {
pub fn write(&self, msg : Vec<u8>) {
if self.store {
self.tx.send(msg).unwrap();
}
}
pub fn read(&self) -> Vec<u8> {
self.rx.recv().unwrap()
}
}
impl TunTest { impl TunTest {
pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> { pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
Ok((TunTest {},TunTest {}, TunTest{}))
let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
let fake = TunFakeIO{tx: tx1, rx: rx2, store};
let reader = TunReader{rx : rx1};
let writer = TunWriter{tx : Mutex::new(tx2), store};
let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
(fake, reader, writer, mtu)
} }
} }
@@ -146,16 +206,11 @@ impl bind::Bind for VoidBind {
type Reader = VoidBind; type Reader = VoidBind;
type Writer = VoidBind; type Writer = VoidBind;
type Closer = ();
fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
Ok((VoidBind{}, VoidBind{}, (), 2600))
}
} }
impl VoidBind { impl VoidBind {
pub fn new() -> VoidBind { pub fn new() -> VoidBind {
VoidBind{} VoidBind {}
} }
} }
@@ -203,45 +258,42 @@ pub struct PairWriter<E> {
pub struct PairBind {} pub struct PairBind {}
impl PairBind { impl PairBind {
pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) { pub fn pair<E>() -> (
(PairReader<E>, PairWriter<E>),
(PairReader<E>, PairWriter<E>),
) {
let (tx1, rx1) = sync_channel(128); let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128); let (tx2, rx2) = sync_channel(128);
( (
( (
PairReader{ PairReader {
recv: Arc::new(Mutex::new(rx1)),
recv: Arc::new(Mutex::new(rx1)), _marker: marker::PhantomData,
_marker: marker::PhantomData },
}, PairWriter {
PairWriter{
send: Arc::new(Mutex::new(tx2)), send: Arc::new(Mutex::new(tx2)),
_marker: marker::PhantomData _marker: marker::PhantomData,
} },
), ),
( (
PairReader{ PairReader {
recv: Arc::new(Mutex::new(rx2)), recv: Arc::new(Mutex::new(rx2)),
_marker: marker::PhantomData _marker: marker::PhantomData,
}, },
PairWriter{ PairWriter {
send: Arc::new(Mutex::new(tx1)), send: Arc::new(Mutex::new(tx1)),
_marker: marker::PhantomData _marker: marker::PhantomData,
} },
), ),
) )
} }
} }
impl bind::Bind for PairBind { impl bind::Bind for PairBind {
type Closer = ();
type Error = BindError; type Error = BindError;
type Endpoint = UnitEndpoint; type Endpoint = UnitEndpoint;
type Reader = PairReader<Self::Endpoint>; type Reader = PairReader<Self::Endpoint>;
type Writer = PairWriter<Self::Endpoint>; type Writer = PairWriter<Self::Endpoint>;
fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
Err(BindError::Disconnected)
}
} }
pub fn keypair(initiator: bool) -> KeyPair { pub fn keypair(initiator: bool) -> KeyPair {

View File

@@ -3,4 +3,5 @@ use std::net::SocketAddr;
pub trait Endpoint: Send + 'static { pub trait Endpoint: Send + 'static {
fn from_address(addr: SocketAddr) -> Self; fn from_address(addr: SocketAddr) -> Self;
fn into_address(&self) -> SocketAddr; fn into_address(&self) -> SocketAddr;
fn clear_src(&self);
} }

View File

@@ -3,8 +3,10 @@ use crate::handshake;
use crate::router; use crate::router;
use crate::timers::{Events, Timers}; use crate::timers::{Events, Timers};
use crate::types::bind::Reader as BindReader;
use crate::types::bind::{Bind, Writer}; use crate::types::bind::{Bind, Writer};
use crate::types::tun::{Reader, Tun, MTU}; use crate::types::tun::{Reader, Tun, MTU};
use crate::types::Endpoint; use crate::types::Endpoint;
use hjul::Runner; use hjul::Runner;
@@ -53,7 +55,7 @@ pub struct PeerInner<B: Bind> {
pub timers: RwLock<Timers>, // pub timers: RwLock<Timers>, //
} }
impl <B:Bind > PeerInner<B> { impl<B: Bind> PeerInner<B> {
#[inline(always)] #[inline(always)]
pub fn timers(&self) -> RwLockReadGuard<Timers> { pub fn timers(&self) -> RwLockReadGuard<Timers> {
self.timers.read() self.timers.read()
@@ -153,7 +155,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
pub fn get_sk(&self) -> Option<StaticSecret> { pub fn get_sk(&self) -> Option<StaticSecret> {
let mut handshake = self.state.handshake.read(); let handshake = self.state.handshake.read();
if handshake.active { if handshake.active {
Some(handshake.device.get_sk()) Some(handshake.device.get_sk())
} else { } else {
@@ -184,66 +186,73 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer peer
} }
pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) { /* Begin consuming messages from the reader.
*
* Any previous reader thread is stopped by closing the previous reader,
* which unblocks the thread and causes an error on reader.read
*/
pub fn add_reader(&self, reader: B::Reader) {
let wg = self.state.clone();
thread::spawn(move || {
let mut last_under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
// drop existing closer loop {
// create vector big enough for any message given current MTU
let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
// swap IO thread for new reader // read UDP packet into vector
let (size, src) = match reader.read(&mut msg) {
// start UDP read IO thread Err(e) => {
debug!("Bind reader closed with {}", e);
/* return;
{
let wg = wg.clone();
let mtu = mtu.clone();
thread::spawn(move || {
let mut last_under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
// create vector big enough for any message given current MTU
let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
// read UDP packet into vector
let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
// message type de-multiplexer
if msg.len() < std::mem::size_of::<u32>() {
continue;
} }
match LittleEndian::read_u32(&msg[..]) { Ok(v) => v,
handshake::TYPE_COOKIE_REPLY };
| handshake::TYPE_INITIATION msg.truncate(size);
| handshake::TYPE_RESPONSE => {
// update under_load flag
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
wg.under_load.store(false, Ordering::SeqCst);
}
wg.queue // message type de-multiplexer
.lock() if msg.len() < std::mem::size_of::<u32>() {
.send(HandshakeJob::Message(msg, src)) continue;
.unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
let _ = wg.router.recv(src, msg);
}
_ => (),
}
} }
}); match LittleEndian::read_u32(&msg[..]) {
} handshake::TYPE_COOKIE_REPLY
*/ | handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => {
// update under_load flag
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
wg.under_load.store(false, Ordering::SeqCst);
}
wg.queue
.lock()
.send(HandshakeJob::Message(msg, src))
.unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
let _ = wg.router.recv(src, msg).map_err(|e| {
debug!("Failed to handle incoming transport message: {}", e);
});
}
_ => (),
}
}
});
} }
pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone());
self.state.router.set_outbound_writer(writer);
}
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
// create device state // create device state
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -292,14 +301,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
None None
}, },
) { ) {
Ok((pk, msg, keypair)) => { Ok((pk, resp, keypair)) => {
// send response // send response
if let Some(msg) = msg { let mut resp_len: u64 = 0;
if let Some(msg) = resp {
resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read(); let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() { if let Some(writer) = send.as_ref() {
let _ = writer.write(&msg[..], &src).map_err(|e| { let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!( debug!(
"handshake worker, failed to send response, error = {:?}", "handshake worker, failed to send response, error = {}",
e e
) )
}); });
@@ -308,16 +319,23 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// update timers // update timers
if let Some(pk) = pk { if let Some(pk) = pk {
// authenticated handshake packet received
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
// add to rx_bytes and tx_bytes
let req_len = msg.len() as u64;
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
// update endpoint // update endpoint
peer.router.set_endpoint(src); peer.router.set_endpoint(src);
// add keypair to peer and free any unused ids // add keypair to peer
if let Some(keypair) = keypair { keypair.map(|kp| {
for id in peer.router.add_keypair(keypair) { // free any unused ids
for id in peer.router.add_keypair(kp) {
state.device.release(id); state.device.release(id);
} }
} });
} }
} }
} }
@@ -325,20 +343,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
} }
HandshakeJob::New(pk) => { HandshakeJob::New(pk) => {
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle let _ = state.device.begin(&mut rng, &pk).map(|msg| {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
peer.router.send(&msg[..]); let _ = peer.router.send(&msg[..]).map_err(|e| {
peer.timers.read().handshake_sent(); debug!("handshake worker, failed to send handshake initiation, error = {}", e)
} });
}
});
} }
} }
} }
}); });
} }
// start TUN read IO thread // start TUN read IO threads (multiple threads to support multi-queue interfaces)
{ debug_assert!(
readers.len() > 0,
"attempted to create WG device without TUN readers"
);
while let Some(reader) = readers.pop() {
let wg = wg.clone(); let wg = wg.clone();
let mtu = mtu.clone();
thread::spawn(move || loop { thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU) // create vector big enough for any transport message (based on MTU)
let mtu = mtu.mtu(); let mtu = mtu.mtu();