Added initiation flood protection
This commit is contained in:
@@ -17,6 +17,8 @@ use super::peer::Peer;
|
|||||||
use super::ratelimiter::RateLimiter;
|
use super::ratelimiter::RateLimiter;
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
|
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
||||||
|
|
||||||
pub struct Device<T> {
|
pub struct Device<T> {
|
||||||
pub sk: StaticSecret, // static secret key
|
pub sk: StaticSecret, // static secret key
|
||||||
pub pk: PublicKey, // static public key
|
pub pk: PublicKey, // static public key
|
||||||
@@ -59,21 +61,23 @@ where
|
|||||||
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
||||||
pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
|
pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
|
||||||
// check that the pk is not added twice
|
// check that the pk is not added twice
|
||||||
|
|
||||||
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
||||||
return Err(ConfigError::new("Duplicate public key"));
|
return Err(ConfigError::new("Duplicate public key"));
|
||||||
};
|
};
|
||||||
|
|
||||||
// check that the pk is not that of the device
|
// check that the pk is not that of the device
|
||||||
|
|
||||||
if *self.pk.as_bytes() == *pk.as_bytes() {
|
if *self.pk.as_bytes() == *pk.as_bytes() {
|
||||||
return Err(ConfigError::new(
|
return Err(ConfigError::new(
|
||||||
"Public key corresponds to secret key of interface",
|
"Public key corresponds to secret key of interface",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// map : pk -> new index
|
// ensure less than 2^20 peers
|
||||||
|
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
|
||||||
|
return Err(ConfigError::new("Too many peers for device"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// map the public key to the peer state
|
||||||
self.pk_map.insert(
|
self.pk_map.insert(
|
||||||
*pk.as_bytes(),
|
*pk.as_bytes(),
|
||||||
Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
|
Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
|
||||||
@@ -353,6 +357,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use hex;
|
use hex;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn handshake() {
|
fn handshake() {
|
||||||
@@ -419,6 +425,9 @@ mod tests {
|
|||||||
|
|
||||||
dev1.release(ks_i.send.id);
|
dev1.release(ks_i.send.id);
|
||||||
dev2.release(ks_r.send.id);
|
dev2.release(ks_r.send.id);
|
||||||
|
|
||||||
|
// to avoid flood detection
|
||||||
|
thread::sleep(Duration::from_millis(20));
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(dev1.get_psk(pk2).unwrap(), psk);
|
assert_eq!(dev1.get_psk(pk2).unwrap(), psk);
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ pub fn consume_initiation<'a, T: Copy>(
|
|||||||
|
|
||||||
// check and update timestamp
|
// check and update timestamp
|
||||||
|
|
||||||
peer.check_timestamp(device, &ts)?;
|
peer.check_replay_flood(device, &ts)?;
|
||||||
|
|
||||||
// H := Hash(H || msg.timestamp)
|
// H := Hash(H || msg.timestamp)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
use lazy_static::lazy_static;
|
||||||
use spin::Mutex;
|
use spin::Mutex;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use generic_array::typenum::U32;
|
use generic_array::typenum::U32;
|
||||||
use generic_array::GenericArray;
|
use generic_array::GenericArray;
|
||||||
@@ -8,15 +10,18 @@ use x25519_dalek::SharedSecret;
|
|||||||
use x25519_dalek::StaticSecret;
|
use x25519_dalek::StaticSecret;
|
||||||
|
|
||||||
use super::device::Device;
|
use super::device::Device;
|
||||||
|
use super::macs;
|
||||||
use super::timestamp;
|
use super::timestamp;
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
use super::macs;
|
|
||||||
|
lazy_static! {
|
||||||
|
pub static ref TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
|
||||||
|
}
|
||||||
|
|
||||||
/* Represents the recomputation and state of a peer.
|
/* Represents the recomputation and state of a peer.
|
||||||
*
|
*
|
||||||
* This type is only for internal use and not exposed.
|
* This type is only for internal use and not exposed.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
pub struct Peer<T> {
|
pub struct Peer<T> {
|
||||||
// external identifier
|
// external identifier
|
||||||
pub(crate) identifier: T,
|
pub(crate) identifier: T,
|
||||||
@@ -24,6 +29,7 @@ pub struct Peer<T> {
|
|||||||
// mutable state
|
// mutable state
|
||||||
state: Mutex<State>,
|
state: Mutex<State>,
|
||||||
timestamp: Mutex<Option<timestamp::TAI64N>>,
|
timestamp: Mutex<Option<timestamp::TAI64N>>,
|
||||||
|
last_initiation_consumption: Mutex<Option<Instant>>,
|
||||||
|
|
||||||
// state related to DoS mitigation fields
|
// state related to DoS mitigation fields
|
||||||
pub(crate) macs: Mutex<macs::Generator>,
|
pub(crate) macs: Mutex<macs::Generator>,
|
||||||
@@ -77,6 +83,7 @@ where
|
|||||||
identifier: identifier,
|
identifier: identifier,
|
||||||
state: Mutex::new(State::Reset),
|
state: Mutex::new(State::Reset),
|
||||||
timestamp: Mutex::new(None),
|
timestamp: Mutex::new(None),
|
||||||
|
last_initiation_consumption: Mutex::new(None),
|
||||||
pk: pk,
|
pk: pk,
|
||||||
ss: ss,
|
ss: ss,
|
||||||
psk: [0u8; 32],
|
psk: [0u8; 32],
|
||||||
@@ -104,38 +111,45 @@ where
|
|||||||
///
|
///
|
||||||
/// * st_new - The updated state of the peer
|
/// * st_new - The updated state of the peer
|
||||||
/// * ts_new - The associated timestamp
|
/// * ts_new - The associated timestamp
|
||||||
pub fn check_timestamp(
|
pub fn check_replay_flood(
|
||||||
&self,
|
&self,
|
||||||
device: &Device<T>,
|
device: &Device<T>,
|
||||||
timestamp_new: ×tamp::TAI64N,
|
timestamp_new: ×tamp::TAI64N,
|
||||||
) -> Result<(), HandshakeError> {
|
) -> Result<(), HandshakeError> {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let mut timestamp = self.timestamp.lock();
|
let mut timestamp = self.timestamp.lock();
|
||||||
|
let mut last_initiation_consumption = self.last_initiation_consumption.lock();
|
||||||
|
|
||||||
let update = match *timestamp {
|
// check replay attack
|
||||||
None => true,
|
match *timestamp {
|
||||||
Some(timestamp_old) => {
|
Some(timestamp_old) => {
|
||||||
if timestamp::compare(×tamp_old, ×tamp_new) {
|
if !timestamp::compare(×tamp_old, ×tamp_new) {
|
||||||
true
|
return Err(HandshakeError::OldTimestamp);
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ => (),
|
||||||
};
|
};
|
||||||
|
|
||||||
if update {
|
// check flood attack
|
||||||
// release existing identifier
|
match *last_initiation_consumption {
|
||||||
match *state {
|
Some(last) => {
|
||||||
State::InitiationSent { sender, .. } => device.release(sender),
|
if last.elapsed() < *TIME_BETWEEN_INITIATIONS {
|
||||||
_ => (),
|
return Err(HandshakeError::InitiationFlood);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
_ => (),
|
||||||
// reset state and update timestamp
|
|
||||||
*state = State::Reset;
|
|
||||||
*timestamp = Some(*timestamp_new);
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(HandshakeError::OldTimestamp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reset state
|
||||||
|
match *state {
|
||||||
|
State::InitiationSent { sender, .. } => device.release(sender),
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
// update replay & flood protection
|
||||||
|
*state = State::Reset;
|
||||||
|
*timestamp = Some(*timestamp_new);
|
||||||
|
*last_initiation_consumption = Some(Instant::now());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ pub enum HandshakeError {
|
|||||||
OldTimestamp,
|
OldTimestamp,
|
||||||
InvalidState,
|
InvalidState,
|
||||||
InvalidMac1,
|
InvalidMac1,
|
||||||
RateLimited
|
RateLimited,
|
||||||
|
InitiationFlood,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for HandshakeError {
|
impl fmt::Display for HandshakeError {
|
||||||
@@ -58,7 +59,10 @@ impl fmt::Display for HandshakeError {
|
|||||||
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
|
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
|
||||||
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
|
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
|
||||||
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
|
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
|
||||||
HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter")
|
HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
|
||||||
|
HandshakeError::InitiationFlood => {
|
||||||
|
write!(f, "Message was dropped because of initiation flood")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user