Write inbound packets to TUN device

This commit is contained in:
Mathias Hall-Andersen
2019-09-07 18:38:19 +02:00
parent 8551e03ee3
commit 7b61ee4c2d
9 changed files with 305 additions and 137 deletions

View File

@@ -1,2 +1,7 @@
// WireGuard semantics constants
pub const MAX_STAGED_PACKETS: usize = 128; pub const MAX_STAGED_PACKETS: usize = 128;
// performance constants
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS; pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;

View File

@@ -10,8 +10,9 @@ use std::time::Instant;
use log::debug; use log::debug;
use spin; use spin::{Mutex, RwLock};
use treebitmap::IpLookupTable; use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::super::types::{Bind, KeyPair, Tun}; use super::super::types::{Bind, KeyPair, Tun};
@@ -20,23 +21,15 @@ use super::peer;
use super::peer::{Peer, PeerInner}; use super::peer::{Peer, PeerInner};
use super::SIZE_MESSAGE_PREFIX; use super::SIZE_MESSAGE_PREFIX;
use super::constants::WORKER_QUEUE_SIZE; use super::constants::*;
use super::messages::TYPE_TRANSPORT; use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
use super::workers::{worker_parallel, JobParallel}; use super::workers::{worker_parallel, JobParallel, Operation};
// minimum sizes for IP headers
const SIZE_IP4_HEADER: usize = 16;
const SIZE_IP6_HEADER: usize = 36;
const VERSION_IP4: u8 = 4;
const VERSION_IP6: u8 = 6;
const OFFSET_IP4_DST: usize = 16;
const OFFSET_IP6_DST: usize = 24;
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer generics // IO & timer callbacks
pub tun: T, pub tun: T,
pub bind: B, pub bind: B,
pub call_recv: C::CallbackRecv, pub call_recv: C::CallbackRecv,
@@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
pub call_need_key: C::CallbackKey, pub call_need_key: C::CallbackKey,
// routing // routing
pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
} }
pub struct EncryptionState { pub struct EncryptionState {
@@ -57,19 +50,18 @@ pub struct EncryptionState {
} }
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> { pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
pub key: [u8; 32], pub keypair: Arc<KeyPair>,
pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference
pub confirmed: AtomicBool, pub confirmed: AtomicBool,
pub protector: spin::Mutex<AntiReplay>, pub protector: Mutex<AntiReplay>,
pub peer: Weak<PeerInner<C, T, B>>, pub peer: Arc<PeerInner<C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption pub death: Instant, // time when the key can no longer be used for decryption
} }
pub struct Device<C: Callbacks, T: Tun, B: Bind> { pub struct Device<C: Callbacks, T: Tun, B: Bind> {
pub state: Arc<DeviceInner<C, T, B>>, // reference to device state state: Arc<DeviceInner<C, T, B>>, // reference to device state
pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers handles: Vec<thread::JoinHandle<()>>, // join handles for workers
pub queue_next: AtomicUsize, // next round-robin index queue_next: AtomicUsize, // next round-robin index
pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread) queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
} }
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
@@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
call_recv, call_recv,
call_send, call_send,
call_need_key, call_need_key,
recv: spin::RwLock::new(HashMap::new()), recv: RwLock::new(HashMap::new()),
ipv4: spin::RwLock::new(IpLookupTable::new()), ipv4: RwLock::new(IpLookupTable::new()),
ipv6: spin::RwLock::new(IpLookupTable::new()), ipv6: RwLock::new(IpLookupTable::new()),
}); });
// start worker threads // start worker threads
@@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
let mut threads = Vec::with_capacity(num_workers); let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers { for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
queues.push(spin::Mutex::new(tx)); queues.push(Mutex::new(tx));
threads.push(thread::spawn(move || worker_parallel(rx))); threads.push(thread::spawn(move || worker_parallel(rx)));
} }
@@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
} }
} }
#[inline(always)]
fn get_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<C, T, B>>> {
match packet[0] >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
let header: LayoutVerified<&[u8], IPv4Header> = header;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| Some(p.clone()))
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
let header: LayoutVerified<&[u8], IPv6Header> = header;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| Some(p.clone()))
}
_ => None,
}
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// Adds a new peer to the device /// Adds a new peer to the device
/// ///
@@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
let packet = &msg[SIZE_MESSAGE_PREFIX..]; let packet = &msg[SIZE_MESSAGE_PREFIX..];
// lookup peer based on IP packet destination address // lookup peer based on IP packet destination address
let peer = match packet[0] >> 4 { let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
VERSION_IP4 => {
if msg.len() >= SIZE_IP4_HEADER {
// extract IPv4 destination address
let mut dst = [0u8; 4];
dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]);
let dst = Ipv4Addr::from(dst);
// lookup peer (project unto and clone "value" field)
self.state
.ipv4
.read()
.longest_match(dst)
.and_then(|(_, _, p)| p.upgrade())
.ok_or(RouterError::NoCryptKeyRoute)
} else {
Err(RouterError::MalformedIPHeader)
}
}
VERSION_IP6 => {
if msg.len() >= SIZE_IP6_HEADER {
// extract IPv6 destination address
let mut dst = [0u8; 16];
dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]);
let dst = Ipv6Addr::from(dst);
// lookup peer (project unto and clone "value" field)
self.state
.ipv6
.read()
.longest_match(dst)
.and_then(|(_, _, p)| p.upgrade())
.ok_or(RouterError::NoCryptKeyRoute)
} else {
Err(RouterError::MalformedIPHeader)
}
}
_ => Err(RouterError::MalformedIPHeader),
}?;
// schedule for encryption and transmission to peer // schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg) { if let Some(job) = peer.send_job(msg) {
debug_assert_eq!(job.1.op, Operation::Encryption);
// add job to worker queue // add job to worker queue
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst); let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
self.queues[idx % self.queues.len()] self.queues[idx % self.queues.len()]
@@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// ///
/// # Arguments /// # Arguments
/// ///
/// - src: Source address of the packet
/// - msg: Encrypted transport message /// - msg: Encrypted transport message
pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> { ///
// ensure that the type field access is within bounds /// # Returns
if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT { ///
return Err(RouterError::MalformedTransportMessage); ///
} pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
// parse / cast // parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
None => {
return Err(RouterError::MalformedTransportMessage);
}
};
let header: LayoutVerified<&[u8], TransportHeader> = header;
debug_assert!(
header.f_type.get() == TYPE_TRANSPORT as u32,
"this should be checked by the message type multiplexer"
);
// lookup peer based on receiver id // lookup peer based on receiver id
let dec = self.state.recv.read();
let dec = dec
.get(&header.f_receiver.get())
.ok_or(RouterError::UnkownReceiverId)?;
unimplemented!(); // schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
debug_assert_eq!(job.1.op, Operation::Decryption);
// add job to worker queue
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
self.queues[idx % self.queues.len()]
.lock()
.send(job)
.unwrap();
}
Ok(())
} }
} }

37
src/router/ip.rs Normal file
View File

@@ -0,0 +1,37 @@
use byteorder::BigEndian;
use zerocopy::byteorder::U16;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
pub const SIZE_IP4_HEADER: usize = 16;
pub const SIZE_IP6_HEADER: usize = 36;
pub const VERSION_IP4: u8 = 4;
pub const VERSION_IP6: u8 = 6;
pub const OFFSET_IP4_SRC: usize = 12;
pub const OFFSET_IP6_SRC: usize = 8;
pub const OFFSET_IP4_DST: usize = 16;
pub const OFFSET_IP6_DST: usize = 24;
pub const TYPE_TRANSPORT: u8 = 4;
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct IPv4Header {
_f_space1: [u8; 2],
pub f_total_len: U16<BigEndian>,
_f_space2: [u8; 8],
pub f_source: [u8; 4],
pub f_destination: [u8; 4],
}
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct IPv6Header {
_f_pre: [u8; 4],
pub f_len: U16<BigEndian>,
_f_space2: [u8; 2],
pub f_source: [u8; 16],
pub f_destination: [u8; 16],
}

View File

@@ -1,6 +1,7 @@
mod anti_replay; mod anti_replay;
mod constants; mod constants;
mod device; mod device;
mod ip;
mod messages; mod messages;
mod peer; mod peer;
mod types; mod types;

View File

@@ -30,7 +30,7 @@ use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound}; use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
use super::constants::MAX_STAGED_PACKETS; use super::constants::*;
use super::types::Callbacks; use super::types::Callbacks;
pub struct KeyWheel { pub struct KeyWheel {
@@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub tx_bytes: AtomicU64, // transmitted bytes pub tx_bytes: AtomicU64, // transmitted bytes
pub keys: Mutex<KeyWheel>, // key-wheel pub keys: Mutex<KeyWheel>, // key-wheel
pub ekey: Mutex<Option<EncryptionState>>, // encryption state pub ekey: Mutex<Option<EncryptionState>>, // encryption state
pub endpoint: Mutex<Option<Arc<SocketAddr>>>, pub endpoint: Mutex<Option<B::Endpoint>>,
} }
pub struct Peer<C: Callbacks, T: Tun, B: Bind> { pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
@@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>( fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
peer: &Arc<PeerInner<C, T, B>>, peer: &Arc<PeerInner<C, T, B>>,
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> E>, callback: Box<dyn Fn(A, u32) -> E>,
) -> Vec<E> ) -> Vec<E>
where where
@@ -70,10 +70,8 @@ where
let mut res = Vec::new(); let mut res = Vec::new();
for subnet in table.read().iter() { for subnet in table.read().iter() {
let (ip, masklen, p) = subnet; let (ip, masklen, p) = subnet;
if let Some(p) = p.upgrade() { if Arc::ptr_eq(&p, &peer) {
if Arc::ptr_eq(&p, &peer) { res.push(callback(ip, masklen))
res.push(callback(ip, masklen))
}
} }
} }
res res
@@ -81,7 +79,7 @@ where
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
peer: &Peer<C, T, B>, peer: &Peer<C, T, B>,
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
) { ) {
let mut m = table.write(); let mut m = table.write();
@@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
let mut subnets = vec![]; let mut subnets = vec![];
for subnet in m.iter() { for subnet in m.iter() {
let (ip, masklen, p) = subnet; let (ip, masklen, p) = subnet;
if let Some(p) = p.upgrade() { if Arc::ptr_eq(&p, &peer.state) {
if Arc::ptr_eq(&p, &peer.state) { subnets.push((ip, masklen))
subnets.push((ip, masklen))
}
} }
} }
@@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
} }
} }
impl EncryptionState {
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
EncryptionState {
id: keypair.send.id,
key: keypair.send.key,
nonce: 0,
death: keypair.birth + REJECT_AFTER_TIME,
}
}
}
impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
protector: spin::Mutex::new(AntiReplay::new()),
peer: peer.clone(),
death: keypair.birth + REJECT_AFTER_TIME,
}
}
}
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
fn drop(&mut self) { fn drop(&mut self) {
let peer = &self.state; let peer = &self.state;
@@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
} }
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
pub fn confirm_key(&self, kp: Weak<KeyPair>) { pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
// upgrade key-pair to strong reference // take lock and check keypair = keys.next
let mut keys = self.keys.lock();
let next = match keys.next.as_ref() {
Some(next) => next,
None => {
return;
}
};
if !Arc::ptr_eq(&next, keypair) {
return;
}
// check it is the new unconfirmed key // allocate new encryption state
let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel // rotate key-wheel
let mut swap = None;
mem::swap(&mut keys.next, &mut swap);
mem::swap(&mut keys.current, &mut swap);
mem::swap(&mut keys.previous, &mut swap);
// set new encryption key
*self.ekey.lock() = ekey;
}
pub fn recv_job(
&self,
src: B::Endpoint,
dec: Arc<DecryptionState<C, T, B>>,
mut msg: Vec<u8>,
) -> Option<JobParallel> {
let (tx, rx) = oneshot();
let key = dec.keypair.send.key;
match self.inbound.lock().try_send((dec, src, rx)) {
Ok(_) => Some((
tx,
JobBuffer {
msg,
key: key,
okay: false,
op: Operation::Decryption,
},
)),
Err(_) => None,
}
} }
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> { pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
@@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
pub fn set_endpoint(&self, endpoint: SocketAddr) { pub fn set_endpoint(&self, endpoint: SocketAddr) {
*self.state.endpoint.lock() = Some(Arc::new(endpoint)) *self.state.endpoint.lock() = Some(endpoint.into());
} }
/// Add a new keypair /// Add a new keypair
@@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
// update key-wheel // update key-wheel
if new.initiator { if new.initiator {
// start using key for encryption // start using key for encryption
*self.state.ekey.lock() = Some(EncryptionState { *self.state.ekey.lock() = Some(EncryptionState::new(&new));
id: new.send.id,
key: new.send.key,
nonce: 0,
death: new.birth + REJECT_AFTER_TIME,
});
// move current into previous // move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone()); keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
recv.remove(&id); recv.remove(&id);
} }
// map new id to keypair // map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id)); debug_assert!(!recv.contains_key(&new.recv.id));
recv.insert( recv.insert(
new.recv.id, new.recv.id,
DecryptionState { Arc::new(DecryptionState::new(&self.state, &new)),
confirmed: AtomicBool::new(new.initiator),
keypair: Arc::downgrade(&new),
key: new.recv.key,
protector: spin::Mutex::new(AntiReplay::new()),
peer: Arc::downgrade(&self.state),
death: new.birth + REJECT_AFTER_TIME,
},
); );
} }
@@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
.device .device
.ipv4 .ipv4
.write() .write()
.insert(v4, masklen, Arc::downgrade(&self.state)) .insert(v4, masklen, self.state.clone())
} }
IpAddr::V6(v6) => { IpAddr::V6(v6) => {
self.state self.state
.device .device
.ipv6 .ipv6
.write() .write()
.insert(v6, masklen, Arc::downgrade(&self.state)) .insert(v6, masklen, self.state.clone())
} }
}; };
} }

View File

@@ -156,6 +156,8 @@ mod tests {
#[bench] #[bench]
fn bench_outbound(b: &mut Bencher) { fn bench_outbound(b: &mut Bencher) {
init();
// type for tracking number of packets // type for tracking number of packets
type Opaque = Arc<AtomicU64>; type Opaque = Arc<AtomicU64>;

View File

@@ -57,6 +57,7 @@ pub enum RouterError {
NoCryptKeyRoute, NoCryptKeyRoute,
MalformedIPHeader, MalformedIPHeader,
MalformedTransportMessage, MalformedTransportMessage,
UnkownReceiverId,
} }
impl fmt::Display for RouterError { impl fmt::Display for RouterError {
@@ -65,6 +66,9 @@ impl fmt::Display for RouterError {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
RouterError::UnkownReceiverId => {
write!(f, "No decryption state associated with receiver id")
}
} }
} }
} }

View File

@@ -1,6 +1,6 @@
use std::mem; use std::mem;
use std::sync::mpsc::Receiver; use std::sync::mpsc::Receiver;
use std::sync::{Arc, Weak}; use std::sync::Arc;
use futures::sync::oneshot; use futures::sync::oneshot;
use futures::*; use futures::*;
@@ -8,15 +8,17 @@ use futures::*;
use log::debug; use log::debug;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use std::sync::atomic::{AtomicBool, Ordering}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified}; use zerocopy::{AsBytes, LayoutVerified};
use super::device::DecryptionState; use super::device::{DecryptionState, DeviceInner};
use super::device::DeviceInner;
use super::messages::TransportHeader; use super::messages::TransportHeader;
use super::peer::PeerInner; use super::peer::PeerInner;
use super::types::Callbacks; use super::types::Callbacks;
use super::ip::*;
use super::super::types::{Bind, Tun}; use super::super::types::{Bind, Tun};
#[derive(PartialEq, Debug)] #[derive(PartialEq, Debug)]
@@ -33,9 +35,60 @@ pub struct JobBuffer {
} }
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer); pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, oneshot::Receiver<JobBuffer>); pub type JobInbound<C, T, B: Bind> = (
Arc<DecryptionState<C, T, B>>,
B::Endpoint,
oneshot::Receiver<JobBuffer>,
);
pub type JobOutbound = oneshot::Receiver<JobBuffer>; pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
fn check_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
peer: &Arc<PeerInner<C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet[0] >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
let header: LayoutVerified<&[u8], IPv4Header> = header;
// check IPv4 source address
device
.ipv4
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_total_len.get() as usize)
} else {
None
}
})
}
VERSION_IP6 => {
// check length and cast to IPv6 header
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
let header: LayoutVerified<&[u8], IPv6Header> = header;
// check IPv6 source address
device
.ipv6
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
if Arc::ptr_eq(p, &peer) {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None
}
})
}
_ => None,
}
}
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
device: Arc<DeviceInner<C, T, B>>, // related device device: Arc<DeviceInner<C, T, B>>, // related device
peer: Arc<PeerInner<C, T, B>>, // related peer peer: Arc<PeerInner<C, T, B>>, // related peer
@@ -43,7 +96,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
) { ) {
loop { loop {
// fetch job // fetch job
let (state, rx) = match receiver.recv() { let (state, endpoint, rx) = match receiver.recv() {
Ok(v) => v, Ok(v) => v,
_ => { _ => {
return; return;
@@ -62,13 +115,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
} }
}; };
let header: LayoutVerified<&[u8], TransportHeader> = header; let header: LayoutVerified<&[u8], TransportHeader> = header;
debug_assert!(
// obtain strong reference to decryption state packet.len() >= 16,
let state = if let Some(state) = state.upgrade() { "this should be checked earlier in the pipeline"
state );
} else {
return;
};
// check for replay // check for replay
if !state.protector.lock().update(header.f_counter.get()) { if !state.protector.lock().update(header.f_counter.get()) {
@@ -77,23 +127,29 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
// check for confirms key // check for confirms key
if !state.confirmed.swap(true, Ordering::SeqCst) { if !state.confirmed.swap(true, Ordering::SeqCst) {
peer.confirm_key(state.keypair.clone()); peer.confirm_key(&state.keypair);
} }
// update endpoint, TODO // update endpoint
*peer.endpoint.lock() = Some(endpoint);
// write packet to TUN device, TODO // calculate length of IP packet + padding
let length = packet.len() - CHACHA20_POLY1305.nonce_len();
// check if should be written to TUN
let mut sent = false;
if length > 0 {
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
debug_assert!(inner_len <= length, "should be validated");
if inner_len <= length {
sent = true;
let _ = device.tun.write(&packet[..inner_len]);
}
}
}
// trigger callback // trigger callback
debug_assert!( (device.call_recv)(&peer.opaque, length == 0, sent);
packet.len() >= CHACHA20_POLY1305.nonce_len(),
"this should be checked earlier in the pipeline"
);
(device.call_recv)(
&peer.opaque,
packet.len() > CHACHA20_POLY1305.nonce_len(),
true,
);
} }
}) })
.wait(); .wait();

View File

@@ -1,5 +1,5 @@
use std::net::SocketAddr; use std::net::SocketAddr;
pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> {} pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> + Send {}
impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> {} impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> + Send {}