Concurrent rate limiter
The new rate limiter allows multiple simultaneous .allow calls. Also delegated GC to tokio.
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use spin::{RwLock, Mutex};
|
||||
|
||||
use tokio::prelude::future;
|
||||
use future::{loop_fn, Future, Loop, lazy};
|
||||
use tokio::timer::Delay;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -18,57 +25,82 @@ struct Entry {
|
||||
pub tokens: u64,
|
||||
}
|
||||
|
||||
pub struct RateLimiter {
|
||||
garbage_collect: Instant,
|
||||
table: HashMap<IpAddr, Entry>,
|
||||
pub struct RateLimiter(Arc<RateLimiterInner>);
|
||||
|
||||
struct RateLimiterInner{
|
||||
gc_running: AtomicBool,
|
||||
table: RwLock<HashMap<IpAddr, Mutex<Entry>>>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new() -> Self {
|
||||
RateLimiter {
|
||||
garbage_collect: Instant::now(),
|
||||
table: HashMap::new(),
|
||||
}
|
||||
RateLimiter (
|
||||
Arc::new(RateLimiterInner {
|
||||
gc_running: AtomicBool::from(false),
|
||||
table: RwLock::new(HashMap::new()),
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
pub fn allow(&mut self, addr: &IpAddr) -> bool {
|
||||
// check for garbage collection
|
||||
if self.garbage_collect.elapsed() > *GC_INTERVAL {
|
||||
self.handle_gc();
|
||||
}
|
||||
pub fn allow(&self, addr: &IpAddr) -> bool {
|
||||
// check if allowed
|
||||
let allowed = {
|
||||
// check for existing entry (required read lock)
|
||||
if let Some(entry) = self.0.table.read().get(addr) {
|
||||
// update existing entry
|
||||
let mut entry = entry.lock();
|
||||
|
||||
// update existing entry
|
||||
if let Some(entry) = self.table.get_mut(addr) {
|
||||
// add tokens earned since last time
|
||||
entry.tokens =
|
||||
MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
|
||||
entry.last_time = Instant::now();
|
||||
// add tokens earned since last time
|
||||
entry.tokens =
|
||||
MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
|
||||
entry.last_time = Instant::now();
|
||||
|
||||
// subtract cost of packet
|
||||
if entry.tokens > PACKET_COST {
|
||||
entry.tokens -= PACKET_COST;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
// subtract cost of packet
|
||||
if entry.tokens > PACKET_COST {
|
||||
entry.tokens -= PACKET_COST;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// add new entry (write lock)
|
||||
self.0.table.write().insert(
|
||||
*addr,
|
||||
Mutex::new(Entry {
|
||||
last_time: Instant::now(),
|
||||
tokens: MAX_TOKENS - PACKET_COST,
|
||||
}),
|
||||
);
|
||||
true
|
||||
};
|
||||
|
||||
// check that GC is scheduled
|
||||
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
|
||||
let limiter = self.0.clone();
|
||||
tokio::spawn(
|
||||
loop_fn((), move |_| {
|
||||
let limiter = limiter.clone();
|
||||
let next_gc = Instant::now() + *GC_INTERVAL;
|
||||
Delay::new(next_gc)
|
||||
.map_err(|_| ())
|
||||
.and_then(move |_| {
|
||||
let mut tw = limiter.table.write();
|
||||
tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
|
||||
if tw.len() > 0 {
|
||||
Ok(Loop::Continue(()))
|
||||
} else {
|
||||
limiter.gc_running.store(false, Ordering::Relaxed);
|
||||
Ok(Loop::Break(()))
|
||||
}
|
||||
})
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// add new entry
|
||||
self.table.insert(
|
||||
*addr,
|
||||
Entry {
|
||||
last_time: Instant::now(),
|
||||
tokens: MAX_TOKENS - PACKET_COST,
|
||||
},
|
||||
);
|
||||
|
||||
true
|
||||
allowed
|
||||
}
|
||||
|
||||
fn handle_gc(&mut self) {
|
||||
self.table
|
||||
.retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -84,79 +116,83 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_ratelimiter() {
|
||||
let mut ratelimiter = RateLimiter::new();
|
||||
let mut expected = vec![];
|
||||
let ips = vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"192.168.1.1".parse().unwrap(),
|
||||
"172.167.2.3".parse().unwrap(),
|
||||
"97.231.252.215".parse().unwrap(),
|
||||
"248.97.91.167".parse().unwrap(),
|
||||
"188.208.233.47".parse().unwrap(),
|
||||
"104.2.183.179".parse().unwrap(),
|
||||
"72.129.46.120".parse().unwrap(),
|
||||
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
||||
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
||||
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
||||
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
||||
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
||||
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
||||
];
|
||||
tokio::run(lazy(|| {
|
||||
let mut ratelimiter = RateLimiter::new();
|
||||
let mut expected = vec![];
|
||||
let ips = vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"192.168.1.1".parse().unwrap(),
|
||||
"172.167.2.3".parse().unwrap(),
|
||||
"97.231.252.215".parse().unwrap(),
|
||||
"248.97.91.167".parse().unwrap(),
|
||||
"188.208.233.47".parse().unwrap(),
|
||||
"104.2.183.179".parse().unwrap(),
|
||||
"72.129.46.120".parse().unwrap(),
|
||||
"2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
|
||||
"f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
|
||||
"b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
|
||||
"a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
|
||||
"ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
|
||||
"3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
|
||||
];
|
||||
|
||||
for _ in 0..PACKETS_BURSTABLE {
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "inital burst",
|
||||
});
|
||||
}
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "after burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, PACKET_COST as u32),
|
||||
text: "filling tokens for single packet",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "not having refilled enough",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
||||
text: "filling tokens for 2 * packet burst",
|
||||
});
|
||||
|
||||
for _ in 0..PACKETS_BURSTABLE {
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "inital burst",
|
||||
text: "second packet in 2 packet burst",
|
||||
});
|
||||
}
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "after burst",
|
||||
});
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "packet following 2 packet burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, PACKET_COST as u32),
|
||||
text: "filling tokens for single packet",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "not having refilled enough",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 2 * PACKET_COST as u32),
|
||||
text: "filling tokens for 2 * packet burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: true,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "second packet in 2 packet burst",
|
||||
});
|
||||
|
||||
expected.push(Result {
|
||||
allowed: false,
|
||||
wait: Duration::new(0, 0),
|
||||
text: "packet following 2 packet burst",
|
||||
});
|
||||
|
||||
for item in expected {
|
||||
std::thread::sleep(item.wait);
|
||||
for ip in ips.iter() {
|
||||
if ratelimiter.allow(&ip) != item.allowed {
|
||||
panic!(
|
||||
"test failed for {} on {}. expected: {}, got: {}",
|
||||
ip, item.text, item.allowed, !item.allowed
|
||||
)
|
||||
for item in expected {
|
||||
std::thread::sleep(item.wait);
|
||||
for ip in ips.iter() {
|
||||
if ratelimiter.allow(&ip) != item.allowed {
|
||||
panic!(
|
||||
"test failed for {} on {}. expected: {}, got: {}",
|
||||
ip, item.text, item.allowed, !item.allowed
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user