Removal of secret key in the handshake module
This commit is contained in:
@@ -21,10 +21,14 @@ use super::types::*;
|
|||||||
|
|
||||||
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
||||||
|
|
||||||
|
pub struct KeyState {
|
||||||
|
pub sk: StaticSecret, // static secret key
|
||||||
|
pub pk: PublicKey, // static public key
|
||||||
|
macs: macs::Validator, // validator for the mac fields
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Device {
|
pub struct Device {
|
||||||
pub sk: StaticSecret, // static secret key
|
keyst: Option<KeyState>, // secret/public key
|
||||||
pub pk: PublicKey, // static public key
|
|
||||||
macs: macs::Validator, // validator for the mac fields
|
|
||||||
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
|
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
|
||||||
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
||||||
limiter: Mutex<RateLimiter>,
|
limiter: Mutex<RateLimiter>,
|
||||||
@@ -35,45 +39,68 @@ pub struct Device {
|
|||||||
*/
|
*/
|
||||||
impl Device {
|
impl Device {
|
||||||
/// Initialize a new handshake state machine
|
/// Initialize a new handshake state machine
|
||||||
///
|
pub fn new() -> Device {
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `sk` - x25519 scalar representing the local private key
|
|
||||||
pub fn new(sk: StaticSecret) -> Device {
|
|
||||||
let pk = PublicKey::from(&sk);
|
|
||||||
Device {
|
Device {
|
||||||
pk,
|
keyst: None,
|
||||||
sk,
|
|
||||||
macs: macs::Validator::new(pk),
|
|
||||||
pk_map: HashMap::new(),
|
pk_map: HashMap::new(),
|
||||||
id_map: RwLock::new(HashMap::new()),
|
id_map: RwLock::new(HashMap::new()),
|
||||||
limiter: Mutex::new(RateLimiter::new()),
|
limiter: Mutex::new(RateLimiter::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> {
|
||||||
|
if let Some(key) = self.keyst.as_ref() {
|
||||||
|
if *peer.pk.as_bytes() == *key.pk.as_bytes() {
|
||||||
|
return Some(peer.pk);
|
||||||
|
}
|
||||||
|
peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
|
||||||
|
} else {
|
||||||
|
peer.ss = [0u8; 32];
|
||||||
|
};
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Update the secret key of the device
|
/// Update the secret key of the device
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `sk` - x25519 scalar representing the local private key
|
/// * `sk` - x25519 scalar representing the local private key
|
||||||
pub fn set_sk(&mut self, sk: StaticSecret) {
|
pub fn set_sk(&mut self, sk: Option<StaticSecret>) -> Option<PublicKey> {
|
||||||
// update secret and public key
|
// update secret and public key
|
||||||
let pk = PublicKey::from(&sk);
|
self.keyst = sk.map(|sk| {
|
||||||
self.sk = sk;
|
let pk = PublicKey::from(&sk);
|
||||||
self.pk = pk;
|
let macs = macs::Validator::new(pk);
|
||||||
self.macs = macs::Validator::new(pk);
|
KeyState { pk, sk, macs }
|
||||||
|
});
|
||||||
|
|
||||||
// recalculate the shared secrets for every peer
|
// recalculate / erase the shared secrets for every peer
|
||||||
let mut ids = vec![];
|
let mut ids = vec![];
|
||||||
|
let mut same = None;
|
||||||
for mut peer in self.pk_map.values_mut() {
|
for mut peer in self.pk_map.values_mut() {
|
||||||
|
// clear any existing handshake state
|
||||||
peer.reset_state().map(|id| ids.push(id));
|
peer.reset_state().map(|id| ids.push(id));
|
||||||
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
|
||||||
|
// update precomputed shared secret
|
||||||
|
if let Some(key) = self.keyst.as_ref() {
|
||||||
|
peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
|
||||||
|
if *peer.pk.as_bytes() == *key.pk.as_bytes() {
|
||||||
|
same = Some(peer.pk)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
peer.ss = [0u8; 32];
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// release ids from aborted handshakes
|
// release ids from aborted handshakes
|
||||||
for id in ids {
|
for id in ids {
|
||||||
self.release(id)
|
self.release(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if we found a peer matching the device public key, remove it.
|
||||||
|
same.map(|pk| {
|
||||||
|
self.pk_map.remove(pk.as_bytes());
|
||||||
|
pk
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the secret key of the device
|
/// Return the secret key of the device
|
||||||
@@ -81,8 +108,8 @@ impl Device {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// A secret key (x25519 scalar)
|
/// A secret key (x25519 scalar)
|
||||||
pub fn get_sk(&self) -> StaticSecret {
|
pub fn get_sk(&self) -> Option<&StaticSecret> {
|
||||||
StaticSecret::from(self.sk.to_bytes())
|
self.keyst.as_ref().map(|key| &key.sk)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new public key to the state machine
|
/// Add a new public key to the state machine
|
||||||
@@ -93,28 +120,28 @@ impl Device {
|
|||||||
/// * `pk` - The public key to add
|
/// * `pk` - The public key to add
|
||||||
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
||||||
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
|
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
|
||||||
// check that the pk is not added twice
|
|
||||||
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
|
||||||
return Ok(());
|
|
||||||
};
|
|
||||||
|
|
||||||
// check that the pk is not that of the device
|
|
||||||
if *self.pk.as_bytes() == *pk.as_bytes() {
|
|
||||||
return Err(ConfigError::new(
|
|
||||||
"Public key corresponds to secret key of interface",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure less than 2^20 peers
|
// ensure less than 2^20 peers
|
||||||
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
|
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
|
||||||
return Err(ConfigError::new("Too many peers for device"));
|
return Err(ConfigError::new("Too many peers for device"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// map the public key to the peer state
|
// create peer and precompute static secret
|
||||||
self.pk_map
|
let mut peer = Peer::new(
|
||||||
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
|
pk,
|
||||||
|
self.keyst
|
||||||
|
.as_ref()
|
||||||
|
.map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
|
||||||
|
.unwrap_or([0u8; 32]),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
// add peer to device
|
||||||
|
match self.update_ss(&mut peer) {
|
||||||
|
Some(_) => Err(ConfigError::new("Public key of peer matches the device")),
|
||||||
|
None => {
|
||||||
|
self.pk_map.insert(*pk.as_bytes(), peer);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove a peer by public key
|
/// Remove a peer by public key
|
||||||
@@ -203,17 +230,17 @@ impl Device {
|
|||||||
rng: &mut R,
|
rng: &mut R,
|
||||||
pk: &PublicKey,
|
pk: &PublicKey,
|
||||||
) -> Result<Vec<u8>, HandshakeError> {
|
) -> Result<Vec<u8>, HandshakeError> {
|
||||||
match self.pk_map.get(pk.as_bytes()) {
|
match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) {
|
||||||
None => Err(HandshakeError::UnknownPublicKey),
|
(_, None) => Err(HandshakeError::UnknownPublicKey),
|
||||||
Some(peer) => {
|
(None, _) => Err(HandshakeError::UnknownPublicKey),
|
||||||
|
(Some(keyst), Some(peer)) => {
|
||||||
let sender = self.allocate(rng, peer);
|
let sender = self.allocate(rng, peer);
|
||||||
|
|
||||||
let mut msg = Initiation::default();
|
let mut msg = Initiation::default();
|
||||||
|
|
||||||
noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
|
// create noise part of initation
|
||||||
|
noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?;
|
||||||
|
|
||||||
// add macs to initation
|
// add macs to initation
|
||||||
|
|
||||||
peer.macs
|
peer.macs
|
||||||
.lock()
|
.lock()
|
||||||
.generate(msg.noise.as_bytes(), &mut msg.macs);
|
.generate(msg.noise.as_bytes(), &mut msg.macs);
|
||||||
@@ -242,6 +269,15 @@ impl Device {
|
|||||||
return Err(HandshakeError::InvalidMessageFormat);
|
return Err(HandshakeError::InvalidMessageFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// obtain reference to key state
|
||||||
|
// if no key is configured return a noop.
|
||||||
|
let keyst = match self.keyst.as_ref() {
|
||||||
|
Some(key) => key,
|
||||||
|
None => {
|
||||||
|
return Ok((None, None, None));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// de-multiplex the message type field
|
// de-multiplex the message type field
|
||||||
match LittleEndian::read_u32(msg) {
|
match LittleEndian::read_u32(msg) {
|
||||||
TYPE_INITIATION => {
|
TYPE_INITIATION => {
|
||||||
@@ -249,7 +285,7 @@ impl Device {
|
|||||||
let msg = Initiation::parse(msg)?;
|
let msg = Initiation::parse(msg)?;
|
||||||
|
|
||||||
// check mac1 field
|
// check mac1 field
|
||||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||||
|
|
||||||
// address validation & DoS mitigation
|
// address validation & DoS mitigation
|
||||||
if let Some(src) = src {
|
if let Some(src) = src {
|
||||||
@@ -257,9 +293,9 @@ impl Device {
|
|||||||
let src = src.into();
|
let src = src.into();
|
||||||
|
|
||||||
// check mac2 field
|
// check mac2 field
|
||||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||||
let mut reply = Default::default();
|
let mut reply = Default::default();
|
||||||
self.macs.create_cookie_reply(
|
keyst.macs.create_cookie_reply(
|
||||||
rng,
|
rng,
|
||||||
msg.noise.f_sender.get(),
|
msg.noise.f_sender.get(),
|
||||||
src,
|
src,
|
||||||
@@ -276,7 +312,7 @@ impl Device {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// consume the initiation
|
// consume the initiation
|
||||||
let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
|
let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
|
||||||
|
|
||||||
// allocate new index for response
|
// allocate new index for response
|
||||||
let sender = self.allocate(rng, peer);
|
let sender = self.allocate(rng, peer);
|
||||||
@@ -304,7 +340,7 @@ impl Device {
|
|||||||
let msg = Response::parse(msg)?;
|
let msg = Response::parse(msg)?;
|
||||||
|
|
||||||
// check mac1 field
|
// check mac1 field
|
||||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||||
|
|
||||||
// address validation & DoS mitigation
|
// address validation & DoS mitigation
|
||||||
if let Some(src) = src {
|
if let Some(src) = src {
|
||||||
@@ -312,9 +348,9 @@ impl Device {
|
|||||||
let src = src.into();
|
let src = src.into();
|
||||||
|
|
||||||
// check mac2 field
|
// check mac2 field
|
||||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||||
let mut reply = Default::default();
|
let mut reply = Default::default();
|
||||||
self.macs.create_cookie_reply(
|
keyst.macs.create_cookie_reply(
|
||||||
rng,
|
rng,
|
||||||
msg.noise.f_sender.get(),
|
msg.noise.f_sender.get(),
|
||||||
src,
|
src,
|
||||||
@@ -331,7 +367,7 @@ impl Device {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// consume inner playload
|
// consume inner playload
|
||||||
noise::consume_response(self, &msg.noise)
|
noise::consume_response(self, keyst, &msg.noise)
|
||||||
}
|
}
|
||||||
TYPE_COOKIE_REPLY => {
|
TYPE_COOKIE_REPLY => {
|
||||||
let msg = CookieReply::parse(msg)?;
|
let msg = CookieReply::parse(msg)?;
|
||||||
@@ -421,8 +457,11 @@ mod tests {
|
|||||||
|
|
||||||
// intialize devices on both ends
|
// intialize devices on both ends
|
||||||
|
|
||||||
let mut dev1 = Device::new(sk1);
|
let mut dev1 = Device::new();
|
||||||
let mut dev2 = Device::new(sk2);
|
let mut dev2 = Device::new();
|
||||||
|
|
||||||
|
dev1.set_sk(Some(sk1));
|
||||||
|
dev2.set_sk(Some(sk2));
|
||||||
|
|
||||||
dev1.add(pk2).unwrap();
|
dev1.add(pk2).unwrap();
|
||||||
dev2.add(pk1).unwrap();
|
dev2.add(pk1).unwrap();
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ use clear_on_drop::clear_stack_on_return;
|
|||||||
|
|
||||||
use subtle::ConstantTimeEq;
|
use subtle::ConstantTimeEq;
|
||||||
|
|
||||||
use super::device::Device;
|
use super::device::{Device, KeyState};
|
||||||
use super::messages::{NoiseInitiation, NoiseResponse};
|
use super::messages::{NoiseInitiation, NoiseResponse};
|
||||||
use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
|
use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
|
||||||
use super::peer::{Peer, State};
|
use super::peer::{Peer, State};
|
||||||
@@ -219,7 +219,7 @@ mod tests {
|
|||||||
|
|
||||||
pub fn create_initiation<R: RngCore + CryptoRng>(
|
pub fn create_initiation<R: RngCore + CryptoRng>(
|
||||||
rng: &mut R,
|
rng: &mut R,
|
||||||
device: &Device,
|
keyst: &KeyState,
|
||||||
peer: &Peer,
|
peer: &Peer,
|
||||||
sender: u32,
|
sender: u32,
|
||||||
msg: &mut NoiseInitiation,
|
msg: &mut NoiseInitiation,
|
||||||
@@ -260,9 +260,9 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
|
|||||||
|
|
||||||
SEAL!(
|
SEAL!(
|
||||||
&key,
|
&key,
|
||||||
&hs, // ad
|
&hs, // ad
|
||||||
device.pk.as_bytes(), // pt
|
keyst.pk.as_bytes(), // pt
|
||||||
&mut msg.f_static // ct || tag
|
&mut msg.f_static // ct || tag
|
||||||
);
|
);
|
||||||
|
|
||||||
// H := Hash(H || msg.static)
|
// H := Hash(H || msg.static)
|
||||||
@@ -271,7 +271,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
|
|||||||
|
|
||||||
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
||||||
|
|
||||||
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
|
let (ck, key) = KDF2!(&ck, &peer.ss);
|
||||||
|
|
||||||
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
||||||
|
|
||||||
@@ -301,6 +301,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
|
|||||||
|
|
||||||
pub fn consume_initiation<'a>(
|
pub fn consume_initiation<'a>(
|
||||||
device: &'a Device,
|
device: &'a Device,
|
||||||
|
keyst: &KeyState,
|
||||||
msg: &NoiseInitiation,
|
msg: &NoiseInitiation,
|
||||||
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
|
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
|
||||||
debug!("consume initation");
|
debug!("consume initation");
|
||||||
@@ -309,7 +310,7 @@ pub fn consume_initiation<'a>(
|
|||||||
|
|
||||||
let ck = INITIAL_CK;
|
let ck = INITIAL_CK;
|
||||||
let hs = INITIAL_HS;
|
let hs = INITIAL_HS;
|
||||||
let hs = HASH!(&hs, device.pk.as_bytes());
|
let hs = HASH!(&hs, keyst.pk.as_bytes());
|
||||||
|
|
||||||
// C := Kdf(C, E_pub)
|
// C := Kdf(C, E_pub)
|
||||||
|
|
||||||
@@ -322,7 +323,7 @@ pub fn consume_initiation<'a>(
|
|||||||
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
|
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
|
||||||
|
|
||||||
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
|
let eph_r_pk = PublicKey::from(msg.f_ephemeral);
|
||||||
let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
let (ck, key) = KDF2!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||||
|
|
||||||
// msg.static := Aead(k, 0, S_pub, H)
|
// msg.static := Aead(k, 0, S_pub, H)
|
||||||
|
|
||||||
@@ -347,7 +348,7 @@ pub fn consume_initiation<'a>(
|
|||||||
|
|
||||||
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
|
||||||
|
|
||||||
let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
|
let (ck, key) = KDF2!(&ck, &peer.ss);
|
||||||
|
|
||||||
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
// msg.timestamp := Aead(k, 0, Timestamp(), H)
|
||||||
|
|
||||||
@@ -461,7 +462,11 @@ pub fn create_response<R: RngCore + CryptoRng>(
|
|||||||
* allow concurrent processing of potential responses to the initiation,
|
* allow concurrent processing of potential responses to the initiation,
|
||||||
* in order to better mitigate DoS from malformed response messages.
|
* in order to better mitigate DoS from malformed response messages.
|
||||||
*/
|
*/
|
||||||
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
|
pub fn consume_response(
|
||||||
|
device: &Device,
|
||||||
|
keyst: &KeyState,
|
||||||
|
msg: &NoiseResponse,
|
||||||
|
) -> Result<Output, HandshakeError> {
|
||||||
debug!("consume response");
|
debug!("consume response");
|
||||||
clear_stack_on_return(CLEAR_PAGES, || {
|
clear_stack_on_return(CLEAR_PAGES, || {
|
||||||
// retrieve peer and copy initiation state
|
// retrieve peer and copy initiation state
|
||||||
@@ -492,7 +497,7 @@ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output,
|
|||||||
|
|
||||||
// C := Kdf1(C, DH(E_priv, S_pub))
|
// C := Kdf1(C, DH(E_priv, S_pub))
|
||||||
|
|
||||||
let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
let ck = KDF1!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes());
|
||||||
|
|
||||||
// (C, tau, k) := Kdf3(C, Q)
|
// (C, tau, k) := Kdf3(C, Q)
|
||||||
|
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ pub struct Peer {
|
|||||||
pub(crate) macs: Mutex<macs::Generator>,
|
pub(crate) macs: Mutex<macs::Generator>,
|
||||||
|
|
||||||
// constant state
|
// constant state
|
||||||
pub(crate) pk: PublicKey, // public key of peer
|
pub(crate) pk: PublicKey, // public key of peer
|
||||||
pub(crate) ss: SharedSecret, // precomputed DH(static, static)
|
pub(crate) ss: [u8; 32], // precomputed DH(static, static)
|
||||||
pub(crate) psk: Psk, // psk of peer
|
pub(crate) psk: Psk, // psk of peer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum State {
|
pub enum State {
|
||||||
@@ -62,17 +62,14 @@ impl Drop for State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Peer {
|
impl Peer {
|
||||||
pub fn new(
|
pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self {
|
||||||
pk: PublicKey, // public key of peer
|
|
||||||
ss: SharedSecret, // precomputed DH(static, static)
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
macs: Mutex::new(macs::Generator::new(pk)),
|
macs: Mutex::new(macs::Generator::new(pk)),
|
||||||
state: Mutex::new(State::Reset),
|
state: Mutex::new(State::Reset),
|
||||||
timestamp: Mutex::new(None),
|
timestamp: Mutex::new(None),
|
||||||
last_initiation_consumption: Mutex::new(None),
|
last_initiation_consumption: Mutex::new(None),
|
||||||
pk: pk,
|
pk,
|
||||||
ss: ss,
|
ss,
|
||||||
psk: [0u8; 32],
|
psk: [0u8; 32],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use super::HandshakeJob;
|
|||||||
|
|
||||||
use super::bind::Bind;
|
use super::bind::Bind;
|
||||||
use super::tun::Tun;
|
use super::tun::Tun;
|
||||||
|
use super::wireguard::WireguardInner;
|
||||||
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
@@ -19,13 +20,16 @@ use x25519_dalek::PublicKey;
|
|||||||
|
|
||||||
pub struct Peer<T: Tun, B: Bind> {
|
pub struct Peer<T: Tun, B: Bind> {
|
||||||
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
|
||||||
pub state: Arc<PeerInner<B>>,
|
pub state: Arc<PeerInner<T, B>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PeerInner<B: Bind> {
|
pub struct PeerInner<T: Tun, B: Bind> {
|
||||||
// internal id (for logging)
|
// internal id (for logging)
|
||||||
pub id: u64,
|
pub id: u64,
|
||||||
|
|
||||||
|
// wireguard device state
|
||||||
|
pub wg: Arc<WireguardInner<T, B>>,
|
||||||
|
|
||||||
// handshake state
|
// handshake state
|
||||||
pub walltime_last_handshake: Mutex<SystemTime>,
|
pub walltime_last_handshake: Mutex<SystemTime>,
|
||||||
pub last_handshake_sent: Mutex<Instant>, // instant for last handshake
|
pub last_handshake_sent: Mutex<Instant>, // instant for last handshake
|
||||||
@@ -50,7 +54,7 @@ impl<T: Tun, B: Bind> Clone for Peer<T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Bind> PeerInner<B> {
|
impl<T: Tun, B: Bind> PeerInner<T, B> {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn timers(&self) -> RwLockReadGuard<Timers> {
|
pub fn timers(&self) -> RwLockReadGuard<Timers> {
|
||||||
self.timers.read()
|
self.timers.read()
|
||||||
@@ -69,7 +73,7 @@ impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Tun, B: Bind> Deref for Peer<T, B> {
|
impl<T: Tun, B: Bind> Deref for Peer<T, B> {
|
||||||
type Target = PeerInner<B>;
|
type Target = PeerInner<T, B>;
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.state
|
&self.state
|
||||||
}
|
}
|
||||||
@@ -91,28 +95,3 @@ impl<T: Tun, B: Bind> Peer<T, B> {
|
|||||||
self.start_timers();
|
self.start_timers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Bind> PeerInner<B> {
|
|
||||||
/* Queue a handshake request for the parallel workers
|
|
||||||
* (if one does not already exist)
|
|
||||||
*
|
|
||||||
* The function is ratelimited.
|
|
||||||
*/
|
|
||||||
pub fn packet_send_handshake_initiation(&self) {
|
|
||||||
// the function is rate limited
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut lhs = self.last_handshake_sent.lock();
|
|
||||||
if lhs.elapsed() < REKEY_TIMEOUT {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
*lhs = Instant::now();
|
|
||||||
}
|
|
||||||
|
|
||||||
// create a new handshake job for the peer
|
|
||||||
|
|
||||||
if !self.handshake_queued.swap(true, Ordering::SeqCst) {
|
|
||||||
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ impl Timers {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: bind::Bind> PeerInner<B> {
|
impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> {
|
||||||
pub fn stop_timers(&self) {
|
pub fn stop_timers(&self) {
|
||||||
// take a write lock preventing simultaneous timer events or "start_timers" call
|
// take a write lock preventing simultaneous timer events or "start_timers" call
|
||||||
let mut timers = self.timers_mut();
|
let mut timers = self.timers_mut();
|
||||||
@@ -180,7 +180,6 @@ impl<B: bind::Bind> PeerInner<B> {
|
|||||||
*/
|
*/
|
||||||
pub fn sent_handshake_initiation(&self) {
|
pub fn sent_handshake_initiation(&self) {
|
||||||
*self.last_handshake_sent.lock() = Instant::now();
|
*self.last_handshake_sent.lock() = Instant::now();
|
||||||
self.handshake_queued.store(false, Ordering::SeqCst);
|
|
||||||
self.timers_set_retransmit_handshake();
|
self.timers_set_retransmit_handshake();
|
||||||
self.timers_any_authenticated_packet_traversal();
|
self.timers_any_authenticated_packet_traversal();
|
||||||
self.timers_any_authenticated_packet_sent();
|
self.timers_any_authenticated_packet_sent();
|
||||||
@@ -333,7 +332,7 @@ impl Timers {
|
|||||||
pub struct Events<T, B>(PhantomData<(T, B)>);
|
pub struct Events<T, B>(PhantomData<(T, B)>);
|
||||||
|
|
||||||
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
|
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
|
||||||
type Opaque = Arc<PeerInner<B>>;
|
type Opaque = Arc<PeerInner<T, B>>;
|
||||||
|
|
||||||
/* Called after the router encrypts a transport message destined for the peer.
|
/* Called after the router encrypts a transport message destined for the peer.
|
||||||
* This method is called, even if the encrypted payload is empty (keepalive)
|
* This method is called, even if the encrypted payload is empty (keepalive)
|
||||||
|
|||||||
@@ -42,46 +42,56 @@ pub struct WireguardInner<T: Tun, B: Bind> {
|
|||||||
mtu: T::MTU,
|
mtu: T::MTU,
|
||||||
send: RwLock<Option<B::Writer>>,
|
send: RwLock<Option<B::Writer>>,
|
||||||
|
|
||||||
// identify and configuration map
|
// identity and configuration map
|
||||||
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||||
|
|
||||||
// cryptokey router
|
// cryptokey router
|
||||||
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
|
router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
|
||||||
|
|
||||||
// handshake related state
|
// handshake related state
|
||||||
handshake: RwLock<Handshake>,
|
handshake: RwLock<handshake::Device>,
|
||||||
under_load: AtomicBool,
|
under_load: AtomicBool,
|
||||||
pending: AtomicUsize, // num of pending handshake packets in queue
|
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||||
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Tun, B: Bind> PeerInner<T, B> {
|
||||||
|
/* Queue a handshake request for the parallel workers
|
||||||
|
* (if one does not already exist)
|
||||||
|
*
|
||||||
|
* The function is ratelimited.
|
||||||
|
*/
|
||||||
|
pub fn packet_send_handshake_initiation(&self) {
|
||||||
|
// the function is rate limited
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut lhs = self.last_handshake_sent.lock();
|
||||||
|
if lhs.elapsed() < REKEY_TIMEOUT {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
*lhs = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new handshake job for the peer
|
||||||
|
|
||||||
|
if !self.handshake_queued.swap(true, Ordering::SeqCst) {
|
||||||
|
self.wg.pending.fetch_add(1, Ordering::SeqCst);
|
||||||
|
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub enum HandshakeJob<E> {
|
pub enum HandshakeJob<E> {
|
||||||
Message(Vec<u8>, E),
|
Message(Vec<u8>, E),
|
||||||
New(PublicKey),
|
New(PublicKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct WireguardHandle<T: Tun, B: Bind> {
|
|
||||||
inner: Arc<WireguardInner<T, B>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
|
impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "wireguard({:x})", self.id)
|
write!(f, "wireguard({:x})", self.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Handshake {
|
|
||||||
device: handshake::Device,
|
|
||||||
active: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Tun, B: Bind> Deref for WireguardHandle<T, B> {
|
|
||||||
type Target = Arc<WireguardInner<T, B>>;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
|
impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
|
||||||
type Target = Arc<WireguardInner<T, B>>;
|
type Target = Arc<WireguardInner<T, B>>;
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
@@ -91,7 +101,7 @@ impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
|
|||||||
|
|
||||||
pub struct Wireguard<T: Tun, B: Bind> {
|
pub struct Wireguard<T: Tun, B: Bind> {
|
||||||
runner: Runner,
|
runner: Runner,
|
||||||
state: WireguardHandle<T, B>,
|
state: Arc<WireguardInner<T, B>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Returns the padded length of a message:
|
/* Returns the padded length of a message:
|
||||||
@@ -181,31 +191,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_key(&self, sk: Option<StaticSecret>) {
|
pub fn set_key(&self, sk: Option<StaticSecret>) {
|
||||||
let mut handshake = self.state.handshake.write();
|
self.handshake.write().set_sk(sk);
|
||||||
match sk {
|
|
||||||
None => {
|
|
||||||
let mut rng = OsRng::new().unwrap();
|
|
||||||
handshake.device.set_sk(StaticSecret::new(&mut rng));
|
|
||||||
handshake.active = false;
|
|
||||||
}
|
|
||||||
Some(sk) => {
|
|
||||||
handshake.device.set_sk(sk);
|
|
||||||
handshake.active = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_sk(&self) -> Option<StaticSecret> {
|
pub fn get_sk(&self) -> Option<StaticSecret> {
|
||||||
let handshake = self.state.handshake.read();
|
self.handshake
|
||||||
if handshake.active {
|
.read()
|
||||||
Some(handshake.device.get_sk())
|
.get_sk()
|
||||||
} else {
|
.map(|sk| StaticSecret::from(sk.to_bytes()))
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool {
|
pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool {
|
||||||
self.state.handshake.write().device.set_psk(pk, psk).is_ok()
|
self.state.handshake.write().set_psk(pk, psk).is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_peer(&self, pk: PublicKey) {
|
pub fn add_peer(&self, pk: PublicKey) {
|
||||||
@@ -217,6 +214,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
let state = Arc::new(PeerInner {
|
let state = Arc::new(PeerInner {
|
||||||
id: rng.gen(),
|
id: rng.gen(),
|
||||||
pk,
|
pk,
|
||||||
|
wg: self.state.clone(),
|
||||||
walltime_last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
|
walltime_last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
|
||||||
last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
|
last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
|
||||||
handshake_queued: AtomicBool::new(false),
|
handshake_queued: AtomicBool::new(false),
|
||||||
@@ -245,14 +243,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
peers.entry(*pk.as_bytes()).or_insert(peer);
|
peers.entry(*pk.as_bytes()).or_insert(peer);
|
||||||
|
|
||||||
// add to the handshake device
|
// add to the handshake device
|
||||||
self.state.handshake.write().device.add(pk).unwrap(); // TODO: handle adding of public key for interface
|
self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Begin consuming messages from the reader.
|
/// Begin consuming messages from the reader.
|
||||||
*
|
/// Multiple readers can be added to support multi-queue and individual Ipv6/Ipv4 sockets interfaces
|
||||||
* Any previous reader thread is stopped by closing the previous reader,
|
///
|
||||||
* which unblocks the thread and causes an error on reader.read
|
/// Any previous reader thread is stopped by closing the previous reader,
|
||||||
*/
|
/// which unblocks the thread and causes an error on reader.read
|
||||||
pub fn add_reader(&self, reader: B::Reader) {
|
pub fn add_reader(&self, reader: B::Reader) {
|
||||||
let wg = self.state.clone();
|
let wg = self.state.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
@@ -285,6 +283,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
| handshake::TYPE_RESPONSE => {
|
| handshake::TYPE_RESPONSE => {
|
||||||
debug!("{} : reader, received handshake message", wg);
|
debug!("{} : reader, received handshake message", wg);
|
||||||
|
|
||||||
|
// add one to pending
|
||||||
let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
|
let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
// update under_load flag
|
// update under_load flag
|
||||||
@@ -297,6 +296,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
wg.under_load.store(false, Ordering::SeqCst);
|
wg.under_load.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add to handshake queue
|
||||||
wg.queue
|
wg.queue
|
||||||
.lock()
|
.lock()
|
||||||
.send(HandshakeJob::Message(msg, src))
|
.send(HandshakeJob::Message(msg, src))
|
||||||
@@ -325,7 +325,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
|
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
|
||||||
// create device state
|
// create device state
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
|
||||||
|
// handshake queue
|
||||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
|
|
||||||
let wg = Arc::new(WireguardInner {
|
let wg = Arc::new(WireguardInner {
|
||||||
start: Instant::now(),
|
start: Instant::now(),
|
||||||
id: rng.gen(),
|
id: rng.gen(),
|
||||||
@@ -334,10 +337,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
send: RwLock::new(None),
|
send: RwLock::new(None),
|
||||||
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
|
||||||
pending: AtomicUsize::new(0),
|
pending: AtomicUsize::new(0),
|
||||||
handshake: RwLock::new(Handshake {
|
handshake: RwLock::new(handshake::Device::new()),
|
||||||
device: handshake::Device::new(StaticSecret::new(&mut rng)),
|
|
||||||
active: false,
|
|
||||||
}),
|
|
||||||
under_load: AtomicBool::new(false),
|
under_load: AtomicBool::new(false),
|
||||||
queue: Mutex::new(tx),
|
queue: Mutex::new(tx),
|
||||||
});
|
});
|
||||||
@@ -350,24 +350,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
debug!("{} : handshake worker, started", wg);
|
debug!("{} : handshake worker, started", wg);
|
||||||
|
|
||||||
// prepare OsRng instance for this thread
|
// prepare OsRng instance for this thread
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG");
|
||||||
|
|
||||||
// process elements from the handshake queue
|
// process elements from the handshake queue
|
||||||
for job in rx {
|
for job in rx {
|
||||||
let state = wg.handshake.read();
|
// decrement pending
|
||||||
if !state.active {
|
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||||
continue;
|
|
||||||
}
|
let device = wg.handshake.read();
|
||||||
|
|
||||||
match job {
|
match job {
|
||||||
HandshakeJob::Message(msg, src) => {
|
HandshakeJob::Message(msg, src) => {
|
||||||
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
|
||||||
|
|
||||||
// feed message to handshake device
|
// feed message to handshake device
|
||||||
let src_validate = (&src).into_address(); // TODO avoid
|
let src_validate = (&src).into_address(); // TODO avoid
|
||||||
|
|
||||||
// process message
|
// process message
|
||||||
match state.device.process(
|
match device.process(
|
||||||
&mut rng,
|
&mut rng,
|
||||||
&msg[..],
|
&msg[..],
|
||||||
if wg.under_load.load(Ordering::Relaxed) {
|
if wg.under_load.load(Ordering::Relaxed) {
|
||||||
@@ -428,7 +426,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
|
|
||||||
// free any unused ids
|
// free any unused ids
|
||||||
for id in peer.router.add_keypair(kp) {
|
for id in peer.router.add_keypair(kp) {
|
||||||
state.device.release(id);
|
device.release(id);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -438,15 +436,19 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
HandshakeJob::New(pk) => {
|
HandshakeJob::New(pk) => {
|
||||||
debug!("{} : handshake worker, new handshake requested", wg);
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
|
debug!(
|
||||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
"{} : handshake worker, new handshake requested for {}",
|
||||||
|
wg, peer
|
||||||
|
);
|
||||||
|
let _ = device.begin(&mut rng, &peer.pk).map(|msg| {
|
||||||
let _ = peer.router.send(&msg[..]).map_err(|e| {
|
let _ = peer.router.send(&msg[..]).map_err(|e| {
|
||||||
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
|
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
|
||||||
});
|
});
|
||||||
peer.state.sent_handshake_initiation();
|
peer.state.sent_handshake_initiation();
|
||||||
}
|
});
|
||||||
});
|
peer.handshake_queued.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -498,7 +500,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Wireguard {
|
Wireguard {
|
||||||
state: WireguardHandle { inner: wg },
|
state: wg,
|
||||||
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
|
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user