Move to run queue

This commit is contained in:
Mathias Hall-Andersen
2019-12-09 13:21:12 +01:00
parent 74e576a9c2
commit 115fa574a8
9 changed files with 500 additions and 273 deletions

View File

@@ -312,7 +312,7 @@ impl LinuxTunStatus {
Err(LinuxTunError::Closed)
} else {
Ok(LinuxTunStatus {
events: vec![],
events: vec![TunEvent::Up(1500)], // TODO: for testing
index: get_ifindex(&name),
fd,
name,

View File

@@ -20,6 +20,7 @@ use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
use super::runq::RunQueue;
use super::route::RoutingTable;
use super::super::{tun, udp, Endpoint, KeyPair};
@@ -37,8 +38,12 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
pub table: RoutingTable<Peer<E, C, T, B>>,
// work queues
pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
// run queues
pub run_inbound: RunQueue<Peer<E, C, T, B>>,
pub run_outbound: RunQueue<Peer<E, C, T, B>>,
}
pub struct EncryptionState {
@@ -96,8 +101,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
debug!("router: dropping device");
// close worker queues
self.state.outbound_queue.close();
self.state.inbound_queue.close();
self.state.queue_outbound.close();
self.state.queue_inbound.close();
// close run queues
self.state.run_outbound.close();
self.state.run_inbound.close();
// join all worker threads
while match self.handles.pop() {
@@ -116,43 +125,73 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
// allocate shared device state
let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers);
let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers);
let inner = DeviceInner {
let (mut outrx, queue_outbound) = ParallelQueue::new(num_workers);
let (mut inrx, queue_inbound) = ParallelQueue::new(num_workers);
let device = Device {
inner: Arc::new(DeviceInner {
inbound: tun,
inbound_queue,
queue_inbound,
outbound: RwLock::new((true, None)),
outbound_queue,
queue_outbound,
run_inbound: RunQueue::new(),
run_outbound: RunQueue::new(),
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
})
};
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
// inbound/decryption workers
for _ in 0..num_workers {
// parallel workers (parallel processing)
{
let device = device.clone();
let rx = inrx.pop().unwrap();
threads.push(thread::spawn(move || {
log::debug!("inbound router worker started");
inbound::worker(rx)
log::debug!("inbound parallel router worker started");
inbound::parallel(device, rx)
}));
}
// sequential workers (in-order processing)
{
let device = device.clone();
threads.push(thread::spawn(move || {
log::debug!("inbound sequential router worker started");
inbound::sequential(device)
}));
}
}
// outbound/encryption workers
for _ in 0..num_workers {
// parallel workers (parallel processing)
{
let device = device.clone();
let rx = outrx.pop().unwrap();
threads.push(thread::spawn(move || {
log::debug!("outbound router worker started");
outbound::worker(rx)
log::debug!("outbound parallel router worker started");
outbound::parallel(device, rx)
}));
}
debug_assert_eq!(threads.len(), num_workers * 2);
// sequential workers (in-order processing)
{
let device = device.clone();
threads.push(thread::spawn(move || {
log::debug!("outbound sequential router worker started");
outbound::sequential(device)
}));
}
}
debug_assert_eq!(threads.len(), num_workers * 4);
// return exported device handle
DeviceHandle {
state: Device {
inner: Arc::new(inner),
},
state: device,
handles: threads,
}
}
@@ -192,7 +231,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX);
log::trace!(
"Router, outbound packet = {}",
"send, packet = {}",
hex::encode(&msg[SIZE_MESSAGE_PREFIX..])
);
@@ -208,7 +247,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
self.state.outbound_queue.send(job);
self.state.queue_outbound.send(job);
}
Ok(())
@@ -225,6 +264,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
///
///
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
log::trace!("receive, src: {}", src.into_address());
// parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
@@ -255,7 +296,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
log::trace!("schedule decryption of transport message");
self.state.inbound_queue.send(job);
self.state.queue_inbound.send(job);
}
Ok(())
}

View File

@@ -4,6 +4,8 @@ use super::peer::Peer;
use super::pool::*;
use super::types::Callbacks;
use super::{tun, udp, Endpoint};
use super::device::Device;
use super::runq::RunQueue;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use zerocopy::{AsBytes, LayoutVerified};
@@ -38,10 +40,22 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C,
}
#[inline(always)]
fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
) {
// run queue to schedule
fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: &Device<E, C, T, B>,
) -> &RunQueue<Peer<E, C, T, B>> {
&device.run_inbound
}
// parallel work to apply
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
) {
) {
log::trace!("worker, parallel section, obtained job");
// cast to header followed by payload
@@ -104,13 +118,20 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
body.msg.truncate(mem::size_of::<TransportHeader>() + len);
}
}
}
worker_parallel(device, |dev| &dev.run_inbound, receiver, work)
}
#[inline(always)]
fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
) {
// sequential work to apply
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>
) {
log::trace!("worker, sequential section, obtained job");
// decryption failed, return early
@@ -160,17 +181,10 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
// trigger callback
C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
}
}
#[inline(always)]
fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> {
&peer.inbound
}
pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
) {
worker_template(receiver, parallel, sequential, queue)
// handle message from the peers inbound queue
device.run_inbound.run(|peer| {
peer.inbound.handle(|body| work(&peer, body));
});
}

View File

@@ -10,6 +10,7 @@ mod pool;
mod queue;
mod route;
mod types;
mod runq;
// mod workers;

View File

@@ -5,6 +5,7 @@ use super::types::Callbacks;
use super::KeyPair;
use super::REJECT_AFTER_MESSAGES;
use super::{tun, udp, Endpoint};
use super::device::Device;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
@@ -31,10 +32,15 @@ impl Outbound {
}
#[inline(always)]
fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
) {
fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
_peer: &Peer<E, C, T, B>,
body: &mut Outbound,
) {
) {
log::trace!("worker, parallel section, obtained job");
// make space for the tag
@@ -72,13 +78,18 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
// append tag
packet[end..].copy_from_slice(tag.as_ref());
}
worker_parallel(device, |dev| &dev.run_outbound, receiver, work);
}
#[inline(always)]
fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Outbound,
pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
device: Device<E, C, T, B>,
) {
device.run_outbound.run(|peer| {
peer.outbound.handle(|body| {
log::trace!("worker, sequential section, obtained job");
// send to peer
@@ -92,17 +103,6 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
&body.keypair,
body.counter,
);
}
#[inline(always)]
pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
) -> &InorderQueue<Peer<E, C, T, B>, Outbound> {
&peer.outbound
}
pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
) {
worker_template(receiver, parallel, sequential, queue)
});
});
}

View File

@@ -20,6 +20,7 @@ use super::messages::TransportHeader;
use super::constants::*;
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
use super::runq::ToKey;
// worker pool related
use super::inbound::Inbound;
@@ -56,14 +57,28 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Pee
}
}
/* Equality of peers is defined as pointer equality
* the atomic reference counted pointer.
*/
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> {
type Key = usize;
fn to_key(&self) -> usize {
Arc::downgrade(&self.inner).into_raw() as usize
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
/* A peer is transparently dereferenced to the inner type
*
*/
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
type Target = PeerInner<E, C, T, B>;
fn deref(&self) -> &Self::Target {
@@ -71,6 +86,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Pee
}
}
/* A peer handle is a specially designated peer pointer
* which removes the peer from the device when dropped.
*/
pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
peer: Peer<E, C, T, B>,
}
@@ -227,7 +246,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
log::debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => {
self.device.outbound_queue.send(job);
self.device.queue_outbound.send(job);
debug!("send_raw: got obtained send_job");
true
}

View File

@@ -2,6 +2,9 @@ use arraydeque::ArrayDeque;
use spin::{Mutex, MutexGuard};
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use std::mem;
use super::runq::{RunQueue, ToKey};
const INORDER_QUEUE_SIZE: usize = 64;
@@ -60,51 +63,53 @@ impl<P, B> InorderQueue<P, B> {
}
#[inline(always)]
pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) {
pub fn handle<F: Fn(&mut B)>(&self, f: F) {
// take the mutex
let mut queue = self.queue.lock();
// handle all complete messages
while queue
.pop_front()
.and_then(|j| {
// check if job is complete
let ret = if let Some(mut guard) = j.complete() {
f(&mut *guard);
loop {
// attempt to extract front element
let front = queue.pop_front();
let elem = match front {
Some(elem) => elem,
_ => {
return;
}
};
// apply function if job complete
let ret = if let Some(mut guard) = elem.complete() {
mem::drop(queue);
f(&mut guard.body);
queue = self.queue.lock();
false
} else {
true
};
// return job to cyclic buffer if not complete
// job not complete yet, return job to front
if ret {
let _res = queue.push_front(j);
debug_assert!(_res.is_ok());
None
} else {
// add job back to pool
Some(())
queue.push_front(elem).unwrap();
return;
}
}
})
.is_some()
{}
}
}
/// Allows easy construction of a semi-parallel worker.
/// Applicable for both decryption and encryption workers.
#[inline(always)]
pub fn worker_template<
P, // represents a peer (atomic reference counted pointer)
pub fn worker_parallel<
P : ToKey, // represents a peer (atomic reference counted pointer)
B, // inner body type (message buffer, key material, ...)
D, // device
W: Fn(&P, &mut B),
S: Fn(&P, &mut B),
Q: Fn(&P) -> &InorderQueue<P, B>,
Q: Fn(&D) -> &RunQueue<P>,
>(
receiver: Receiver<Job<P, B>>, // receiever for new jobs
work_parallel: W, // perform parallel / out-of-order work on peer
work_sequential: S, // perform sequential work on peer
queue: Q, // resolve a peer to an inorder queue
device: D,
queue: Q,
receiver: Receiver<Job<P, B>>,
work: W,
) {
log::trace!("router worker started");
loop {
@@ -123,11 +128,11 @@ pub fn worker_template<
let peer = job.peer.take().unwrap();
// process job
work_parallel(&peer, &mut job.body);
work(&peer, &mut job.body);
peer
};
// process inorder jobs for peer
queue(&peer).handle(|j| work_sequential(&peer, &mut j.body));
queue(&device).insert(peer);
}
}

View File

@@ -0,0 +1,145 @@
use std::mem;
use std::sync::{Condvar, Mutex};
use std::hash::Hash;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::VecDeque;
pub trait ToKey {
type Key: Hash + Eq;
fn to_key(&self) -> Self::Key;
}
pub struct RunQueue<T : ToKey> {
cvar: Condvar,
inner: Mutex<Inner<T>>,
}
struct Inner<T : ToKey> {
stop: bool,
queue: VecDeque<T>,
members: HashMap<T::Key, usize>,
}
impl<T : ToKey> RunQueue<T> {
pub fn close(&self) {
let mut inner = self.inner.lock().unwrap();
inner.stop = true;
self.cvar.notify_all();
}
pub fn new() -> RunQueue<T> {
RunQueue {
cvar: Condvar::new(),
inner: Mutex::new(Inner {
stop:false,
queue: VecDeque::new(),
members: HashMap::new(),
}),
}
}
pub fn insert(&self, v: T) {
let key = v.to_key();
let mut inner = self.inner.lock().unwrap();
match inner.members.entry(key) {
Entry::Occupied(mut elem) => {
*elem.get_mut() += 1;
}
Entry::Vacant(spot) => {
// add entry to back of queue
spot.insert(0);
inner.queue.push_back(v);
// wake a thread
self.cvar.notify_one();
}
}
}
pub fn run<F: Fn(&T) -> ()>(&self, f: F) {
let mut inner = self.inner.lock().unwrap();
loop {
// fetch next element
let elem = loop {
// run-queue closed
if inner.stop {
return;
}
// try to pop from queue
match inner.queue.pop_front() {
Some(elem) => {
break elem;
}
None => (),
};
// wait for an element to be inserted
inner = self.cvar.wait(inner).unwrap();
};
// fetch current request number
let key = elem.to_key();
let old_n = *inner.members.get(&key).unwrap();
mem::drop(inner); // drop guard
// handle element
f(&elem);
// retake lock and check if should be added back to queue
inner = self.inner.lock().unwrap();
match inner.members.entry(key) {
Entry::Occupied(occ) => {
if *occ.get() == old_n {
// no new requests since last, remove entry.
occ.remove();
} else {
// new requests, reschedule.
inner.queue.push_back(elem);
}
}
Entry::Vacant(_) => {
unreachable!();
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::sync::Arc;
use std::time::Duration;
/*
#[test]
fn test_wait() {
let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new());
{
let queue = queue.clone();
thread::spawn(move || {
queue.run(|e| {
println!("t0 {}", e);
thread::sleep(Duration::from_millis(100));
})
});
}
{
let queue = queue.clone();
thread::spawn(move || {
queue.run(|e| {
println!("t1 {}", e);
thread::sleep(Duration::from_millis(100));
})
});
}
}
*/
}

View File

@@ -273,6 +273,8 @@ mod tests {
}
}
}
println!("Test complete, drop device");
}
#[test]