Restructure dummy implementations

This commit is contained in:
Mathias Hall-Andersen
2019-10-06 13:33:15 +02:00
parent edfd2f235a
commit c82d3e554b
8 changed files with 320 additions and 230 deletions

View File

@@ -16,3 +16,5 @@ pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
pub const TIMERS_TICK: Duration = Duration::from_millis(100); pub const TIMERS_TICK: Duration = Duration::from_millis(100);
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
pub const TIMERS_CAPACITY: usize = 1024; pub const TIMERS_CAPACITY: usize = 1024;
pub const MESSAGE_PADDING_MULTIPLE: usize = 16;

View File

@@ -4,6 +4,9 @@ use hex;
#[cfg(test)] #[cfg(test)]
use std::fmt; use std::fmt;
use std::cmp;
use std::mem;
use byteorder::LittleEndian; use byteorder::LittleEndian;
use zerocopy::byteorder::U32; use zerocopy::byteorder::U32;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
@@ -21,6 +24,16 @@ pub const TYPE_INITIATION: u32 = 1;
pub const TYPE_RESPONSE: u32 = 2; pub const TYPE_RESPONSE: u32 = 2;
pub const TYPE_COOKIE_REPLY: u32 = 3; pub const TYPE_COOKIE_REPLY: u32 = 3;
const fn max(a: usize, b: usize) -> usize {
let m: usize = (a > b) as usize;
m * a + (1 - m) * b
}
pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
mem::size_of::<CookieReply>(),
);
/* Handshake messsages */ /* Handshake messsages */
#[repr(packed)] #[repr(packed)]

View File

@@ -18,4 +18,4 @@ mod types;
// publicly exposed interface // publicly exposed interface
pub use device::Device; pub use device::Device;
pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};

View File

@@ -12,7 +12,24 @@ mod timers;
mod types; mod types;
mod wireguard; mod wireguard;
#[test] #[cfg(test)]
fn test_pure_wireguard() {} mod tests {
use crate::types::{dummy, Bind};
use crate::wireguard::Wireguard;
use std::thread;
use std::time::Duration;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[test]
fn test_pure_wireguard() {
init();
let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
thread::sleep(Duration::from_millis(500));
}
}
fn main() {} fn main() {}

View File

@@ -12,209 +12,13 @@ use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet; use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::types::{Bind, Endpoint, Key, KeyPair, Tun}; use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test; extern crate test;
const SIZE_KEEPALIVE: usize = 32; const SIZE_KEEPALIVE: usize = 32;
/* Error implementation */
#[derive(Debug)]
enum BindError {
Disconnected,
}
impl Error for BindError {
fn description(&self) -> &str {
"Generic Bind Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BindError::Disconnected => write!(f, "PairBind disconnected"),
}
}
}
/* TUN implementation */
#[derive(Debug)]
enum TunError {}
impl Error for TunError {
fn description(&self) -> &str {
"Generic Tun Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for TunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Not Possible")
}
}
/* Endpoint implementation */
#[derive(Clone, Copy)]
struct UnitEndpoint {}
impl Endpoint for UnitEndpoint {
fn from_address(_: SocketAddr) -> UnitEndpoint {
UnitEndpoint {}
}
fn into_address(&self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap()
}
}
#[derive(Clone, Copy)]
struct TunTest {}
impl Tun for TunTest {
type Error = TunError;
fn mtu(&self) -> usize {
1500
}
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
Ok(0)
}
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
Ok(())
}
}
/* Bind implemenentations */
#[derive(Clone, Copy)]
struct VoidBind {}
impl Bind for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> VoidBind {
VoidBind {}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None
}
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Clone)]
struct PairBind {
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl Bind for PairBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> PairBind {
PairBind {
send: Arc::new(Mutex::new(sync_channel(0).0)),
recv: Arc::new(Mutex::new(sync_channel(0).1)),
}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None
}
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
let vec = self
.recv
.lock()
.unwrap()
.recv()
.map_err(|_| BindError::Disconnected)?;
let len = vec.len();
buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
Ok(_) => Ok(()),
}
}
}
fn bind_pair() -> (PairBind, PairBind) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
PairBind {
send: Arc::new(Mutex::new(tx1)),
recv: Arc::new(Mutex::new(rx2)),
},
PairBind {
send: Arc::new(Mutex::new(tx2)),
recv: Arc::new(Mutex::new(rx1)),
},
)
}
fn dummy_keypair(initiator: bool) -> KeyPair {
let k1 = Key {
key: [0x53u8; 32],
id: 0x646e6573,
};
let k2 = Key {
key: [0x52u8; 32],
id: 0x76636572,
};
if initiator {
KeyPair {
birth: Instant::now(),
initiator: true,
send: k1,
recv: k2,
}
} else {
KeyPair {
birth: Instant::now(),
initiator: false,
send: k2,
recv: k1,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -341,13 +145,13 @@ mod tests {
} }
// create device // create device
let router: Device<BencherCallbacks, TunTest, VoidBind> = let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
Device::new(num_cpus::get(), TunTest {}, VoidBind::new()); Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
// add new peer // add new peer
let opaque = Arc::new(AtomicUsize::new(0)); let opaque = Arc::new(AtomicUsize::new(0));
let peer = router.new_peer(opaque.clone()); let peer = router.new_peer(opaque.clone());
peer.add_keypair(dummy_keypair(true)); peer.add_keypair(dummy::keypair(true));
// add subnet to peer // add subnet to peer
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
@@ -370,7 +174,8 @@ mod tests {
init(); init();
// create device // create device
let router: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, VoidBind::new()); let router: Device<TestCallbacks, _, _> =
Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
let tests = vec![ let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true), ("192.168.1.0", 24, "192.168.1.20", true),
@@ -404,9 +209,8 @@ mod tests {
let opaque = Opaque::new(); let opaque = Opaque::new();
let peer = router.new_peer(opaque.clone()); let peer = router.new_peer(opaque.clone());
let mask: IpAddr = mask.parse().unwrap(); let mask: IpAddr = mask.parse().unwrap();
if set_key { if set_key {
peer.add_keypair(dummy_keypair(true)); peer.add_keypair(dummy::keypair(true));
} }
// map subnet to peer // map subnet to peer
@@ -512,9 +316,11 @@ mod tests {
for (stage, p1, p2) in tests.iter() { for (stage, p1, p2) in tests.iter() {
// create matching devices // create matching devices
let (bind1, bind2) = bind_pair(); let (bind1, bind2) = dummy::PairBind::pair();
let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone()); let router1: Device<TestCallbacks, _, _> =
let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone()); Device::new(1, dummy::TunTest::new(), bind1.clone());
let router2: Device<TestCallbacks, _, _> =
Device::new(1, dummy::TunTest::new(), bind2.clone());
// prepare opaque values for tracing callbacks // prepare opaque values for tracing callbacks
@@ -527,7 +333,7 @@ mod tests {
let peer1 = router1.new_peer(opaq1.clone()); let peer1 = router1.new_peer(opaq1.clone());
let mask: IpAddr = mask.parse().unwrap(); let mask: IpAddr = mask.parse().unwrap();
peer1.add_subnet(mask, *len); peer1.add_subnet(mask, *len);
peer1.add_keypair(dummy_keypair(false)); peer1.add_keypair(dummy::keypair(false));
let (mask, len, _ip, _okay) = p2; let (mask, len, _ip, _okay) = p2;
let peer2 = router2.new_peer(opaq2.clone()); let peer2 = router2.new_peer(opaq2.clone());
@@ -557,7 +363,7 @@ mod tests {
// this should cause a key-confirmation packet (keepalive or staged packet) // this should cause a key-confirmation packet (keepalive or staged packet)
// this also causes peer1 to learn the "endpoint" for peer2 // this also causes peer1 to learn the "endpoint" for peer2
assert!(peer1.get_endpoint().is_none()); assert!(peer1.get_endpoint().is_none());
peer2.add_keypair(dummy_keypair(true)); peer2.add_keypair(dummy::keypair(true));
wait(); wait();
assert!(opaq2.send().is_some()); assert!(opaq2.send().is_some());

217
src/types/dummy.rs Normal file
View File

@@ -0,0 +1,217 @@
use std::error::Error;
use std::fmt;
use std::net::SocketAddr;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Instant;
use super::{Bind, Endpoint, Key, KeyPair, Tun};
/* This submodule provides pure/dummy implementations of the IO interfaces
* for use in unit tests thoughout the project.
*/
/* Error implementation */
#[derive(Debug)]
pub enum BindError {
Disconnected,
}
impl Error for BindError {
fn description(&self) -> &str {
"Generic Bind Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BindError::Disconnected => write!(f, "PairBind disconnected"),
}
}
}
/* TUN implementation */
#[derive(Debug)]
pub enum TunError {}
impl Error for TunError {
fn description(&self) -> &str {
"Generic Tun Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for TunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Not Possible")
}
}
/* Endpoint implementation */
#[derive(Clone, Copy)]
pub struct UnitEndpoint {}
impl Endpoint for UnitEndpoint {
fn from_address(_: SocketAddr) -> UnitEndpoint {
UnitEndpoint {}
}
fn into_address(&self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap()
}
}
#[derive(Clone, Copy)]
pub struct TunTest {}
impl Tun for TunTest {
type Error = TunError;
fn mtu(&self) -> usize {
1500
}
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
Ok(0)
}
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
Ok(())
}
}
impl TunTest {
pub fn new() -> TunTest {
TunTest {}
}
}
/* Bind implemenentations */
#[derive(Clone, Copy)]
pub struct VoidBind {}
impl Bind for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> VoidBind {
VoidBind {}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None
}
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Clone)]
pub struct PairBind {
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl PairBind {
pub fn pair() -> (PairBind, PairBind) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
PairBind {
send: Arc::new(Mutex::new(tx1)),
recv: Arc::new(Mutex::new(rx2)),
},
PairBind {
send: Arc::new(Mutex::new(tx2)),
recv: Arc::new(Mutex::new(rx1)),
},
)
}
}
impl Bind for PairBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
fn new() -> PairBind {
PairBind {
send: Arc::new(Mutex::new(sync_channel(0).0)),
recv: Arc::new(Mutex::new(sync_channel(0).1)),
}
}
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None
}
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
let vec = self
.recv
.lock()
.unwrap()
.recv()
.map_err(|_| BindError::Disconnected)?;
let len = vec.len();
buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
Ok(_) => Ok(()),
}
}
}
pub fn keypair(initiator: bool) -> KeyPair {
let k1 = Key {
key: [0x53u8; 32],
id: 0x646e6573,
};
let k2 = Key {
key: [0x52u8; 32],
id: 0x76636572,
};
if initiator {
KeyPair {
birth: Instant::now(),
initiator: true,
send: k1,
recv: k2,
}
} else {
KeyPair {
birth: Instant::now(),
initiator: false,
send: k2,
recv: k1,
}
}
}

View File

@@ -3,6 +3,9 @@ mod keys;
mod tun; mod tun;
mod udp; mod udp;
#[cfg(test)]
pub mod dummy;
pub use endpoint::Endpoint; pub use endpoint::Endpoint;
pub use keys::{Key, KeyPair}; pub use keys::{Key, KeyPair};
pub use tun::Tun; pub use tun::Tun;

View File

@@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun};
use hjul::Runner; use hjul::Runner;
use std::cmp;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
@@ -86,8 +87,19 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>, state: Arc<WireguardInner<T, B>>,
} }
#[inline(always)]
const fn padding(size: usize, mtu: usize) -> usize {
#[inline(always)]
const fn min(a: usize, b: usize) -> usize {
let m = (a > b) as usize;
a * m + (1 - m) * b
}
let pad = MESSAGE_PADDING_MULTIPLE;
min(mtu, size + (pad - size % pad) % pad)
}
impl<T: Tun, B: Bind> Wireguard<T, B> { impl<T: Tun, B: Bind> Wireguard<T, B> {
fn set_key(&self, sk: Option<StaticSecret>) { pub fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write(); let mut handshake = self.state.handshake.write();
match sk { match sk {
None => { None => {
@@ -102,7 +114,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
} }
fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner { let state = Arc::new(PeerInner {
pk, pk,
queue: Mutex::new(self.state.queue.lock().clone()), queue: Mutex::new(self.state.queue.lock().clone()),
@@ -111,11 +123,21 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
tx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&self.runner)), timers: RwLock::new(Timers::dummy(&self.runner)),
}); });
let router = Arc::new(self.state.router.new_peer(state.clone())); let router = Arc::new(self.state.router.new_peer(state.clone()));
Peer { router, state }
let peer = Peer { router, state };
/* The need for dummy timers arises from the chicken-egg
* problem of the timer callbacks being able to set timers themselves.
*
* This is in fact the only place where the write lock is ever taken.
*/
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
peer
} }
fn new(tun: T, bind: B) -> Wireguard<T, B> { pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state // create device state
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -215,10 +237,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop { loop {
// read UDP packet into vector // create vector big enough for any message given current MTU
let size = tun.mtu() + 148; // maximum message size let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size); let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0); msg.resize(size, 0);
// read UDP packet into vector
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
msg.truncate(size); msg.truncate(size);
@@ -226,7 +250,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
if msg.len() < std::mem::size_of::<u32>() { if msg.len() < std::mem::size_of::<u32>() {
continue; continue;
} }
match LittleEndian::read_u32(&msg[..]) { match LittleEndian::read_u32(&msg[..]) {
handshake::TYPE_COOKIE_REPLY handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION | handshake::TYPE_INITIATION
@@ -246,9 +269,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
router::TYPE_TRANSPORT => { router::TYPE_TRANSPORT => {
// transport message // transport message
// pad the message
let _ = wg.router.recv(src, msg); let _ = wg.router.recv(src, msg);
} }
_ => (), _ => (),
@@ -261,20 +281,32 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
{ {
let wg = wg.clone(); let wg = wg.clone();
thread::spawn(move || loop { thread::spawn(move || loop {
// read a new IP packet // create vector big enough for any transport message (based on MTU)
let mtu = tun.mtu(); let mtu = tun.mtu();
let size = mtu + 148; let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); msg.resize(size, 0);
msg.truncate(size);
// pad message to multiple of 16 bytes // read a new IP packet
while msg.len() < mtu && msg.len() % 16 != 0 { let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
msg.push(0); debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
}
// truncate padding
let payload = padding(payload, mtu);
msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
debug_assert!(payload <= mtu);
debug_assert_eq!(
if payload < mtu {
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
} else {
0
},
0
);
// crypt-key route // crypt-key route
let _ = wg.router.send(msg); let e = wg.router.send(msg);
debug!("TUN worker, router returned {:?}", e);
}); });
} }