WIP: TUN IO worker
Also removed the type parameters from the handshake device.
This commit is contained in:
@@ -21,11 +21,11 @@ use super::types::*;
|
|||||||
|
|
||||||
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
|
||||||
|
|
||||||
pub struct Device<T> {
|
pub struct Device {
|
||||||
pub sk: StaticSecret, // static secret key
|
pub sk: StaticSecret, // static secret key
|
||||||
pub pk: PublicKey, // static public key
|
pub pk: PublicKey, // static public key
|
||||||
macs: macs::Validator, // validator for the mac fields
|
macs: macs::Validator, // validator for the mac fields
|
||||||
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
|
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
|
||||||
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
|
||||||
limiter: Mutex<RateLimiter>,
|
limiter: Mutex<RateLimiter>,
|
||||||
}
|
}
|
||||||
@@ -33,16 +33,13 @@ pub struct Device<T> {
|
|||||||
/* A mutable reference to the device needs to be held during configuration.
|
/* A mutable reference to the device needs to be held during configuration.
|
||||||
* Wrapping the device in a RwLock enables peer config after "configuration time"
|
* Wrapping the device in a RwLock enables peer config after "configuration time"
|
||||||
*/
|
*/
|
||||||
impl<T> Device<T>
|
impl Device {
|
||||||
where
|
|
||||||
T: Clone,
|
|
||||||
{
|
|
||||||
/// Initialize a new handshake state machine
|
/// Initialize a new handshake state machine
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `sk` - x25519 scalar representing the local private key
|
/// * `sk` - x25519 scalar representing the local private key
|
||||||
pub fn new(sk: StaticSecret) -> Device<T> {
|
pub fn new(sk: StaticSecret) -> Device {
|
||||||
let pk = PublicKey::from(&sk);
|
let pk = PublicKey::from(&sk);
|
||||||
Device {
|
Device {
|
||||||
pk,
|
pk,
|
||||||
@@ -54,6 +51,25 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update the secret key of the device
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `sk` - x25519 scalar representing the local private key
|
||||||
|
pub fn set_sk(&mut self, sk: StaticSecret) {
|
||||||
|
// update secret and public key
|
||||||
|
let pk = PublicKey::from(&sk);
|
||||||
|
self.sk = sk;
|
||||||
|
self.pk = pk;
|
||||||
|
self.macs = macs::Validator::new(pk);
|
||||||
|
|
||||||
|
// recalculate the shared secrets for every peer
|
||||||
|
for &mut peer in self.pk_map.values_mut() {
|
||||||
|
peer.reset_state().map(|id| self.release(id));
|
||||||
|
peer.ss = self.sk.diffie_hellman(&peer.pk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a new public key to the state machine
|
/// Add a new public key to the state machine
|
||||||
/// To remove public keys, you must create a new machine instance
|
/// To remove public keys, you must create a new machine instance
|
||||||
///
|
///
|
||||||
@@ -61,7 +77,7 @@ where
|
|||||||
///
|
///
|
||||||
/// * `pk` - The public key to add
|
/// * `pk` - The public key to add
|
||||||
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
/// * `identifier` - Associated identifier which can be used to distinguish the peers
|
||||||
pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
|
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
|
||||||
// check that the pk is not added twice
|
// check that the pk is not added twice
|
||||||
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
|
||||||
return Err(ConfigError::new("Duplicate public key"));
|
return Err(ConfigError::new("Duplicate public key"));
|
||||||
@@ -80,10 +96,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// map the public key to the peer state
|
// map the public key to the peer state
|
||||||
self.pk_map.insert(
|
self.pk_map
|
||||||
*pk.as_bytes(),
|
.insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
|
||||||
Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -204,7 +218,7 @@ where
|
|||||||
rng: &mut R, // rng instance to sample randomness from
|
rng: &mut R, // rng instance to sample randomness from
|
||||||
msg: &[u8], // message buffer
|
msg: &[u8], // message buffer
|
||||||
src: Option<&'a S>, // optional source endpoint, set when "under load"
|
src: Option<&'a S>, // optional source endpoint, set when "under load"
|
||||||
) -> Result<Output<T>, HandshakeError>
|
) -> Result<Output, HandshakeError>
|
||||||
where
|
where
|
||||||
&'a S: Into<&'a SocketAddr>,
|
&'a S: Into<&'a SocketAddr>,
|
||||||
{
|
{
|
||||||
@@ -269,11 +283,7 @@ where
|
|||||||
.generate(resp.noise.as_bytes(), &mut resp.macs);
|
.generate(resp.noise.as_bytes(), &mut resp.macs);
|
||||||
|
|
||||||
// return unconfirmed keypair and the response as vector
|
// return unconfirmed keypair and the response as vector
|
||||||
Ok((
|
Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
|
||||||
Some(peer.identifier.clone()),
|
|
||||||
Some(resp.as_bytes().to_owned()),
|
|
||||||
Some(keys),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
TYPE_RESPONSE => {
|
TYPE_RESPONSE => {
|
||||||
let msg = Response::parse(msg)?;
|
let msg = Response::parse(msg)?;
|
||||||
@@ -328,7 +338,7 @@ where
|
|||||||
// Internal function
|
// Internal function
|
||||||
//
|
//
|
||||||
// Return the peer associated with the public key
|
// Return the peer associated with the public key
|
||||||
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<T>, HandshakeError> {
|
pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
|
||||||
self.pk_map
|
self.pk_map
|
||||||
.get(pk.as_bytes())
|
.get(pk.as_bytes())
|
||||||
.ok_or(HandshakeError::UnknownPublicKey)
|
.ok_or(HandshakeError::UnknownPublicKey)
|
||||||
@@ -337,7 +347,7 @@ where
|
|||||||
// Internal function
|
// Internal function
|
||||||
//
|
//
|
||||||
// Return the peer currently associated with the receiver identifier
|
// Return the peer currently associated with the receiver identifier
|
||||||
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer<T>, HandshakeError> {
|
pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
|
||||||
let im = self.id_map.read();
|
let im = self.id_map.read();
|
||||||
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
|
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
|
||||||
match self.pk_map.get(pk) {
|
match self.pk_map.get(pk) {
|
||||||
@@ -349,7 +359,7 @@ where
|
|||||||
// Internal function
|
// Internal function
|
||||||
//
|
//
|
||||||
// Allocated a new receiver identifier for the peer
|
// Allocated a new receiver identifier for the peer
|
||||||
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer<T>) -> u32 {
|
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
|
||||||
loop {
|
loop {
|
||||||
let id = rng.gen();
|
let id = rng.gen();
|
||||||
|
|
||||||
@@ -380,7 +390,7 @@ mod tests {
|
|||||||
|
|
||||||
fn setup_devices<R: RngCore + CryptoRng>(
|
fn setup_devices<R: RngCore + CryptoRng>(
|
||||||
rng: &mut R,
|
rng: &mut R,
|
||||||
) -> (PublicKey, Device<usize>, PublicKey, Device<usize>) {
|
) -> (PublicKey, Device, PublicKey, Device) {
|
||||||
// generate new keypairs
|
// generate new keypairs
|
||||||
|
|
||||||
let sk1 = StaticSecret::new(rng);
|
let sk1 = StaticSecret::new(rng);
|
||||||
@@ -399,8 +409,8 @@ mod tests {
|
|||||||
let mut dev1 = Device::new(sk1);
|
let mut dev1 = Device::new(sk1);
|
||||||
let mut dev2 = Device::new(sk2);
|
let mut dev2 = Device::new(sk2);
|
||||||
|
|
||||||
dev1.add(pk2, 1337).unwrap();
|
dev1.add(pk2).unwrap();
|
||||||
dev2.add(pk1, 2600).unwrap();
|
dev2.add(pk1).unwrap();
|
||||||
|
|
||||||
dev1.set_psk(pk2, Some(psk)).unwrap();
|
dev1.set_psk(pk2, Some(psk)).unwrap();
|
||||||
dev2.set_psk(pk1, Some(psk)).unwrap();
|
dev2.set_psk(pk1, Some(psk)).unwrap();
|
||||||
|
|||||||
@@ -215,10 +215,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
|
pub fn create_initiation<R: RngCore + CryptoRng>(
|
||||||
rng: &mut R,
|
rng: &mut R,
|
||||||
device: &Device<T>,
|
device: &Device,
|
||||||
peer: &Peer<T>,
|
peer: &Peer,
|
||||||
sender: u32,
|
sender: u32,
|
||||||
msg: &mut NoiseInitiation,
|
msg: &mut NoiseInitiation,
|
||||||
) -> Result<(), HandshakeError> {
|
) -> Result<(), HandshakeError> {
|
||||||
@@ -296,10 +296,10 @@ pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn consume_initiation<'a, T: Clone>(
|
pub fn consume_initiation<'a>(
|
||||||
device: &'a Device<T>,
|
device: &'a Device,
|
||||||
msg: &NoiseInitiation,
|
msg: &NoiseInitiation,
|
||||||
) -> Result<(&'a Peer<T>, TemporaryState), HandshakeError> {
|
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
|
||||||
clear_stack_on_return(CLEAR_PAGES, || {
|
clear_stack_on_return(CLEAR_PAGES, || {
|
||||||
// initialize new state
|
// initialize new state
|
||||||
|
|
||||||
@@ -370,9 +370,9 @@ pub fn consume_initiation<'a, T: Clone>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
|
pub fn create_response<R: RngCore + CryptoRng>(
|
||||||
rng: &mut R,
|
rng: &mut R,
|
||||||
peer: &Peer<T>,
|
peer: &Peer,
|
||||||
sender: u32, // sending identifier
|
sender: u32, // sending identifier
|
||||||
state: TemporaryState, // state from "consume_initiation"
|
state: TemporaryState, // state from "consume_initiation"
|
||||||
msg: &mut NoiseResponse, // resulting response
|
msg: &mut NoiseResponse, // resulting response
|
||||||
@@ -456,10 +456,7 @@ pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
|
|||||||
* allow concurrent processing of potential responses to the initiation,
|
* allow concurrent processing of potential responses to the initiation,
|
||||||
* in order to better mitigate DoS from malformed response messages.
|
* in order to better mitigate DoS from malformed response messages.
|
||||||
*/
|
*/
|
||||||
pub fn consume_response<T: Clone>(
|
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
|
||||||
device: &Device<T>,
|
|
||||||
msg: &NoiseResponse,
|
|
||||||
) -> Result<Output<T>, HandshakeError> {
|
|
||||||
clear_stack_on_return(CLEAR_PAGES, || {
|
clear_stack_on_return(CLEAR_PAGES, || {
|
||||||
// retrieve peer and copy initiation state
|
// retrieve peer and copy initiation state
|
||||||
let peer = device.lookup_id(msg.f_receiver.get())?;
|
let peer = device.lookup_id(msg.f_receiver.get())?;
|
||||||
@@ -530,7 +527,7 @@ pub fn consume_response<T: Clone>(
|
|||||||
|
|
||||||
// return confirmed key-pair
|
// return confirmed key-pair
|
||||||
Ok((
|
Ok((
|
||||||
Some(peer.identifier.clone()),
|
Some(peer.pk),
|
||||||
None,
|
None,
|
||||||
Some(KeyPair {
|
Some(KeyPair {
|
||||||
birth,
|
birth,
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use spin::Mutex;
|
use spin::Mutex;
|
||||||
|
|
||||||
|
use std::mem;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use generic_array::typenum::U32;
|
use generic_array::typenum::U32;
|
||||||
@@ -24,10 +26,7 @@ lazy_static! {
|
|||||||
*
|
*
|
||||||
* This type is only for internal use and not exposed.
|
* This type is only for internal use and not exposed.
|
||||||
*/
|
*/
|
||||||
pub struct Peer<T> {
|
pub struct Peer {
|
||||||
// external identifier
|
|
||||||
pub(crate) identifier: T,
|
|
||||||
|
|
||||||
// mutable state
|
// mutable state
|
||||||
pub(crate) state: Mutex<State>,
|
pub(crate) state: Mutex<State>,
|
||||||
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
|
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
|
||||||
@@ -65,18 +64,13 @@ impl Drop for State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Peer<T>
|
impl Peer {
|
||||||
where
|
|
||||||
T: Clone,
|
|
||||||
{
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
identifier: T, // external identifier
|
|
||||||
pk: PublicKey, // public key of peer
|
pk: PublicKey, // public key of peer
|
||||||
ss: SharedSecret, // precomputed DH(static, static)
|
ss: SharedSecret, // precomputed DH(static, static)
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
macs: Mutex::new(macs::Generator::new(pk)),
|
macs: Mutex::new(macs::Generator::new(pk)),
|
||||||
identifier: identifier,
|
|
||||||
state: Mutex::new(State::Reset),
|
state: Mutex::new(State::Reset),
|
||||||
timestamp: Mutex::new(None),
|
timestamp: Mutex::new(None),
|
||||||
last_initiation_consumption: Mutex::new(None),
|
last_initiation_consumption: Mutex::new(None),
|
||||||
@@ -94,6 +88,13 @@ where
|
|||||||
*self.state.lock() = state_new;
|
*self.state.lock() = state_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset_state(&self) -> Option<u32> {
|
||||||
|
match mem::replace(&mut *self.state.lock(), State::Reset) {
|
||||||
|
State::InitiationSent { sender, .. } => Some(sender),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the mutable state of the peer conditioned on the timestamp being newer
|
/// Set the mutable state of the peer conditioned on the timestamp being newer
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -102,7 +103,7 @@ where
|
|||||||
/// * ts_new - The associated timestamp
|
/// * ts_new - The associated timestamp
|
||||||
pub fn check_replay_flood(
|
pub fn check_replay_flood(
|
||||||
&self,
|
&self,
|
||||||
device: &Device<T>,
|
device: &Device,
|
||||||
timestamp_new: ×tamp::TAI64N,
|
timestamp_new: ×tamp::TAI64N,
|
||||||
) -> Result<(), HandshakeError> {
|
) -> Result<(), HandshakeError> {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
use x25519_dalek::PublicKey;
|
||||||
|
|
||||||
use crate::types::KeyPair;
|
use crate::types::KeyPair;
|
||||||
|
|
||||||
/* Internal types for the noise IKpsk2 implementation */
|
/* Internal types for the noise IKpsk2 implementation */
|
||||||
@@ -77,8 +79,8 @@ impl Error for HandshakeError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Output<T> = (
|
pub type Output = (
|
||||||
Option<T>, // external identifier associated with peer
|
Option<PublicKey>, // external identifier associated with peer
|
||||||
Option<Vec<u8>>, // message to send
|
Option<Vec<u8>>, // message to send
|
||||||
Option<KeyPair>, // resulting key-pair of successful handshake
|
Option<KeyPair>, // resulting key-pair of successful handshake
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -121,7 +121,6 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
||||||
|
|
||||||
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
|
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
|
||||||
// allocate shared device state
|
// allocate shared device state
|
||||||
let mut inner = DeviceInner {
|
let mut inner = DeviceInner {
|
||||||
@@ -149,6 +148,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A new secret key has been set for the device.
|
||||||
|
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
|
||||||
|
pub fn new_sk(&self) {}
|
||||||
|
|
||||||
/// Adds a new peer to the device
|
/// Adds a new peer to the device
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
|
|||||||
236
src/wireguard.rs
236
src/wireguard.rs
@@ -2,17 +2,20 @@ use crate::handshake;
|
|||||||
use crate::router;
|
use crate::router;
|
||||||
use crate::types::{Bind, Endpoint, Tun};
|
use crate::types::{Bind, Endpoint, Tun};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
|
use spin::{Mutex, RwLock};
|
||||||
|
|
||||||
use byteorder::{ByteOrder, LittleEndian};
|
use byteorder::{ByteOrder, LittleEndian};
|
||||||
use crossbeam_channel::bounded;
|
use crossbeam_channel::{bounded, Sender};
|
||||||
use x25519_dalek::StaticSecret;
|
use x25519_dalek::{PublicKey, StaticSecret};
|
||||||
|
|
||||||
const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
||||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||||
@@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
|||||||
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
|
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
|
||||||
|
|
||||||
pub struct PeerInner<T: Tun, B: Bind> {
|
pub struct PeerInner<T: Tun, B: Bind> {
|
||||||
peer: router::Peer<Events, T, B>,
|
router: router::Peer<Events, T, B>,
|
||||||
timers: Timers,
|
timers: Timers,
|
||||||
|
rx: AtomicU64,
|
||||||
|
tx: AtomicU64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Timers {}
|
pub struct Timers {}
|
||||||
@@ -40,39 +45,137 @@ impl router::Callbacks for Events {
|
|||||||
fn need_key(t: &Timers) {}
|
fn need_key(t: &Timers) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Handshake {
|
||||||
|
device: handshake::Device,
|
||||||
|
active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WireguardInner<T: Tun, B: Bind> {
|
||||||
|
// identify and configuration map
|
||||||
|
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
|
||||||
|
|
||||||
|
// cryptkey routing
|
||||||
|
router: router::Device<Events, T, B>,
|
||||||
|
|
||||||
|
// handshake related state
|
||||||
|
handshake: RwLock<Handshake>,
|
||||||
|
under_load: AtomicBool,
|
||||||
|
pending: AtomicUsize, // num of pending handshake packets in queue
|
||||||
|
queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
|
||||||
|
|
||||||
|
// IO
|
||||||
|
bind: B,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Wireguard<T: Tun, B: Bind> {
|
pub struct Wireguard<T: Tun, B: Bind> {
|
||||||
router: Arc<router::Device<Events, T, B>>,
|
state: Arc<WireguardInner<T, B>>,
|
||||||
handshake: Option<Arc<handshake::Device<()>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Tun, B: Bind> Wireguard<T, B> {
|
impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||||
fn start(&self) {}
|
fn set_key(&self, sk: Option<StaticSecret>) {
|
||||||
|
let mut handshake = self.state.handshake.write();
|
||||||
|
match sk {
|
||||||
|
None => {
|
||||||
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
handshake.device.set_sk(StaticSecret::new(&mut rng));
|
||||||
|
handshake.active = false;
|
||||||
|
}
|
||||||
|
Some(sk) => {
|
||||||
|
handshake.device.set_sk(sk);
|
||||||
|
handshake.active = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> {
|
fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
||||||
let router = Arc::new(router::Device::new(
|
// create device state
|
||||||
num_cpus::get(),
|
let mut rng = OsRng::new().unwrap();
|
||||||
tun.clone(),
|
let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
bind.clone(),
|
let wg = Arc::new(WireguardInner {
|
||||||
));
|
peers: RwLock::new(HashMap::new()),
|
||||||
|
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
|
||||||
|
pending: AtomicUsize::new(0),
|
||||||
|
handshake: RwLock::new(Handshake {
|
||||||
|
device: handshake::Device::new(StaticSecret::new(&mut rng)),
|
||||||
|
active: false,
|
||||||
|
}),
|
||||||
|
under_load: AtomicBool::new(false),
|
||||||
|
bind: bind.clone(),
|
||||||
|
queue: Mutex::new(tx),
|
||||||
|
});
|
||||||
|
|
||||||
let handshake_staged = Arc::new(AtomicUsize::new(0));
|
// start handshake workers
|
||||||
let handshake_device: Arc<handshake::Device<Peer<T, B>>> =
|
for _ in 0..num_cpus::get() {
|
||||||
Arc::new(handshake::Device::new(sk));
|
let wg = wg.clone();
|
||||||
|
let rx = rx.clone();
|
||||||
|
let bind = bind.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
// prepare OsRng instance for this thread
|
||||||
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
|
||||||
|
// process elements from the handshake queue
|
||||||
|
for (msg, src) in rx {
|
||||||
|
wg.pending.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
// feed message to handshake device
|
||||||
|
let src_validate = (&src).into_address(); // TODO avoid
|
||||||
|
let state = wg.handshake.read();
|
||||||
|
if !state.active {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// process message
|
||||||
|
match state.device.process(
|
||||||
|
&mut rng,
|
||||||
|
&msg[..],
|
||||||
|
if wg.under_load.load(Ordering::Relaxed) {
|
||||||
|
Some(&src_validate)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok((pk, msg, keypair)) => {
|
||||||
|
// send response
|
||||||
|
if let Some(msg) = msg {
|
||||||
|
let _ = bind.send(&msg[..], &src).map_err(|e| {
|
||||||
|
debug!(
|
||||||
|
"handshake worker, failed to send response, error = {:?}",
|
||||||
|
e
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// update timers
|
||||||
|
if let Some(pk) = pk {
|
||||||
|
// add keypair to peer and free any unused ids
|
||||||
|
if let Some(keypair) = keypair {
|
||||||
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
|
for id in peer.0.router.add_keypair(keypair) {
|
||||||
|
state.device.release(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => debug!("handshake worker, error = {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// start UDP read IO thread
|
// start UDP read IO thread
|
||||||
let (handshake_tx, handshake_rx) = bounded(128);
|
|
||||||
{
|
{
|
||||||
|
let wg = wg.clone();
|
||||||
let tun = tun.clone();
|
let tun = tun.clone();
|
||||||
let bind = bind.clone();
|
let bind = bind.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let mut under_load =
|
let mut last_under_load =
|
||||||
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// read UDP packet into vector
|
// read UDP packet into vector
|
||||||
let size = tun.mtu() + 148; // maximum message size
|
let size = tun.mtu() + 148; // maximum message size
|
||||||
let mut msg: Vec<u8> =
|
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
||||||
Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
|
||||||
msg.resize(size, 0);
|
msg.resize(size, 0);
|
||||||
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
||||||
msg.truncate(size);
|
msg.truncate(size);
|
||||||
@@ -86,20 +189,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
handshake::TYPE_COOKIE_REPLY
|
handshake::TYPE_COOKIE_REPLY
|
||||||
| handshake::TYPE_INITIATION
|
| handshake::TYPE_INITIATION
|
||||||
| handshake::TYPE_RESPONSE => {
|
| handshake::TYPE_RESPONSE => {
|
||||||
// detect if under load
|
// update under_load flag
|
||||||
if handshake_staged.fetch_add(1, Ordering::SeqCst)
|
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
|
||||||
> THRESHOLD_UNDER_LOAD
|
last_under_load = Instant::now();
|
||||||
{
|
wg.under_load.store(true, Ordering::SeqCst);
|
||||||
under_load = Instant::now()
|
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
|
||||||
|
wg.under_load.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pass source address along if under load
|
wg.queue.lock().send((msg, src)).unwrap();
|
||||||
handshake_tx
|
|
||||||
.send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD))
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
router::TYPE_TRANSPORT => {
|
router::TYPE_TRANSPORT => {
|
||||||
// transport message
|
// transport message
|
||||||
|
|
||||||
|
// pad the message
|
||||||
|
|
||||||
|
let _ = wg.router.recv(src, msg);
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
@@ -107,62 +212,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// start handshake workers
|
|
||||||
for _ in 0..num_cpus::get() {
|
|
||||||
let bind = bind.clone();
|
|
||||||
let handshake_rx = handshake_rx.clone();
|
|
||||||
let handshake_device = handshake_device.clone();
|
|
||||||
thread::spawn(move || {
|
|
||||||
// prepare OsRng instance for this thread
|
|
||||||
let mut rng = OsRng::new().unwrap();
|
|
||||||
|
|
||||||
// process elements from the handshake queue
|
|
||||||
for (msg, src, under_load) in handshake_rx {
|
|
||||||
// feed message to handshake device
|
|
||||||
let src_validate = (&src).into_address(); // TODO avoid
|
|
||||||
match handshake_device.process(
|
|
||||||
&mut rng,
|
|
||||||
&msg[..],
|
|
||||||
if under_load {
|
|
||||||
Some(&src_validate)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
Ok((identity, msg, keypair)) => {
|
|
||||||
// send response
|
|
||||||
if let Some(msg) = msg {
|
|
||||||
let _ = bind.send(&msg[..], &src).map_err(|e| {
|
|
||||||
debug!(
|
|
||||||
"handshake worker, failed to send response, error = {:?}",
|
|
||||||
e
|
|
||||||
)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// update timers
|
|
||||||
if let Some(identity) = identity {
|
|
||||||
// add keypair to peer and free any unused ids
|
|
||||||
if let Some(keypair) = keypair {
|
|
||||||
for id in identity.0.peer.add_keypair(keypair) {
|
|
||||||
handshake_device.release(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => debug!("handshake worker, error = {:?}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// start TUN read IO thread
|
// start TUN read IO thread
|
||||||
|
{
|
||||||
|
let wg = wg.clone();
|
||||||
|
thread::spawn(move || loop {
|
||||||
|
// read a new IP packet
|
||||||
|
let mtu = tun.mtu();
|
||||||
|
let size = mtu + 148;
|
||||||
|
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
||||||
|
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||||
|
msg.truncate(size);
|
||||||
|
|
||||||
thread::spawn(move || {});
|
// pad message to multiple of 16
|
||||||
|
while msg.len() < mtu && msg.len() % 16 != 0 {
|
||||||
Wireguard {
|
msg.push(0);
|
||||||
router,
|
|
||||||
handshake: None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// crypt-key route
|
||||||
|
let _ = wg.router.send(msg);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Wireguard { state: wg }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user