Write inbound packets to TUN device
This commit is contained in:
@@ -1,2 +1,7 @@
|
|||||||
|
// WireGuard semantics constants
|
||||||
|
|
||||||
pub const MAX_STAGED_PACKETS: usize = 128;
|
pub const MAX_STAGED_PACKETS: usize = 128;
|
||||||
|
|
||||||
|
// performance constants
|
||||||
|
|
||||||
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
|
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ use std::time::Instant;
|
|||||||
|
|
||||||
use log::debug;
|
use log::debug;
|
||||||
|
|
||||||
use spin;
|
use spin::{Mutex, RwLock};
|
||||||
use treebitmap::IpLookupTable;
|
use treebitmap::IpLookupTable;
|
||||||
|
use zerocopy::LayoutVerified;
|
||||||
|
|
||||||
use super::super::types::{Bind, KeyPair, Tun};
|
use super::super::types::{Bind, KeyPair, Tun};
|
||||||
|
|
||||||
@@ -20,23 +21,15 @@ use super::peer;
|
|||||||
use super::peer::{Peer, PeerInner};
|
use super::peer::{Peer, PeerInner};
|
||||||
use super::SIZE_MESSAGE_PREFIX;
|
use super::SIZE_MESSAGE_PREFIX;
|
||||||
|
|
||||||
use super::constants::WORKER_QUEUE_SIZE;
|
use super::constants::*;
|
||||||
use super::messages::TYPE_TRANSPORT;
|
use super::ip::*;
|
||||||
|
|
||||||
|
use super::messages::{TransportHeader, TYPE_TRANSPORT};
|
||||||
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
|
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
|
||||||
use super::workers::{worker_parallel, JobParallel};
|
use super::workers::{worker_parallel, JobParallel, Operation};
|
||||||
|
|
||||||
// minimum sizes for IP headers
|
|
||||||
const SIZE_IP4_HEADER: usize = 16;
|
|
||||||
const SIZE_IP6_HEADER: usize = 36;
|
|
||||||
|
|
||||||
const VERSION_IP4: u8 = 4;
|
|
||||||
const VERSION_IP6: u8 = 6;
|
|
||||||
|
|
||||||
const OFFSET_IP4_DST: usize = 16;
|
|
||||||
const OFFSET_IP6_DST: usize = 24;
|
|
||||||
|
|
||||||
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
||||||
// IO & timer generics
|
// IO & timer callbacks
|
||||||
pub tun: T,
|
pub tun: T,
|
||||||
pub bind: B,
|
pub bind: B,
|
||||||
pub call_recv: C::CallbackRecv,
|
pub call_recv: C::CallbackRecv,
|
||||||
@@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
|
|||||||
pub call_need_key: C::CallbackKey,
|
pub call_need_key: C::CallbackKey,
|
||||||
|
|
||||||
// routing
|
// routing
|
||||||
pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state
|
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
|
||||||
pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
|
||||||
pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct EncryptionState {
|
pub struct EncryptionState {
|
||||||
@@ -57,19 +50,18 @@ pub struct EncryptionState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
|
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
|
||||||
pub key: [u8; 32],
|
pub keypair: Arc<KeyPair>,
|
||||||
pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference
|
|
||||||
pub confirmed: AtomicBool,
|
pub confirmed: AtomicBool,
|
||||||
pub protector: spin::Mutex<AntiReplay>,
|
pub protector: Mutex<AntiReplay>,
|
||||||
pub peer: Weak<PeerInner<C, T, B>>,
|
pub peer: Arc<PeerInner<C, T, B>>,
|
||||||
pub death: Instant, // time when the key can no longer be used for decryption
|
pub death: Instant, // time when the key can no longer be used for decryption
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
|
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
|
||||||
pub state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
state: Arc<DeviceInner<C, T, B>>, // reference to device state
|
||||||
pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
|
||||||
pub queue_next: AtomicUsize, // next round-robin index
|
queue_next: AtomicUsize, // next round-robin index
|
||||||
pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
|
||||||
@@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
|||||||
call_recv,
|
call_recv,
|
||||||
call_send,
|
call_send,
|
||||||
call_need_key,
|
call_need_key,
|
||||||
recv: spin::RwLock::new(HashMap::new()),
|
recv: RwLock::new(HashMap::new()),
|
||||||
ipv4: spin::RwLock::new(IpLookupTable::new()),
|
ipv4: RwLock::new(IpLookupTable::new()),
|
||||||
ipv6: spin::RwLock::new(IpLookupTable::new()),
|
ipv6: RwLock::new(IpLookupTable::new()),
|
||||||
});
|
});
|
||||||
|
|
||||||
// start worker threads
|
// start worker threads
|
||||||
@@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
|||||||
let mut threads = Vec::with_capacity(num_workers);
|
let mut threads = Vec::with_capacity(num_workers);
|
||||||
for _ in 0..num_workers {
|
for _ in 0..num_workers {
|
||||||
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
|
||||||
queues.push(spin::Mutex::new(tx));
|
queues.push(Mutex::new(tx));
|
||||||
threads.push(thread::spawn(move || worker_parallel(rx)));
|
threads.push(thread::spawn(move || worker_parallel(rx)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
||||||
|
device: &Arc<DeviceInner<C, T, B>>,
|
||||||
|
packet: &[u8],
|
||||||
|
) -> Option<Arc<PeerInner<C, T, B>>> {
|
||||||
|
match packet[0] >> 4 {
|
||||||
|
VERSION_IP4 => {
|
||||||
|
// check length and cast to IPv4 header
|
||||||
|
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
|
||||||
|
let header: LayoutVerified<&[u8], IPv4Header> = header;
|
||||||
|
|
||||||
|
// check IPv4 source address
|
||||||
|
device
|
||||||
|
.ipv4
|
||||||
|
.read()
|
||||||
|
.longest_match(Ipv4Addr::from(header.f_source))
|
||||||
|
.and_then(|(_, _, p)| Some(p.clone()))
|
||||||
|
}
|
||||||
|
VERSION_IP6 => {
|
||||||
|
// check length and cast to IPv6 header
|
||||||
|
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
|
||||||
|
let header: LayoutVerified<&[u8], IPv6Header> = header;
|
||||||
|
|
||||||
|
// check IPv6 source address
|
||||||
|
device
|
||||||
|
.ipv6
|
||||||
|
.read()
|
||||||
|
.longest_match(Ipv6Addr::from(header.f_source))
|
||||||
|
.and_then(|(_, _, p)| Some(p.clone()))
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||||
/// Adds a new peer to the device
|
/// Adds a new peer to the device
|
||||||
///
|
///
|
||||||
@@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
|||||||
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
let packet = &msg[SIZE_MESSAGE_PREFIX..];
|
||||||
|
|
||||||
// lookup peer based on IP packet destination address
|
// lookup peer based on IP packet destination address
|
||||||
let peer = match packet[0] >> 4 {
|
let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
|
||||||
VERSION_IP4 => {
|
|
||||||
if msg.len() >= SIZE_IP4_HEADER {
|
|
||||||
// extract IPv4 destination address
|
|
||||||
let mut dst = [0u8; 4];
|
|
||||||
dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]);
|
|
||||||
let dst = Ipv4Addr::from(dst);
|
|
||||||
|
|
||||||
// lookup peer (project unto and clone "value" field)
|
|
||||||
self.state
|
|
||||||
.ipv4
|
|
||||||
.read()
|
|
||||||
.longest_match(dst)
|
|
||||||
.and_then(|(_, _, p)| p.upgrade())
|
|
||||||
.ok_or(RouterError::NoCryptKeyRoute)
|
|
||||||
} else {
|
|
||||||
Err(RouterError::MalformedIPHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
VERSION_IP6 => {
|
|
||||||
if msg.len() >= SIZE_IP6_HEADER {
|
|
||||||
// extract IPv6 destination address
|
|
||||||
let mut dst = [0u8; 16];
|
|
||||||
dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]);
|
|
||||||
let dst = Ipv6Addr::from(dst);
|
|
||||||
|
|
||||||
// lookup peer (project unto and clone "value" field)
|
|
||||||
self.state
|
|
||||||
.ipv6
|
|
||||||
.read()
|
|
||||||
.longest_match(dst)
|
|
||||||
.and_then(|(_, _, p)| p.upgrade())
|
|
||||||
.ok_or(RouterError::NoCryptKeyRoute)
|
|
||||||
} else {
|
|
||||||
Err(RouterError::MalformedIPHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err(RouterError::MalformedIPHeader),
|
|
||||||
}?;
|
|
||||||
|
|
||||||
// schedule for encryption and transmission to peer
|
// schedule for encryption and transmission to peer
|
||||||
if let Some(job) = peer.send_job(msg) {
|
if let Some(job) = peer.send_job(msg) {
|
||||||
|
debug_assert_eq!(job.1.op, Operation::Encryption);
|
||||||
|
|
||||||
// add job to worker queue
|
// add job to worker queue
|
||||||
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||||
self.queues[idx % self.queues.len()]
|
self.queues[idx % self.queues.len()]
|
||||||
@@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
|||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
|
/// - src: Source address of the packet
|
||||||
/// - msg: Encrypted transport message
|
/// - msg: Encrypted transport message
|
||||||
pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> {
|
///
|
||||||
// ensure that the type field access is within bounds
|
/// # Returns
|
||||||
if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT {
|
///
|
||||||
return Err(RouterError::MalformedTransportMessage);
|
///
|
||||||
}
|
pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
|
||||||
|
|
||||||
// parse / cast
|
// parse / cast
|
||||||
|
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
|
||||||
|
Some(v) => v,
|
||||||
|
None => {
|
||||||
|
return Err(RouterError::MalformedTransportMessage);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||||
|
debug_assert!(
|
||||||
|
header.f_type.get() == TYPE_TRANSPORT as u32,
|
||||||
|
"this should be checked by the message type multiplexer"
|
||||||
|
);
|
||||||
|
|
||||||
// lookup peer based on receiver id
|
// lookup peer based on receiver id
|
||||||
|
let dec = self.state.recv.read();
|
||||||
|
let dec = dec
|
||||||
|
.get(&header.f_receiver.get())
|
||||||
|
.ok_or(RouterError::UnkownReceiverId)?;
|
||||||
|
|
||||||
unimplemented!();
|
// schedule for decryption and TUN write
|
||||||
|
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
|
||||||
|
debug_assert_eq!(job.1.op, Operation::Decryption);
|
||||||
|
|
||||||
|
// add job to worker queue
|
||||||
|
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
|
||||||
|
self.queues[idx % self.queues.len()]
|
||||||
|
.lock()
|
||||||
|
.send(job)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
37
src/router/ip.rs
Normal file
37
src/router/ip.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
use byteorder::BigEndian;
|
||||||
|
use zerocopy::byteorder::U16;
|
||||||
|
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
|
||||||
|
|
||||||
|
pub const SIZE_IP4_HEADER: usize = 16;
|
||||||
|
pub const SIZE_IP6_HEADER: usize = 36;
|
||||||
|
|
||||||
|
pub const VERSION_IP4: u8 = 4;
|
||||||
|
pub const VERSION_IP6: u8 = 6;
|
||||||
|
|
||||||
|
pub const OFFSET_IP4_SRC: usize = 12;
|
||||||
|
pub const OFFSET_IP6_SRC: usize = 8;
|
||||||
|
|
||||||
|
pub const OFFSET_IP4_DST: usize = 16;
|
||||||
|
pub const OFFSET_IP6_DST: usize = 24;
|
||||||
|
|
||||||
|
pub const TYPE_TRANSPORT: u8 = 4;
|
||||||
|
|
||||||
|
#[repr(packed)]
|
||||||
|
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||||
|
pub struct IPv4Header {
|
||||||
|
_f_space1: [u8; 2],
|
||||||
|
pub f_total_len: U16<BigEndian>,
|
||||||
|
_f_space2: [u8; 8],
|
||||||
|
pub f_source: [u8; 4],
|
||||||
|
pub f_destination: [u8; 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(packed)]
|
||||||
|
#[derive(Copy, Clone, FromBytes, AsBytes)]
|
||||||
|
pub struct IPv6Header {
|
||||||
|
_f_pre: [u8; 4],
|
||||||
|
pub f_len: U16<BigEndian>,
|
||||||
|
_f_space2: [u8; 2],
|
||||||
|
pub f_source: [u8; 16],
|
||||||
|
pub f_destination: [u8; 16],
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
mod anti_replay;
|
mod anti_replay;
|
||||||
mod constants;
|
mod constants;
|
||||||
mod device;
|
mod device;
|
||||||
|
mod ip;
|
||||||
mod messages;
|
mod messages;
|
||||||
mod peer;
|
mod peer;
|
||||||
mod types;
|
mod types;
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ use super::workers::Operation;
|
|||||||
use super::workers::{worker_inbound, worker_outbound};
|
use super::workers::{worker_inbound, worker_outbound};
|
||||||
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
|
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
|
||||||
|
|
||||||
use super::constants::MAX_STAGED_PACKETS;
|
use super::constants::*;
|
||||||
use super::types::Callbacks;
|
use super::types::Callbacks;
|
||||||
|
|
||||||
pub struct KeyWheel {
|
pub struct KeyWheel {
|
||||||
@@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
|||||||
pub tx_bytes: AtomicU64, // transmitted bytes
|
pub tx_bytes: AtomicU64, // transmitted bytes
|
||||||
pub keys: Mutex<KeyWheel>, // key-wheel
|
pub keys: Mutex<KeyWheel>, // key-wheel
|
||||||
pub ekey: Mutex<Option<EncryptionState>>, // encryption state
|
pub ekey: Mutex<Option<EncryptionState>>, // encryption state
|
||||||
pub endpoint: Mutex<Option<Arc<SocketAddr>>>,
|
pub endpoint: Mutex<Option<B::Endpoint>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
||||||
@@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
|
|||||||
|
|
||||||
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
|
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
|
||||||
peer: &Arc<PeerInner<C, T, B>>,
|
peer: &Arc<PeerInner<C, T, B>>,
|
||||||
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
|
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||||
callback: Box<dyn Fn(A, u32) -> E>,
|
callback: Box<dyn Fn(A, u32) -> E>,
|
||||||
) -> Vec<E>
|
) -> Vec<E>
|
||||||
where
|
where
|
||||||
@@ -70,10 +70,8 @@ where
|
|||||||
let mut res = Vec::new();
|
let mut res = Vec::new();
|
||||||
for subnet in table.read().iter() {
|
for subnet in table.read().iter() {
|
||||||
let (ip, masklen, p) = subnet;
|
let (ip, masklen, p) = subnet;
|
||||||
if let Some(p) = p.upgrade() {
|
if Arc::ptr_eq(&p, &peer) {
|
||||||
if Arc::ptr_eq(&p, &peer) {
|
res.push(callback(ip, masklen))
|
||||||
res.push(callback(ip, masklen))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res
|
res
|
||||||
@@ -81,7 +79,7 @@ where
|
|||||||
|
|
||||||
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
||||||
peer: &Peer<C, T, B>,
|
peer: &Peer<C, T, B>,
|
||||||
table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
|
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
|
||||||
) {
|
) {
|
||||||
let mut m = table.write();
|
let mut m = table.write();
|
||||||
|
|
||||||
@@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
|||||||
let mut subnets = vec![];
|
let mut subnets = vec![];
|
||||||
for subnet in m.iter() {
|
for subnet in m.iter() {
|
||||||
let (ip, masklen, p) = subnet;
|
let (ip, masklen, p) = subnet;
|
||||||
if let Some(p) = p.upgrade() {
|
if Arc::ptr_eq(&p, &peer.state) {
|
||||||
if Arc::ptr_eq(&p, &peer.state) {
|
subnets.push((ip, masklen))
|
||||||
subnets.push((ip, masklen))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl EncryptionState {
|
||||||
|
fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
|
||||||
|
EncryptionState {
|
||||||
|
id: keypair.send.id,
|
||||||
|
key: keypair.send.key,
|
||||||
|
nonce: 0,
|
||||||
|
death: keypair.birth + REJECT_AFTER_TIME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
|
||||||
|
fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
|
||||||
|
DecryptionState {
|
||||||
|
confirmed: AtomicBool::new(keypair.initiator),
|
||||||
|
keypair: keypair.clone(),
|
||||||
|
protector: spin::Mutex::new(AntiReplay::new()),
|
||||||
|
peer: peer.clone(),
|
||||||
|
death: keypair.birth + REJECT_AFTER_TIME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let peer = &self.state;
|
let peer = &self.state;
|
||||||
@@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
||||||
pub fn confirm_key(&self, kp: Weak<KeyPair>) {
|
pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
|
||||||
// upgrade key-pair to strong reference
|
// take lock and check keypair = keys.next
|
||||||
|
let mut keys = self.keys.lock();
|
||||||
|
let next = match keys.next.as_ref() {
|
||||||
|
Some(next) => next,
|
||||||
|
None => {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if !Arc::ptr_eq(&next, keypair) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// check it is the new unconfirmed key
|
// allocate new encryption state
|
||||||
|
let ekey = Some(EncryptionState::new(&next));
|
||||||
|
|
||||||
// rotate key-wheel
|
// rotate key-wheel
|
||||||
|
let mut swap = None;
|
||||||
|
mem::swap(&mut keys.next, &mut swap);
|
||||||
|
mem::swap(&mut keys.current, &mut swap);
|
||||||
|
mem::swap(&mut keys.previous, &mut swap);
|
||||||
|
|
||||||
|
// set new encryption key
|
||||||
|
*self.ekey.lock() = ekey;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn recv_job(
|
||||||
|
&self,
|
||||||
|
src: B::Endpoint,
|
||||||
|
dec: Arc<DecryptionState<C, T, B>>,
|
||||||
|
mut msg: Vec<u8>,
|
||||||
|
) -> Option<JobParallel> {
|
||||||
|
let (tx, rx) = oneshot();
|
||||||
|
let key = dec.keypair.send.key;
|
||||||
|
match self.inbound.lock().try_send((dec, src, rx)) {
|
||||||
|
Ok(_) => Some((
|
||||||
|
tx,
|
||||||
|
JobBuffer {
|
||||||
|
msg,
|
||||||
|
key: key,
|
||||||
|
okay: false,
|
||||||
|
op: Operation::Decryption,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
Err(_) => None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
|
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
|
||||||
@@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
|
|||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||||
pub fn set_endpoint(&self, endpoint: SocketAddr) {
|
pub fn set_endpoint(&self, endpoint: SocketAddr) {
|
||||||
*self.state.endpoint.lock() = Some(Arc::new(endpoint))
|
*self.state.endpoint.lock() = Some(endpoint.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new keypair
|
/// Add a new keypair
|
||||||
@@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
// update key-wheel
|
// update key-wheel
|
||||||
if new.initiator {
|
if new.initiator {
|
||||||
// start using key for encryption
|
// start using key for encryption
|
||||||
*self.state.ekey.lock() = Some(EncryptionState {
|
*self.state.ekey.lock() = Some(EncryptionState::new(&new));
|
||||||
id: new.send.id,
|
|
||||||
key: new.send.key,
|
|
||||||
nonce: 0,
|
|
||||||
death: new.birth + REJECT_AFTER_TIME,
|
|
||||||
});
|
|
||||||
|
|
||||||
// move current into previous
|
// move current into previous
|
||||||
keys.previous = keys.current.as_ref().map(|v| v.clone());
|
keys.previous = keys.current.as_ref().map(|v| v.clone());
|
||||||
@@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
recv.remove(&id);
|
recv.remove(&id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// map new id to keypair
|
// map new id to decryption state
|
||||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||||
|
|
||||||
recv.insert(
|
recv.insert(
|
||||||
new.recv.id,
|
new.recv.id,
|
||||||
DecryptionState {
|
Arc::new(DecryptionState::new(&self.state, &new)),
|
||||||
confirmed: AtomicBool::new(new.initiator),
|
|
||||||
keypair: Arc::downgrade(&new),
|
|
||||||
key: new.recv.key,
|
|
||||||
protector: spin::Mutex::new(AntiReplay::new()),
|
|
||||||
peer: Arc::downgrade(&self.state),
|
|
||||||
death: new.birth + REJECT_AFTER_TIME,
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
.device
|
.device
|
||||||
.ipv4
|
.ipv4
|
||||||
.write()
|
.write()
|
||||||
.insert(v4, masklen, Arc::downgrade(&self.state))
|
.insert(v4, masklen, self.state.clone())
|
||||||
}
|
}
|
||||||
IpAddr::V6(v6) => {
|
IpAddr::V6(v6) => {
|
||||||
self.state
|
self.state
|
||||||
.device
|
.device
|
||||||
.ipv6
|
.ipv6
|
||||||
.write()
|
.write()
|
||||||
.insert(v6, masklen, Arc::downgrade(&self.state))
|
.insert(v6, masklen, self.state.clone())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -156,6 +156,8 @@ mod tests {
|
|||||||
|
|
||||||
#[bench]
|
#[bench]
|
||||||
fn bench_outbound(b: &mut Bencher) {
|
fn bench_outbound(b: &mut Bencher) {
|
||||||
|
init();
|
||||||
|
|
||||||
// type for tracking number of packets
|
// type for tracking number of packets
|
||||||
type Opaque = Arc<AtomicU64>;
|
type Opaque = Arc<AtomicU64>;
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ pub enum RouterError {
|
|||||||
NoCryptKeyRoute,
|
NoCryptKeyRoute,
|
||||||
MalformedIPHeader,
|
MalformedIPHeader,
|
||||||
MalformedTransportMessage,
|
MalformedTransportMessage,
|
||||||
|
UnkownReceiverId,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for RouterError {
|
impl fmt::Display for RouterError {
|
||||||
@@ -65,6 +66,9 @@ impl fmt::Display for RouterError {
|
|||||||
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
|
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
|
||||||
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
|
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
|
||||||
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
|
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
|
||||||
|
RouterError::UnkownReceiverId => {
|
||||||
|
write!(f, "No decryption state associated with receiver id")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use std::mem;
|
use std::mem;
|
||||||
use std::sync::mpsc::Receiver;
|
use std::sync::mpsc::Receiver;
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use futures::sync::oneshot;
|
use futures::sync::oneshot;
|
||||||
use futures::*;
|
use futures::*;
|
||||||
@@ -8,15 +8,17 @@ use futures::*;
|
|||||||
use log::debug;
|
use log::debug;
|
||||||
|
|
||||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
use zerocopy::{AsBytes, LayoutVerified};
|
use zerocopy::{AsBytes, LayoutVerified};
|
||||||
|
|
||||||
use super::device::DecryptionState;
|
use super::device::{DecryptionState, DeviceInner};
|
||||||
use super::device::DeviceInner;
|
|
||||||
use super::messages::TransportHeader;
|
use super::messages::TransportHeader;
|
||||||
use super::peer::PeerInner;
|
use super::peer::PeerInner;
|
||||||
use super::types::Callbacks;
|
use super::types::Callbacks;
|
||||||
|
|
||||||
|
use super::ip::*;
|
||||||
|
|
||||||
use super::super::types::{Bind, Tun};
|
use super::super::types::{Bind, Tun};
|
||||||
|
|
||||||
#[derive(PartialEq, Debug)]
|
#[derive(PartialEq, Debug)]
|
||||||
@@ -33,9 +35,60 @@ pub struct JobBuffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
|
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
|
||||||
pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, oneshot::Receiver<JobBuffer>);
|
pub type JobInbound<C, T, B: Bind> = (
|
||||||
|
Arc<DecryptionState<C, T, B>>,
|
||||||
|
B::Endpoint,
|
||||||
|
oneshot::Receiver<JobBuffer>,
|
||||||
|
);
|
||||||
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
|
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn check_route<C: Callbacks, T: Tun, B: Bind>(
|
||||||
|
device: &Arc<DeviceInner<C, T, B>>,
|
||||||
|
peer: &Arc<PeerInner<C, T, B>>,
|
||||||
|
packet: &[u8],
|
||||||
|
) -> Option<usize> {
|
||||||
|
match packet[0] >> 4 {
|
||||||
|
VERSION_IP4 => {
|
||||||
|
// check length and cast to IPv4 header
|
||||||
|
let (header, _) = LayoutVerified::new_from_prefix(packet)?;
|
||||||
|
let header: LayoutVerified<&[u8], IPv4Header> = header;
|
||||||
|
|
||||||
|
// check IPv4 source address
|
||||||
|
device
|
||||||
|
.ipv4
|
||||||
|
.read()
|
||||||
|
.longest_match(Ipv4Addr::from(header.f_source))
|
||||||
|
.and_then(|(_, _, p)| {
|
||||||
|
if Arc::ptr_eq(p, &peer) {
|
||||||
|
Some(header.f_total_len.get() as usize)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
VERSION_IP6 => {
|
||||||
|
// check length and cast to IPv6 header
|
||||||
|
let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
|
||||||
|
let header: LayoutVerified<&[u8], IPv6Header> = header;
|
||||||
|
|
||||||
|
// check IPv6 source address
|
||||||
|
device
|
||||||
|
.ipv6
|
||||||
|
.read()
|
||||||
|
.longest_match(Ipv6Addr::from(header.f_source))
|
||||||
|
.and_then(|(_, _, p)| {
|
||||||
|
if Arc::ptr_eq(p, &peer) {
|
||||||
|
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
||||||
device: Arc<DeviceInner<C, T, B>>, // related device
|
device: Arc<DeviceInner<C, T, B>>, // related device
|
||||||
peer: Arc<PeerInner<C, T, B>>, // related peer
|
peer: Arc<PeerInner<C, T, B>>, // related peer
|
||||||
@@ -43,7 +96,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
// fetch job
|
// fetch job
|
||||||
let (state, rx) = match receiver.recv() {
|
let (state, endpoint, rx) = match receiver.recv() {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
_ => {
|
_ => {
|
||||||
return;
|
return;
|
||||||
@@ -62,13 +115,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
let header: LayoutVerified<&[u8], TransportHeader> = header;
|
||||||
|
debug_assert!(
|
||||||
// obtain strong reference to decryption state
|
packet.len() >= 16,
|
||||||
let state = if let Some(state) = state.upgrade() {
|
"this should be checked earlier in the pipeline"
|
||||||
state
|
);
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
// check for replay
|
// check for replay
|
||||||
if !state.protector.lock().update(header.f_counter.get()) {
|
if !state.protector.lock().update(header.f_counter.get()) {
|
||||||
@@ -77,23 +127,29 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
|
|
||||||
// check for confirms key
|
// check for confirms key
|
||||||
if !state.confirmed.swap(true, Ordering::SeqCst) {
|
if !state.confirmed.swap(true, Ordering::SeqCst) {
|
||||||
peer.confirm_key(state.keypair.clone());
|
peer.confirm_key(&state.keypair);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update endpoint, TODO
|
// update endpoint
|
||||||
|
*peer.endpoint.lock() = Some(endpoint);
|
||||||
|
|
||||||
// write packet to TUN device, TODO
|
// calculate length of IP packet + padding
|
||||||
|
let length = packet.len() - CHACHA20_POLY1305.nonce_len();
|
||||||
|
|
||||||
|
// check if should be written to TUN
|
||||||
|
let mut sent = false;
|
||||||
|
if length > 0 {
|
||||||
|
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
|
||||||
|
debug_assert!(inner_len <= length, "should be validated");
|
||||||
|
if inner_len <= length {
|
||||||
|
sent = true;
|
||||||
|
let _ = device.tun.write(&packet[..inner_len]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// trigger callback
|
// trigger callback
|
||||||
debug_assert!(
|
(device.call_recv)(&peer.opaque, length == 0, sent);
|
||||||
packet.len() >= CHACHA20_POLY1305.nonce_len(),
|
|
||||||
"this should be checked earlier in the pipeline"
|
|
||||||
);
|
|
||||||
(device.call_recv)(
|
|
||||||
&peer.opaque,
|
|
||||||
packet.len() > CHACHA20_POLY1305.nonce_len(),
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.wait();
|
.wait();
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> {}
|
pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> + Send {}
|
||||||
|
|
||||||
impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> {}
|
impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> + Send {}
|
||||||
|
|||||||
Reference in New Issue
Block a user