Enable adding TUN reader to WG interface
This commit is contained in:
21
src/main.rs
21
src/main.rs
@@ -15,25 +15,6 @@ mod types;
|
|||||||
mod wireguard;
|
mod wireguard;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests;
|
||||||
use crate::types::tun::Tun;
|
|
||||||
use crate::types::{bind, dummy, tun};
|
|
||||||
use crate::wireguard::Wireguard;
|
|
||||||
|
|
||||||
use std::thread;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
fn init() {
|
|
||||||
let _ = env_logger::builder().is_test(true).try_init();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pure_wireguard() {
|
|
||||||
init();
|
|
||||||
let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
|
|
||||||
let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
|
|
||||||
thread::sleep(Duration::from_millis(500));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
fn main() {}
|
||||||
|
|||||||
@@ -145,8 +145,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create device
|
// create device
|
||||||
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
|
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
|
||||||
let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
|
let router: Device< _, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
|
||||||
Device::new(num_cpus::get(), tun_writer);
|
Device::new(num_cpus::get(), tun_writer);
|
||||||
|
|
||||||
// add new peer
|
// add new peer
|
||||||
@@ -175,7 +175,7 @@ mod tests {
|
|||||||
init();
|
init();
|
||||||
|
|
||||||
// create device
|
// create device
|
||||||
let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
|
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
|
||||||
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
|
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
|
||||||
router.set_outbound_writer(dummy::VoidBind::new());
|
router.set_outbound_writer(dummy::VoidBind::new());
|
||||||
|
|
||||||
@@ -321,8 +321,8 @@ mod tests {
|
|||||||
dummy::PairBind::pair();
|
dummy::PairBind::pair();
|
||||||
|
|
||||||
// create matching device
|
// create matching device
|
||||||
let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap();
|
let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
|
||||||
let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap();
|
let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
|
||||||
|
|
||||||
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
|
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
|
||||||
router1.set_outbound_writer(bind_writer1);
|
router1.set_outbound_writer(bind_writer1);
|
||||||
|
|||||||
46
src/tests.rs
Normal file
46
src/tests.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
use crate::types::tun::Tun;
|
||||||
|
use crate::types::{bind, dummy, tun};
|
||||||
|
use crate::wireguard::Wireguard;
|
||||||
|
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
fn init() {
|
||||||
|
let _ = env_logger::builder().is_test(true).try_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Create and configure two matching pure instances of WireGuard
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn test_pure_wireguard() {
|
||||||
|
init();
|
||||||
|
|
||||||
|
// create WG instances for fake TUN devices
|
||||||
|
|
||||||
|
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
|
||||||
|
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
|
||||||
|
Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
|
||||||
|
|
||||||
|
let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
|
||||||
|
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
|
||||||
|
Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
|
||||||
|
|
||||||
|
// create pair bind to connect the interfaces "over the internet"
|
||||||
|
|
||||||
|
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
|
||||||
|
|
||||||
|
wg1.set_writer(bind_writer1);
|
||||||
|
wg2.set_writer(bind_writer2);
|
||||||
|
|
||||||
|
wg1.add_reader(bind_reader1);
|
||||||
|
wg2.add_reader(bind_reader2);
|
||||||
|
|
||||||
|
// generate (public, pivate) key pairs
|
||||||
|
|
||||||
|
// configure cryptkey router
|
||||||
|
|
||||||
|
// create IP packets
|
||||||
|
|
||||||
|
thread::sleep(Duration::from_millis(500));
|
||||||
|
}
|
||||||
@@ -20,9 +20,4 @@ pub trait Bind: Send + Sync + 'static {
|
|||||||
/* Until Rust gets type equality constraints these have to be generic */
|
/* Until Rust gets type equality constraints these have to be generic */
|
||||||
type Writer: Writer<Self::Endpoint>;
|
type Writer: Writer<Self::Endpoint>;
|
||||||
type Reader: Reader<Self::Endpoint>;
|
type Reader: Reader<Self::Endpoint>;
|
||||||
|
|
||||||
/* Used to close the reader/writer when binding to a new port */
|
|
||||||
type Closer;
|
|
||||||
|
|
||||||
fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use std::marker;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use std::marker;
|
use std::sync::atomic::{Ordering, AtomicUsize};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@@ -41,7 +42,9 @@ impl fmt::Display for BindError {
|
|||||||
/* TUN implementation */
|
/* TUN implementation */
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum TunError {}
|
pub enum TunError {
|
||||||
|
Disconnected
|
||||||
|
}
|
||||||
|
|
||||||
impl Error for TunError {
|
impl Error for TunError {
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
@@ -68,54 +71,111 @@ impl Endpoint for UnitEndpoint {
|
|||||||
fn from_address(_: SocketAddr) -> UnitEndpoint {
|
fn from_address(_: SocketAddr) -> UnitEndpoint {
|
||||||
UnitEndpoint {}
|
UnitEndpoint {}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_address(&self) -> SocketAddr {
|
fn into_address(&self) -> SocketAddr {
|
||||||
"127.0.0.1:8080".parse().unwrap()
|
"127.0.0.1:8080".parse().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_src(&self) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnitEndpoint {
|
impl UnitEndpoint {
|
||||||
pub fn new() -> UnitEndpoint {
|
pub fn new() -> UnitEndpoint {
|
||||||
UnitEndpoint{}
|
UnitEndpoint {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* */
|
/* */
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
pub struct TunTest {}
|
pub struct TunTest {}
|
||||||
|
|
||||||
impl tun::Reader for TunTest {
|
pub struct TunFakeIO {
|
||||||
|
store: bool,
|
||||||
|
tx: SyncSender<Vec<u8>>,
|
||||||
|
rx: Receiver<Vec<u8>>
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TunReader {
|
||||||
|
rx: Receiver<Vec<u8>>
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TunWriter {
|
||||||
|
store: bool,
|
||||||
|
tx: Mutex<SyncSender<Vec<u8>>>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TunMTU {
|
||||||
|
mtu: Arc<AtomicUsize>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tun::Reader for TunReader {
|
||||||
type Error = TunError;
|
type Error = TunError;
|
||||||
|
|
||||||
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
|
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
|
||||||
Ok(0)
|
match self.rx.recv() {
|
||||||
|
Ok(m) => {
|
||||||
|
buf[offset..].copy_from_slice(&m[..]);
|
||||||
|
Ok(m.len())
|
||||||
|
}
|
||||||
|
Err(_) => Err(TunError::Disconnected)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl tun::MTU for TunTest {
|
impl tun::Writer for TunWriter {
|
||||||
|
type Error = TunError;
|
||||||
|
|
||||||
|
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
|
||||||
|
if self.store {
|
||||||
|
let m = src.to_owned();
|
||||||
|
match self.tx.lock().unwrap().send(m) {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(_) => Err(TunError::Disconnected)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tun::MTU for TunMTU {
|
||||||
fn mtu(&self) -> usize {
|
fn mtu(&self) -> usize {
|
||||||
1500
|
self.mtu.load(Ordering::Acquire)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl tun::Writer for TunTest {
|
|
||||||
type Error = TunError;
|
|
||||||
|
|
||||||
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl tun::Tun for TunTest {
|
impl tun::Tun for TunTest {
|
||||||
type Writer = TunTest;
|
type Writer = TunWriter;
|
||||||
type Reader = TunTest;
|
type Reader = TunReader;
|
||||||
type MTU = TunTest;
|
type MTU = TunMTU;
|
||||||
type Error = TunError;
|
type Error = TunError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TunFakeIO {
|
||||||
|
pub fn write(&self, msg : Vec<u8>) {
|
||||||
|
if self.store {
|
||||||
|
self.tx.send(msg).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read(&self) -> Vec<u8> {
|
||||||
|
self.rx.recv().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TunTest {
|
impl TunTest {
|
||||||
pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> {
|
pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
|
||||||
Ok((TunTest {},TunTest {}, TunTest{}))
|
|
||||||
|
let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
|
||||||
|
let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
|
||||||
|
|
||||||
|
let fake = TunFakeIO{tx: tx1, rx: rx2, store};
|
||||||
|
let reader = TunReader{rx : rx1};
|
||||||
|
let writer = TunWriter{tx : Mutex::new(tx2), store};
|
||||||
|
let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
|
||||||
|
|
||||||
|
(fake, reader, writer, mtu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,16 +206,11 @@ impl bind::Bind for VoidBind {
|
|||||||
|
|
||||||
type Reader = VoidBind;
|
type Reader = VoidBind;
|
||||||
type Writer = VoidBind;
|
type Writer = VoidBind;
|
||||||
type Closer = ();
|
|
||||||
|
|
||||||
fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
|
|
||||||
Ok((VoidBind{}, VoidBind{}, (), 2600))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VoidBind {
|
impl VoidBind {
|
||||||
pub fn new() -> VoidBind {
|
pub fn new() -> VoidBind {
|
||||||
VoidBind{}
|
VoidBind {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,45 +258,42 @@ pub struct PairWriter<E> {
|
|||||||
pub struct PairBind {}
|
pub struct PairBind {}
|
||||||
|
|
||||||
impl PairBind {
|
impl PairBind {
|
||||||
pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) {
|
pub fn pair<E>() -> (
|
||||||
|
(PairReader<E>, PairWriter<E>),
|
||||||
|
(PairReader<E>, PairWriter<E>),
|
||||||
|
) {
|
||||||
let (tx1, rx1) = sync_channel(128);
|
let (tx1, rx1) = sync_channel(128);
|
||||||
let (tx2, rx2) = sync_channel(128);
|
let (tx2, rx2) = sync_channel(128);
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
PairReader{
|
PairReader {
|
||||||
|
recv: Arc::new(Mutex::new(rx1)),
|
||||||
recv: Arc::new(Mutex::new(rx1)),
|
_marker: marker::PhantomData,
|
||||||
_marker: marker::PhantomData
|
},
|
||||||
},
|
PairWriter {
|
||||||
PairWriter{
|
|
||||||
send: Arc::new(Mutex::new(tx2)),
|
send: Arc::new(Mutex::new(tx2)),
|
||||||
_marker: marker::PhantomData
|
_marker: marker::PhantomData,
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
PairReader{
|
PairReader {
|
||||||
recv: Arc::new(Mutex::new(rx2)),
|
recv: Arc::new(Mutex::new(rx2)),
|
||||||
_marker: marker::PhantomData
|
_marker: marker::PhantomData,
|
||||||
},
|
},
|
||||||
PairWriter{
|
PairWriter {
|
||||||
send: Arc::new(Mutex::new(tx1)),
|
send: Arc::new(Mutex::new(tx1)),
|
||||||
_marker: marker::PhantomData
|
_marker: marker::PhantomData,
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl bind::Bind for PairBind {
|
impl bind::Bind for PairBind {
|
||||||
type Closer = ();
|
|
||||||
type Error = BindError;
|
type Error = BindError;
|
||||||
type Endpoint = UnitEndpoint;
|
type Endpoint = UnitEndpoint;
|
||||||
type Reader = PairReader<Self::Endpoint>;
|
type Reader = PairReader<Self::Endpoint>;
|
||||||
type Writer = PairWriter<Self::Endpoint>;
|
type Writer = PairWriter<Self::Endpoint>;
|
||||||
|
|
||||||
fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
|
|
||||||
Err(BindError::Disconnected)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keypair(initiator: bool) -> KeyPair {
|
pub fn keypair(initiator: bool) -> KeyPair {
|
||||||
|
|||||||
@@ -3,4 +3,5 @@ use std::net::SocketAddr;
|
|||||||
pub trait Endpoint: Send + 'static {
|
pub trait Endpoint: Send + 'static {
|
||||||
fn from_address(addr: SocketAddr) -> Self;
|
fn from_address(addr: SocketAddr) -> Self;
|
||||||
fn into_address(&self) -> SocketAddr;
|
fn into_address(&self) -> SocketAddr;
|
||||||
|
fn clear_src(&self);
|
||||||
}
|
}
|
||||||
|
|||||||
163
src/wireguard.rs
163
src/wireguard.rs
@@ -3,8 +3,10 @@ use crate::handshake;
|
|||||||
use crate::router;
|
use crate::router;
|
||||||
use crate::timers::{Events, Timers};
|
use crate::timers::{Events, Timers};
|
||||||
|
|
||||||
|
use crate::types::bind::Reader as BindReader;
|
||||||
use crate::types::bind::{Bind, Writer};
|
use crate::types::bind::{Bind, Writer};
|
||||||
use crate::types::tun::{Reader, Tun, MTU};
|
use crate::types::tun::{Reader, Tun, MTU};
|
||||||
|
|
||||||
use crate::types::Endpoint;
|
use crate::types::Endpoint;
|
||||||
|
|
||||||
use hjul::Runner;
|
use hjul::Runner;
|
||||||
@@ -53,7 +55,7 @@ pub struct PeerInner<B: Bind> {
|
|||||||
pub timers: RwLock<Timers>, //
|
pub timers: RwLock<Timers>, //
|
||||||
}
|
}
|
||||||
|
|
||||||
impl <B:Bind > PeerInner<B> {
|
impl<B: Bind> PeerInner<B> {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn timers(&self) -> RwLockReadGuard<Timers> {
|
pub fn timers(&self) -> RwLockReadGuard<Timers> {
|
||||||
self.timers.read()
|
self.timers.read()
|
||||||
@@ -153,7 +155,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_sk(&self) -> Option<StaticSecret> {
|
pub fn get_sk(&self) -> Option<StaticSecret> {
|
||||||
let mut handshake = self.state.handshake.read();
|
let handshake = self.state.handshake.read();
|
||||||
if handshake.active {
|
if handshake.active {
|
||||||
Some(handshake.device.get_sk())
|
Some(handshake.device.get_sk())
|
||||||
} else {
|
} else {
|
||||||
@@ -184,66 +186,73 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
peer
|
peer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) {
|
/* Begin consuming messages from the reader.
|
||||||
|
*
|
||||||
|
* Any previous reader thread is stopped by closing the previous reader,
|
||||||
|
* which unblocks the thread and causes an error on reader.read
|
||||||
|
*/
|
||||||
|
pub fn add_reader(&self, reader: B::Reader) {
|
||||||
|
let wg = self.state.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut last_under_load =
|
||||||
|
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
||||||
|
|
||||||
// drop existing closer
|
loop {
|
||||||
|
// create vector big enough for any message given current MTU
|
||||||
|
let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||||
|
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
||||||
|
msg.resize(size, 0);
|
||||||
|
|
||||||
// swap IO thread for new reader
|
// read UDP packet into vector
|
||||||
|
let (size, src) = match reader.read(&mut msg) {
|
||||||
// start UDP read IO thread
|
Err(e) => {
|
||||||
|
debug!("Bind reader closed with {}", e);
|
||||||
/*
|
return;
|
||||||
{
|
|
||||||
let wg = wg.clone();
|
|
||||||
let mtu = mtu.clone();
|
|
||||||
thread::spawn(move || {
|
|
||||||
let mut last_under_load =
|
|
||||||
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
|
||||||
|
|
||||||
loop {
|
|
||||||
// create vector big enough for any message given current MTU
|
|
||||||
let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
|
||||||
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
|
||||||
msg.resize(size, 0);
|
|
||||||
|
|
||||||
// read UDP packet into vector
|
|
||||||
let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
|
|
||||||
msg.truncate(size);
|
|
||||||
|
|
||||||
// message type de-multiplexer
|
|
||||||
if msg.len() < std::mem::size_of::<u32>() {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
match LittleEndian::read_u32(&msg[..]) {
|
Ok(v) => v,
|
||||||
handshake::TYPE_COOKIE_REPLY
|
};
|
||||||
| handshake::TYPE_INITIATION
|
msg.truncate(size);
|
||||||
| handshake::TYPE_RESPONSE => {
|
|
||||||
// update under_load flag
|
|
||||||
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
|
|
||||||
last_under_load = Instant::now();
|
|
||||||
wg.under_load.store(true, Ordering::SeqCst);
|
|
||||||
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
|
|
||||||
wg.under_load.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.queue
|
// message type de-multiplexer
|
||||||
.lock()
|
if msg.len() < std::mem::size_of::<u32>() {
|
||||||
.send(HandshakeJob::Message(msg, src))
|
continue;
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
router::TYPE_TRANSPORT => {
|
|
||||||
// transport message
|
|
||||||
let _ = wg.router.recv(src, msg);
|
|
||||||
}
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
match LittleEndian::read_u32(&msg[..]) {
|
||||||
}
|
handshake::TYPE_COOKIE_REPLY
|
||||||
*/
|
| handshake::TYPE_INITIATION
|
||||||
|
| handshake::TYPE_RESPONSE => {
|
||||||
|
// update under_load flag
|
||||||
|
if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
|
||||||
|
last_under_load = Instant::now();
|
||||||
|
wg.under_load.store(true, Ordering::SeqCst);
|
||||||
|
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
|
||||||
|
wg.under_load.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.queue
|
||||||
|
.lock()
|
||||||
|
.send(HandshakeJob::Message(msg, src))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
router::TYPE_TRANSPORT => {
|
||||||
|
// transport message
|
||||||
|
let _ = wg.router.recv(src, msg).map_err(|e| {
|
||||||
|
debug!("Failed to handle incoming transport message: {}", e);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
|
pub fn set_writer(&self, writer: B::Writer) {
|
||||||
|
// TODO: Consider unifying these and avoid Clone requirement on writer
|
||||||
|
*self.state.send.write() = Some(writer.clone());
|
||||||
|
self.state.router.set_outbound_writer(writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
|
||||||
// create device state
|
// create device state
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
@@ -292,14 +301,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
None
|
None
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
Ok((pk, msg, keypair)) => {
|
Ok((pk, resp, keypair)) => {
|
||||||
// send response
|
// send response
|
||||||
if let Some(msg) = msg {
|
let mut resp_len: u64 = 0;
|
||||||
|
if let Some(msg) = resp {
|
||||||
|
resp_len = msg.len() as u64;
|
||||||
let send: &Option<B::Writer> = &*wg.send.read();
|
let send: &Option<B::Writer> = &*wg.send.read();
|
||||||
if let Some(writer) = send.as_ref() {
|
if let Some(writer) = send.as_ref() {
|
||||||
let _ = writer.write(&msg[..], &src).map_err(|e| {
|
let _ = writer.write(&msg[..], &src).map_err(|e| {
|
||||||
debug!(
|
debug!(
|
||||||
"handshake worker, failed to send response, error = {:?}",
|
"handshake worker, failed to send response, error = {}",
|
||||||
e
|
e
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
@@ -308,16 +319,23 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
if let Some(pk) = pk {
|
if let Some(pk) = pk {
|
||||||
|
// authenticated handshake packet received
|
||||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
|
// add to rx_bytes and tx_bytes
|
||||||
|
let req_len = msg.len() as u64;
|
||||||
|
peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
|
||||||
|
peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
|
||||||
|
|
||||||
// update endpoint
|
// update endpoint
|
||||||
peer.router.set_endpoint(src);
|
peer.router.set_endpoint(src);
|
||||||
|
|
||||||
// add keypair to peer and free any unused ids
|
// add keypair to peer
|
||||||
if let Some(keypair) = keypair {
|
keypair.map(|kp| {
|
||||||
for id in peer.router.add_keypair(keypair) {
|
// free any unused ids
|
||||||
|
for id in peer.router.add_keypair(kp) {
|
||||||
state.device.release(id);
|
state.device.release(id);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,20 +343,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
HandshakeJob::New(pk) => {
|
HandshakeJob::New(pk) => {
|
||||||
let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
|
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
|
||||||
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
|
||||||
peer.router.send(&msg[..]);
|
let _ = peer.router.send(&msg[..]).map_err(|e| {
|
||||||
peer.timers.read().handshake_sent();
|
debug!("handshake worker, failed to send handshake initiation, error = {}", e)
|
||||||
}
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// start TUN read IO thread
|
// start TUN read IO threads (multiple threads to support multi-queue interfaces)
|
||||||
{
|
debug_assert!(
|
||||||
|
readers.len() > 0,
|
||||||
|
"attempted to create WG device without TUN readers"
|
||||||
|
);
|
||||||
|
while let Some(reader) = readers.pop() {
|
||||||
let wg = wg.clone();
|
let wg = wg.clone();
|
||||||
|
let mtu = mtu.clone();
|
||||||
thread::spawn(move || loop {
|
thread::spawn(move || loop {
|
||||||
// create vector big enough for any transport message (based on MTU)
|
// create vector big enough for any transport message (based on MTU)
|
||||||
let mtu = mtu.mtu();
|
let mtu = mtu.mtu();
|
||||||
|
|||||||
Reference in New Issue
Block a user