Unified use of make_packet during tests

This commit is contained in:
Mathias Hall-Andersen
2019-10-29 16:53:59 +01:00
parent 4ff328b7da
commit e04a11a8ca
6 changed files with 144 additions and 119 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel; use std::sync::mpsc::sync_channel;
@@ -14,13 +15,15 @@ use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay; use super::anti_replay::AntiReplay;
use super::constants::*; use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner}; use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callbacks, RouterError}; use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation}; use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX; use super::SIZE_MESSAGE_PREFIX;
use super::route::get_route;
use super::super::{bind, tun, Endpoint, KeyPair}; use super::super::{bind, tun, Endpoint, KeyPair};
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
@@ -84,40 +87,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Dev
} }
} }
#[inline(always)]
fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<E, C, T, B>>> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// lookup destination address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// lookup destination address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
_ => None,
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state // allocate shared device state

View File

@@ -4,6 +4,7 @@ mod device;
mod ip; mod ip;
mod messages; mod messages;
mod peer; mod peer;
mod route;
mod types; mod types;
mod workers; mod workers;

View File

@@ -0,0 +1,101 @@
use super::super::{bind, tun, Endpoint};
use super::device::DeviceInner;
use super::ip::*;
use super::peer::PeerInner;
use super::types::Callbacks;
use log::trace;
use zerocopy::LayoutVerified;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
#[inline(always)]
pub fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<E, C, T, B>>> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
trace!("cryptokey router, get route for IPv4 packet");
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
VERSION_IP6 => {
trace!("cryptokey router, get route for IPv6 packet");
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
_ => None,
}
}
#[inline(always)]
pub fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
trace!("cryptokey route, check route for IPv4 packet");
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, peer) {
Some(header.f_total_len.get() as usize)
} else {
None
}
})
}
VERSION_IP6 => {
trace!("cryptokey route, check route for IPv6 packet");
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, peer) {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None
}
})
}
_ => None,
}
}

View File

@@ -6,13 +6,13 @@ use std::thread;
use std::time::Duration; use std::time::Duration;
use num_cpus; use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::bind::*; use super::super::bind::*;
use super::super::dummy; use super::super::dummy;
use super::super::dummy_keypair; use super::super::dummy_keypair;
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; use super::super::tests::make_packet_dst;
use super::SIZE_MESSAGE_PREFIX;
use super::{Callbacks, Device};
extern crate test; extern crate test;
@@ -111,23 +111,11 @@ mod tests {
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
} }
fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> { fn make_packet_dst_padded(size: usize, dst: IpAddr, id: u64) -> Vec<u8> {
// create "IP packet" let p = make_packet_dst(size, dst, id);
let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16); let mut o = vec![0; p.len() + SIZE_MESSAGE_PREFIX];
msg.resize(SIZE_MESSAGE_PREFIX + size, 0); o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + p.len()].copy_from_slice(&p[..]);
match ip { o
IpAddr::V4(ip) => {
let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
packet.set_destination(ip);
packet.set_version(4);
}
IpAddr::V6(ip) => {
let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
packet.set_destination(ip);
packet.set_version(6);
}
}
msg
} }
#[bench] #[bench]
@@ -162,7 +150,7 @@ mod tests {
// every iteration sends 10 GB // every iteration sends 10 GB
b.iter(|| { b.iter(|| {
opaque.store(0, Ordering::SeqCst); opaque.store(0, Ordering::SeqCst);
let msg = make_packet(1024, ip1); let msg = make_packet_dst_padded(1024, ip1, 0);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
router.send(msg.to_vec()).unwrap(); router.send(msg.to_vec()).unwrap();
} }
@@ -218,7 +206,7 @@ mod tests {
peer.add_allowed_ips(mask, *len); peer.add_allowed_ips(mask, *len);
// create "IP packet" // create "IP packet"
let msg = make_packet(1024, ip.parse().unwrap()); let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
// cryptkey route the IP packet // cryptkey route the IP packet
let res = router.send(msg); let res = router.send(msg);
@@ -228,7 +216,7 @@ mod tests {
if *okay { if *okay {
// cryptkey routing succeeded // cryptkey routing succeeded
assert!(res.is_ok(), "crypt-key routing should succeed"); assert!(res.is_ok(), "crypt-key routing should succeed: {:?}", res);
assert_eq!( assert_eq!(
opaque.need_key().is_some(), opaque.need_key().is_some(),
!set_key, !set_key,
@@ -351,7 +339,7 @@ mod tests {
if *stage { if *stage {
// stage a packet which can be used for confirmation (in place of a keepalive) // stage a packet which can be used for confirmation (in place of a keepalive)
let (_mask, _len, ip, _okay) = p2; let (_mask, _len, ip, _okay) = p2;
let msg = make_packet(1024, ip.parse().unwrap()); let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
router2.send(msg).expect("failed to sent staged packet"); router2.send(msg).expect("failed to sent staged packet");
wait(); wait();
@@ -396,7 +384,7 @@ mod tests {
// now that peer1 has an endpoint // now that peer1 has an endpoint
// route packets : peer1 -> peer2 // route packets : peer1 -> peer2
for _ in 0..10 { for id in 0..10 {
assert!( assert!(
opaq1.is_empty(), opaq1.is_empty(),
"we should have asserted a value for every callback on peer1" "we should have asserted a value for every callback on peer1"
@@ -408,7 +396,7 @@ mod tests {
// pass IP packet to router // pass IP packet to router
let (_mask, _len, ip, _okay) = p1; let (_mask, _len, ip, _okay) = p1;
let msg = make_packet(1024, ip.parse().unwrap()); let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), id);
router1.send(msg).unwrap(); router1.send(msg).unwrap();
wait(); wait();

View File

@@ -1,4 +1,3 @@
use std::mem;
use std::sync::mpsc::Receiver; use std::sync::mpsc::Receiver;
use std::sync::Arc; use std::sync::Arc;
@@ -8,17 +7,17 @@ use futures::*;
use log::debug; use log::debug;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified}; use zerocopy::{AsBytes, LayoutVerified};
use super::device::{DecryptionState, DeviceInner}; use super::device::{DecryptionState, DeviceInner};
use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner; use super::peer::PeerInner;
use super::route::check_route;
use super::types::Callbacks; use super::types::Callbacks;
use super::super::{bind, tun, Endpoint}; use super::super::{bind, tun, Endpoint};
use super::ip::*;
pub const SIZE_TAG: usize = 16; pub const SIZE_TAG: usize = 16;
@@ -46,53 +45,6 @@ pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
pub type JobOutbound = oneshot::Receiver<JobBuffer>; pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet[0] >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_total_len.get() as usize)
} else {
None
}
})
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
LayoutVerified::new_from_prefix(packet)?;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None
}
})
}
_ => None,
}
}
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer peer: Arc<PeerInner<E, C, T, B>>, // related peer

View File

@@ -14,19 +14,32 @@ use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet; use pnet::packet::ipv6::MutableIpv6Packet;
fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> { pub fn make_packet_src(size: usize, src: IpAddr, id: u64) -> Vec<u8> {
match src {
IpAddr::V4(_) => make_packet(size, src, "127.0.0.1".parse().unwrap(), id),
IpAddr::V6(_) => make_packet(size, src, "::1".parse().unwrap(), id),
}
}
pub fn make_packet_dst(size: usize, dst: IpAddr, id: u64) -> Vec<u8> {
match dst {
IpAddr::V4(_) => make_packet(size, "127.0.0.1".parse().unwrap(), dst, id),
IpAddr::V6(_) => make_packet(size, "::1".parse().unwrap(), dst, id),
}
}
pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
// expand pseudo random payload // expand pseudo random payload
let mut rng: _ = ChaCha8Rng::seed_from_u64(id); let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
let mut p: Vec<u8> = vec![]; let mut p: Vec<u8> = vec![0; size];
for _ in 0..size { rng.fill_bytes(&mut p[..]);
p.push(rng.next_u32() as u8);
}
// create "IP packet" // create "IP packet"
let mut msg = Vec::with_capacity(size); let mut msg = Vec::with_capacity(size);
msg.resize(size, 0); msg.resize(size, 0);
match dst { match dst {
IpAddr::V4(dst) => { IpAddr::V4(dst) => {
let length = size - MutableIpv4Packet::minimum_packet_size();
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst); packet.set_destination(dst);
packet.set_total_length(size as u16); packet.set_total_length(size as u16);
@@ -35,19 +48,20 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
} else { } else {
panic!("src.version != dst.version") panic!("src.version != dst.version")
}); });
packet.set_payload(&p[..]); packet.set_payload(&p[..length]);
packet.set_version(4); packet.set_version(4);
} }
IpAddr::V6(dst) => { IpAddr::V6(dst) => {
let length = size - MutableIpv6Packet::minimum_packet_size();
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst); packet.set_destination(dst);
packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16); packet.set_payload_length(length as u16);
packet.set_source(if let IpAddr::V6(src) = src { packet.set_source(if let IpAddr::V6(src) = src {
src src
} else { } else {
panic!("src.version != dst.version") panic!("src.version != dst.version")
}); });
packet.set_payload(&p[..]); packet.set_payload(&p[..length]);
packet.set_version(6); packet.set_version(6);
} }
} }