Begin work on full router interaction unittest

This commit is contained in:
Mathias Hall-Andersen
2019-09-08 12:59:35 +02:00
parent eae915b2e8
commit e371d39052
5 changed files with 239 additions and 91 deletions

1
Cargo.lock generated
View File

@@ -1569,6 +1569,7 @@ dependencies = [
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"num_cpus 1.10.1 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", "parking_lot 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
"pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)", "pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)",
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)", "proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@@ -30,6 +30,7 @@ clear_on_drop = "0.2.3"
parking_lot = "^0.9" parking_lot = "^0.9"
futures-channel = "^0.2" futures-channel = "^0.2"
env_logger = "0.6" env_logger = "0.6"
num_cpus = "^1.10"
[dependencies.x25519-dalek] [dependencies.x25519-dalek]
version = "^0.5" version = "^0.5"

View File

@@ -1,32 +1,28 @@
use std::cmp;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel; use std::sync::mpsc::sync_channel;
use std::sync::mpsc::SyncSender; use std::sync::mpsc::SyncSender;
use std::sync::{Arc, Weak}; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Instant; use std::time::Instant;
use log::debug; use log::debug;
use spin::{Mutex, RwLock}; use spin::{Mutex, RwLock};
use treebitmap::IpLookupTable; use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified; use zerocopy::LayoutVerified;
use super::super::types::{Bind, KeyPair, Tun};
use super::anti_replay::AntiReplay; use super::anti_replay::AntiReplay;
use super::peer;
use super::peer::{Peer, PeerInner};
use super::SIZE_MESSAGE_PREFIX;
use super::constants::*; use super::constants::*;
use super::ip::*; use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer;
use super::peer::{Peer, PeerInner};
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation}; use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
use super::super::types::{Bind, KeyPair, Tun};
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer callbacks // IO & timer callbacks
@@ -139,8 +135,8 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
match packet[0] >> 4 { match packet[0] >> 4 {
VERSION_IP4 => { VERSION_IP4 => {
// check length and cast to IPv4 header // check length and cast to IPv4 header
let (header, _) = LayoutVerified::new_from_prefix(packet)?; let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
let header: LayoutVerified<&[u8], IPv4Header> = header; LayoutVerified::new_from_prefix(packet)?;
// lookup destination address // lookup destination address
device device
@@ -151,8 +147,8 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
} }
VERSION_IP6 => { VERSION_IP6 => {
// check length and cast to IPv6 header // check length and cast to IPv6 header
let (header, packet) = LayoutVerified::new_from_prefix(packet)?; let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
let header: LayoutVerified<&[u8], IPv6Header> = header; LayoutVerified::new_from_prefix(packet)?;
// lookup destination address // lookup destination address
device device

View File

@@ -2,10 +2,13 @@ use std::error::Error;
use std::fmt; use std::fmt;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex;
use std::thread; use std::thread;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet; use pnet::packet::ipv6::MutableIpv6Packet;
@@ -14,6 +17,33 @@ use super::{Device, SIZE_MESSAGE_PREFIX};
extern crate test; extern crate test;
/* Error implementation */
#[derive(Debug)]
enum BindError {
Disconnected,
}
impl Error for BindError {
fn description(&self) -> &str {
"Generic Bind Error"
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BindError::Disconnected => write!(f, "PairBind disconnected"),
}
}
}
/* TUN implementation */
#[derive(Debug)] #[derive(Debug)]
enum TunError {} enum TunError {}
@@ -33,6 +63,22 @@ impl fmt::Display for TunError {
} }
} }
/* Endpoint implementation */
struct UnitEndpoint {}
impl From<SocketAddr> for UnitEndpoint {
fn from(addr: SocketAddr) -> UnitEndpoint {
UnitEndpoint {}
}
}
impl Into<SocketAddr> for UnitEndpoint {
fn into(self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap()
}
}
struct TunTest {} struct TunTest {}
impl Tun for TunTest { impl Tun for TunTest {
@@ -51,14 +97,16 @@ impl Tun for TunTest {
} }
} }
struct BindTest {} /* Bind implemenentations */
impl Bind for BindTest { struct VoidBind {}
impl Bind for VoidBind {
type Error = BindError; type Error = BindError;
type Endpoint = SocketAddr; type Endpoint = UnitEndpoint;
fn new() -> BindTest { fn new() -> VoidBind {
BindTest {} VoidBind {}
} }
fn set_port(&self, port: u16) -> Result<(), Self::Error> { fn set_port(&self, port: u16) -> Result<(), Self::Error> {
@@ -70,7 +118,7 @@ impl Bind for BindTest {
} }
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
Ok((0, "127.0.0.1:8080".parse().unwrap())) Ok((0, UnitEndpoint {}))
} }
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> { fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
@@ -78,23 +126,59 @@ impl Bind for BindTest {
} }
} }
#[derive(Debug)] struct PairBind {
enum BindError {} send: Mutex<SyncSender<Vec<u8>>>,
recv: Mutex<Receiver<Vec<u8>>>,
}
impl Error for BindError { impl Bind for PairBind {
fn description(&self) -> &str { type Error = BindError;
"Generic Bind Error" type Endpoint = UnitEndpoint;
fn new() -> PairBind {
PairBind {
send: Mutex::new(sync_channel(0).0),
recv: Mutex::new(sync_channel(0).1),
}
} }
fn source(&self) -> Option<&(dyn Error + 'static)> { fn set_port(&self, port: u16) -> Result<(), Self::Error> {
Ok(())
}
fn get_port(&self) -> Option<u16> {
None None
} }
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
let vec = self
.recv
.lock()
.unwrap()
.recv()
.map_err(|_| BindError::Disconnected)?;
buf.copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
Ok(())
}
} }
impl fmt::Display for BindError { fn bind_pair() -> (PairBind, PairBind) {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let (tx1, rx1) = sync_channel(0);
write!(f, "Not Possible") let (tx2, rx2) = sync_channel(0);
} (
PairBind {
send: Mutex::new(tx1),
recv: Mutex::new(rx2),
},
PairBind {
send: Mutex::new(tx2),
recv: Mutex::new(rx1),
},
)
} }
fn dummy_keypair(initiator: bool) -> KeyPair { fn dummy_keypair(initiator: bool) -> KeyPair {
@@ -131,6 +215,32 @@ mod tests {
use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicU64;
use test::Bencher; use test::Bencher;
fn get_tests() -> Vec<(&'static str, u32, &'static str, bool)> {
vec![
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
("172.133.133.133", 32, "172.133.133.132", false),
(
"2001:db8::ff00:42:0000",
112,
"2001:db8::ff00:42:3242",
true,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:0660",
false,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
]
}
fn init() { fn init() {
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
} }
@@ -162,16 +272,15 @@ mod tests {
type Opaque = Arc<AtomicU64>; type Opaque = Arc<AtomicU64>;
// create device // create device
let workers = 4;
let router = Device::new( let router = Device::new(
workers, num_cpus::get(),
TunTest {}, TunTest {},
BindTest {}, VoidBind::new(),
|t: &Opaque, _data: bool, _sent: bool| { |t: &Opaque, _data: bool, _sent: bool| {
t.fetch_add(1, Ordering::SeqCst); t.fetch_add(1, Ordering::SeqCst);
}, },
|t: &Opaque, _data: bool, _sent: bool| {}, |_t: &Opaque, _data: bool, _sent: bool| {},
|t: &Opaque| {}, |_t: &Opaque| {},
); );
// add new peer // add new peer
@@ -185,16 +294,10 @@ mod tests {
let ip: IpAddr = ip.parse().unwrap(); let ip: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len); peer.add_subnet(mask, len);
for _ in 0..1024 { // every iteration sends 10 MB
let msg = make_packet(1024, ip);
router.send(msg).unwrap();
}
b.iter(|| { b.iter(|| {
opaque.store(0, Ordering::SeqCst); opaque.store(0, Ordering::SeqCst);
// wait till 10 MB
while opaque.load(Ordering::Acquire) < 10 * 1024 { while opaque.load(Ordering::Acquire) < 10 * 1024 {
// create "IP packet"
let msg = make_packet(1024, ip); let msg = make_packet(1024, ip);
router.send(msg).unwrap(); router.send(msg).unwrap();
} }
@@ -214,40 +317,16 @@ mod tests {
type Opaque = Arc<Flags>; type Opaque = Arc<Flags>;
// create device // create device
let workers = 4;
let router = Device::new( let router = Device::new(
workers, 1,
TunTest {}, TunTest {},
BindTest {}, VoidBind::new(),
|t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst), |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
|t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst), |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
|t: &Opaque| t.need_key.store(true, Ordering::SeqCst), |t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
); );
let tests = vec![ let tests = get_tests();
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
("172.133.133.133", 32, "172.133.133.132", false),
(
"2001:db8::ff00:42:0000",
112,
"2001:db8::ff00:42:3242",
true,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:0660",
false,
),
(
"2001:db8::ff00:42:8000",
113,
"2001:db8::ff00:42:ffff",
true,
),
];
for (num, (mask, len, ip, okay)) in tests.iter().enumerate() { for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
for set_key in vec![true, false] { for set_key in vec![true, false] {
debug!("index = {}, set_key = {}", num, set_key); debug!("index = {}, set_key = {}", num, set_key);
@@ -317,4 +396,60 @@ mod tests {
} }
} }
} }
#[test]
fn test_outbound_inbound() {
// type for tracking events inside the router module
struct Flags {
send: AtomicBool,
recv: AtomicBool,
need_key: AtomicBool,
}
type Opaque = Arc<Flags>;
let (bind1, bind2) = bind_pair();
// create matching devices
let router1 = Device::new(
1,
TunTest {},
bind1,
|t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
|t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
|t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
);
let router2 = Device::new(
1,
TunTest {},
bind2,
|t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
|t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
|t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
);
// create peers with matching keypairs
let opaq1 = Arc::new(Flags {
send: AtomicBool::new(false),
recv: AtomicBool::new(false),
need_key: AtomicBool::new(false),
});
let opaq2 = Arc::new(Flags {
send: AtomicBool::new(false),
recv: AtomicBool::new(false),
need_key: AtomicBool::new(false),
});
let peer1 = router1.new_peer(opaq1.clone());
peer1.set_endpoint("127.0.0.1:8080".parse().unwrap());
peer1.add_keypair(dummy_keypair(false));
let peer2 = router2.new_peer(opaq2.clone());
peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
peer2.add_keypair(dummy_keypair(true)); // this should cause an empty key-confirmation packet
}
} }

View File

@@ -54,8 +54,8 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>(
match packet[0] >> 4 { match packet[0] >> 4 {
VERSION_IP4 => { VERSION_IP4 => {
// check length and cast to IPv4 header // check length and cast to IPv4 header
let (header, _) = LayoutVerified::new_from_prefix(packet)?; let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
let header: LayoutVerified<&[u8], IPv4Header> = header; LayoutVerified::new_from_prefix(packet)?;
// check IPv4 source address // check IPv4 source address
device device
@@ -72,8 +72,8 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>(
} }
VERSION_IP6 => { VERSION_IP6 => {
// check length and cast to IPv6 header // check length and cast to IPv6 header
let (header, _) = LayoutVerified::new_from_prefix(packet)?; let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
let header: LayoutVerified<&[u8], IPv6Header> = header; LayoutVerified::new_from_prefix(packet)?;
// check IPv6 source address // check IPv6 source address
device device
@@ -110,14 +110,15 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
let _ = rx let _ = rx
.map(|buf| { .map(|buf| {
if buf.okay { if buf.okay {
// parse / cast // cast transport header
let (header, packet) = match LayoutVerified::new_from_prefix(&buf.msg[..]) { let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
Some(v) => v, match LayoutVerified::new_from_prefix(&buf.msg[..]) {
None => { Some(v) => v,
return; None => {
} return;
}; }
let header: LayoutVerified<&[u8], TransportHeader> = header; };
debug_assert!( debug_assert!(
packet.len() >= CHACHA20_POLY1305.tag_len(), packet.len() >= CHACHA20_POLY1305.tag_len(),
"this should be checked earlier in the pipeline" "this should be checked earlier in the pipeline"
@@ -145,8 +146,13 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
debug_assert!(inner_len <= length, "should be validated"); debug_assert!(inner_len <= length, "should be validated");
if inner_len <= length { if inner_len <= length {
sent = true; sent = match device.tun.write(&packet[..inner_len]) {
let _ = device.tun.write(&packet[..inner_len]); Err(e) => {
debug!("failed to write inbound packet to TUN: {:?}", e);
false
}
Ok(_) => true,
}
} }
} }
} }
@@ -177,8 +183,18 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
let _ = rx let _ = rx
.map(|buf| { .map(|buf| {
if buf.okay { if buf.okay {
// write to UDP device, TODO // write to UDP bind
let xmit = false; let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
match device.bind.send(&buf.msg[..], dst) {
Err(e) => {
debug!("failed to send outbound packet: {:?}", e);
false
}
Ok(_) => true,
}
} else {
false
};
// trigger callback // trigger callback
(device.call_send)( (device.call_send)(
@@ -204,17 +220,16 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
}; };
// cast and check size of packet // cast and check size of packet
let (header, packet) = match LayoutVerified::new_from_prefix(&buf.msg[..]) { let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
Some(v) => v, match LayoutVerified::new_from_prefix(&buf.msg[..]) {
None => continue, Some(v) => v,
}; None => continue,
};
if packet.len() < CHACHA20_POLY1305.nonce_len() { if packet.len() < CHACHA20_POLY1305.nonce_len() {
continue; continue;
} }
let header: LayoutVerified<&[u8], TransportHeader> = header;
// do the weird ring AEAD dance // do the weird ring AEAD dance
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap()); let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());