Ensure peer threads are stopped on drop

This commit is contained in:
Mathias Hall-Andersen
2019-08-20 21:19:53 +02:00
parent f4da998812
commit 9cef264581
3 changed files with 158 additions and 101 deletions

View File

@@ -15,6 +15,8 @@ fn main() {
sodiumoxide::init().unwrap();
let mut router = router::Device::new(8);
let peer = router.new_peer();
{
let peer = router.new_peer();
}
loop {}
}

View File

@@ -1,29 +1,29 @@
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::sync::{Weak, Arc};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::thread;
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::sync::mpsc::{sync_channel, SyncSender};
use spin;
use arraydeque::{ArrayDeque, Wrapping};
use treebitmap::IpLookupTable;
use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use super::super::types::KeyPair;
use super::super::constants::*;
use super::super::types::KeyPair;
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::DeviceInner;
use super::device::EncryptionState;
use super::device::DecryptionState;
use super::workers::{worker_inbound, worker_outbound, JobInbound, JobOutbound};
const MAX_STAGED_PACKETS: usize = 128;
struct KeyWheel {
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
@@ -31,18 +31,18 @@ struct KeyWheel {
}
pub struct PeerInner {
stopped: AtomicBool,
device: Arc<DeviceInner>,
thread_outbound: spin::Mutex<thread::JoinHandle<()>>,
thread_inbound: spin::Mutex<thread::JoinHandle<()>>,
inorder_outbound: SyncSender<()>,
inorder_inbound: SyncSender<()>,
staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
rx_bytes: AtomicU64, // received bytes
tx_bytes: AtomicU64, // transmitted bytes
keys: spin::Mutex<KeyWheel>, // key-wheel
ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
pub stopped: AtomicBool,
pub device: Arc<DeviceInner>,
pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub queue_outbound: SyncSender<JobOutbound>,
pub queue_inbound: SyncSender<JobInbound>,
pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes
pub keys: spin::Mutex<KeyWheel>, // key-wheel
pub ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
}
pub struct Peer(Arc<PeerInner>);
@@ -93,6 +93,7 @@ where
impl Drop for Peer {
fn drop(&mut self) {
// mark peer as stopped
let peer = &self.0;
@@ -105,8 +106,19 @@ impl Drop for Peer {
// unpark threads
peer.thread_inbound.lock().thread().unpark();
peer.thread_outbound.lock().thread().unpark();
peer.thread_inbound
.lock()
.as_ref()
.unwrap()
.thread()
.unpark();
peer.thread_outbound
.lock()
.as_ref()
.unwrap()
.thread()
.unpark();
// release ids from the receiver map
@@ -132,42 +144,62 @@ impl Drop for Peer {
*peer.ekey.lock() = None;
*peer.endpoint.lock() = None;
}
}
pub fn new_peer(device: Arc<DeviceInner>) -> Peer {
// spawn inbound thread
let (send_inbound, recv_inbound) = sync_channel(1);
let handle_inbound = thread::spawn(move || {});
// spawn outbound thread
let (send_outbound, recv_inbound) = sync_channel(1);
let handle_outbound = thread::spawn(move || {});
// allocate in-order queues
let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS);
let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS);
// allocate peer object
Peer::new(PeerInner {
stopped: AtomicBool::new(false),
device: device,
ekey: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
inorder_inbound: send_inbound,
inorder_outbound: send_outbound,
keys: spin::Mutex::new(KeyWheel {
next: None,
current: None,
previous: None,
retired: None,
}),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
staged_packets: spin::Mutex::new(ArrayDeque::new()),
thread_inbound: spin::Mutex::new(handle_inbound),
thread_outbound: spin::Mutex::new(handle_outbound),
})
let peer = {
let device = device.clone();
Arc::new(PeerInner {
stopped: AtomicBool::new(false),
device: device,
ekey: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
queue_inbound: send_inbound,
queue_outbound: send_outbound,
keys: spin::Mutex::new(KeyWheel {
next: None,
current: None,
previous: None,
retired: None,
}),
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
staged_packets: spin::Mutex::new(ArrayDeque::new()),
thread_inbound: spin::Mutex::new(None),
thread_outbound: spin::Mutex::new(None),
})
};
// spawn inbound thread
*peer.thread_inbound.lock() = {
let peer = peer.clone();
let device = device.clone();
Some(thread::spawn(move || {
worker_outbound(device, peer, recv_outbound)
}))
};
// spawn outbound thread
*peer.thread_outbound.lock() = {
let peer = peer.clone();
let device = device.clone();
Some(thread::spawn(move || {
worker_inbound(device, peer, recv_inbound)
}))
};
Peer(peer)
}
impl Peer {
fn new(inner : PeerInner) -> Peer {
fn new(inner: PeerInner) -> Peer {
Peer(Arc::new(inner))
}
@@ -282,4 +314,4 @@ impl Peer {
));
res
}
}
}

View File

@@ -6,7 +6,7 @@ use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use spin;
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{sync_channel, Receiver};
use std::sync::mpsc::{sync_channel, Receiver, TryRecvError};
use std::sync::{Arc, Weak};
use std::thread;
@@ -23,17 +23,17 @@ enum Status {
Waiting, // job awaiting completion
}
struct JobInner {
pub struct JobInner {
msg: Vec<u8>, // message buffer (nonce and receiver id set)
key: [u8; 32], // chacha20poly1305 key
status: Status, // state of the job
op: Operation, // should be buffer be encrypted / decrypted?
}
type JobBuffer = Arc<spin::Mutex<JobInner>>;
type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
type JobInbound = (Arc<DecryptionState>, JobBuffer);
type JobOutbound = (Weak<PeerInner>, JobBuffer);
pub type JobBuffer = Arc<spin::Mutex<JobInner>>;
pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
pub type JobInbound = (Weak<DecryptionState>, JobBuffer);
pub type JobOutbound = JobBuffer;
/* Strategy for workers acquiring a new job:
*
@@ -53,62 +53,85 @@ fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]
})
}
fn worker_inbound(
fn wait_buffer(stopped: AtomicBool, buf: &JobBuffer) {
while !stopped.load(Ordering::Acquire) {
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status == Status::Waiting {
return;
}
}
};
thread::park();
}
}
fn wait_recv<T>(stopped: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvError> {
while !stopped.load(Ordering::Acquire) {
match recv.try_recv() {
Err(TryRecvError::Empty) => (),
value => {
return value;
}
};
thread::park();
}
return Err(TryRecvError::Disconnected);
}
pub fn worker_inbound(
device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer
recv: Receiver<JobInbound>, // in order queue
) {
// reads from in order channel
for job in recv.recv().iter() {
loop {
let (state, buf) = job;
// check if job is complete
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status != Status::Waiting {
// check replay protector
// check if confirms keypair
// write to tun device
// continue to next job (no parking)
break;
}
loop {
match wait_recv(&peer.stopped, &recv) {
Ok((state, buf)) => {
while !peer.stopped.load(Ordering::Acquire) {
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status != Status::Waiting {
// consume
break;
}
}
};
thread::park();
}
}
// wait for job to complete
thread::park();
Err(_) => {
break;
}
}
}
}
fn worker_outbound(
device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer
recv: Receiver<JobInbound>, // in order queue
pub fn worker_outbound(
device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer
recv: Receiver<JobOutbound>, // in order queue
) {
// reads from in order channel
for job in recv.recv().iter() {
loop {
let (peer, buf) = job;
// check if job is complete
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status != Status::Waiting {
// send buffer to peer endpoint
break;
}
loop {
match wait_recv(&peer.stopped, &recv) {
Ok(buf) => {
while !peer.stopped.load(Ordering::Acquire) {
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status != Status::Waiting {
// consume
break;
}
}
};
thread::park();
}
}
// wait for job to complete
thread::park();
Err(_) => {
break;
}
}
}
}