Removal of secret key in the handshake module

This commit is contained in:
Mathias Hall-Andersen
2019-11-08 19:00:12 +01:00
parent 293914e47b
commit dd85201c15
6 changed files with 189 additions and 168 deletions

View File

@@ -21,10 +21,14 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20; const MAX_PEER_PER_DEVICE: usize = 1 << 20;
pub struct KeyState {
pub sk: StaticSecret, // static secret key
pub pk: PublicKey, // static public key
macs: macs::Validator, // validator for the mac fields
}
pub struct Device { pub struct Device {
pub sk: StaticSecret, // static secret key keyst: Option<KeyState>, // secret/public key
pub pk: PublicKey, // static public key
macs: macs::Validator, // validator for the mac fields
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
limiter: Mutex<RateLimiter>, limiter: Mutex<RateLimiter>,
@@ -35,45 +39,68 @@ pub struct Device {
*/ */
impl Device { impl Device {
/// Initialize a new handshake state machine /// Initialize a new handshake state machine
/// pub fn new() -> Device {
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
pub fn new(sk: StaticSecret) -> Device {
let pk = PublicKey::from(&sk);
Device { Device {
pk, keyst: None,
sk,
macs: macs::Validator::new(pk),
pk_map: HashMap::new(), pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()), id_map: RwLock::new(HashMap::new()),
limiter: Mutex::new(RateLimiter::new()), limiter: Mutex::new(RateLimiter::new()),
} }
} }
fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> {
if let Some(key) = self.keyst.as_ref() {
if *peer.pk.as_bytes() == *key.pk.as_bytes() {
return Some(peer.pk);
}
peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
} else {
peer.ss = [0u8; 32];
};
None
}
/// Update the secret key of the device /// Update the secret key of the device
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `sk` - x25519 scalar representing the local private key /// * `sk` - x25519 scalar representing the local private key
pub fn set_sk(&mut self, sk: StaticSecret) { pub fn set_sk(&mut self, sk: Option<StaticSecret>) -> Option<PublicKey> {
// update secret and public key // update secret and public key
let pk = PublicKey::from(&sk); self.keyst = sk.map(|sk| {
self.sk = sk; let pk = PublicKey::from(&sk);
self.pk = pk; let macs = macs::Validator::new(pk);
self.macs = macs::Validator::new(pk); KeyState { pk, sk, macs }
});
// recalculate the shared secrets for every peer // recalculate / erase the shared secrets for every peer
let mut ids = vec![]; let mut ids = vec![];
let mut same = None;
for mut peer in self.pk_map.values_mut() { for mut peer in self.pk_map.values_mut() {
// clear any existing handshake state
peer.reset_state().map(|id| ids.push(id)); peer.reset_state().map(|id| ids.push(id));
peer.ss = self.sk.diffie_hellman(&peer.pk)
// update precomputed shared secret
if let Some(key) = self.keyst.as_ref() {
peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
if *peer.pk.as_bytes() == *key.pk.as_bytes() {
same = Some(peer.pk)
}
} else {
peer.ss = [0u8; 32];
};
} }
// release ids from aborted handshakes // release ids from aborted handshakes
for id in ids { for id in ids {
self.release(id) self.release(id)
} }
// if we found a peer matching the device public key, remove it.
same.map(|pk| {
self.pk_map.remove(pk.as_bytes());
pk
})
} }
/// Return the secret key of the device /// Return the secret key of the device
@@ -81,8 +108,8 @@ impl Device {
/// # Returns /// # Returns
/// ///
/// A secret key (x25519 scalar) /// A secret key (x25519 scalar)
pub fn get_sk(&self) -> StaticSecret { pub fn get_sk(&self) -> Option<&StaticSecret> {
StaticSecret::from(self.sk.to_bytes()) self.keyst.as_ref().map(|key| &key.sk)
} }
/// Add a new public key to the state machine /// Add a new public key to the state machine
@@ -93,28 +120,28 @@ impl Device {
/// * `pk` - The public key to add /// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers /// * `identifier` - Associated identifier which can be used to distinguish the peers
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
// check that the pk is not added twice
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
return Ok(());
};
// check that the pk is not that of the device
if *self.pk.as_bytes() == *pk.as_bytes() {
return Err(ConfigError::new(
"Public key corresponds to secret key of interface",
));
}
// ensure less than 2^20 peers // ensure less than 2^20 peers
if self.pk_map.len() > MAX_PEER_PER_DEVICE { if self.pk_map.len() > MAX_PEER_PER_DEVICE {
return Err(ConfigError::new("Too many peers for device")); return Err(ConfigError::new("Too many peers for device"));
} }
// map the public key to the peer state // create peer and precompute static secret
self.pk_map let mut peer = Peer::new(
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); pk,
self.keyst
.as_ref()
.map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
.unwrap_or([0u8; 32]),
);
Ok(()) // add peer to device
match self.update_ss(&mut peer) {
Some(_) => Err(ConfigError::new("Public key of peer matches the device")),
None => {
self.pk_map.insert(*pk.as_bytes(), peer);
Ok(())
}
}
} }
/// Remove a peer by public key /// Remove a peer by public key
@@ -203,17 +230,17 @@ impl Device {
rng: &mut R, rng: &mut R,
pk: &PublicKey, pk: &PublicKey,
) -> Result<Vec<u8>, HandshakeError> { ) -> Result<Vec<u8>, HandshakeError> {
match self.pk_map.get(pk.as_bytes()) { match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) {
None => Err(HandshakeError::UnknownPublicKey), (_, None) => Err(HandshakeError::UnknownPublicKey),
Some(peer) => { (None, _) => Err(HandshakeError::UnknownPublicKey),
(Some(keyst), Some(peer)) => {
let sender = self.allocate(rng, peer); let sender = self.allocate(rng, peer);
let mut msg = Initiation::default(); let mut msg = Initiation::default();
noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; // create noise part of initation
noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?;
// add macs to initation // add macs to initation
peer.macs peer.macs
.lock() .lock()
.generate(msg.noise.as_bytes(), &mut msg.macs); .generate(msg.noise.as_bytes(), &mut msg.macs);
@@ -242,6 +269,15 @@ impl Device {
return Err(HandshakeError::InvalidMessageFormat); return Err(HandshakeError::InvalidMessageFormat);
} }
// obtain reference to key state
// if no key is configured return a noop.
let keyst = match self.keyst.as_ref() {
Some(key) => key,
None => {
return Ok((None, None, None));
}
};
// de-multiplex the message type field // de-multiplex the message type field
match LittleEndian::read_u32(msg) { match LittleEndian::read_u32(msg) {
TYPE_INITIATION => { TYPE_INITIATION => {
@@ -249,7 +285,7 @@ impl Device {
let msg = Initiation::parse(msg)?; let msg = Initiation::parse(msg)?;
// check mac1 field // check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation // address validation & DoS mitigation
if let Some(src) = src { if let Some(src) = src {
@@ -257,9 +293,9 @@ impl Device {
let src = src.into(); let src = src.into();
// check mac2 field // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default(); let mut reply = Default::default();
self.macs.create_cookie_reply( keyst.macs.create_cookie_reply(
rng, rng,
msg.noise.f_sender.get(), msg.noise.f_sender.get(),
src, src,
@@ -276,7 +312,7 @@ impl Device {
} }
// consume the initiation // consume the initiation
let (peer, st) = noise::consume_initiation(self, &msg.noise)?; let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
// allocate new index for response // allocate new index for response
let sender = self.allocate(rng, peer); let sender = self.allocate(rng, peer);
@@ -304,7 +340,7 @@ impl Device {
let msg = Response::parse(msg)?; let msg = Response::parse(msg)?;
// check mac1 field // check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation // address validation & DoS mitigation
if let Some(src) = src { if let Some(src) = src {
@@ -312,9 +348,9 @@ impl Device {
let src = src.into(); let src = src.into();
// check mac2 field // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default(); let mut reply = Default::default();
self.macs.create_cookie_reply( keyst.macs.create_cookie_reply(
rng, rng,
msg.noise.f_sender.get(), msg.noise.f_sender.get(),
src, src,
@@ -331,7 +367,7 @@ impl Device {
} }
// consume inner playload // consume inner playload
noise::consume_response(self, &msg.noise) noise::consume_response(self, keyst, &msg.noise)
} }
TYPE_COOKIE_REPLY => { TYPE_COOKIE_REPLY => {
let msg = CookieReply::parse(msg)?; let msg = CookieReply::parse(msg)?;
@@ -421,8 +457,11 @@ mod tests {
// intialize devices on both ends // intialize devices on both ends
let mut dev1 = Device::new(sk1); let mut dev1 = Device::new();
let mut dev2 = Device::new(sk2); let mut dev2 = Device::new();
dev1.set_sk(Some(sk1));
dev2.set_sk(Some(sk2));
dev1.add(pk2).unwrap(); dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap(); dev2.add(pk1).unwrap();

View File

@@ -22,7 +22,7 @@ use clear_on_drop::clear_stack_on_return;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use super::device::Device; use super::device::{Device, KeyState};
use super::messages::{NoiseInitiation, NoiseResponse}; use super::messages::{NoiseInitiation, NoiseResponse};
use super::messages::{TYPE_INITIATION, TYPE_RESPONSE}; use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
use super::peer::{Peer, State}; use super::peer::{Peer, State};
@@ -219,7 +219,7 @@ mod tests {
pub fn create_initiation<R: RngCore + CryptoRng>( pub fn create_initiation<R: RngCore + CryptoRng>(
rng: &mut R, rng: &mut R,
device: &Device, keyst: &KeyState,
peer: &Peer, peer: &Peer,
sender: u32, sender: u32,
msg: &mut NoiseInitiation, msg: &mut NoiseInitiation,
@@ -260,9 +260,9 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
SEAL!( SEAL!(
&key, &key,
&hs, // ad &hs, // ad
device.pk.as_bytes(), // pt keyst.pk.as_bytes(), // pt
&mut msg.f_static // ct || tag &mut msg.f_static // ct || tag
); );
// H := Hash(H || msg.static) // H := Hash(H || msg.static)
@@ -271,7 +271,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
// (C, k) := Kdf2(C, DH(S_priv, S_pub)) // (C, k) := Kdf2(C, DH(S_priv, S_pub))
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); let (ck, key) = KDF2!(&ck, &peer.ss);
// msg.timestamp := Aead(k, 0, Timestamp(), H) // msg.timestamp := Aead(k, 0, Timestamp(), H)
@@ -301,6 +301,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
pub fn consume_initiation<'a>( pub fn consume_initiation<'a>(
device: &'a Device, device: &'a Device,
keyst: &KeyState,
msg: &NoiseInitiation, msg: &NoiseInitiation,
) -> Result<(&'a Peer, TemporaryState), HandshakeError> { ) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
debug!("consume initation"); debug!("consume initation");
@@ -309,7 +310,7 @@ pub fn consume_initiation<'a>(
let ck = INITIAL_CK; let ck = INITIAL_CK;
let hs = INITIAL_HS; let hs = INITIAL_HS;
let hs = HASH!(&hs, device.pk.as_bytes()); let hs = HASH!(&hs, keyst.pk.as_bytes());
// C := Kdf(C, E_pub) // C := Kdf(C, E_pub)
@@ -322,7 +323,7 @@ pub fn consume_initiation<'a>(
// (C, k) := Kdf2(C, DH(E_priv, S_pub)) // (C, k) := Kdf2(C, DH(E_priv, S_pub))
let eph_r_pk = PublicKey::from(msg.f_ephemeral); let eph_r_pk = PublicKey::from(msg.f_ephemeral);
let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); let (ck, key) = KDF2!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes());
// msg.static := Aead(k, 0, S_pub, H) // msg.static := Aead(k, 0, S_pub, H)
@@ -347,7 +348,7 @@ pub fn consume_initiation<'a>(
// (C, k) := Kdf2(C, DH(S_priv, S_pub)) // (C, k) := Kdf2(C, DH(S_priv, S_pub))
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); let (ck, key) = KDF2!(&ck, &peer.ss);
// msg.timestamp := Aead(k, 0, Timestamp(), H) // msg.timestamp := Aead(k, 0, Timestamp(), H)
@@ -461,7 +462,11 @@ pub fn create_response<R: RngCore + CryptoRng>(
* allow concurrent processing of potential responses to the initiation, * allow concurrent processing of potential responses to the initiation,
* in order to better mitigate DoS from malformed response messages. * in order to better mitigate DoS from malformed response messages.
*/ */
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> { pub fn consume_response(
device: &Device,
keyst: &KeyState,
msg: &NoiseResponse,
) -> Result<Output, HandshakeError> {
debug!("consume response"); debug!("consume response");
clear_stack_on_return(CLEAR_PAGES, || { clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state // retrieve peer and copy initiation state
@@ -492,7 +497,7 @@ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output,
// C := Kdf1(C, DH(E_priv, S_pub)) // C := Kdf1(C, DH(E_priv, S_pub))
let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); let ck = KDF1!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes());
// (C, tau, k) := Kdf3(C, Q) // (C, tau, k) := Kdf3(C, Q)

View File

@@ -33,9 +33,9 @@ pub struct Peer {
pub(crate) macs: Mutex<macs::Generator>, pub(crate) macs: Mutex<macs::Generator>,
// constant state // constant state
pub(crate) pk: PublicKey, // public key of peer pub(crate) pk: PublicKey, // public key of peer
pub(crate) ss: SharedSecret, // precomputed DH(static, static) pub(crate) ss: [u8; 32], // precomputed DH(static, static)
pub(crate) psk: Psk, // psk of peer pub(crate) psk: Psk, // psk of peer
} }
pub enum State { pub enum State {
@@ -62,17 +62,14 @@ impl Drop for State {
} }
impl Peer { impl Peer {
pub fn new( pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self {
pk: PublicKey, // public key of peer
ss: SharedSecret, // precomputed DH(static, static)
) -> Self {
Self { Self {
macs: Mutex::new(macs::Generator::new(pk)), macs: Mutex::new(macs::Generator::new(pk)),
state: Mutex::new(State::Reset), state: Mutex::new(State::Reset),
timestamp: Mutex::new(None), timestamp: Mutex::new(None),
last_initiation_consumption: Mutex::new(None), last_initiation_consumption: Mutex::new(None),
pk: pk, pk,
ss: ss, ss,
psk: [0u8; 32], psk: [0u8; 32],
} }
} }

View File

@@ -5,6 +5,7 @@ use super::HandshakeJob;
use super::bind::Bind; use super::bind::Bind;
use super::tun::Tun; use super::tun::Tun;
use super::wireguard::WireguardInner;
use std::fmt; use std::fmt;
use std::ops::Deref; use std::ops::Deref;
@@ -19,13 +20,16 @@ use x25519_dalek::PublicKey;
pub struct Peer<T: Tun, B: Bind> { pub struct Peer<T: Tun, B: Bind> {
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<B>>, pub state: Arc<PeerInner<T, B>>,
} }
pub struct PeerInner<B: Bind> { pub struct PeerInner<T: Tun, B: Bind> {
// internal id (for logging) // internal id (for logging)
pub id: u64, pub id: u64,
// wireguard device state
pub wg: Arc<WireguardInner<T, B>>,
// handshake state // handshake state
pub walltime_last_handshake: Mutex<SystemTime>, pub walltime_last_handshake: Mutex<SystemTime>,
pub last_handshake_sent: Mutex<Instant>, // instant for last handshake pub last_handshake_sent: Mutex<Instant>, // instant for last handshake
@@ -50,7 +54,7 @@ impl<T: Tun, B: Bind> Clone for Peer<T, B> {
} }
} }
impl<B: Bind> PeerInner<B> { impl<T: Tun, B: Bind> PeerInner<T, B> {
#[inline(always)] #[inline(always)]
pub fn timers(&self) -> RwLockReadGuard<Timers> { pub fn timers(&self) -> RwLockReadGuard<Timers> {
self.timers.read() self.timers.read()
@@ -69,7 +73,7 @@ impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
} }
impl<T: Tun, B: Bind> Deref for Peer<T, B> { impl<T: Tun, B: Bind> Deref for Peer<T, B> {
type Target = PeerInner<B>; type Target = PeerInner<T, B>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.state &self.state
} }
@@ -91,28 +95,3 @@ impl<T: Tun, B: Bind> Peer<T, B> {
self.start_timers(); self.start_timers();
} }
} }
impl<B: Bind> PeerInner<B> {
/* Queue a handshake request for the parallel workers
* (if one does not already exist)
*
* The function is ratelimited.
*/
pub fn packet_send_handshake_initiation(&self) {
// the function is rate limited
{
let mut lhs = self.last_handshake_sent.lock();
if lhs.elapsed() < REKEY_TIMEOUT {
return;
}
*lhs = Instant::now();
}
// create a new handshake job for the peer
if !self.handshake_queued.swap(true, Ordering::SeqCst) {
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}
}
}

View File

@@ -35,7 +35,7 @@ impl Timers {
} }
} }
impl<B: bind::Bind> PeerInner<B> { impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> {
pub fn stop_timers(&self) { pub fn stop_timers(&self) {
// take a write lock preventing simultaneous timer events or "start_timers" call // take a write lock preventing simultaneous timer events or "start_timers" call
let mut timers = self.timers_mut(); let mut timers = self.timers_mut();
@@ -180,7 +180,6 @@ impl<B: bind::Bind> PeerInner<B> {
*/ */
pub fn sent_handshake_initiation(&self) { pub fn sent_handshake_initiation(&self) {
*self.last_handshake_sent.lock() = Instant::now(); *self.last_handshake_sent.lock() = Instant::now();
self.handshake_queued.store(false, Ordering::SeqCst);
self.timers_set_retransmit_handshake(); self.timers_set_retransmit_handshake();
self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_traversal();
self.timers_any_authenticated_packet_sent(); self.timers_any_authenticated_packet_sent();
@@ -333,7 +332,7 @@ impl Timers {
pub struct Events<T, B>(PhantomData<(T, B)>); pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> { impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<B>>; type Opaque = Arc<PeerInner<T, B>>;
/* Called after the router encrypts a transport message destined for the peer. /* Called after the router encrypts a transport message destined for the peer.
* This method is called, even if the encrypted payload is empty (keepalive) * This method is called, even if the encrypted payload is empty (keepalive)

View File

@@ -42,46 +42,56 @@ pub struct WireguardInner<T: Tun, B: Bind> {
mtu: T::MTU, mtu: T::MTU,
send: RwLock<Option<B::Writer>>, send: RwLock<Option<B::Writer>>,
// identify and configuration map // identity and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptokey router // cryptokey router
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state // handshake related state
handshake: RwLock<Handshake>, handshake: RwLock<handshake::Device>,
under_load: AtomicBool, under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
} }
impl<T: Tun, B: Bind> PeerInner<T, B> {
/* Queue a handshake request for the parallel workers
* (if one does not already exist)
*
* The function is ratelimited.
*/
pub fn packet_send_handshake_initiation(&self) {
// the function is rate limited
{
let mut lhs = self.last_handshake_sent.lock();
if lhs.elapsed() < REKEY_TIMEOUT {
return;
}
*lhs = Instant::now();
}
// create a new handshake job for the peer
if !self.handshake_queued.swap(true, Ordering::SeqCst) {
self.wg.pending.fetch_add(1, Ordering::SeqCst);
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}
}
}
pub enum HandshakeJob<E> { pub enum HandshakeJob<E> {
Message(Vec<u8>, E), Message(Vec<u8>, E),
New(PublicKey), New(PublicKey),
} }
#[derive(Clone)]
pub struct WireguardHandle<T: Tun, B: Bind> {
inner: Arc<WireguardInner<T, B>>,
}
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> { impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "wireguard({:x})", self.id) write!(f, "wireguard({:x})", self.id)
} }
} }
struct Handshake {
device: handshake::Device,
active: bool,
}
impl<T: Tun, B: Bind> Deref for WireguardHandle<T, B> {
type Target = Arc<WireguardInner<T, B>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T: Tun, B: Bind> Deref for Wireguard<T, B> { impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
type Target = Arc<WireguardInner<T, B>>; type Target = Arc<WireguardInner<T, B>>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@@ -91,7 +101,7 @@ impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
pub struct Wireguard<T: Tun, B: Bind> { pub struct Wireguard<T: Tun, B: Bind> {
runner: Runner, runner: Runner,
state: WireguardHandle<T, B>, state: Arc<WireguardInner<T, B>>,
} }
/* Returns the padded length of a message: /* Returns the padded length of a message:
@@ -181,31 +191,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
pub fn set_key(&self, sk: Option<StaticSecret>) { pub fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write(); self.handshake.write().set_sk(sk);
match sk {
None => {
let mut rng = OsRng::new().unwrap();
handshake.device.set_sk(StaticSecret::new(&mut rng));
handshake.active = false;
}
Some(sk) => {
handshake.device.set_sk(sk);
handshake.active = true;
}
}
} }
pub fn get_sk(&self) -> Option<StaticSecret> { pub fn get_sk(&self) -> Option<StaticSecret> {
let handshake = self.state.handshake.read(); self.handshake
if handshake.active { .read()
Some(handshake.device.get_sk()) .get_sk()
} else { .map(|sk| StaticSecret::from(sk.to_bytes()))
None
}
} }
pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool { pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool {
self.state.handshake.write().device.set_psk(pk, psk).is_ok() self.state.handshake.write().set_psk(pk, psk).is_ok()
} }
pub fn add_peer(&self, pk: PublicKey) { pub fn add_peer(&self, pk: PublicKey) {
@@ -217,6 +214,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let state = Arc::new(PeerInner { let state = Arc::new(PeerInner {
id: rng.gen(), id: rng.gen(),
pk, pk,
wg: self.state.clone(),
walltime_last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), walltime_last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
handshake_queued: AtomicBool::new(false), handshake_queued: AtomicBool::new(false),
@@ -245,14 +243,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peers.entry(*pk.as_bytes()).or_insert(peer); peers.entry(*pk.as_bytes()).or_insert(peer);
// add to the handshake device // add to the handshake device
self.state.handshake.write().device.add(pk).unwrap(); // TODO: handle adding of public key for interface self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface
} }
/* Begin consuming messages from the reader. /// Begin consuming messages from the reader.
* /// Multiple readers can be added to support multi-queue and individual Ipv6/Ipv4 sockets interfaces
* Any previous reader thread is stopped by closing the previous reader, ///
* which unblocks the thread and causes an error on reader.read /// Any previous reader thread is stopped by closing the previous reader,
*/ /// which unblocks the thread and causes an error on reader.read
pub fn add_reader(&self, reader: B::Reader) { pub fn add_reader(&self, reader: B::Reader) {
let wg = self.state.clone(); let wg = self.state.clone();
thread::spawn(move || { thread::spawn(move || {
@@ -285,6 +283,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
| handshake::TYPE_RESPONSE => { | handshake::TYPE_RESPONSE => {
debug!("{} : reader, received handshake message", wg); debug!("{} : reader, received handshake message", wg);
// add one to pending
let pending = wg.pending.fetch_add(1, Ordering::SeqCst); let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
// update under_load flag // update under_load flag
@@ -297,6 +296,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
wg.under_load.store(false, Ordering::SeqCst); wg.under_load.store(false, Ordering::SeqCst);
} }
// add to handshake queue
wg.queue wg.queue
.lock() .lock()
.send(HandshakeJob::Message(msg, src)) .send(HandshakeJob::Message(msg, src))
@@ -325,7 +325,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
// create device state // create device state
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
// handshake queue
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner { let wg = Arc::new(WireguardInner {
start: Instant::now(), start: Instant::now(),
id: rng.gen(), id: rng.gen(),
@@ -334,10 +337,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
send: RwLock::new(None), send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0), pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake { handshake: RwLock::new(handshake::Device::new()),
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false), under_load: AtomicBool::new(false),
queue: Mutex::new(tx), queue: Mutex::new(tx),
}); });
@@ -350,24 +350,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
debug!("{} : handshake worker, started", wg); debug!("{} : handshake worker, started", wg);
// prepare OsRng instance for this thread // prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG");
// process elements from the handshake queue // process elements from the handshake queue
for job in rx { for job in rx {
let state = wg.handshake.read(); // decrement pending
if !state.active { wg.pending.fetch_sub(1, Ordering::SeqCst);
continue;
} let device = wg.handshake.read();
match job { match job {
HandshakeJob::Message(msg, src) => { HandshakeJob::Message(msg, src) => {
wg.pending.fetch_sub(1, Ordering::SeqCst);
// feed message to handshake device // feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid let src_validate = (&src).into_address(); // TODO avoid
// process message // process message
match state.device.process( match device.process(
&mut rng, &mut rng,
&msg[..], &msg[..],
if wg.under_load.load(Ordering::Relaxed) { if wg.under_load.load(Ordering::Relaxed) {
@@ -428,7 +426,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// free any unused ids // free any unused ids
for id in peer.router.add_keypair(kp) { for id in peer.router.add_keypair(kp) {
state.device.release(id); device.release(id);
} }
}); });
} }
@@ -438,15 +436,19 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
} }
HandshakeJob::New(pk) => { HandshakeJob::New(pk) => {
debug!("{} : handshake worker, new handshake requested", wg); if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
let _ = state.device.begin(&mut rng, &pk).map(|msg| { debug!(
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { "{} : handshake worker, new handshake requested for {}",
wg, peer
);
let _ = device.begin(&mut rng, &peer.pk).map(|msg| {
let _ = peer.router.send(&msg[..]).map_err(|e| { let _ = peer.router.send(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
}); });
peer.state.sent_handshake_initiation(); peer.state.sent_handshake_initiation();
} });
}); peer.handshake_queued.store(false, Ordering::SeqCst);
}
} }
} }
} }
@@ -498,7 +500,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
} }
Wireguard { Wireguard {
state: WireguardHandle { inner: wg }, state: wg,
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
} }
} }