Begin work on the pure Wireguard implemenation

Start joining the handshake device and router device in the top-level Wireguard implemenation.
This commit is contained in:
Mathias Hall-Andersen
2019-09-14 12:43:09 +02:00
parent c3ad827197
commit b31becda71
13 changed files with 181 additions and 143 deletions

View File

@@ -4,6 +4,8 @@ use std::net::SocketAddr;
use std::sync::Mutex;
use zerocopy::AsBytes;
use byteorder::{LittleEndian, ByteOrder};
use rand::prelude::*;
use x25519_dalek::PublicKey;
@@ -206,8 +208,14 @@ where
where
&'a S: Into<&'a SocketAddr>,
{
match msg.get(0) {
Some(&TYPE_INITIATION) => {
// ensure type read in-range
if msg.len() < 4 {
return Err(HandshakeError::InvalidMessageFormat);
}
// de-multiplex the message type field
match LittleEndian::read_u32(msg) {
TYPE_INITIATION => {
// parse message
let msg = Initiation::parse(msg)?;
@@ -267,7 +275,7 @@ where
Some(keys),
))
}
Some(&TYPE_RESPONSE) => {
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
// check mac1 field
@@ -300,7 +308,7 @@ where
// consume inner playload
noise::consume_response(self, &msg.noise)
}
Some(&TYPE_COOKIE_REPLY) => {
TYPE_COOKIE_REPLY => {
let msg = CookieReply::parse(msg)?;
// lookup peer

View File

@@ -17,9 +17,9 @@ const SIZE_COOKIE: usize = 16; //
const SIZE_X25519_POINT: usize = 32; // x25519 public key
const SIZE_TIMESTAMP: usize = 12;
pub const TYPE_INITIATION: u8 = 1;
pub const TYPE_RESPONSE: u8 = 2;
pub const TYPE_COOKIE_REPLY: u8 = 3;
pub const TYPE_INITIATION: u32 = 1;
pub const TYPE_RESPONSE: u32 = 2;
pub const TYPE_COOKIE_REPLY: u32 = 3;
/* Handshake messsages */

View File

@@ -18,3 +18,4 @@ mod types;
// publicly exposed interface
pub use device::Device;
pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE };

View File

@@ -9,5 +9,6 @@ mod constants;
mod handshake;
mod router;
mod types;
mod wireguard;
fn main() {}

View File

@@ -17,7 +17,7 @@ use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
use super::types::{Callbacks, Opaque, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
@@ -27,9 +27,6 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer callbacks
pub tun: T,
pub bind: B,
pub call_recv: C::CallbackRecv,
pub call_send: C::CallbackSend,
pub call_need_key: C::CallbackKey,
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
@@ -83,47 +80,6 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
}
}
impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bind>
Device<PhantomCallbacks<O, R, S, K>, T, B>
{
pub fn new(
num_workers: usize,
tun: T,
bind: B,
call_send: S,
call_recv: R,
call_need_key: K,
) -> Device<PhantomCallbacks<O, R, S, K>, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
tun,
bind,
call_recv,
call_send,
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
call_need_key,
recv: RwLock::new(HashMap::new()),
ipv4: RwLock::new(IpLookupTable::new()),
ipv6: RwLock::new(IpLookupTable::new()),
};
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
inner.queues.lock().push(tx);
threads.push(thread::spawn(move || worker_parallel(rx)));
}
// return exported device handle
Device {
state: Arc::new(inner),
handles: threads,
}
}
}
#[inline(always)]
fn get_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
@@ -165,6 +121,34 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
tun,
bind,
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
recv: RwLock::new(HashMap::new()),
ipv4: RwLock::new(IpLookupTable::new()),
ipv6: RwLock::new(IpLookupTable::new()),
};
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
inner.queues.lock().push(tx);
threads.push(thread::spawn(move || worker_parallel(rx)));
}
// return exported device handle
Device {
state: Arc::new(inner),
handles: threads,
}
}
/// Adds a new peer to the device
///
/// # Returns
@@ -228,7 +212,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
let dec = self.state.recv.read();
let dec = dec
.get(&header.f_receiver.get())
.ok_or(RouterError::UnkownReceiverId)?;
.ok_or(RouterError::UnknownReceiverId)?;
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {

View File

@@ -14,5 +14,9 @@ use messages::TransportHeader;
use std::mem;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
pub use messages::TYPE_TRANSPORT;
pub use device::Device;
pub use peer::Peer;
pub use types::Callbacks;

View File

@@ -280,7 +280,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
None => {
// add to staged packets (create no job)
debug!("execute callback: call_need_key");
(self.device.call_need_key)(&self.opaque);
C::need_key(&self.opaque);
self.staged_packets.lock().push_back(msg);
return None;
}

View File

@@ -13,7 +13,7 @@ use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::types::{Bind, Key, KeyPair, Tun};
use super::{Device, SIZE_MESSAGE_PREFIX};
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
@@ -82,6 +82,7 @@ impl Into<SocketAddr> for UnitEndpoint {
}
}
#[derive(Clone, Copy)]
struct TunTest {}
impl Tun for TunTest {
@@ -102,6 +103,7 @@ impl Tun for TunTest {
/* Bind implemenentations */
#[derive(Clone, Copy)]
struct VoidBind {}
impl Bind for VoidBind {
@@ -166,7 +168,7 @@ impl Bind for PairBind {
Ok((vec.len(), UnitEndpoint {}))
}
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
@@ -221,7 +223,7 @@ mod tests {
use super::*;
use env_logger;
use log::debug;
use std::sync::atomic::{AtomicU64, AtomicUsize};
use std::sync::atomic::AtomicUsize;
use test::Bencher;
// type for tracking events inside the router module
@@ -234,6 +236,8 @@ mod tests {
#[derive(Clone)]
struct Opaque(Arc<Flags>);
struct TestCallbacks();
impl Opaque {
fn new() -> Opaque {
Opaque(Arc::new(Flags {
@@ -269,17 +273,21 @@ mod tests {
}
}
fn callback_send(t: &Opaque, size: usize, data: bool, sent: bool) {
impl Callbacks for TestCallbacks {
type Opaque = Opaque;
fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
t.0.send.lock().unwrap().push((size, data, sent))
}
fn callback_recv(t: &Opaque, size: usize, data: bool, sent: bool) {
fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
t.0.recv.lock().unwrap().push((size, data, sent))
}
fn callback_need_key(t: &Opaque) {
fn need_key(t: &Self::Opaque) {
t.0.need_key.lock().unwrap().push(());
}
}
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
@@ -306,19 +314,19 @@ mod tests {
#[bench]
fn bench_outbound(b: &mut Bencher) {
struct BencherCallbacks {}
impl Callbacks for BencherCallbacks {
type Opaque = Arc<AtomicUsize>;
fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
t.fetch_add(size, Ordering::SeqCst);
}
fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
fn need_key(_: &Self::Opaque) {}
}
// create device
let router = Device::new(
num_cpus::get(),
TunTest {},
VoidBind::new(),
|t: &Opaque, size: usize, _data: bool, _sent: bool| {
t.fetch_add(size, Ordering::SeqCst);
},
|t: &Opaque, _size: usize, _data: bool, _sent: bool| {},
|t: &Opaque| (),
);
let router: Device<BencherCallbacks, TunTest, VoidBind> =
Device::new(num_cpus::get(), TunTest {}, VoidBind::new());
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
@@ -328,15 +336,15 @@ mod tests {
// add subnet to peer
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap();
let ip: IpAddr = ip.parse().unwrap();
let ip1: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len);
// every iteration sends 10 MB
// every iteration sends 50 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
while opaque.load(Ordering::Acquire) < 10 * 1024 {
let msg = make_packet(1024, ip);
router.send(msg).unwrap();
let msg = make_packet(1024, ip1);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
router.send(msg.to_vec()).unwrap();
}
});
}
@@ -346,14 +354,7 @@ mod tests {
init();
// create device
let router = Device::new(
1,
TunTest {},
VoidBind::new(),
callback_send,
callback_recv,
callback_need_key,
);
let router: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
@@ -447,7 +448,7 @@ mod tests {
}
fn wait() {
thread::sleep(Duration::from_millis(10));
thread::sleep(Duration::from_millis(20));
}
#[test]
@@ -472,23 +473,9 @@ mod tests {
// create matching devices
let router1 = Device::new(
1,
TunTest {},
bind1.clone(),
callback_send,
callback_recv,
callback_need_key,
);
let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone());
let router2 = Device::new(
1,
TunTest {},
bind2.clone(),
callback_send,
callback_recv,
callback_need_key,
);
let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone());
// prepare opaque values for tracing callbacks
@@ -514,6 +501,7 @@ mod tests {
let (_mask, _len, ip, _okay) = p2;
let msg = make_packet(1024, ip.parse().unwrap());
router2.send(msg).expect("failed to sent staged packet");
wait();
assert!(opaq2.recv().is_none());
assert!(
@@ -537,7 +525,7 @@ mod tests {
assert!(opaq2.recv().is_none());
assert!(opaq2.need_key().is_none());
assert!(opaq2.is_empty());
assert!(opaq1.is_empty(), "nothing should happend on peer1");
assert!(opaq1.is_empty(), "nothing should happened on peer1");
// read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
@@ -551,7 +539,7 @@ mod tests {
assert!(opaq1.need_key().is_none());
assert!(opaq1.is_empty());
assert!(peer1.get_endpoint().is_some());
assert!(opaq2.is_empty(), "nothing should happend on peer2");
assert!(opaq2.is_empty(), "nothing should happened on peer2");
// how that peer1 has an endpoint
// route packets : peer1 -> peer2

View File

@@ -22,34 +22,11 @@ pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
pub trait Endpoint: Send + Sync {}
pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque;
type CallbackRecv: Callback<Self::Opaque>;
type CallbackSend: Callback<Self::Opaque>;
type CallbackKey: KeyCallback<Self::Opaque>;
}
/* Concrete implementation of "Callbacks",
* used to hide the constituent type parameters.
*
* This type is never instantiated.
*/
pub struct PhantomCallbacks<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> {
_phantom_opaque: PhantomData<O>,
_phantom_recv: PhantomData<R>,
_phantom_send: PhantomData<S>,
_phantom_key: PhantomData<K>,
}
impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks
for PhantomCallbacks<O, R, S, K>
{
type Opaque = O;
type CallbackRecv = R;
type CallbackSend = S;
type CallbackKey = K;
fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
fn need_key(_opaque: &Self::Opaque) {}
}
#[derive(Debug)]
@@ -57,7 +34,7 @@ pub enum RouterError {
NoCryptKeyRoute,
MalformedIPHeader,
MalformedTransportMessage,
UnkownReceiverId,
UnknownReceiverId,
NoEndpoint,
SendError,
}
@@ -68,7 +45,7 @@ impl fmt::Display for RouterError {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
RouterError::UnkownReceiverId => {
RouterError::UnknownReceiverId => {
write!(f, "No decryption state associated with receiver id")
}
RouterError::NoEndpoint => write!(f, "No endpoint for peer"),

View File

@@ -167,7 +167,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
// trigger callback
(device.call_recv)(&peer.opaque, buf.msg.len(), length == 0, sent);
C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
} else {
debug!("inbound worker: authentication failure")
}
@@ -210,7 +210,7 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
};
// trigger callback
(device.call_send)(
C::send(
&peer.opaque,
buf.msg.len(),
buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),

View File

@@ -1,13 +1,13 @@
use std::error;
pub trait Tun: Send + Sync + 'static {
pub trait Tun: Send + Sync + Clone + 'static {
type Error: error::Error;
/// Returns the MTU of the device
///
/// This function needs to be efficient (called for every read).
/// The goto implementation stragtegy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a seperate thread.
/// The goto implementation strategy is to .load an atomic variable,
/// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///

View File

@@ -3,8 +3,8 @@ use std::error;
/* Often times an a file descriptor in an atomic might suffice.
*/
pub trait Bind: Send + Sync + 'static {
type Error: error::Error;
pub trait Bind: Send + Sync + Clone + 'static {
type Error: error::Error + Send;
type Endpoint: Endpoint;
fn new() -> Self;

75
src/wireguard.rs Normal file
View File

@@ -0,0 +1,75 @@
use crate::handshake;
use crate::router;
use crate::types::{Bind, Tun};
use byteorder::{ByteOrder, LittleEndian};
use std::thread;
use x25519_dalek::StaticSecret;
pub struct Timers {}
pub struct Events();
impl router::Callbacks for Events {
type Opaque = Timers;
fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
fn need_key(t: &Timers) {}
}
pub struct Wireguard<T: Tun, B: Bind> {
router: router::Device<Events, T, B>,
handshake: Option<handshake::Device<()>>,
}
impl<T: Tun, B: Bind> Wireguard<T, B> {
fn new(tun: T, bind: B) -> Wireguard<T, B> {
let router = router::Device::new(num_cpus::get(), tun.clone(), bind.clone());
// start UDP read IO thread
{
let tun = tun.clone();
thread::spawn(move || {
loop {
// read UDP packet into vector
let size = tun.mtu() + 148; // maximum message size
let mut msg: Vec<u8> =
Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
// message type de-multiplexer
if msg.len() < 4 {
continue;
}
match LittleEndian::read_u32(&msg[..]) {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => {
// handshake message
}
router::TYPE_TRANSPORT => {
// transport message
}
_ => (),
}
}
});
}
// start TUN read IO thread
thread::spawn(move || {});
Wireguard {
router,
handshake: None,
}
}
}