Added zero_key to peer
This commit is contained in:
@@ -64,10 +64,16 @@ impl Device {
|
||||
self.macs = macs::Validator::new(pk);
|
||||
|
||||
// recalculate the shared secrets for every peer
|
||||
for &mut peer in self.pk_map.values_mut() {
|
||||
peer.reset_state().map(|id| self.release(id));
|
||||
let mut ids = vec![];
|
||||
for mut peer in self.pk_map.values_mut() {
|
||||
peer.reset_state().map(|id| ids.push(id));
|
||||
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
||||
}
|
||||
|
||||
// release ids from aborted handshakes
|
||||
for id in ids {
|
||||
self.release(id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new public key to the state machine
|
||||
|
||||
@@ -8,6 +8,7 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
|
||||
mod constants;
|
||||
mod handshake;
|
||||
mod router;
|
||||
mod timers;
|
||||
mod types;
|
||||
mod wireguard;
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ pub struct KeyWheel {
|
||||
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
|
||||
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
|
||||
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
|
||||
retired: Option<u32>, // retired id (previous id, after confirming key-pair)
|
||||
retired: Vec<u32>, // retired ids
|
||||
}
|
||||
|
||||
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
||||
@@ -188,7 +188,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
||||
next: None,
|
||||
current: None,
|
||||
previous: None,
|
||||
retired: None,
|
||||
retired: vec![],
|
||||
}),
|
||||
staged_packets: spin::Mutex::new(ArrayDeque::new()),
|
||||
})
|
||||
@@ -375,6 +375,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
|
||||
}
|
||||
|
||||
/// Returns the current endpoint of the peer (for configuration)
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// Does not convey potential "sticky socket" information
|
||||
pub fn get_endpoint(&self) -> Option<SocketAddr> {
|
||||
self.state
|
||||
.endpoint
|
||||
@@ -383,6 +388,30 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
.map(|e| e.into_address())
|
||||
}
|
||||
|
||||
/// Zero all key-material related to the peer
|
||||
pub fn zero_keys(&self) {
|
||||
let mut release: Vec<u32> = Vec::with_capacity(3);
|
||||
let mut keys = self.state.keys.lock();
|
||||
|
||||
// update key-wheel
|
||||
|
||||
mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
|
||||
mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
|
||||
mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
|
||||
keys.retired.extend(&release[..]);
|
||||
|
||||
// update inbound "recv" map
|
||||
{
|
||||
let mut recv = self.state.device.recv.write();
|
||||
for id in release {
|
||||
recv.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
// clear encryption state
|
||||
*self.state.ekey.lock() = None;
|
||||
}
|
||||
|
||||
/// Add a new keypair
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -393,14 +422,16 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
///
|
||||
/// A vector of ids which has been released.
|
||||
/// These should be released in the handshake module.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The number of ids to be released can be at most 3,
|
||||
/// since the only way to add additional keys to the peer is by using this method
|
||||
/// and a peer can have at most 3 keys allocated in the router at any time.
|
||||
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
|
||||
let mut keys = self.state.keys.lock();
|
||||
let mut release = Vec::with_capacity(2);
|
||||
let new = Arc::new(new);
|
||||
|
||||
// collect ids to be released
|
||||
keys.retired.map(|v| release.push(v));
|
||||
keys.previous.as_ref().map(|k| release.push(k.recv.id));
|
||||
let mut keys = self.state.keys.lock();
|
||||
let mut release = mem::replace(&mut keys.retired, vec![]);
|
||||
|
||||
// update key-wheel
|
||||
if new.initiator {
|
||||
@@ -420,10 +451,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
{
|
||||
let mut recv = self.state.device.recv.write();
|
||||
|
||||
// purge recv map of released ids
|
||||
for id in &release {
|
||||
recv.remove(&id);
|
||||
}
|
||||
// purge recv map of previous id
|
||||
keys.previous.as_ref().map(|k| {
|
||||
recv.remove(&k.local_id());
|
||||
release.push(k.local_id());
|
||||
});
|
||||
|
||||
// map new id to decryption state
|
||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||
@@ -442,7 +474,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
// return the released id (for handshake state machine)
|
||||
debug_assert!(release.len() <= 3);
|
||||
release
|
||||
}
|
||||
|
||||
|
||||
65
src/timers.rs
Normal file
65
src/timers.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use hjul::{Runner, Timer};
|
||||
|
||||
use crate::router::Callbacks;
|
||||
|
||||
const ZERO_DURATION: Duration = Duration::from_micros(0);
|
||||
|
||||
pub struct TimersInner {
|
||||
handshake_pending: AtomicBool,
|
||||
handshake_attempts: AtomicUsize,
|
||||
|
||||
retransmit_handshake: Timer,
|
||||
send_keepalive: Timer,
|
||||
zero_key_material: Timer,
|
||||
new_handshake: Timer,
|
||||
|
||||
// stats
|
||||
rx_bytes: AtomicU64,
|
||||
tx_bytes: AtomicU64,
|
||||
}
|
||||
|
||||
impl TimersInner {
|
||||
pub fn new(runner: &Runner) -> Timers {
|
||||
Arc::new(TimersInner {
|
||||
handshake_pending: AtomicBool::new(false),
|
||||
handshake_attempts: AtomicUsize::new(0),
|
||||
retransmit_handshake: runner.timer(|| {}),
|
||||
new_handshake: runner.timer(|| {}),
|
||||
send_keepalive: runner.timer(|| {}),
|
||||
zero_key_material: runner.timer(|| {}),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handshake_sent(&self) {
|
||||
self.send_keepalive.stop();
|
||||
}
|
||||
}
|
||||
|
||||
pub type Timers = Arc<TimersInner>;
|
||||
|
||||
pub struct Events();
|
||||
|
||||
impl Callbacks for Events {
|
||||
type Opaque = Timers;
|
||||
|
||||
fn send(t: &Timers, size: usize, data: bool, sent: bool) {
|
||||
t.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {
|
||||
t.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn need_key(t: &Timers) {
|
||||
if !t.handshake_pending.swap(true, Ordering::SeqCst) {
|
||||
t.handshake_attempts.store(0, Ordering::SeqCst);
|
||||
t.new_handshake.reset(ZERO_DURATION);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,3 +28,9 @@ pub struct KeyPair {
|
||||
pub send: Key, // key for outbound messages
|
||||
pub recv: Key, // key for inbound messages
|
||||
}
|
||||
|
||||
impl KeyPair {
|
||||
pub fn local_id(&self) -> u32 {
|
||||
self.recv.id
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::handshake;
|
||||
use crate::router;
|
||||
use crate::timers::{Events, Timers};
|
||||
use crate::types::{Bind, Endpoint, Tun};
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||
@@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
|
||||
type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>;
|
||||
|
||||
pub struct PeerInner<T: Tun, B: Bind> {
|
||||
router: router::Peer<Events, T, B>,
|
||||
timers: Timers,
|
||||
rx: AtomicU64,
|
||||
tx: AtomicU64,
|
||||
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
|
||||
router: router::Peer<Events, T, B>, // router peer
|
||||
timers: Option<Timers>, //
|
||||
}
|
||||
|
||||
pub struct Timers {}
|
||||
|
||||
pub struct Events();
|
||||
|
||||
impl router::Callbacks for Events {
|
||||
type Opaque = Timers;
|
||||
|
||||
fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
|
||||
|
||||
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
|
||||
|
||||
fn need_key(t: &Timers) {}
|
||||
impl<T: Tun, B: Bind> PeerInner<T, B> {
|
||||
#[inline(always)]
|
||||
fn timers(&self) -> &Timers {
|
||||
self.timers.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct Handshake {
|
||||
@@ -50,6 +42,11 @@ struct Handshake {
|
||||
active: bool,
|
||||
}
|
||||
|
||||
enum HandshakeJob<E> {
|
||||
Message(Vec<u8>, E),
|
||||
New(PublicKey),
|
||||
}
|
||||
|
||||
struct WireguardInner<T: Tun, B: Bind> {
|
||||
// identify and configuration map
|
||||
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||
@@ -61,7 +58,7 @@ struct WireguardInner<T: Tun, B: Bind> {
|
||||
handshake: RwLock<Handshake>,
|
||||
under_load: AtomicBool,
|
||||
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||
queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
|
||||
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
||||
|
||||
// IO
|
||||
bind: B,
|
||||
@@ -90,7 +87,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
||||
// create device state
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||
let wg = Arc::new(WireguardInner {
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
|
||||
@@ -114,16 +111,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
let mut rng = OsRng::new().unwrap();
|
||||
|
||||
// process elements from the handshake queue
|
||||
for (msg, src) in rx {
|
||||
for job in rx {
|
||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
let state = wg.handshake.read();
|
||||
if !state.active {
|
||||
continue;
|
||||
}
|
||||
|
||||
match job {
|
||||
HandshakeJob::Message(msg, src) => {
|
||||
// feed message to handshake device
|
||||
let src_validate = (&src).into_address(); // TODO avoid
|
||||
|
||||
// process message
|
||||
match state.device.process(
|
||||
&mut rng,
|
||||
@@ -147,10 +146,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
|
||||
// update timers
|
||||
if let Some(pk) = pk {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
// update endpoint (DISCUSS: right semantics?)
|
||||
peer.router.set_endpoint(src_validate);
|
||||
|
||||
// add keypair to peer and free any unused ids
|
||||
if let Some(keypair) = keypair {
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
for id in peer.0.router.add_keypair(keypair) {
|
||||
for id in peer.router.add_keypair(keypair) {
|
||||
state.device.release(id);
|
||||
}
|
||||
}
|
||||
@@ -160,6 +162,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||
}
|
||||
}
|
||||
HandshakeJob::New(pk) => {
|
||||
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
|
||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||
peer.router.send(&msg[..]);
|
||||
peer.timers().handshake_sent();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -197,7 +208,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
wg.under_load.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
wg.queue.lock().send((msg, src)).unwrap();
|
||||
wg.queue
|
||||
.lock()
|
||||
.send(HandshakeJob::Message(msg, src))
|
||||
.unwrap();
|
||||
}
|
||||
router::TYPE_TRANSPORT => {
|
||||
// transport message
|
||||
@@ -223,7 +237,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||
msg.truncate(size);
|
||||
|
||||
// pad message to multiple of 16
|
||||
// pad message to multiple of 16 bytes
|
||||
while msg.len() < mtu && msg.len() % 16 != 0 {
|
||||
msg.push(0);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user