Restructure dummy implementations
This commit is contained in:
@@ -16,3 +16,5 @@ pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
|
|||||||
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
|
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
|
||||||
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
|
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
|
||||||
pub const TIMERS_CAPACITY: usize = 1024;
|
pub const TIMERS_CAPACITY: usize = 1024;
|
||||||
|
|
||||||
|
pub const MESSAGE_PADDING_MULTIPLE: usize = 16;
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ use hex;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
use std::cmp;
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
use byteorder::LittleEndian;
|
use byteorder::LittleEndian;
|
||||||
use zerocopy::byteorder::U32;
|
use zerocopy::byteorder::U32;
|
||||||
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
|
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
|
||||||
@@ -21,6 +24,16 @@ pub const TYPE_INITIATION: u32 = 1;
|
|||||||
pub const TYPE_RESPONSE: u32 = 2;
|
pub const TYPE_RESPONSE: u32 = 2;
|
||||||
pub const TYPE_COOKIE_REPLY: u32 = 3;
|
pub const TYPE_COOKIE_REPLY: u32 = 3;
|
||||||
|
|
||||||
|
const fn max(a: usize, b: usize) -> usize {
|
||||||
|
let m: usize = (a > b) as usize;
|
||||||
|
m * a + (1 - m) * b
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
|
||||||
|
max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
|
||||||
|
mem::size_of::<CookieReply>(),
|
||||||
|
);
|
||||||
|
|
||||||
/* Handshake messsages */
|
/* Handshake messsages */
|
||||||
|
|
||||||
#[repr(packed)]
|
#[repr(packed)]
|
||||||
|
|||||||
@@ -18,4 +18,4 @@ mod types;
|
|||||||
// publicly exposed interface
|
// publicly exposed interface
|
||||||
|
|
||||||
pub use device::Device;
|
pub use device::Device;
|
||||||
pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
|
||||||
|
|||||||
21
src/main.rs
21
src/main.rs
@@ -12,7 +12,24 @@ mod timers;
|
|||||||
mod types;
|
mod types;
|
||||||
mod wireguard;
|
mod wireguard;
|
||||||
|
|
||||||
#[test]
|
#[cfg(test)]
|
||||||
fn test_pure_wireguard() {}
|
mod tests {
|
||||||
|
use crate::types::{dummy, Bind};
|
||||||
|
use crate::wireguard::Wireguard;
|
||||||
|
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
fn init() {
|
||||||
|
let _ = env_logger::builder().is_test(true).try_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pure_wireguard() {
|
||||||
|
init();
|
||||||
|
let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
|
||||||
|
thread::sleep(Duration::from_millis(500));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {}
|
fn main() {}
|
||||||
|
|||||||
@@ -12,209 +12,13 @@ use num_cpus;
|
|||||||
use pnet::packet::ipv4::MutableIpv4Packet;
|
use pnet::packet::ipv4::MutableIpv4Packet;
|
||||||
use pnet::packet::ipv6::MutableIpv6Packet;
|
use pnet::packet::ipv6::MutableIpv6Packet;
|
||||||
|
|
||||||
use super::super::types::{Bind, Endpoint, Key, KeyPair, Tun};
|
use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
|
||||||
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
|
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
|
||||||
|
|
||||||
extern crate test;
|
extern crate test;
|
||||||
|
|
||||||
const SIZE_KEEPALIVE: usize = 32;
|
const SIZE_KEEPALIVE: usize = 32;
|
||||||
|
|
||||||
/* Error implementation */
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum BindError {
|
|
||||||
Disconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for BindError {
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Generic Bind Error"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for BindError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
BindError::Disconnected => write!(f, "PairBind disconnected"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* TUN implementation */
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum TunError {}
|
|
||||||
|
|
||||||
impl Error for TunError {
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Generic Tun Error"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for TunError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Not Possible")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Endpoint implementation */
|
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
struct UnitEndpoint {}
|
|
||||||
|
|
||||||
impl Endpoint for UnitEndpoint {
|
|
||||||
fn from_address(_: SocketAddr) -> UnitEndpoint {
|
|
||||||
UnitEndpoint {}
|
|
||||||
}
|
|
||||||
fn into_address(&self) -> SocketAddr {
|
|
||||||
"127.0.0.1:8080".parse().unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
struct TunTest {}
|
|
||||||
|
|
||||||
impl Tun for TunTest {
|
|
||||||
type Error = TunError;
|
|
||||||
|
|
||||||
fn mtu(&self) -> usize {
|
|
||||||
1500
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
|
|
||||||
Ok(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Bind implemenentations */
|
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
struct VoidBind {}
|
|
||||||
|
|
||||||
impl Bind for VoidBind {
|
|
||||||
type Error = BindError;
|
|
||||||
type Endpoint = UnitEndpoint;
|
|
||||||
|
|
||||||
fn new() -> VoidBind {
|
|
||||||
VoidBind {}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_port(&self) -> Option<u16> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
|
||||||
Ok((0, UnitEndpoint {}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct PairBind {
|
|
||||||
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
|
||||||
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Bind for PairBind {
|
|
||||||
type Error = BindError;
|
|
||||||
type Endpoint = UnitEndpoint;
|
|
||||||
|
|
||||||
fn new() -> PairBind {
|
|
||||||
PairBind {
|
|
||||||
send: Arc::new(Mutex::new(sync_channel(0).0)),
|
|
||||||
recv: Arc::new(Mutex::new(sync_channel(0).1)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_port(&self) -> Option<u16> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
|
||||||
let vec = self
|
|
||||||
.recv
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.recv()
|
|
||||||
.map_err(|_| BindError::Disconnected)?;
|
|
||||||
let len = vec.len();
|
|
||||||
buf[..len].copy_from_slice(&vec[..]);
|
|
||||||
Ok((vec.len(), UnitEndpoint {}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
|
||||||
let owned = buf.to_owned();
|
|
||||||
match self.send.lock().unwrap().send(owned) {
|
|
||||||
Err(_) => Err(BindError::Disconnected),
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bind_pair() -> (PairBind, PairBind) {
|
|
||||||
let (tx1, rx1) = sync_channel(128);
|
|
||||||
let (tx2, rx2) = sync_channel(128);
|
|
||||||
(
|
|
||||||
PairBind {
|
|
||||||
send: Arc::new(Mutex::new(tx1)),
|
|
||||||
recv: Arc::new(Mutex::new(rx2)),
|
|
||||||
},
|
|
||||||
PairBind {
|
|
||||||
send: Arc::new(Mutex::new(tx2)),
|
|
||||||
recv: Arc::new(Mutex::new(rx1)),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dummy_keypair(initiator: bool) -> KeyPair {
|
|
||||||
let k1 = Key {
|
|
||||||
key: [0x53u8; 32],
|
|
||||||
id: 0x646e6573,
|
|
||||||
};
|
|
||||||
let k2 = Key {
|
|
||||||
key: [0x52u8; 32],
|
|
||||||
id: 0x76636572,
|
|
||||||
};
|
|
||||||
if initiator {
|
|
||||||
KeyPair {
|
|
||||||
birth: Instant::now(),
|
|
||||||
initiator: true,
|
|
||||||
send: k1,
|
|
||||||
recv: k2,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
KeyPair {
|
|
||||||
birth: Instant::now(),
|
|
||||||
initiator: false,
|
|
||||||
send: k2,
|
|
||||||
recv: k1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -341,13 +145,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create device
|
// create device
|
||||||
let router: Device<BencherCallbacks, TunTest, VoidBind> =
|
let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
|
||||||
Device::new(num_cpus::get(), TunTest {}, VoidBind::new());
|
Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
|
||||||
|
|
||||||
// add new peer
|
// add new peer
|
||||||
let opaque = Arc::new(AtomicUsize::new(0));
|
let opaque = Arc::new(AtomicUsize::new(0));
|
||||||
let peer = router.new_peer(opaque.clone());
|
let peer = router.new_peer(opaque.clone());
|
||||||
peer.add_keypair(dummy_keypair(true));
|
peer.add_keypair(dummy::keypair(true));
|
||||||
|
|
||||||
// add subnet to peer
|
// add subnet to peer
|
||||||
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
|
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
|
||||||
@@ -370,7 +174,8 @@ mod tests {
|
|||||||
init();
|
init();
|
||||||
|
|
||||||
// create device
|
// create device
|
||||||
let router: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, VoidBind::new());
|
let router: Device<TestCallbacks, _, _> =
|
||||||
|
Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
|
||||||
|
|
||||||
let tests = vec![
|
let tests = vec![
|
||||||
("192.168.1.0", 24, "192.168.1.20", true),
|
("192.168.1.0", 24, "192.168.1.20", true),
|
||||||
@@ -404,9 +209,8 @@ mod tests {
|
|||||||
let opaque = Opaque::new();
|
let opaque = Opaque::new();
|
||||||
let peer = router.new_peer(opaque.clone());
|
let peer = router.new_peer(opaque.clone());
|
||||||
let mask: IpAddr = mask.parse().unwrap();
|
let mask: IpAddr = mask.parse().unwrap();
|
||||||
|
|
||||||
if set_key {
|
if set_key {
|
||||||
peer.add_keypair(dummy_keypair(true));
|
peer.add_keypair(dummy::keypair(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
// map subnet to peer
|
// map subnet to peer
|
||||||
@@ -512,9 +316,11 @@ mod tests {
|
|||||||
|
|
||||||
for (stage, p1, p2) in tests.iter() {
|
for (stage, p1, p2) in tests.iter() {
|
||||||
// create matching devices
|
// create matching devices
|
||||||
let (bind1, bind2) = bind_pair();
|
let (bind1, bind2) = dummy::PairBind::pair();
|
||||||
let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone());
|
let router1: Device<TestCallbacks, _, _> =
|
||||||
let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone());
|
Device::new(1, dummy::TunTest::new(), bind1.clone());
|
||||||
|
let router2: Device<TestCallbacks, _, _> =
|
||||||
|
Device::new(1, dummy::TunTest::new(), bind2.clone());
|
||||||
|
|
||||||
// prepare opaque values for tracing callbacks
|
// prepare opaque values for tracing callbacks
|
||||||
|
|
||||||
@@ -527,7 +333,7 @@ mod tests {
|
|||||||
let peer1 = router1.new_peer(opaq1.clone());
|
let peer1 = router1.new_peer(opaq1.clone());
|
||||||
let mask: IpAddr = mask.parse().unwrap();
|
let mask: IpAddr = mask.parse().unwrap();
|
||||||
peer1.add_subnet(mask, *len);
|
peer1.add_subnet(mask, *len);
|
||||||
peer1.add_keypair(dummy_keypair(false));
|
peer1.add_keypair(dummy::keypair(false));
|
||||||
|
|
||||||
let (mask, len, _ip, _okay) = p2;
|
let (mask, len, _ip, _okay) = p2;
|
||||||
let peer2 = router2.new_peer(opaq2.clone());
|
let peer2 = router2.new_peer(opaq2.clone());
|
||||||
@@ -557,7 +363,7 @@ mod tests {
|
|||||||
// this should cause a key-confirmation packet (keepalive or staged packet)
|
// this should cause a key-confirmation packet (keepalive or staged packet)
|
||||||
// this also causes peer1 to learn the "endpoint" for peer2
|
// this also causes peer1 to learn the "endpoint" for peer2
|
||||||
assert!(peer1.get_endpoint().is_none());
|
assert!(peer1.get_endpoint().is_none());
|
||||||
peer2.add_keypair(dummy_keypair(true));
|
peer2.add_keypair(dummy::keypair(true));
|
||||||
|
|
||||||
wait();
|
wait();
|
||||||
assert!(opaq2.send().is_some());
|
assert!(opaq2.send().is_some());
|
||||||
|
|||||||
217
src/types/dummy.rs
Normal file
217
src/types/dummy.rs
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use super::{Bind, Endpoint, Key, KeyPair, Tun};
|
||||||
|
|
||||||
|
/* This submodule provides pure/dummy implementations of the IO interfaces
|
||||||
|
* for use in unit tests thoughout the project.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* Error implementation */
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum BindError {
|
||||||
|
Disconnected,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for BindError {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Generic Bind Error"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for BindError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
BindError::Disconnected => write!(f, "PairBind disconnected"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* TUN implementation */
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum TunError {}
|
||||||
|
|
||||||
|
impl Error for TunError {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Generic Tun Error"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for TunError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Not Possible")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Endpoint implementation */
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct UnitEndpoint {}
|
||||||
|
|
||||||
|
impl Endpoint for UnitEndpoint {
|
||||||
|
fn from_address(_: SocketAddr) -> UnitEndpoint {
|
||||||
|
UnitEndpoint {}
|
||||||
|
}
|
||||||
|
fn into_address(&self) -> SocketAddr {
|
||||||
|
"127.0.0.1:8080".parse().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct TunTest {}
|
||||||
|
|
||||||
|
impl Tun for TunTest {
|
||||||
|
type Error = TunError;
|
||||||
|
|
||||||
|
fn mtu(&self) -> usize {
|
||||||
|
1500
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunTest {
|
||||||
|
pub fn new() -> TunTest {
|
||||||
|
TunTest {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Bind implemenentations */
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct VoidBind {}
|
||||||
|
|
||||||
|
impl Bind for VoidBind {
|
||||||
|
type Error = BindError;
|
||||||
|
type Endpoint = UnitEndpoint;
|
||||||
|
|
||||||
|
fn new() -> VoidBind {
|
||||||
|
VoidBind {}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_port(&self) -> Option<u16> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
||||||
|
Ok((0, UnitEndpoint {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PairBind {
|
||||||
|
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
|
||||||
|
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PairBind {
|
||||||
|
pub fn pair() -> (PairBind, PairBind) {
|
||||||
|
let (tx1, rx1) = sync_channel(128);
|
||||||
|
let (tx2, rx2) = sync_channel(128);
|
||||||
|
(
|
||||||
|
PairBind {
|
||||||
|
send: Arc::new(Mutex::new(tx1)),
|
||||||
|
recv: Arc::new(Mutex::new(rx2)),
|
||||||
|
},
|
||||||
|
PairBind {
|
||||||
|
send: Arc::new(Mutex::new(tx2)),
|
||||||
|
recv: Arc::new(Mutex::new(rx1)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bind for PairBind {
|
||||||
|
type Error = BindError;
|
||||||
|
type Endpoint = UnitEndpoint;
|
||||||
|
|
||||||
|
fn new() -> PairBind {
|
||||||
|
PairBind {
|
||||||
|
send: Arc::new(Mutex::new(sync_channel(0).0)),
|
||||||
|
recv: Arc::new(Mutex::new(sync_channel(0).1)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_port(&self) -> Option<u16> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
|
||||||
|
let vec = self
|
||||||
|
.recv
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.recv()
|
||||||
|
.map_err(|_| BindError::Disconnected)?;
|
||||||
|
let len = vec.len();
|
||||||
|
buf[..len].copy_from_slice(&vec[..]);
|
||||||
|
Ok((vec.len(), UnitEndpoint {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
|
||||||
|
let owned = buf.to_owned();
|
||||||
|
match self.send.lock().unwrap().send(owned) {
|
||||||
|
Err(_) => Err(BindError::Disconnected),
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn keypair(initiator: bool) -> KeyPair {
|
||||||
|
let k1 = Key {
|
||||||
|
key: [0x53u8; 32],
|
||||||
|
id: 0x646e6573,
|
||||||
|
};
|
||||||
|
let k2 = Key {
|
||||||
|
key: [0x52u8; 32],
|
||||||
|
id: 0x76636572,
|
||||||
|
};
|
||||||
|
if initiator {
|
||||||
|
KeyPair {
|
||||||
|
birth: Instant::now(),
|
||||||
|
initiator: true,
|
||||||
|
send: k1,
|
||||||
|
recv: k2,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
KeyPair {
|
||||||
|
birth: Instant::now(),
|
||||||
|
initiator: false,
|
||||||
|
send: k2,
|
||||||
|
recv: k1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,9 @@ mod keys;
|
|||||||
mod tun;
|
mod tun;
|
||||||
mod udp;
|
mod udp;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod dummy;
|
||||||
|
|
||||||
pub use endpoint::Endpoint;
|
pub use endpoint::Endpoint;
|
||||||
pub use keys::{Key, KeyPair};
|
pub use keys::{Key, KeyPair};
|
||||||
pub use tun::Tun;
|
pub use tun::Tun;
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun};
|
|||||||
|
|
||||||
use hjul::Runner;
|
use hjul::Runner;
|
||||||
|
|
||||||
|
use std::cmp;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -86,8 +87,19 @@ pub struct Wireguard<T: Tun, B: Bind> {
|
|||||||
state: Arc<WireguardInner<T, B>>,
|
state: Arc<WireguardInner<T, B>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
const fn padding(size: usize, mtu: usize) -> usize {
|
||||||
|
#[inline(always)]
|
||||||
|
const fn min(a: usize, b: usize) -> usize {
|
||||||
|
let m = (a > b) as usize;
|
||||||
|
a * m + (1 - m) * b
|
||||||
|
}
|
||||||
|
let pad = MESSAGE_PADDING_MULTIPLE;
|
||||||
|
min(mtu, size + (pad - size % pad) % pad)
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Tun, B: Bind> Wireguard<T, B> {
|
impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||||
fn set_key(&self, sk: Option<StaticSecret>) {
|
pub fn set_key(&self, sk: Option<StaticSecret>) {
|
||||||
let mut handshake = self.state.handshake.write();
|
let mut handshake = self.state.handshake.write();
|
||||||
match sk {
|
match sk {
|
||||||
None => {
|
None => {
|
||||||
@@ -102,7 +114,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
|
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
|
||||||
let state = Arc::new(PeerInner {
|
let state = Arc::new(PeerInner {
|
||||||
pk,
|
pk,
|
||||||
queue: Mutex::new(self.state.queue.lock().clone()),
|
queue: Mutex::new(self.state.queue.lock().clone()),
|
||||||
@@ -111,11 +123,21 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
tx_bytes: AtomicU64::new(0),
|
tx_bytes: AtomicU64::new(0),
|
||||||
timers: RwLock::new(Timers::dummy(&self.runner)),
|
timers: RwLock::new(Timers::dummy(&self.runner)),
|
||||||
});
|
});
|
||||||
|
|
||||||
let router = Arc::new(self.state.router.new_peer(state.clone()));
|
let router = Arc::new(self.state.router.new_peer(state.clone()));
|
||||||
Peer { router, state }
|
|
||||||
|
let peer = Peer { router, state };
|
||||||
|
|
||||||
|
/* The need for dummy timers arises from the chicken-egg
|
||||||
|
* problem of the timer callbacks being able to set timers themselves.
|
||||||
|
*
|
||||||
|
* This is in fact the only place where the write lock is ever taken.
|
||||||
|
*/
|
||||||
|
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
|
||||||
|
peer
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
|
||||||
// create device state
|
// create device state
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
@@ -215,10 +237,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// read UDP packet into vector
|
// create vector big enough for any message given current MTU
|
||||||
let size = tun.mtu() + 148; // maximum message size
|
let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
|
||||||
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
let mut msg: Vec<u8> = Vec::with_capacity(size);
|
||||||
msg.resize(size, 0);
|
msg.resize(size, 0);
|
||||||
|
|
||||||
|
// read UDP packet into vector
|
||||||
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
|
||||||
msg.truncate(size);
|
msg.truncate(size);
|
||||||
|
|
||||||
@@ -226,7 +250,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
if msg.len() < std::mem::size_of::<u32>() {
|
if msg.len() < std::mem::size_of::<u32>() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
match LittleEndian::read_u32(&msg[..]) {
|
match LittleEndian::read_u32(&msg[..]) {
|
||||||
handshake::TYPE_COOKIE_REPLY
|
handshake::TYPE_COOKIE_REPLY
|
||||||
| handshake::TYPE_INITIATION
|
| handshake::TYPE_INITIATION
|
||||||
@@ -246,9 +269,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
}
|
}
|
||||||
router::TYPE_TRANSPORT => {
|
router::TYPE_TRANSPORT => {
|
||||||
// transport message
|
// transport message
|
||||||
|
|
||||||
// pad the message
|
|
||||||
|
|
||||||
let _ = wg.router.recv(src, msg);
|
let _ = wg.router.recv(src, msg);
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
@@ -261,20 +281,32 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
|||||||
{
|
{
|
||||||
let wg = wg.clone();
|
let wg = wg.clone();
|
||||||
thread::spawn(move || loop {
|
thread::spawn(move || loop {
|
||||||
// read a new IP packet
|
// create vector big enough for any transport message (based on MTU)
|
||||||
let mtu = tun.mtu();
|
let mtu = tun.mtu();
|
||||||
let size = mtu + 148;
|
let size = mtu + router::SIZE_MESSAGE_PREFIX;
|
||||||
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
||||||
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
msg.resize(size, 0);
|
||||||
msg.truncate(size);
|
|
||||||
|
|
||||||
// pad message to multiple of 16 bytes
|
// read a new IP packet
|
||||||
while msg.len() < mtu && msg.len() % 16 != 0 {
|
let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
|
||||||
msg.push(0);
|
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
|
||||||
}
|
|
||||||
|
// truncate padding
|
||||||
|
let payload = padding(payload, mtu);
|
||||||
|
msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
|
||||||
|
debug_assert!(payload <= mtu);
|
||||||
|
debug_assert_eq!(
|
||||||
|
if payload < mtu {
|
||||||
|
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
},
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
// crypt-key route
|
// crypt-key route
|
||||||
let _ = wg.router.send(msg);
|
let e = wg.router.send(msg);
|
||||||
|
debug!("TUN worker, router returned {:?}", e);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user