Write inbound packets to TUN device
This commit is contained in:
@@ -10,8 +10,9 @@ use std::time::Instant;
|
||||
|
||||
use log::debug;
|
||||
|
||||
use spin;
|
||||
use spin::{Mutex, RwLock};
|
||||
use treebitmap::IpLookupTable;
|
||||
use zerocopy::LayoutVerified;
|
||||
|
||||
use super::super::types::{Bind, KeyPair, Tun};
|
||||
|
||||
@@ -20,23 +21,15 @@ use super::peer;
|
||||
use super::peer::{Peer, PeerInner};
|
||||
use super::SIZE_MESSAGE_PREFIX;
|
||||
|
||||
use super::constants::WORKER_QUEUE_SIZE;
|
||||
use super::messages::TYPE_TRANSPORT;
|
||||
use super::constants::*;
|
||||
use super::ip::*;
|
||||
|
||||
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
|
||||
use super::workers::{worker_parallel, JobParallel};
|
||||
|
||||
// minimum sizes for IP headers
|
||||
const SIZE_IP4_HEADER: usize = 16;
|
||||
const SIZE_IP6_HEADER: usize = 36;
|
||||
|
||||
const VERSION_IP4: u8 = 4;
|
||||
const VERSION_IP6: u8 = 6;
|
||||
|
||||
const OFFSET_IP4_DST: usize = 16;
|
||||
const OFFSET_IP6_DST: usize = 24;
|
||||
use super::workers::{worker_parallel, JobParallel, Operation};
|
||||
|
||||
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
// IO & timer generics
|
||||
// IO & timer callbacks
|
||||
pub tun: T,
|
||||
pub bind: B,
|
||||
pub call_recv: C::CallbackRecv,
|
||||
@@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub call_need_key: C::CallbackKey,
|
||||
|
||||
// routing
|
||||
pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state
|
||||
pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
|
||||
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||
}
|
||||
|
||||
pub struct EncryptionState {
|
||||
@@ -57,19 +50,18 @@ pub struct EncryptionState {
|
||||
}
|
||||
|
||||
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub key: [u8; 32],
|
||||
pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub confirmed: AtomicBool,
|
||||
pub protector: spin::Mutex<AntiReplay>,
|
||||
pub peer: Weak<PeerInner<C, T, B>>,
|
||||
pub protector: Mutex<AntiReplay>,
|
||||
pub peer: Arc<PeerInner<C, T, B>>,
|
||||
pub death: Instant, // time when the key can no longer be used for decryption
|
||||
}
|
||||
|
||||
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
|
||||
pub state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||
pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
pub queue_next: AtomicUsize, // next round-robin index
|
||||
pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||
queue_next: AtomicUsize, // next round-robin index
|
||||
queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
||||
@@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
call_recv,
|
||||
call_send,
|
||||
call_need_key,
|
||||
recv: spin::RwLock::new(HashMap::new()),
|
||||
ipv4: spin::RwLock::new(IpLookupTable::new()),
|
||||
ipv6: spin::RwLock::new(IpLookupTable::new()),
|
||||
recv: RwLock::new(HashMap::new()),
|
||||
ipv4: RwLock::new(IpLookupTable::new()),
|
||||
ipv6: RwLock::new(IpLookupTable::new()),
|
||||
});
|
||||
|
||||
// start worker threads
|
||||
@@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
let mut threads = Vec::with_capacity(num_workers);
|
||||
for _ in 0..num_workers {
|
||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||
queues.push(spin::Mutex::new(tx));
|
||||
queues.push(Mutex::new(tx));
|
||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||
}
|
||||
|
||||
@@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
||||
device: &Arc<DeviceInner<C, T, B>>,
|
||||
packet: &[u8],
|
||||
) -> Option<Arc<PeerInner<C, T, B>>> {
|
||||
match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
// check length and cast to IPv4 header
|
||||
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv4Header> = header;
|
||||
|
||||
// check IPv4 source address
|
||||
device
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(Ipv4Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
// check length and cast to IPv6 header
|
||||
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
|
||||
let header: LayoutVerified<&[u8], IPv6Header> = header;
|
||||
|
||||
// check IPv6 source address
|
||||
device
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(Ipv6Addr::from(header.f_source))
|
||||
.and_then(|(_, _, p)| Some(p.clone()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
/// Adds a new peer to the device
|
||||
///
|
||||
@@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||
|
||||
// lookup peer based on IP packet destination address
|
||||
let peer = match packet[0] >> 4 {
|
||||
VERSION_IP4 => {
|
||||
if msg.len() >= SIZE_IP4_HEADER {
|
||||
// extract IPv4 destination address
|
||||
let mut dst = [0u8; 4];
|
||||
dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]);
|
||||
let dst = Ipv4Addr::from(dst);
|
||||
|
||||
// lookup peer (project unto and clone "value" field)
|
||||
self.state
|
||||
.ipv4
|
||||
.read()
|
||||
.longest_match(dst)
|
||||
.and_then(|(_, _, p)| p.upgrade())
|
||||
.ok_or(RouterError::NoCryptKeyRoute)
|
||||
} else {
|
||||
Err(RouterError::MalformedIPHeader)
|
||||
}
|
||||
}
|
||||
VERSION_IP6 => {
|
||||
if msg.len() >= SIZE_IP6_HEADER {
|
||||
// extract IPv6 destination address
|
||||
let mut dst = [0u8; 16];
|
||||
dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]);
|
||||
let dst = Ipv6Addr::from(dst);
|
||||
|
||||
// lookup peer (project unto and clone "value" field)
|
||||
self.state
|
||||
.ipv6
|
||||
.read()
|
||||
.longest_match(dst)
|
||||
.and_then(|(_, _, p)| p.upgrade())
|
||||
.ok_or(RouterError::NoCryptKeyRoute)
|
||||
} else {
|
||||
Err(RouterError::MalformedIPHeader)
|
||||
}
|
||||
}
|
||||
_ => Err(RouterError::MalformedIPHeader),
|
||||
}?;
|
||||
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
|
||||
|
||||
// schedule for encryption and transmission to peer
|
||||
if let Some(job) = peer.send_job(msg) {
|
||||
debug_assert_eq!(job.1.op, Operation::Encryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
self.queues[idx % self.queues.len()]
|
||||
@@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - src: Source address of the packet
|
||||
/// - msg: Encrypted transport message
|
||||
pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// ensure that the type field access is within bounds
|
||||
if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT {
|
||||
return Err(RouterError::MalformedTransportMessage);
|
||||
}
|
||||
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
///
|
||||
pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||
// parse / cast
|
||||
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err(RouterError::MalformedTransportMessage);
|
||||
}
|
||||
};
|
||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||
debug_assert!(
|
||||
header.f_type.get() == TYPE_TRANSPORT as u32,
|
||||
"this should be checked by the message type multiplexer"
|
||||
);
|
||||
|
||||
// lookup peer based on receiver id
|
||||
let dec = self.state.recv.read();
|
||||
let dec = dec
|
||||
.get(&header.f_receiver.get())
|
||||
.ok_or(RouterError::UnkownReceiverId)?;
|
||||
|
||||
unimplemented!();
|
||||
// schedule for decryption and TUN write
|
||||
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||
debug_assert_eq!(job.1.op, Operation::Decryption);
|
||||
|
||||
// add job to worker queue
|
||||
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||
self.queues[idx % self.queues.len()]
|
||||
.lock()
|
||||
.send(job)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user