Add rate limiter check to handshake messages.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -586,6 +586,7 @@ dependencies = [
|
|||||||
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"sodiumoxide 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"sodiumoxide 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ zerocopy = "0.2.7"
|
|||||||
byteorder = "1.3.1"
|
byteorder = "1.3.1"
|
||||||
digest = "0.8.0"
|
digest = "0.8.0"
|
||||||
sodiumoxide = "0.2.2"
|
sodiumoxide = "0.2.2"
|
||||||
|
lazy_static = "^1.3"
|
||||||
|
|
||||||
[dependencies.x25519-dalek]
|
[dependencies.x25519-dalek]
|
||||||
version = "^0.5"
|
version = "^0.5"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use spin::RwLock;
|
use spin::RwLock;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Mutex;
|
||||||
use zerocopy::AsBytes;
|
use zerocopy::AsBytes;
|
||||||
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@@ -13,6 +14,7 @@ use super::messages::{CookieReply, Initiation, Response};
|
|||||||
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
||||||
use super::noise;
|
use super::noise;
|
||||||
use super::peer::Peer;
|
use super::peer::Peer;
|
||||||
|
use super::ratelimiter::RateLimiter;
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
pub struct Device<T> {
|
pub struct Device<T> {
|
||||||
@@ -21,6 +23,7 @@ pub struct Device<T> {
|
|||||||
macs: macs::Validator, // validator for the mac fields
|
macs: macs::Validator, // validator for the mac fields
|
||||||
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
|
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
|
||||||
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
||||||
|
limiter: Mutex<RateLimiter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* A mutable reference to the device needs to be held during configuration.
|
/* A mutable reference to the device needs to be held during configuration.
|
||||||
@@ -43,6 +46,7 @@ where
|
|||||||
macs: macs::Validator::new(pk),
|
macs: macs::Validator::new(pk),
|
||||||
pk_map: HashMap::new(),
|
pk_map: HashMap::new(),
|
||||||
id_map: RwLock::new(HashMap::new()),
|
id_map: RwLock::new(HashMap::new()),
|
||||||
|
limiter: Mutex::new(RateLimiter::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,8 +207,9 @@ where
|
|||||||
// check mac1 field
|
// check mac1 field
|
||||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||||
|
|
||||||
// check mac2 field
|
// address validation & DoS mitigation
|
||||||
if let Some(src) = src {
|
if let Some(src) = src {
|
||||||
|
// check mac2 field
|
||||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||||
let mut reply = Default::default();
|
let mut reply = Default::default();
|
||||||
self.macs.create_cookie_reply(
|
self.macs.create_cookie_reply(
|
||||||
@@ -216,6 +221,11 @@ where
|
|||||||
);
|
);
|
||||||
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check ratelimiter
|
||||||
|
if !self.limiter.lock().unwrap().allow(&src.ip()) {
|
||||||
|
return Err(HandshakeError::RateLimited);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume the initiation
|
// consume the initiation
|
||||||
@@ -253,8 +263,9 @@ where
|
|||||||
// check mac1 field
|
// check mac1 field
|
||||||
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
|
||||||
|
|
||||||
// check mac2 field
|
// address validation & DoS mitigation
|
||||||
if let Some(src) = src {
|
if let Some(src) = src {
|
||||||
|
// check mac2 field
|
||||||
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
|
||||||
let mut reply = Default::default();
|
let mut reply = Default::default();
|
||||||
self.macs.create_cookie_reply(
|
self.macs.create_cookie_reply(
|
||||||
@@ -266,6 +277,11 @@ where
|
|||||||
);
|
);
|
||||||
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
return Ok((None, Some(reply.as_bytes().to_owned()), None));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check ratelimiter
|
||||||
|
if !self.limiter.lock().unwrap().allow(&src.ip()) {
|
||||||
|
return Err(HandshakeError::RateLimited);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume inner playload
|
// consume inner playload
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use lazy_static::lazy_static;
|
||||||
use rand::{CryptoRng, RngCore};
|
use rand::{CryptoRng, RngCore};
|
||||||
use spin::RwLock;
|
use spin::RwLock;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -19,7 +20,9 @@ const SIZE_COOKIE: usize = 16;
|
|||||||
const SIZE_SECRET: usize = 32;
|
const SIZE_SECRET: usize = 32;
|
||||||
const SIZE_MAC: usize = 16; // blake2s-mac128
|
const SIZE_MAC: usize = 16; // blake2s-mac128
|
||||||
|
|
||||||
const SECS_COOKIE_UPDATE: u64 = 120;
|
lazy_static! {
|
||||||
|
pub static ref COOKIE_UPDATE_INTERVAL: Duration = Duration::new(120, 0);
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! HASH {
|
macro_rules! HASH {
|
||||||
( $($input:expr),* ) => {{
|
( $($input:expr),* ) => {{
|
||||||
@@ -172,7 +175,7 @@ impl Generator {
|
|||||||
macs.f_mac1 = MAC!(&self.mac1_key, inner);
|
macs.f_mac1 = MAC!(&self.mac1_key, inner);
|
||||||
macs.f_mac2 = match &self.cookie {
|
macs.f_mac2 = match &self.cookie {
|
||||||
Some(cookie) => {
|
Some(cookie) => {
|
||||||
if cookie.birth.elapsed() > Duration::from_secs(SECS_COOKIE_UPDATE) {
|
if cookie.birth.elapsed() > *COOKIE_UPDATE_INTERVAL {
|
||||||
self.cookie = None;
|
self.cookie = None;
|
||||||
[0u8; SIZE_MAC]
|
[0u8; SIZE_MAC]
|
||||||
} else {
|
} else {
|
||||||
@@ -203,14 +206,14 @@ impl Validator {
|
|||||||
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
|
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
|
||||||
secret: RwLock::new(Secret {
|
secret: RwLock::new(Secret {
|
||||||
value: [0u8; SIZE_SECRET],
|
value: [0u8; SIZE_SECRET],
|
||||||
birth: Instant::now() - Duration::from_secs(2 * SECS_COOKIE_UPDATE),
|
birth: Instant::now() - Duration::new(86400, 0),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
|
fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
|
||||||
let secret = self.secret.read();
|
let secret = self.secret.read();
|
||||||
if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
|
if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
|
||||||
Some(MAC!(&secret.value, src))
|
Some(MAC!(&secret.value, src))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -221,7 +224,7 @@ impl Validator {
|
|||||||
// check if current value is still valid
|
// check if current value is still valid
|
||||||
{
|
{
|
||||||
let secret = self.secret.read();
|
let secret = self.secret.read();
|
||||||
if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
|
if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
|
||||||
return MAC!(&secret.value, src);
|
return MAC!(&secret.value, src);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -229,7 +232,7 @@ impl Validator {
|
|||||||
// take write lock, check again
|
// take write lock, check again
|
||||||
{
|
{
|
||||||
let mut secret = self.secret.write();
|
let mut secret = self.secret.write();
|
||||||
if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
|
if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
|
||||||
return MAC!(&secret.value, src);
|
return MAC!(&secret.value, src);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ mod macs;
|
|||||||
mod messages;
|
mod messages;
|
||||||
mod noise;
|
mod noise;
|
||||||
mod peer;
|
mod peer;
|
||||||
|
mod ratelimiter;
|
||||||
mod timestamp;
|
mod timestamp;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
|
|||||||
162
src/handshake/ratelimiter.rs
Normal file
162
src/handshake/ratelimiter.rs
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
|
||||||
|
const PACKETS_PER_SECOND: u64 = 20;
|
||||||
|
const PACKETS_BURSTABLE: u64 = 5;
|
||||||
|
const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
|
||||||
|
const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
pub static ref GC_INTERVAL: Duration = Duration::new(1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Entry {
|
||||||
|
pub last_time: Instant,
|
||||||
|
pub tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RateLimiter {
|
||||||
|
garbage_collect: Instant,
|
||||||
|
table: HashMap<IpAddr, Entry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
RateLimiter {
|
||||||
|
garbage_collect: Instant::now(),
|
||||||
|
table: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allow(&mut self, addr: &IpAddr) -> bool {
|
||||||
|
// check for garbage collection
|
||||||
|
if self.garbage_collect.elapsed() > *GC_INTERVAL {
|
||||||
|
self.handle_gc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// update existing entry
|
||||||
|
if let Some(entry) = self.table.get_mut(addr) {
|
||||||
|
// add tokens earned since last time
|
||||||
|
entry.tokens =
|
||||||
|
MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
|
||||||
|
entry.last_time = Instant::now();
|
||||||
|
|
||||||
|
// subtract cost of packet
|
||||||
|
if entry.tokens > PACKET_COST {
|
||||||
|
entry.tokens -= PACKET_COST;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add new entry
|
||||||
|
self.table.insert(
|
||||||
|
*addr,
|
||||||
|
Entry {
|
||||||
|
last_time: Instant::now(),
|
||||||
|
tokens: MAX_TOKENS - PACKET_COST,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_gc(&mut self) {
|
||||||
|
self.table
|
||||||
|
.retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std;
|
||||||
|
|
||||||
|
struct Result {
|
||||||
|
allowed: bool,
|
||||||
|
text: &'static str,
|
||||||
|
wait: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ratelimiter() {
|
||||||
|
let mut ratelimiter = RateLimiter::new();
|
||||||
|
let mut expected = vec![];
|
||||||
|
let ips = vec![
|
||||||
|
"127.0.0.1".parse().unwrap(),
|
||||||
|
"192.168.1.1".parse().unwrap(),
|
||||||
|
"172.167.2.3".parse().unwrap(),
|
||||||
|
"97.231.252.215".parse().unwrap(),
|
||||||
|
"248.97.91.167".parse().unwrap(),
|
||||||
|
"188.208.233.47".parse().unwrap(),
|
||||||
|
"104.2.183.179".parse().unwrap(),
|
||||||
|
"72.129.46.120".parse().unwrap(),
|
||||||
|
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
||||||
|
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
||||||
|
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
||||||
|
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
||||||
|
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
||||||
|
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
||||||
|
];
|
||||||
|
|
||||||
|
for _ in 0..PACKETS_BURSTABLE {
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "inital burst",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: false,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "after burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, PACKET_COST as u32),
|
||||||
|
text: "filling tokens for single packet",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: false,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "not having refilled enough",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
||||||
|
text: "filling tokens for 2 * packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "second packet in 2 packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: false,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "packet following 2 packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
for item in expected {
|
||||||
|
std::thread::sleep(item.wait);
|
||||||
|
for ip in ips.iter() {
|
||||||
|
if ratelimiter.allow(&ip) != item.allowed {
|
||||||
|
panic!(
|
||||||
|
"test failed for {} on {}. expected: {}, got: {}",
|
||||||
|
ip, item.text, item.allowed, !item.allowed
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,6 +43,7 @@ pub enum HandshakeError {
|
|||||||
OldTimestamp,
|
OldTimestamp,
|
||||||
InvalidState,
|
InvalidState,
|
||||||
InvalidMac1,
|
InvalidMac1,
|
||||||
|
RateLimited
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for HandshakeError {
|
impl fmt::Display for HandshakeError {
|
||||||
@@ -57,6 +58,7 @@ impl fmt::Display for HandshakeError {
|
|||||||
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
|
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
|
||||||
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
|
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
|
||||||
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
|
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
|
||||||
|
HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user