Ensure peer threads are stopped on drop

This commit is contained in:
Mathias Hall-Andersen
2019-08-20 21:19:53 +02:00
parent f4da998812
commit 9cef264581
3 changed files with 158 additions and 101 deletions

View File

@@ -15,6 +15,8 @@ fn main() {
sodiumoxide::init().unwrap(); sodiumoxide::init().unwrap();
let mut router = router::Device::new(8); let mut router = router::Device::new(8);
{
let peer = router.new_peer(); let peer = router.new_peer();
}
loop {}
} }

View File

@@ -1,29 +1,29 @@
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Weak, Arc}; use std::sync::{Arc, Weak};
use std::thread; use std::thread;
use std::mem;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::mpsc::{sync_channel, SyncSender};
use spin; use spin;
use arraydeque::{ArrayDeque, Wrapping}; use arraydeque::{ArrayDeque, Wrapping};
use treebitmap::IpLookupTable;
use treebitmap::address::Address; use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use super::super::types::KeyPair;
use super::super::constants::*; use super::super::constants::*;
use super::super::types::KeyPair;
use super::anti_replay::AntiReplay; use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::DeviceInner; use super::device::DeviceInner;
use super::device::EncryptionState; use super::device::EncryptionState;
use super::device::DecryptionState; use super::workers::{worker_inbound, worker_outbound, JobInbound, JobOutbound};
const MAX_STAGED_PACKETS: usize = 128; const MAX_STAGED_PACKETS: usize = 128;
struct KeyWheel { pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed) next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption) current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption) previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
@@ -31,18 +31,18 @@ struct KeyWheel {
} }
pub struct PeerInner { pub struct PeerInner {
stopped: AtomicBool, pub stopped: AtomicBool,
device: Arc<DeviceInner>, pub device: Arc<DeviceInner>,
thread_outbound: spin::Mutex<thread::JoinHandle<()>>, pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
thread_inbound: spin::Mutex<thread::JoinHandle<()>>, pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
inorder_outbound: SyncSender<()>, pub queue_outbound: SyncSender<JobOutbound>,
inorder_inbound: SyncSender<()>, pub queue_inbound: SyncSender<JobInbound>,
staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
rx_bytes: AtomicU64, // received bytes pub rx_bytes: AtomicU64, // received bytes
tx_bytes: AtomicU64, // transmitted bytes pub tx_bytes: AtomicU64, // transmitted bytes
keys: spin::Mutex<KeyWheel>, // key-wheel pub keys: spin::Mutex<KeyWheel>, // key-wheel
ekey: spin::Mutex<Option<EncryptionState>>, // encryption state pub ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
endpoint: spin::Mutex<Option<Arc<SocketAddr>>>, pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
} }
pub struct Peer(Arc<PeerInner>); pub struct Peer(Arc<PeerInner>);
@@ -93,6 +93,7 @@ where
impl Drop for Peer { impl Drop for Peer {
fn drop(&mut self) { fn drop(&mut self) {
// mark peer as stopped // mark peer as stopped
let peer = &self.0; let peer = &self.0;
@@ -105,8 +106,19 @@ impl Drop for Peer {
// unpark threads // unpark threads
peer.thread_inbound.lock().thread().unpark(); peer.thread_inbound
peer.thread_outbound.lock().thread().unpark(); .lock()
.as_ref()
.unwrap()
.thread()
.unpark();
peer.thread_outbound
.lock()
.as_ref()
.unwrap()
.thread()
.unpark();
// release ids from the receiver map // release ids from the receiver map
@@ -132,26 +144,25 @@ impl Drop for Peer {
*peer.ekey.lock() = None; *peer.ekey.lock() = None;
*peer.endpoint.lock() = None; *peer.endpoint.lock() = None;
} }
} }
pub fn new_peer(device: Arc<DeviceInner>) -> Peer { pub fn new_peer(device: Arc<DeviceInner>) -> Peer {
// spawn inbound thread // allocate in-order queues
let (send_inbound, recv_inbound) = sync_channel(1); let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS);
let handle_inbound = thread::spawn(move || {}); let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS);
// spawn outbound thread
let (send_outbound, recv_inbound) = sync_channel(1);
let handle_outbound = thread::spawn(move || {});
// allocate peer object // allocate peer object
Peer::new(PeerInner { let peer = {
let device = device.clone();
Arc::new(PeerInner {
stopped: AtomicBool::new(false), stopped: AtomicBool::new(false),
device: device, device: device,
ekey: spin::Mutex::new(None), ekey: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None), endpoint: spin::Mutex::new(None),
inorder_inbound: send_inbound, queue_inbound: send_inbound,
inorder_outbound: send_outbound, queue_outbound: send_outbound,
keys: spin::Mutex::new(KeyWheel { keys: spin::Mutex::new(KeyWheel {
next: None, next: None,
current: None, current: None,
@@ -161,13 +172,34 @@ pub fn new_peer(device: Arc<DeviceInner>) -> Peer {
rx_bytes: AtomicU64::new(0), rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0),
staged_packets: spin::Mutex::new(ArrayDeque::new()), staged_packets: spin::Mutex::new(ArrayDeque::new()),
thread_inbound: spin::Mutex::new(handle_inbound), thread_inbound: spin::Mutex::new(None),
thread_outbound: spin::Mutex::new(handle_outbound), thread_outbound: spin::Mutex::new(None),
}) })
};
// spawn inbound thread
*peer.thread_inbound.lock() = {
let peer = peer.clone();
let device = device.clone();
Some(thread::spawn(move || {
worker_outbound(device, peer, recv_outbound)
}))
};
// spawn outbound thread
*peer.thread_outbound.lock() = {
let peer = peer.clone();
let device = device.clone();
Some(thread::spawn(move || {
worker_inbound(device, peer, recv_inbound)
}))
};
Peer(peer)
} }
impl Peer { impl Peer {
fn new(inner : PeerInner) -> Peer { fn new(inner: PeerInner) -> Peer {
Peer(Arc::new(inner)) Peer(Arc::new(inner))
} }

View File

@@ -6,7 +6,7 @@ use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use spin; use spin;
use std::iter; use std::iter;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{sync_channel, Receiver}; use std::sync::mpsc::{sync_channel, Receiver, TryRecvError};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::thread; use std::thread;
@@ -23,17 +23,17 @@ enum Status {
Waiting, // job awaiting completion Waiting, // job awaiting completion
} }
struct JobInner { pub struct JobInner {
msg: Vec<u8>, // message buffer (nonce and receiver id set) msg: Vec<u8>, // message buffer (nonce and receiver id set)
key: [u8; 32], // chacha20poly1305 key key: [u8; 32], // chacha20poly1305 key
status: Status, // state of the job status: Status, // state of the job
op: Operation, // should be buffer be encrypted / decrypted? op: Operation, // should be buffer be encrypted / decrypted?
} }
type JobBuffer = Arc<spin::Mutex<JobInner>>; pub type JobBuffer = Arc<spin::Mutex<JobInner>>;
type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer); pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
type JobInbound = (Arc<DecryptionState>, JobBuffer); pub type JobInbound = (Weak<DecryptionState>, JobBuffer);
type JobOutbound = (Weak<PeerInner>, JobBuffer); pub type JobOutbound = JobBuffer;
/* Strategy for workers acquiring a new job: /* Strategy for workers acquiring a new job:
* *
@@ -53,64 +53,87 @@ fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]
}) })
} }
fn worker_inbound( fn wait_buffer(stopped: AtomicBool, buf: &JobBuffer) {
while !stopped.load(Ordering::Acquire) {
match buf.try_lock() {
None => (),
Some(buf) => {
if buf.status == Status::Waiting {
return;
}
}
};
thread::park();
}
}
fn wait_recv<T>(stopped: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvError> {
while !stopped.load(Ordering::Acquire) {
match recv.try_recv() {
Err(TryRecvError::Empty) => (),
value => {
return value;
}
};
thread::park();
}
return Err(TryRecvError::Disconnected);
}
pub fn worker_inbound(
device: Arc<DeviceInner>, // related device device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer peer: Arc<PeerInner>, // related peer
recv: Receiver<JobInbound>, // in order queue recv: Receiver<JobInbound>, // in order queue
) { ) {
// reads from in order channel
for job in recv.recv().iter() {
loop { loop {
let (state, buf) = job; match wait_recv(&peer.stopped, &recv) {
Ok((state, buf)) => {
// check if job is complete while !peer.stopped.load(Ordering::Acquire) {
match buf.try_lock() { match buf.try_lock() {
None => (), None => (),
Some(buf) => { Some(buf) => {
if buf.status != Status::Waiting { if buf.status != Status::Waiting {
// check replay protector // consume
// check if confirms keypair
// write to tun device
// continue to next job (no parking)
break; break;
} }
} }
} };
// wait for job to complete
thread::park(); thread::park();
} }
} }
Err(_) => {
break;
}
}
}
} }
fn worker_outbound( pub fn worker_outbound(
device: Arc<DeviceInner>, // related device device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer peer: Arc<PeerInner>, // related peer
recv: Receiver<JobInbound>, // in order queue recv: Receiver<JobOutbound>, // in order queue
) { ) {
// reads from in order channel
for job in recv.recv().iter() {
loop { loop {
let (peer, buf) = job; match wait_recv(&peer.stopped, &recv) {
Ok(buf) => {
// check if job is complete while !peer.stopped.load(Ordering::Acquire) {
match buf.try_lock() { match buf.try_lock() {
None => (), None => (),
Some(buf) => { Some(buf) => {
if buf.status != Status::Waiting { if buf.status != Status::Waiting {
// send buffer to peer endpoint // consume
break; break;
} }
} }
} };
// wait for job to complete
thread::park(); thread::park();
} }
} }
Err(_) => {
break;
}
}
}
} }
fn worker_parallel( fn worker_parallel(