Kill GC thread on Ratelimiter drop
This commit is contained in:
@@ -356,20 +356,18 @@ mod tests {
|
|||||||
use super::super::messages::*;
|
use super::super::messages::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
use hex;
|
use hex;
|
||||||
use rand::rngs::OsRng;
|
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
use rand::rngs::OsRng;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
#[test]
|
fn setup_devices<R: RngCore + CryptoRng>(rng : &mut R) -> (PublicKey, Device<usize>, PublicKey, Device<usize>) {
|
||||||
fn handshake() {
|
|
||||||
// generate new keypairs
|
// generate new keypairs
|
||||||
|
|
||||||
let mut rng = OsRng::new().unwrap();
|
let sk1 = StaticSecret::new(rng);
|
||||||
|
|
||||||
let sk1 = StaticSecret::new(&mut rng);
|
|
||||||
let pk1 = PublicKey::from(&sk1);
|
let pk1 = PublicKey::from(&sk1);
|
||||||
|
|
||||||
let sk2 = StaticSecret::new(&mut rng);
|
let sk2 = StaticSecret::new(rng);
|
||||||
let pk2 = PublicKey::from(&sk2);
|
let pk2 = PublicKey::from(&sk2);
|
||||||
|
|
||||||
// pick random psk
|
// pick random psk
|
||||||
@@ -388,7 +386,103 @@ mod tests {
|
|||||||
dev1.set_psk(pk2, Some(psk)).unwrap();
|
dev1.set_psk(pk2, Some(psk)).unwrap();
|
||||||
dev2.set_psk(pk1, Some(psk)).unwrap();
|
dev2.set_psk(pk1, Some(psk)).unwrap();
|
||||||
|
|
||||||
// do a few handshakes
|
(pk1, dev1, pk2, dev2)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Test longest possible handshake interaction (7 messages):
|
||||||
|
*
|
||||||
|
* 1. I -> R (initation)
|
||||||
|
* 2. I <- R (cookie reply)
|
||||||
|
* 3. I -> R (initation)
|
||||||
|
* 4. I <- R (response)
|
||||||
|
* 5. I -> R (cookie reply)
|
||||||
|
* 6. I -> R (initation)
|
||||||
|
* 7. I <- R (response)
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn handshake_under_load() {
|
||||||
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
|
||||||
|
|
||||||
|
let src1 : SocketAddr = "172.16.0.1:8080".parse().unwrap();
|
||||||
|
let src2 : SocketAddr = "172.16.0.2:7070".parse().unwrap();
|
||||||
|
|
||||||
|
// 1. device-1 : create first initation
|
||||||
|
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||||
|
|
||||||
|
// 2. device-2 : responds with CookieReply
|
||||||
|
let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||||
|
(None, Some(msg), None) => msg,
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
};
|
||||||
|
|
||||||
|
// device-1 : processes CookieReply (no response)
|
||||||
|
match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
|
||||||
|
(None, None, None) => (),
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// avoid initation flood
|
||||||
|
thread::sleep(Duration::from_millis(20));
|
||||||
|
|
||||||
|
// 3. device-1 : create second initation
|
||||||
|
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||||
|
|
||||||
|
// 4. device-2 : responds with noise response
|
||||||
|
let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||||
|
(Some(_), Some(msg), Some(kp)) => {
|
||||||
|
assert_eq!(kp.confirmed, false);
|
||||||
|
msg
|
||||||
|
},
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
};
|
||||||
|
|
||||||
|
// 5. device-1 : responds with CookieReply
|
||||||
|
let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||||
|
(None, Some(msg), None) => msg,
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
};
|
||||||
|
|
||||||
|
// device-2 : processes CookieReply (no response)
|
||||||
|
match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
|
||||||
|
(None, None, None) => (),
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// avoid initation flood
|
||||||
|
thread::sleep(Duration::from_millis(20));
|
||||||
|
|
||||||
|
// 6. device-1 : create third initation
|
||||||
|
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
|
||||||
|
|
||||||
|
// 7. device-2 : responds with noise response
|
||||||
|
let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
|
||||||
|
(Some(_), Some(msg), Some(kp)) => {
|
||||||
|
assert_eq!(kp.confirmed, false);
|
||||||
|
(msg, kp)
|
||||||
|
},
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
};
|
||||||
|
|
||||||
|
// device-1 : process noise response
|
||||||
|
let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
|
||||||
|
(Some(_), None, Some(kp)) => {
|
||||||
|
assert_eq!(kp.confirmed, true);
|
||||||
|
kp
|
||||||
|
},
|
||||||
|
_ => panic!("unexpected response")
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(kp1.send, kp2.recv);
|
||||||
|
assert_eq!(kp1.recv, kp2.send);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handshake_no_load() {
|
||||||
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
|
||||||
|
|
||||||
|
// do a few handshakes (every handshake should succeed)
|
||||||
|
|
||||||
for i in 0..10 {
|
for i in 0..10 {
|
||||||
println!("handshake : {}", i);
|
println!("handshake : {}", i);
|
||||||
@@ -430,9 +524,6 @@ mod tests {
|
|||||||
thread::sleep(Duration::from_millis(20));
|
thread::sleep(Duration::from_millis(20));
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(dev1.get_psk(pk2).unwrap(), psk);
|
|
||||||
assert_eq!(dev2.get_psk(pk1).unwrap(), psk);
|
|
||||||
|
|
||||||
dev1.remove(pk2).unwrap();
|
dev1.remove(pk2).unwrap();
|
||||||
dev2.remove(pk1).unwrap();
|
dev2.remove(pk1).unwrap();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ mod tests {
|
|||||||
let mut msg = CookieReply::default();
|
let mut msg = CookieReply::default();
|
||||||
let mut rng = OsRng::new().expect("failed to create rng");
|
let mut rng = OsRng::new().expect("failed to create rng");
|
||||||
let mut macs = MacsFooter::default();
|
let mut macs = MacsFooter::default();
|
||||||
let src = "127.0.0.1:8080".parse().unwrap();
|
let src = "192.0.2.16:8080".parse().unwrap();
|
||||||
let (validator, mut generator) = new_validator_generator();
|
let (validator, mut generator) = new_validator_generator();
|
||||||
|
|
||||||
// generate mac1 for first message
|
// generate mac1 for first message
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::{Condvar, Mutex, Arc};
|
||||||
|
use std::thread;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use spin::{RwLock, Mutex};
|
use spin;
|
||||||
|
|
||||||
use tokio::prelude::future;
|
|
||||||
use future::{loop_fn, Future, Loop, lazy};
|
|
||||||
use tokio::timer::Delay;
|
|
||||||
|
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
|
||||||
@@ -29,15 +26,27 @@ pub struct RateLimiter(Arc<RateLimiterInner>);
|
|||||||
|
|
||||||
struct RateLimiterInner{
|
struct RateLimiterInner{
|
||||||
gc_running: AtomicBool,
|
gc_running: AtomicBool,
|
||||||
table: RwLock<HashMap<IpAddr, Mutex<Entry>>>,
|
gc_dropped: (Mutex<bool>, Condvar),
|
||||||
|
table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for RateLimiter {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// wake up & terminate any lingering GC thread
|
||||||
|
let &(ref lock, ref cvar) = &self.0.gc_dropped;
|
||||||
|
let mut dropped = lock.lock().unwrap();
|
||||||
|
*dropped = true;
|
||||||
|
cvar.notify_all();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RateLimiter {
|
impl RateLimiter {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
RateLimiter (
|
RateLimiter (
|
||||||
Arc::new(RateLimiterInner {
|
Arc::new(RateLimiterInner {
|
||||||
|
gc_dropped: (Mutex::new(false), Condvar::new()),
|
||||||
gc_running: AtomicBool::from(false),
|
gc_running: AtomicBool::from(false),
|
||||||
table: RwLock::new(HashMap::new()),
|
table: spin::RwLock::new(HashMap::new()),
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -45,7 +54,7 @@ impl RateLimiter {
|
|||||||
pub fn allow(&self, addr: &IpAddr) -> bool {
|
pub fn allow(&self, addr: &IpAddr) -> bool {
|
||||||
// check if allowed
|
// check if allowed
|
||||||
let allowed = {
|
let allowed = {
|
||||||
// check for existing entry (required read lock)
|
// check for existing entry (only requires read lock)
|
||||||
if let Some(entry) = self.0.table.read().get(addr) {
|
if let Some(entry) = self.0.table.read().get(addr) {
|
||||||
// update existing entry
|
// update existing entry
|
||||||
let mut entry = entry.lock();
|
let mut entry = entry.lock();
|
||||||
@@ -67,7 +76,7 @@ impl RateLimiter {
|
|||||||
// add new entry (write lock)
|
// add new entry (write lock)
|
||||||
self.0.table.write().insert(
|
self.0.table.write().insert(
|
||||||
*addr,
|
*addr,
|
||||||
Mutex::new(Entry {
|
spin::Mutex::new(Entry {
|
||||||
last_time: Instant::now(),
|
last_time: Instant::now(),
|
||||||
tokens: MAX_TOKENS - PACKET_COST,
|
tokens: MAX_TOKENS - PACKET_COST,
|
||||||
}),
|
}),
|
||||||
@@ -75,27 +84,28 @@ impl RateLimiter {
|
|||||||
true
|
true
|
||||||
};
|
};
|
||||||
|
|
||||||
// check that GC is scheduled
|
// check that GC thread is scheduled
|
||||||
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
|
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
|
||||||
let limiter = self.0.clone();
|
let limiter = self.0.clone();
|
||||||
tokio::spawn(
|
thread::spawn(move || {
|
||||||
loop_fn((), move |_| {
|
let &(ref lock, ref cvar) = &limiter.gc_dropped;
|
||||||
let limiter = limiter.clone();
|
let mut dropped = lock.lock().unwrap();
|
||||||
let next_gc = Instant::now() + *GC_INTERVAL;
|
while !*dropped {
|
||||||
Delay::new(next_gc)
|
// garbage collect
|
||||||
.map_err(|_| ())
|
{
|
||||||
.and_then(move |_| {
|
let mut tw = limiter.table.write();
|
||||||
let mut tw = limiter.table.write();
|
tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
|
||||||
tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
|
if tw.len() == 0 {
|
||||||
if tw.len() > 0 {
|
limiter.gc_running.store(false, Ordering::Relaxed);
|
||||||
Ok(Loop::Continue(()))
|
return;
|
||||||
} else {
|
}
|
||||||
limiter.gc_running.store(false, Ordering::Relaxed);
|
}
|
||||||
Ok(Loop::Break(()))
|
|
||||||
}
|
// wait until stopped or new GC (~1 every sec)
|
||||||
})
|
let res = cvar.wait_timeout(dropped,*GC_INTERVAL).unwrap();
|
||||||
})
|
dropped = res.0;
|
||||||
);
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
allowed
|
allowed
|
||||||
@@ -116,83 +126,79 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ratelimiter() {
|
fn test_ratelimiter() {
|
||||||
tokio::run(lazy(|| {
|
let ratelimiter = RateLimiter::new();
|
||||||
let mut ratelimiter = RateLimiter::new();
|
let mut expected = vec![];
|
||||||
let mut expected = vec![];
|
let ips = vec![
|
||||||
let ips = vec![
|
"127.0.0.1".parse().unwrap(),
|
||||||
"127.0.0.1".parse().unwrap(),
|
"192.168.1.1".parse().unwrap(),
|
||||||
"192.168.1.1".parse().unwrap(),
|
"172.167.2.3".parse().unwrap(),
|
||||||
"172.167.2.3".parse().unwrap(),
|
"97.231.252.215".parse().unwrap(),
|
||||||
"97.231.252.215".parse().unwrap(),
|
"248.97.91.167".parse().unwrap(),
|
||||||
"248.97.91.167".parse().unwrap(),
|
"188.208.233.47".parse().unwrap(),
|
||||||
"188.208.233.47".parse().unwrap(),
|
"104.2.183.179".parse().unwrap(),
|
||||||
"104.2.183.179".parse().unwrap(),
|
"72.129.46.120".parse().unwrap(),
|
||||||
"72.129.46.120".parse().unwrap(),
|
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
||||||
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
||||||
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
||||||
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
||||||
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
||||||
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
||||||
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
];
|
||||||
];
|
|
||||||
|
|
||||||
for _ in 0..PACKETS_BURSTABLE {
|
|
||||||
expected.push(Result {
|
|
||||||
allowed: true,
|
|
||||||
wait: Duration::new(0, 0),
|
|
||||||
text: "inital burst",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
expected.push(Result {
|
|
||||||
allowed: false,
|
|
||||||
wait: Duration::new(0, 0),
|
|
||||||
text: "after burst",
|
|
||||||
});
|
|
||||||
|
|
||||||
expected.push(Result {
|
|
||||||
allowed: true,
|
|
||||||
wait: Duration::new(0, PACKET_COST as u32),
|
|
||||||
text: "filling tokens for single packet",
|
|
||||||
});
|
|
||||||
|
|
||||||
expected.push(Result {
|
|
||||||
allowed: false,
|
|
||||||
wait: Duration::new(0, 0),
|
|
||||||
text: "not having refilled enough",
|
|
||||||
});
|
|
||||||
|
|
||||||
expected.push(Result {
|
|
||||||
allowed: true,
|
|
||||||
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
|
||||||
text: "filling tokens for 2 * packet burst",
|
|
||||||
});
|
|
||||||
|
|
||||||
|
for _ in 0..PACKETS_BURSTABLE {
|
||||||
expected.push(Result {
|
expected.push(Result {
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: Duration::new(0, 0),
|
wait: Duration::new(0, 0),
|
||||||
text: "second packet in 2 packet burst",
|
text: "inital burst",
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
expected.push(Result {
|
expected.push(Result {
|
||||||
allowed: false,
|
allowed: false,
|
||||||
wait: Duration::new(0, 0),
|
wait: Duration::new(0, 0),
|
||||||
text: "packet following 2 packet burst",
|
text: "after burst",
|
||||||
});
|
});
|
||||||
|
|
||||||
for item in expected {
|
expected.push(Result {
|
||||||
std::thread::sleep(item.wait);
|
allowed: true,
|
||||||
for ip in ips.iter() {
|
wait: Duration::new(0, PACKET_COST as u32),
|
||||||
if ratelimiter.allow(&ip) != item.allowed {
|
text: "filling tokens for single packet",
|
||||||
panic!(
|
});
|
||||||
"test failed for {} on {}. expected: {}, got: {}",
|
|
||||||
ip, item.text, item.allowed, !item.allowed
|
expected.push(Result {
|
||||||
)
|
allowed: false,
|
||||||
}
|
wait: Duration::new(0, 0),
|
||||||
|
text: "not having refilled enough",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
||||||
|
text: "filling tokens for 2 * packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: true,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "second packet in 2 packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
expected.push(Result {
|
||||||
|
allowed: false,
|
||||||
|
wait: Duration::new(0, 0),
|
||||||
|
text: "packet following 2 packet burst",
|
||||||
|
});
|
||||||
|
|
||||||
|
for item in expected {
|
||||||
|
std::thread::sleep(item.wait);
|
||||||
|
for ip in ips.iter() {
|
||||||
|
if ratelimiter.allow(&ip) != item.allowed {
|
||||||
|
panic!(
|
||||||
|
"test failed for {} on {}. expected: {}, got: {}",
|
||||||
|
ip, item.text, item.allowed, !item.allowed
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Ok(())
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user