Write inbound packets to TUN device
This commit is contained in:
@@ -1,2 +1,7 @@
|
||||
// WireGuard semantics constants
|
||||
|
||||
pub const MAX_STAGED_PACKETS: usize = 128;
|
||||
|
||||
// performance constants
|
||||
|
||||
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
|
||||
|
||||
@@ -10,8 +10,9 @@ use std::time::Instant;
|
||||
|
||||
use log::debug;
|
||||
|
||||
use spin;
|
||||
use spin::{Mutex, RwLock};
|
||||
use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::super::types::{Bind, KeyPair, Tun};
|
||||
|
||||
@@ -20,23 +21,15 @@ use super::peer;
|
||||
use super::peer::{Peer, PeerInner};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::constants::WORKER_QUEUE_SIZE;
|
||||
use super::messages::TYPE_TRANSPORT;
|
||||
use super::constants::*;
|
||||
use super::ip::*;
|
||||
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel};
|
||||
|
||||
// minimum sizes for IP headers
|
||||
const SIZE_IP4_HEADER: usize = 16;
|
||||
const SIZE_IP6_HEADER: usize = 36;
|
||||
|
||||
const VERSION_IP4: u8 = 4;
|
||||
const VERSION_IP6: u8 = 6;
|
||||
|
||||
const OFFSET_IP4_DST: usize = 16;
|
||||
const OFFSET_IP6_DST: usize = 24;
|
||||
use super::workers::{worker_parallel, JobParallel, Operation};
|
||||
|
||||
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
// IO & timer generics
|
||||
// IO & timer callbacks
|
||||
pub tun: T,
|
||||
pub bind: B,
|
||||
pub call_recv: C::CallbackRecv,
|
||||
@@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub call_need_key: C::CallbackKey,
|
||||
|
||||
// routing
|
||||
pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state
|
||||
pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||
}
|
||||
|
||||
pub struct EncryptionState {
|
||||
@@ -57,19 +50,18 @@ pub struct EncryptionState {
|
||||
}
|
||||
|
||||
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub key: [u8; 32],
|
||||
pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: spin::Mutex<AntiReplay>,
|
||||
pub peer: Weak<PeerInner<C, T, B>>,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<C, T, B>>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||
pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
queue_next: AtomicUsize, // next round-robin index
|
||||
queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
||||
@@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
call_recv,
|
||||
call_send,
|
||||
call_need_key,
|
||||
recv: spin::RwLock::new(HashMap::new()),
|
||||
ipv4: spin::RwLock::new(IpLookupTable::new()),
|
||||
ipv6: spin::RwLock::new(IpLookupTable::new()),
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
ipv4: RwLock::new(IpLookupTable::new()),
|
||||
ipv6: RwLock::new(IpLookupTable::new()),
|
||||
});
|
||||
|
||||
// start worker threads
|
||||
@@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
let mut threads = Vec::with_capacity(num_workers);
|
||||
for _ in 0..num_workers {
|
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||
queues.push(spin::Mutex::new(tx));
|
||||
queues.push(Mutex::new(tx));
|
||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||
}
|
||||
|
||||
@@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: &Arc<DeviceInner<C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<Arc<PeerInner<C, T, B>>> {
|
||||
match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv4Header> = header;
|
||||
|
||||
// check IPv4 source address
|
||||
device
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
// check length and cast to IPv6 header
|
||||
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv6Header> = header;
|
||||
|
||||
// check IPv6 source address
|
||||
device
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
@@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
if msg.len() >= SIZE_IP4_HEADER {
|
||||
// extract IPv4 destination address
|
||||
let mut dst = [0u8; 4];
|
||||
dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]);
|
||||
let dst = Ipv4Addr::from(dst);
|
||||
|
||||
// lookup peer (project unto and clone "value" field)
|
||||
self.state
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(dst)
|
||||
.and_then(|(_, _, p)| p.upgrade())
|
||||
.ok_or(RouterError::NoCryptKeyRoute)
|
||||
} else {
|
||||
Err(RouterError::MalformedIPHeader)
|
||||
}
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
if msg.len() >= SIZE_IP6_HEADER {
|
||||
// extract IPv6 destination address
|
||||
let mut dst = [0u8; 16];
|
||||
dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]);
|
||||
let dst = Ipv6Addr::from(dst);
|
||||
|
||||
// lookup peer (project unto and clone "value" field)
|
||||
self.state
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(dst)
|
||||
.and_then(|(_, _, p)| p.upgrade())
|
||||
.ok_or(RouterError::NoCryptKeyRoute)
|
||||
} else {
|
||||
Err(RouterError::MalformedIPHeader)
|
||||
}
|
||||
}
|
||||
_ => Err(RouterError::MalformedIPHeader),
|
||||
}?;
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg) {
|
||||
debug_assert_eq!(job.1.op, Operation::Encryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
self.queues[idx % self.queues.len()]
|
||||
@@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Source address of the packet
|
||||
/// - msg: Encrypted transport message
|
||||
pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// ensure that the type field access is within bounds
|
||||
if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err(RouterError::MalformedTransportMessage);
|
||||
}
|
||||
|
||||
// parse / cast
|
||||
};
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||
debug_assert!(
|
||||
header.f_type.get() == TYPE_TRANSPORT as u32,
|
||||
"this should be checked by the message type multiplexer"
|
||||
);
|
||||
|
||||
// lookup peer based on receiver id
|
||||
let dec = self.state.recv.read();
|
||||
let dec = dec
|
||||
.get(&header.f_receiver.get())
|
||||
.ok_or(RouterError::UnkownReceiverId)?;
|
||||
|
||||
unimplemented!();
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||
debug_assert_eq!(job.1.op, Operation::Decryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
self.queues[idx % self.queues.len()]
|
||||
.lock()
|
||||
.send(job)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
37
src/router/ip.rs
Normal file
37
src/router/ip.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use byteorder::BigEndian;
|
||||
use zerocopy::byteorder::U16;
|
||||
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
|
||||
|
||||
pub const SIZE_IP4_HEADER: usize = 16;
|
||||
pub const SIZE_IP6_HEADER: usize = 36;
|
||||
|
||||
pub const VERSION_IP4: u8 = 4;
|
||||
pub const VERSION_IP6: u8 = 6;
|
||||
|
||||
pub const OFFSET_IP4_SRC: usize = 12;
|
||||
pub const OFFSET_IP6_SRC: usize = 8;
|
||||
|
||||
pub const OFFSET_IP4_DST: usize = 16;
|
||||
pub const OFFSET_IP6_DST: usize = 24;
|
||||
|
||||
pub const TYPE_TRANSPORT: u8 = 4;
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct IPv4Header {
|
||||
_f_space1: [u8; 2],
|
||||
pub f_total_len: U16<BigEndian>,
|
||||
_f_space2: [u8; 8],
|
||||
pub f_source: [u8; 4],
|
||||
pub f_destination: [u8; 4],
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||
pub struct IPv6Header {
|
||||
_f_pre: [u8; 4],
|
||||
pub f_len: U16<BigEndian>,
|
||||
_f_space2: [u8; 2],
|
||||
pub f_source: [u8; 16],
|
||||
pub f_destination: [u8; 16],
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
mod anti_replay;
|
||||
mod constants;
|
||||
mod device;
|
||||
mod ip;
|
||||
mod messages;
|
||||
mod peer;
|
||||
mod types;
|
||||
|
||||
@@ -30,7 +30,7 @@ use super::workers::Operation;
|
||||
use super::workers::{worker_inbound, worker_outbound};
|
||||
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
|
||||
|
||||
use super::constants::MAX_STAGED_PACKETS;
|
||||
use super::constants::*;
|
||||
use super::types::Callbacks;
|
||||
|
||||
pub struct KeyWheel {
|
||||
@@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub tx_bytes: AtomicU64, // transmitted bytes
|
||||
pub keys: Mutex<KeyWheel>, // key-wheel
|
||||
pub ekey: Mutex<Option<EncryptionState>>, // encryption state
|
||||
pub endpoint: Mutex<Option<Arc<SocketAddr>>>,
|
||||
pub endpoint: Mutex<Option<B::Endpoint>>,
|
||||
}
|
||||
|
||||
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
||||
@@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
||||
|
||||
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
|
||||
peer: &Arc<PeerInner<C, T, B>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||
callback: Box<dyn Fn(A, u32) -> E>,
|
||||
) -> Vec<E>
|
||||
where
|
||||
@@ -70,18 +70,16 @@ where
|
||||
let mut res = Vec::new();
|
||||
for subnet in table.read().iter() {
|
||||
let (ip, masklen, p) = subnet;
|
||||
if let Some(p) = p.upgrade() {
|
||||
if Arc::ptr_eq(&p, &peer) {
|
||||
res.push(callback(ip, masklen))
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
||||
peer: &Peer<C, T, B>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
|
||||
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||
) {
|
||||
let mut m = table.write();
|
||||
|
||||
@@ -89,12 +87,10 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
||||
let mut subnets = vec![];
|
||||
for subnet in m.iter() {
|
||||
let (ip, masklen, p) = subnet;
|
||||
if let Some(p) = p.upgrade() {
|
||||
if Arc::ptr_eq(&p, &peer.state) {
|
||||
subnets.push((ip, masklen))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove all key mappings
|
||||
for (ip, masklen) in subnets {
|
||||
@@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptionState {
|
||||
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
|
||||
EncryptionState {
|
||||
id: keypair.send.id,
|
||||
key: keypair.send.key,
|
||||
nonce: 0,
|
||||
death: keypair.birth + REJECT_AFTER_TIME,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
|
||||
fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
|
||||
DecryptionState {
|
||||
confirmed: AtomicBool::new(keypair.initiator),
|
||||
keypair: keypair.clone(),
|
||||
protector: spin::Mutex::new(AntiReplay::new()),
|
||||
peer: peer.clone(),
|
||||
death: keypair.birth + REJECT_AFTER_TIME,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
|
||||
fn drop(&mut self) {
|
||||
let peer = &self.state;
|
||||
@@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||
pub fn confirm_key(&self, kp: Weak<KeyPair>) {
|
||||
// upgrade key-pair to strong reference
|
||||
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
|
||||
// take lock and check keypair = keys.next
|
||||
let mut keys = self.keys.lock();
|
||||
let next = match keys.next.as_ref() {
|
||||
Some(next) => next,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
if !Arc::ptr_eq(&next, keypair) {
|
||||
return;
|
||||
}
|
||||
|
||||
// check it is the new unconfirmed key
|
||||
// allocate new encryption state
|
||||
let ekey = Some(EncryptionState::new(&next));
|
||||
|
||||
// rotate key-wheel
|
||||
let mut swap = None;
|
||||
mem::swap(&mut keys.next, &mut swap);
|
||||
mem::swap(&mut keys.current, &mut swap);
|
||||
mem::swap(&mut keys.previous, &mut swap);
|
||||
|
||||
// set new encryption key
|
||||
*self.ekey.lock() = ekey;
|
||||
}
|
||||
|
||||
pub fn recv_job(
|
||||
&self,
|
||||
src: B::Endpoint,
|
||||
dec: Arc<DecryptionState<C, T, B>>,
|
||||
mut msg: Vec<u8>,
|
||||
) -> Option<JobParallel> {
|
||||
let (tx, rx) = oneshot();
|
||||
let key = dec.keypair.send.key;
|
||||
match self.inbound.lock().try_send((dec, src, rx)) {
|
||||
Ok(_) => Some((
|
||||
tx,
|
||||
JobBuffer {
|
||||
msg,
|
||||
key: key,
|
||||
okay: false,
|
||||
op: Operation::Decryption,
|
||||
},
|
||||
)),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
|
||||
@@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
pub fn set_endpoint(&self, endpoint: SocketAddr) {
|
||||
*self.state.endpoint.lock() = Some(Arc::new(endpoint))
|
||||
*self.state.endpoint.lock() = Some(endpoint.into());
|
||||
}
|
||||
|
||||
/// Add a new keypair
|
||||
@@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
// update key-wheel
|
||||
if new.initiator {
|
||||
// start using key for encryption
|
||||
*self.state.ekey.lock() = Some(EncryptionState {
|
||||
id: new.send.id,
|
||||
key: new.send.key,
|
||||
nonce: 0,
|
||||
death: new.birth + REJECT_AFTER_TIME,
|
||||
});
|
||||
*self.state.ekey.lock() = Some(EncryptionState::new(&new));
|
||||
|
||||
// move current into previous
|
||||
keys.previous = keys.current.as_ref().map(|v| v.clone());
|
||||
@@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
recv.remove(&id);
|
||||
}
|
||||
|
||||
// map new id to keypair
|
||||
// map new id to decryption state
|
||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||
|
||||
recv.insert(
|
||||
new.recv.id,
|
||||
DecryptionState {
|
||||
confirmed: AtomicBool::new(new.initiator),
|
||||
keypair: Arc::downgrade(&new),
|
||||
key: new.recv.key,
|
||||
protector: spin::Mutex::new(AntiReplay::new()),
|
||||
peer: Arc::downgrade(&self.state),
|
||||
death: new.birth + REJECT_AFTER_TIME,
|
||||
},
|
||||
Arc::new(DecryptionState::new(&self.state, &new)),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
.device
|
||||
.ipv4
|
||||
.write()
|
||||
.insert(v4, masklen, Arc::downgrade(&self.state))
|
||||
.insert(v4, masklen, self.state.clone())
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
self.state
|
||||
.device
|
||||
.ipv6
|
||||
.write()
|
||||
.insert(v6, masklen, Arc::downgrade(&self.state))
|
||||
.insert(v6, masklen, self.state.clone())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -156,6 +156,8 @@ mod tests {
|
||||
|
||||
#[bench]
|
||||
fn bench_outbound(b: &mut Bencher) {
|
||||
init();
|
||||
|
||||
// type for tracking number of packets
|
||||
type Opaque = Arc<AtomicU64>;
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ pub enum RouterError {
|
||||
NoCryptKeyRoute,
|
||||
MalformedIPHeader,
|
||||
MalformedTransportMessage,
|
||||
UnkownReceiverId,
|
||||
}
|
||||
|
||||
impl fmt::Display for RouterError {
|
||||
@@ -65,6 +66,9 @@ impl fmt::Display for RouterError {
|
||||
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
|
||||
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
|
||||
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
|
||||
RouterError::UnkownReceiverId => {
|
||||
write!(f, "No decryption state associated with receiver id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::mem;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::sync::oneshot;
|
||||
use futures::*;
|
||||
@@ -8,15 +8,17 @@ use futures::*;
|
||||
use log::debug;
|
||||
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::atomic::Ordering;
|
||||
use zerocopy::{AsBytes, LayoutVerified};
|
||||
|
||||
use super::device::DecryptionState;
|
||||
use super::device::DeviceInner;
|
||||
use super::device::{DecryptionState, DeviceInner};
|
||||
use super::messages::TransportHeader;
|
||||
use super::peer::PeerInner;
|
||||
use super::types::Callbacks;
|
||||
|
||||
use super::ip::*;
|
||||
|
||||
use super::super::types::{Bind, Tun};
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
@@ -33,9 +35,60 @@ pub struct JobBuffer {
|
||||
}
|
||||
|
||||
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
|
||||
pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, oneshot::Receiver<JobBuffer>);
|
||||
pub type JobInbound<C, T, B: Bind> = (
|
||||
Arc<DecryptionState<C, T, B>>,
|
||||
B::Endpoint,
|
||||
oneshot::Receiver<JobBuffer>,
|
||||
);
|
||||
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
|
||||
|
||||
#[inline(always)]
|
||||
fn check_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: &Arc<DeviceInner<C, T, B>>,
|
||||
peer: &Arc<PeerInner<C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<usize> {
|
||||
match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv4Header> = header;
|
||||
|
||||
// check IPv4 source address
|
||||
device
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, &peer) {
|
||||
Some(header.f_total_len.get() as usize)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
// check length and cast to IPv6 header
|
||||
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv6Header> = header;
|
||||
|
||||
// check IPv6 source address
|
||||
device
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| {
|
||||
if Arc::ptr_eq(p, &peer) {
|
||||
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: Arc<DeviceInner<C, T, B>>, // related device
|
||||
peer: Arc<PeerInner<C, T, B>>, // related peer
|
||||
@@ -43,7 +96,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
) {
|
||||
loop {
|
||||
// fetch job
|
||||
let (state, rx) = match receiver.recv() {
|
||||
let (state, endpoint, rx) = match receiver.recv() {
|
||||
Ok(v) => v,
|
||||
_ => {
|
||||
return;
|
||||
@@ -62,13 +115,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
}
|
||||
};
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||
|
||||
// obtain strong reference to decryption state
|
||||
let state = if let Some(state) = state.upgrade() {
|
||||
state
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
debug_assert!(
|
||||
packet.len() >= 16,
|
||||
"this should be checked earlier in the pipeline"
|
||||
);
|
||||
|
||||
// check for replay
|
||||
if !state.protector.lock().update(header.f_counter.get()) {
|
||||
@@ -77,23 +127,29 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||
|
||||
// check for confirms key
|
||||
if !state.confirmed.swap(true, Ordering::SeqCst) {
|
||||
peer.confirm_key(state.keypair.clone());
|
||||
peer.confirm_key(&state.keypair);
|
||||
}
|
||||
|
||||
// update endpoint, TODO
|
||||
// update endpoint
|
||||
*peer.endpoint.lock() = Some(endpoint);
|
||||
|
||||
// write packet to TUN device, TODO
|
||||
// calculate length of IP packet + padding
|
||||
let length = packet.len() - CHACHA20_POLY1305.nonce_len();
|
||||
|
||||
// check if should be written to TUN
|
||||
let mut sent = false;
|
||||
if length > 0 {
|
||||
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
|
||||
debug_assert!(inner_len <= length, "should be validated");
|
||||
if inner_len <= length {
|
||||
sent = true;
|
||||
let _ = device.tun.write(&packet[..inner_len]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trigger callback
|
||||
debug_assert!(
|
||||
packet.len() >= CHACHA20_POLY1305.nonce_len(),
|
||||
"this should be checked earlier in the pipeline"
|
||||
);
|
||||
(device.call_recv)(
|
||||
&peer.opaque,
|
||||
packet.len() > CHACHA20_POLY1305.nonce_len(),
|
||||
true,
|
||||
);
|
||||
(device.call_recv)(&peer.opaque, length == 0, sent);
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> {}
|
||||
pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> + Send {}
|
||||
|
||||
impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> {}
|
||||
impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> + Send {}
|
||||
|
||||
Reference in New Issue
Block a user