Added zero_key to peer
This commit is contained in:
@@ -64,10 +64,16 @@ impl Device {
|
|||||||
self.macs = macs::Validator::new(pk);
|
self.macs = macs::Validator::new(pk);
|
||||||
|
|
||||||
// recalculate the shared secrets for every peer
|
// recalculate the shared secrets for every peer
|
||||||
for &mut peer in self.pk_map.values_mut() {
|
let mut ids = vec![];
|
||||||
peer.reset_state().map(|id| self.release(id));
|
for mut peer in self.pk_map.values_mut() {
|
||||||
|
peer.reset_state().map(|id| ids.push(id));
|
||||||
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// release ids from aborted handshakes
|
||||||
|
for id in ids {
|
||||||
|
self.release(id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new public key to the state machine
|
/// Add a new public key to the state machine
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
|
|||||||
mod constants;
|
mod constants;
|
||||||
mod handshake;
|
mod handshake;
|
||||||
mod router;
|
mod router;
|
||||||
|
mod timers;
|
||||||
mod types;
|
mod types;
|
||||||
mod wireguard;
|
mod wireguard;
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ pub struct KeyWheel {
|
|||||||
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
|
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
|
||||||
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
|
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
|
||||||
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
|
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
|
||||||
retired: Option<u32>, // retired id (previous id, after confirming key-pair)
|
retired: Vec<u32>, // retired ids
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
|
||||||
@@ -188,7 +188,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
next: None,
|
next: None,
|
||||||
current: None,
|
current: None,
|
||||||
previous: None,
|
previous: None,
|
||||||
retired: None,
|
retired: vec![],
|
||||||
}),
|
}),
|
||||||
staged_packets: spin::Mutex::new(ArrayDeque::new()),
|
staged_packets: spin::Mutex::new(ArrayDeque::new()),
|
||||||
})
|
})
|
||||||
@@ -375,6 +375,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
|
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the current endpoint of the peer (for configuration)
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
///
|
||||||
|
/// Does not convey potential "sticky socket" information
|
||||||
pub fn get_endpoint(&self) -> Option<SocketAddr> {
|
pub fn get_endpoint(&self) -> Option<SocketAddr> {
|
||||||
self.state
|
self.state
|
||||||
.endpoint
|
.endpoint
|
||||||
@@ -383,6 +388,30 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
.map(|e| e.into_address())
|
.map(|e| e.into_address())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Zero all key-material related to the peer
|
||||||
|
pub fn zero_keys(&self) {
|
||||||
|
let mut release: Vec<u32> = Vec::with_capacity(3);
|
||||||
|
let mut keys = self.state.keys.lock();
|
||||||
|
|
||||||
|
// update key-wheel
|
||||||
|
|
||||||
|
mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
|
||||||
|
mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
|
||||||
|
mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
|
||||||
|
keys.retired.extend(&release[..]);
|
||||||
|
|
||||||
|
// update inbound "recv" map
|
||||||
|
{
|
||||||
|
let mut recv = self.state.device.recv.write();
|
||||||
|
for id in release {
|
||||||
|
recv.remove(&id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear encryption state
|
||||||
|
*self.state.ekey.lock() = None;
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new keypair
|
/// Add a new keypair
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -393,14 +422,16 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
///
|
///
|
||||||
/// A vector of ids which has been released.
|
/// A vector of ids which has been released.
|
||||||
/// These should be released in the handshake module.
|
/// These should be released in the handshake module.
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
///
|
||||||
|
/// The number of ids to be released can be at most 3,
|
||||||
|
/// since the only way to add additional keys to the peer is by using this method
|
||||||
|
/// and a peer can have at most 3 keys allocated in the router at any time.
|
||||||
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
|
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
|
||||||
let mut keys = self.state.keys.lock();
|
|
||||||
let mut release = Vec::with_capacity(2);
|
|
||||||
let new = Arc::new(new);
|
let new = Arc::new(new);
|
||||||
|
let mut keys = self.state.keys.lock();
|
||||||
// collect ids to be released
|
let mut release = mem::replace(&mut keys.retired, vec![]);
|
||||||
keys.retired.map(|v| release.push(v));
|
|
||||||
keys.previous.as_ref().map(|k| release.push(k.recv.id));
|
|
||||||
|
|
||||||
// update key-wheel
|
// update key-wheel
|
||||||
if new.initiator {
|
if new.initiator {
|
||||||
@@ -420,10 +451,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
{
|
{
|
||||||
let mut recv = self.state.device.recv.write();
|
let mut recv = self.state.device.recv.write();
|
||||||
|
|
||||||
// purge recv map of released ids
|
// purge recv map of previous id
|
||||||
for id in &release {
|
keys.previous.as_ref().map(|k| {
|
||||||
recv.remove(&id);
|
recv.remove(&k.local_id());
|
||||||
}
|
release.push(k.local_id());
|
||||||
|
});
|
||||||
|
|
||||||
// map new id to decryption state
|
// map new id to decryption state
|
||||||
debug_assert!(!recv.contains_key(&new.recv.id));
|
debug_assert!(!recv.contains_key(&new.recv.id));
|
||||||
@@ -442,7 +474,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the released id (for handshake state machine)
|
debug_assert!(release.len() <= 3);
|
||||||
release
|
release
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
65
src/timers.rs
Normal file
65
src/timers.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use hjul::{Runner, Timer};
|
||||||
|
|
||||||
|
use crate::router::Callbacks;
|
||||||
|
|
||||||
|
const ZERO_DURATION: Duration = Duration::from_micros(0);
|
||||||
|
|
||||||
|
pub struct TimersInner {
|
||||||
|
handshake_pending: AtomicBool,
|
||||||
|
handshake_attempts: AtomicUsize,
|
||||||
|
|
||||||
|
retransmit_handshake: Timer,
|
||||||
|
send_keepalive: Timer,
|
||||||
|
zero_key_material: Timer,
|
||||||
|
new_handshake: Timer,
|
||||||
|
|
||||||
|
// stats
|
||||||
|
rx_bytes: AtomicU64,
|
||||||
|
tx_bytes: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TimersInner {
|
||||||
|
pub fn new(runner: &Runner) -> Timers {
|
||||||
|
Arc::new(TimersInner {
|
||||||
|
handshake_pending: AtomicBool::new(false),
|
||||||
|
handshake_attempts: AtomicUsize::new(0),
|
||||||
|
retransmit_handshake: runner.timer(|| {}),
|
||||||
|
new_handshake: runner.timer(|| {}),
|
||||||
|
send_keepalive: runner.timer(|| {}),
|
||||||
|
zero_key_material: runner.timer(|| {}),
|
||||||
|
rx_bytes: AtomicU64::new(0),
|
||||||
|
tx_bytes: AtomicU64::new(0),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handshake_sent(&self) {
|
||||||
|
self.send_keepalive.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Timers = Arc<TimersInner>;
|
||||||
|
|
||||||
|
pub struct Events();
|
||||||
|
|
||||||
|
impl Callbacks for Events {
|
||||||
|
type Opaque = Timers;
|
||||||
|
|
||||||
|
fn send(t: &Timers, size: usize, data: bool, sent: bool) {
|
||||||
|
t.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {
|
||||||
|
t.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn need_key(t: &Timers) {
|
||||||
|
if !t.handshake_pending.swap(true, Ordering::SeqCst) {
|
||||||
|
t.handshake_attempts.store(0, Ordering::SeqCst);
|
||||||
|
t.new_handshake.reset(ZERO_DURATION);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,3 +28,9 @@ pub struct KeyPair {
|
|||||||
pub send: Key, // key for outbound messages
|
pub send: Key, // key for outbound messages
|
||||||
pub recv: Key, // key for inbound messages
|
pub recv: Key, // key for inbound messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl KeyPair {
|
||||||
|
pub fn local_id(&self) -> u32 {
|
||||||
|
self.recv.id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
116
src/wireguard.rs
116
src/wireguard.rs
@@ -1,5 +1,6 @@
|
|||||||
use crate::handshake;
|
use crate::handshake;
|
||||||
use crate::router;
|
use crate::router;
|
||||||
|
use crate::timers::{Events, Timers};
|
||||||
use crate::types::{Bind, Endpoint, Tun};
|
use crate::types::{Bind, Endpoint, Tun};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||||
@@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
|||||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||||
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
||||||
|
|
||||||
#[derive(Clone)]
|
type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>;
|
||||||
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
|
|
||||||
|
|
||||||
pub struct PeerInner<T: Tun, B: Bind> {
|
pub struct PeerInner<T: Tun, B: Bind> {
|
||||||
router: router::Peer<Events, T, B>,
|
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
|
||||||
timers: Timers,
|
router: router::Peer<Events, T, B>, // router peer
|
||||||
rx: AtomicU64,
|
timers: Option<Timers>, //
|
||||||
tx: AtomicU64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Timers {}
|
impl<T: Tun, B: Bind> PeerInner<T, B> {
|
||||||
|
#[inline(always)]
|
||||||
pub struct Events();
|
fn timers(&self) -> &Timers {
|
||||||
|
self.timers.as_ref().unwrap()
|
||||||
impl router::Callbacks for Events {
|
}
|
||||||
type Opaque = Timers;
|
|
||||||
|
|
||||||
fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
|
|
||||||
|
|
||||||
fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
|
|
||||||
|
|
||||||
fn need_key(t: &Timers) {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Handshake {
|
struct Handshake {
|
||||||
@@ -50,6 +42,11 @@ struct Handshake {
|
|||||||
active: bool,
|
active: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum HandshakeJob<E> {
|
||||||
|
Message(Vec<u8>, E),
|
||||||
|
New(PublicKey),
|
||||||
|
}
|
||||||
|
|
||||||
struct WireguardInner<T: Tun, B: Bind> {
|
struct WireguardInner<T: Tun, B: Bind> {
|
||||||
// identify and configuration map
|
// identify and configuration map
|
||||||
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||||
@@ -61,7 +58,7 @@ struct WireguardInner<T: Tun, B: Bind> {
|
|||||||
handshake: RwLock<Handshake>,
|
handshake: RwLock<Handshake>,
|
||||||
under_load: AtomicBool,
|
under_load: AtomicBool,
|
||||||
pending: AtomicUsize, // num of pending handshake packets in queue
|
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||||
queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
|
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
||||||
|
|
||||||
// IO
|
// IO
|
||||||
bind: B,
|
bind: B,
|
||||||
@@ -90,7 +87,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
||||||
// create device state
|
// create device state
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
let wg = Arc::new(WireguardInner {
|
let wg = Arc::new(WireguardInner {
|
||||||
peers: RwLock::new(HashMap::new()),
|
peers: RwLock::new(HashMap::new()),
|
||||||
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
|
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
|
||||||
@@ -114,50 +111,64 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
|
||||||
// process elements from the handshake queue
|
// process elements from the handshake queue
|
||||||
for (msg, src) in rx {
|
for job in rx {
|
||||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
|
||||||
// feed message to handshake device
|
|
||||||
let src_validate = (&src).into_address(); // TODO avoid
|
|
||||||
let state = wg.handshake.read();
|
let state = wg.handshake.read();
|
||||||
if !state.active {
|
if !state.active {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process message
|
match job {
|
||||||
match state.device.process(
|
HandshakeJob::Message(msg, src) => {
|
||||||
&mut rng,
|
// feed message to handshake device
|
||||||
&msg[..],
|
let src_validate = (&src).into_address(); // TODO avoid
|
||||||
if wg.under_load.load(Ordering::Relaxed) {
|
|
||||||
Some(&src_validate)
|
// process message
|
||||||
} else {
|
match state.device.process(
|
||||||
None
|
&mut rng,
|
||||||
},
|
&msg[..],
|
||||||
) {
|
if wg.under_load.load(Ordering::Relaxed) {
|
||||||
Ok((pk, msg, keypair)) => {
|
Some(&src_validate)
|
||||||
// send response
|
} else {
|
||||||
if let Some(msg) = msg {
|
None
|
||||||
let _ = bind.send(&msg[..], &src).map_err(|e| {
|
},
|
||||||
debug!(
|
) {
|
||||||
|
Ok((pk, msg, keypair)) => {
|
||||||
|
// send response
|
||||||
|
if let Some(msg) = msg {
|
||||||
|
let _ = bind.send(&msg[..], &src).map_err(|e| {
|
||||||
|
debug!(
|
||||||
"handshake worker, failed to send response, error = {:?}",
|
"handshake worker, failed to send response, error = {:?}",
|
||||||
e
|
e
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
if let Some(pk) = pk {
|
if let Some(pk) = pk {
|
||||||
// add keypair to peer and free any unused ids
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
if let Some(keypair) = keypair {
|
// update endpoint (DISCUSS: right semantics?)
|
||||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
peer.router.set_endpoint(src_validate);
|
||||||
for id in peer.0.router.add_keypair(keypair) {
|
|
||||||
state.device.release(id);
|
// add keypair to peer and free any unused ids
|
||||||
|
if let Some(keypair) = keypair {
|
||||||
|
for id in peer.router.add_keypair(keypair) {
|
||||||
|
state.device.release(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HandshakeJob::New(pk) => {
|
||||||
|
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
|
||||||
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
|
peer.router.send(&msg[..]);
|
||||||
|
peer.timers().handshake_sent();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -197,7 +208,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
wg.under_load.store(false, Ordering::SeqCst);
|
wg.under_load.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.queue.lock().send((msg, src)).unwrap();
|
wg.queue
|
||||||
|
.lock()
|
||||||
|
.send(HandshakeJob::Message(msg, src))
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
router::TYPE_TRANSPORT => {
|
router::TYPE_TRANSPORT => {
|
||||||
// transport message
|
// transport message
|
||||||
@@ -223,7 +237,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||||
msg.truncate(size);
|
msg.truncate(size);
|
||||||
|
|
||||||
// pad message to multiple of 16
|
// pad message to multiple of 16 bytes
|
||||||
while msg.len() < mtu && msg.len() % 16 != 0 {
|
while msg.len() < mtu && msg.len() % 16 != 0 {
|
||||||
msg.push(0);
|
msg.push(0);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user