Bumped crate versions.

This commit is contained in:
Mathias Hall-Andersen
2020-02-12 21:38:25 +01:00
parent dcd567c08f
commit 5e6edb280e
12 changed files with 489 additions and 451 deletions

504
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -43,8 +43,9 @@ features = ["nightly"]
profiler = ["cpuprofiler"] profiler = ["cpuprofiler"]
start_up = [] start_up = []
[dev-dependencies] [dev-dependencies]
pnet = "^0.22" pnet = "0.25.0"
proptest = "0.9.4" proptest = "0.9.4"
rand_chacha = "0.2.1" rand_chacha = "0.2.1"

View File

@@ -62,14 +62,10 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireGuardConfig<T, B> {
/// Exposed configuration interface /// Exposed configuration interface
pub trait Configuration { pub trait Configuration {
fn up(&self, mtu: usize); fn up(&self, mtu: usize) -> Result<(), ConfigError>;
fn down(&self); fn down(&self);
fn start_listener(&self) -> Result<(), ConfigError>;
fn stop_listener(&self) -> Result<(), ConfigError>;
/// Updates the private key of the device /// Updates the private key of the device
/// ///
/// # Arguments /// # Arguments
@@ -196,13 +192,48 @@ pub trait Configuration {
fn get_fwmark(&self) -> Option<u32>; fn get_fwmark(&self) -> Option<u32>;
} }
fn start_listener<T: tun::Tun, B: udp::PlatformUDP>(
mut cfg: MutexGuard<Inner<T, B>>,
) -> Result<(), ConfigError> {
cfg.bind = None;
// create new listener
let (mut readers, writer, mut owner) = match B::bind(cfg.port) {
Ok(r) => r,
Err(_) => {
return Err(ConfigError::FailedToBind);
}
};
// set fwmark
let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle
// set writer on WireGuard
cfg.wireguard.set_writer(writer);
// add readers
while let Some(reader) = readers.pop() {
cfg.wireguard.add_udp_reader(reader);
}
// create new UDP state
cfg.bind = Some(owner);
Ok(())
}
impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> { impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn up(&self, mtu: usize) { fn up(&self, mtu: usize) -> Result<(), ConfigError> {
self.lock().wireguard.up(mtu); log::info!("configuration, set device up");
let cfg = self.lock();
cfg.wireguard.up(mtu);
start_listener(cfg)
} }
fn down(&self) { fn down(&self) {
self.lock().wireguard.down(); log::info!("configuration, set device down");
let mut cfg = self.lock();
cfg.wireguard.down();
cfg.bind = None;
} }
fn get_fwmark(&self) -> Option<u32> { fn get_fwmark(&self) -> Option<u32> {
@@ -210,6 +241,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
} }
fn set_private_key(&self, sk: Option<StaticSecret>) { fn set_private_key(&self, sk: Option<StaticSecret>) {
log::info!("configuration, set private key");
self.lock().wireguard.set_key(sk) self.lock().wireguard.set_key(sk)
} }
@@ -227,62 +259,23 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
st.bind.as_ref().map(|bind| bind.get_port()) st.bind.as_ref().map(|bind| bind.get_port())
} }
fn stop_listener(&self) -> Result<(), ConfigError> {
self.lock().bind = None;
Ok(())
}
fn start_listener(&self) -> Result<(), ConfigError> {
let mut cfg = self.lock();
// check if already listening
if cfg.bind.is_some() {
return Ok(());
}
// create new listener
let (mut readers, writer, mut owner) = match B::bind(cfg.port) {
Ok(r) => r,
Err(_) => {
return Err(ConfigError::FailedToBind);
}
};
// set fwmark
let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle
// set writer on wireguard
cfg.wireguard.set_writer(writer);
// add readers
while let Some(reader) = readers.pop() {
cfg.wireguard.add_udp_reader(reader);
}
// create new UDP state
cfg.bind = Some(owner);
Ok(())
}
fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port); log::trace!("Config, Set listen port: {:?}", port);
// update port and take old bind // update port and take old bind
let old: Option<B::Owner> = { let mut cfg = self.lock();
let mut cfg = self.lock(); let bound: bool = {
let old = mem::replace(&mut cfg.bind, None); let old = mem::replace(&mut cfg.bind, None);
cfg.port = port; cfg.port = port;
old old.is_some()
}; };
// restart listener if bound // restart listener if bound
if old.is_some() { if bound {
self.start_listener() start_listener(cfg)
} else { } else {
Ok(()) Ok(())
} }
// old bind is dropped, causing the file-descriptors to be released
} }
fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> { fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {

View File

@@ -145,25 +145,11 @@ fn main() {
} }
Ok(tun::TunEvent::Up(mtu)) => { Ok(tun::TunEvent::Up(mtu)) => {
log::info!("Tun up (mtu = {})", mtu); log::info!("Tun up (mtu = {})", mtu);
let _ = cfg.up(mtu); // TODO: handle
// bring the wireguard device up
cfg.up(mtu);
// start listening on UDP
let _ = cfg
.start_listener()
.map_err(|e| log::info!("Failed to start UDP listener: {}", e));
} }
Ok(tun::TunEvent::Down) => { Ok(tun::TunEvent::Down) => {
log::info!("Tun down"); log::info!("Tun down");
// set wireguard device down
cfg.down(); cfg.down();
// close UDP listener
let _ = cfg
.stop_listener()
.map_err(|e| log::info!("Failed to stop UDP listener {}", e));
} }
} }
}); });

View File

@@ -9,14 +9,17 @@ use std::mem;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::ptr; use std::ptr;
use std::sync::Arc;
fn errno() -> libc::c_int { pub struct FD(RawFd);
unsafe {
let ptr = libc::__errno_location(); impl Drop for FD {
if ptr.is_null() { fn drop(&mut self) {
0 if self.0 != -1 {
} else { log::debug!("linux udp, release fd (fd = {})", self.0);
*ptr unsafe {
libc::close(self.0);
};
} }
} }
} }
@@ -47,19 +50,19 @@ pub struct LinuxUDP();
pub struct LinuxOwner { pub struct LinuxOwner {
port: u16, port: u16,
sock4: Option<RawFd>, sock4: Option<Arc<FD>>,
sock6: Option<RawFd>, sock6: Option<Arc<FD>>,
} }
pub enum LinuxUDPReader { pub enum LinuxUDPReader {
V4(RawFd), V4(Arc<FD>),
V6(RawFd), V6(Arc<FD>),
} }
#[derive(Clone)] #[derive(Clone)]
pub struct LinuxUDPWriter { pub struct LinuxUDPWriter {
sock4: RawFd, sock4: Arc<FD>,
sock6: RawFd, sock6: Arc<FD>,
} }
pub enum LinuxEndpoint { pub enum LinuxEndpoint {
@@ -67,6 +70,67 @@ pub enum LinuxEndpoint {
V6(EndpointV6), V6(EndpointV6),
} }
fn errno() -> libc::c_int {
unsafe {
let ptr = libc::__errno_location();
if ptr.is_null() {
0
} else {
*ptr
}
}
}
fn setsockopt<V: Sized>(
fd: RawFd,
level: libc::c_int,
name: libc::c_int,
value: &V,
) -> Result<(), io::Error> {
let res = unsafe {
libc::setsockopt(
fd,
level,
name,
mem::transmute(value),
mem::size_of_val(value).try_into().unwrap(),
)
};
if res == 0 {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
format!("Failed to set sockopt (res = {}, errno = {})", res, errno()),
))
}
}
#[inline(always)]
fn setsockopt_int(
fd: RawFd,
level: libc::c_int,
name: libc::c_int,
value: libc::c_int,
) -> Result<(), io::Error> {
setsockopt(fd, level, name, &value)
}
#[allow(non_snake_case)]
const fn CMSG_ALIGN(len: usize) -> usize {
(((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1))
}
#[allow(non_snake_case)]
const fn CMSG_LEN(len: usize) -> usize {
CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>())
}
#[inline(always)]
fn safe_cast<T, D>(v: &mut T) -> *mut D {
(v as *mut T) as *mut D
}
impl Endpoint for LinuxEndpoint { impl Endpoint for LinuxEndpoint {
fn clear_src(&mut self) { fn clear_src(&mut self) {
match self { match self {
@@ -134,56 +198,6 @@ impl Endpoint for LinuxEndpoint {
} }
} }
fn setsockopt<V: Sized>(
fd: RawFd,
level: libc::c_int,
name: libc::c_int,
value: &V,
) -> Result<(), io::Error> {
let res = unsafe {
libc::setsockopt(
fd,
level,
name,
mem::transmute(value),
mem::size_of_val(value).try_into().unwrap(),
)
};
if res == 0 {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
format!("Failed to set sockopt (res = {}, errno = {})", res, errno()),
))
}
}
#[inline(always)]
fn setsockopt_int(
fd: RawFd,
level: libc::c_int,
name: libc::c_int,
value: libc::c_int,
) -> Result<(), io::Error> {
setsockopt(fd, level, name, &value)
}
#[allow(non_snake_case)]
const fn CMSG_ALIGN(len: usize) -> usize {
(((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1))
}
#[allow(non_snake_case)]
const fn CMSG_LEN(len: usize) -> usize {
CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>())
}
#[inline(always)]
fn safe_cast<T, D>(v: &mut T) -> *mut D {
(v as *mut T) as *mut D
}
impl LinuxUDPReader { impl LinuxUDPReader {
fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> {
log::trace!( log::trace!(
@@ -192,6 +206,8 @@ impl LinuxUDPReader {
buf.len() buf.len()
); );
debug_assert!(buf.len() > 0, "reading into empty buffer (will fail)");
let mut iovs: [libc::iovec; 1] = [libc::iovec { let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(), iov_len: buf.len(),
@@ -215,10 +231,16 @@ impl LinuxUDPReader {
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
if len < 0 { if len <= 0 {
// TODO: FIX!
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::NotConnected, io::ErrorKind::NotConnected,
"failed to receive", format!(
"Failed to receive (len = {}, fd = {}, errno = {})",
len,
fd,
errno()
),
)); ));
} }
@@ -238,6 +260,8 @@ impl LinuxUDPReader {
buf.len() buf.len()
); );
debug_assert!(buf.len() > 0, "reading into empty buffer (will fail)");
let mut iovs: [libc::iovec; 1] = [libc::iovec { let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(), iov_len: buf.len(),
@@ -261,10 +285,15 @@ impl LinuxUDPReader {
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
if len < 0 { if len <= 0 {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::NotConnected, io::ErrorKind::NotConnected,
"failed to receive", format!(
"failed to receive (len = {}, fd = {}, errno = {})",
len,
fd,
errno()
),
)); ));
} }
@@ -283,8 +312,8 @@ impl Reader<LinuxEndpoint> for LinuxUDPReader {
fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> { fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> {
match self { match self {
Self::V4(fd) => Self::read4(*fd, buf), Self::V4(fd) => Self::read4(fd.0, buf),
Self::V6(fd) => Self::read6(*fd, buf), Self::V6(fd) => Self::read6(fd.0, buf),
} }
} }
} }
@@ -426,8 +455,8 @@ impl Writer<LinuxEndpoint> for LinuxUDPWriter {
fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> { fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> {
match dst { match dst {
LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end), LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4.0, buf, end),
LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end), LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6.0, buf, end),
} }
} }
} }
@@ -448,21 +477,21 @@ impl Owner for LinuxOwner {
} }
} }
let value = value.unwrap_or(0); let value = value.unwrap_or(0);
set_mark(self.sock6, value)?; set_mark(self.sock6.as_ref().map(|fd| fd.0), value)?;
set_mark(self.sock4, value) set_mark(self.sock4.as_ref().map(|fd| fd.0), value)
} }
} }
impl Drop for LinuxOwner { impl Drop for LinuxOwner {
fn drop(&mut self) { fn drop(&mut self) {
log::trace!("closing the bind (port {})", self.port); log::debug!("closing the bind (port = {})", self.port);
self.sock4.map(|fd| unsafe { self.sock4.as_ref().map(|fd| unsafe {
libc::shutdown(fd, libc::SHUT_RDWR); log::debug!("shutdown IPv4 (fd = {})", fd.0);
libc::close(fd) libc::shutdown(fd.0, libc::SHUT_RDWR);
}); });
self.sock6.map(|fd| unsafe { self.sock6.as_ref().map(|fd| unsafe {
libc::shutdown(fd, libc::SHUT_RDWR); log::debug!("shutdown IPv6 (fd = {})", fd.0);
libc::close(fd) libc::shutdown(fd.0, libc::SHUT_RDWR);
}); });
} }
} }
@@ -491,7 +520,7 @@ impl LinuxUDP {
// create socket fd // create socket fd
let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) }; let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) };
if fd < 0 { if fd < 0 {
log::debug!("failed to create IPv6 socket"); log::debug!("failed to create IPv6 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -502,11 +531,13 @@ impl LinuxUDP {
setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, 1)?; setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, 1)?;
setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, 1)?; setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, 1)?;
const INADDR_ANY: libc::in6_addr = libc::in6_addr { s6_addr: [0; 16] };
// bind // bind
let mut sockaddr = libc::sockaddr_in6 { let mut sockaddr = libc::sockaddr_in6 {
sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, sin6_addr: INADDR_ANY,
sin6_family: libc::AF_INET6 as libc::sa_family_t, sin6_family: libc::AF_INET6 as libc::sa_family_t,
sin6_port: port.to_be(), // convert to network (big-endian) byteorder sin6_port: port.to_be(), // convert to network (big-endian) byte-order
sin6_scope_id: 0, sin6_scope_id: 0,
sin6_flowinfo: 0, sin6_flowinfo: 0,
}; };
@@ -514,13 +545,12 @@ impl LinuxUDP {
let err = unsafe { let err = unsafe {
libc::bind( libc::bind(
fd, fd,
mem::transmute(&sockaddr as *const libc::sockaddr_in6), safe_cast(&mut sockaddr),
mem::size_of_val(&sockaddr).try_into().unwrap(), mem::size_of_val(&sockaddr).try_into().unwrap(),
) )
}; };
if err != 0 { if err != 0 {
log::debug!("failed to bind IPv6 socket"); log::debug!("failed to bind IPv6 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -532,12 +562,12 @@ impl LinuxUDP {
let err = unsafe { let err = unsafe {
libc::getsockname( libc::getsockname(
fd, fd,
mem::transmute(&mut sockaddr as *mut libc::sockaddr_in6), safe_cast(&mut sockaddr),
&mut socklen as *mut libc::socklen_t, &mut socklen as *mut libc::socklen_t,
) )
}; };
if err != 0 { if err != 0 {
log::debug!("failed to get port of IPv6 socket"); log::debug!("failed to get port of IPv6 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -569,7 +599,7 @@ impl LinuxUDP {
// create socket fd // create socket fd
let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) }; let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) };
if fd < 0 { if fd < 0 {
log::trace!("failed to create IPv4 socket (errno = {})", errno()); log::debug!("failed to create IPv4 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -592,13 +622,12 @@ impl LinuxUDP {
let err = unsafe { let err = unsafe {
libc::bind( libc::bind(
fd, fd,
mem::transmute(&sockaddr as *const libc::sockaddr_in), safe_cast(&mut sockaddr),
mem::size_of_val(&sockaddr).try_into().unwrap(), mem::size_of_val(&sockaddr).try_into().unwrap(),
) )
}; };
if err != 0 { if err != 0 {
log::trace!("failed to bind IPv4 socket (errno = {})", errno()); log::debug!("failed to bind IPv4 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -610,12 +639,12 @@ impl LinuxUDP {
let err = unsafe { let err = unsafe {
libc::getsockname( libc::getsockname(
fd, fd,
mem::transmute(&mut sockaddr as *mut libc::sockaddr_in), safe_cast(&mut sockaddr),
&mut socklen as *mut libc::socklen_t, &mut socklen as *mut libc::socklen_t,
) )
}; };
if err != 0 { if err != 0 {
log::trace!("failed to get port of IPv4 socket (errno = {})", errno()); log::debug!("failed to get port of IPv4 socket (errno = {})", errno());
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"failed to create socket", "failed to create socket",
@@ -656,26 +685,30 @@ impl PlatformUDP for LinuxUDP {
return Err(bind6.unwrap_err()); return Err(bind6.unwrap_err());
} }
let sock6 = bind6.ok().map(|(_, fd)| fd); let sock6 = bind6.ok().map(|(_, fd)| Arc::new(FD(fd)));
let sock4 = bind4.ok().map(|(_, fd)| fd); let sock4 = bind4.ok().map(|(_, fd)| Arc::new(FD(fd)));
// create owner // create owner
let owner = LinuxOwner { let owner = LinuxOwner {
port, port,
sock6: sock6, sock6: sock6.clone(),
sock4: sock4, sock4: sock4.clone(),
}; };
// create readers // create readers
let mut readers: Vec<Self::Reader> = Vec::with_capacity(2); let mut readers: Vec<Self::Reader> = Vec::with_capacity(2);
sock6.map(|sock| readers.push(LinuxUDPReader::V6(sock))); sock6
sock4.map(|sock| readers.push(LinuxUDPReader::V4(sock))); .clone()
.map(|sock| readers.push(LinuxUDPReader::V6(sock)));
sock4
.clone()
.map(|sock| readers.push(LinuxUDPReader::V4(sock)));
debug_assert!(readers.len() > 0); debug_assert!(readers.len() > 0);
// create writer // create writer
let writer = LinuxUDPWriter { let writer = LinuxUDPWriter {
sock4: sock4.unwrap_or(-1), sock4: sock4.unwrap_or(Arc::new(FD(-1))),
sock6: sock6.unwrap_or(-1), sock6: sock6.unwrap_or(Arc::new(FD(-1))),
}; };
Ok((readers, writer, owner)) Ok((readers, writer, owner))

View File

@@ -7,7 +7,7 @@ pub trait Reader<E: Endpoint>: Send + Sync {
fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
} }
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static { pub trait Writer<E: Endpoint>: Send + Sync + 'static {
type Error: Error; type Error: Error;
fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>; fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>;

View File

@@ -193,6 +193,7 @@ impl<O> Device<O> {
opaque, opaque,
), ),
); );
Ok(()) Ok(())
} }
@@ -474,3 +475,39 @@ impl<O> Device<O> {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::collections::HashSet;
proptest! {
#[test]
fn unique_shared_secrets(sk_bs: [u8; 32], pk1_bs: [u8; 32], pk2_bs: [u8; 32]) {
let sk = StaticSecret::from(sk_bs);
let pk1 = PublicKey::from(pk1_bs);
let pk2 = PublicKey::from(pk2_bs);
assert_eq!(pk1.as_bytes(), &pk1_bs);
assert_eq!(pk2.as_bytes(), &pk2_bs);
let mut dev : Device<u32> = Device::new();
dev.set_sk(Some(sk));
dev.add(pk1, 1).unwrap();
if dev.add(pk2, 0).is_err() {
assert_eq!(pk1_bs, pk2_bs);
assert_eq!(*dev.get(&pk1).unwrap(), 1);
}
// every shared secret is unique
let mut ss: HashSet<[u8; 32]> = HashSet::new();
for peer in dev.pk_map.values() {
ss.insert(peer.ss);
}
assert_eq!(ss.len(), dev.len());
}
}
}

View File

@@ -5,7 +5,7 @@ use std::net::SocketAddr;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use rand::prelude::*; use rand::prelude::{CryptoRng, RngCore};
use x25519_dalek::PublicKey; use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret; use x25519_dalek::StaticSecret;

View File

@@ -311,4 +311,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
pub fn set_outbound_writer(&self, new: B) { pub fn set_outbound_writer(&self, new: B) {
self.state.outbound.write().1 = Some(new); self.state.outbound.write().1 = Some(new);
} }
pub fn write(&self, msg: &[u8], endpoint: &mut E) -> Result<(), RouterError> {
let outbound = self.state.outbound.read();
if outbound.0 {
outbound
.1
.as_ref()
.ok_or(RouterError::SendError)
.and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError))
} else {
Ok(())
}
}
} }

View File

@@ -162,7 +162,7 @@ mod tests {
}; };
let msg = make_packet_padded(1024, src, dst, 0); let msg = make_packet_padded(1024, src, dst, 0);
// every iteration sends 10 GB // every iteration sends 10 MB
b.iter(|| { b.iter(|| {
opaque.store(0, Ordering::SeqCst); opaque.store(0, Ordering::SeqCst);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {

View File

@@ -44,9 +44,6 @@ pub struct WireguardInner<T: Tun, B: UDP> {
// current MTU // current MTU
pub mtu: AtomicUsize, pub mtu: AtomicUsize,
// outbound writer
pub send: RwLock<Option<B::Writer>>,
// peer map // peer map
pub peers: RwLock<handshake::Device<Peer<T, B>>>, pub peers: RwLock<handshake::Device<Peer<T, B>>>,
@@ -134,7 +131,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// set mtu // set mtu
self.mtu.store(0, Ordering::Relaxed); self.mtu.store(0, Ordering::Relaxed);
// avoid tranmission from router // avoid transmission from router
self.router.down(); self.router.down();
// set all peers down (stops timers) // set all peers down (stops timers)
@@ -264,8 +261,6 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
} }
pub fn set_writer(&self, writer: B::Writer) { pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer
*self.send.write() = Some(writer.clone());
self.router.set_outbound_writer(writer); self.router.set_outbound_writer(writer);
} }
@@ -301,8 +296,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
id: OsRng.gen(), id: OsRng.gen(),
mtu: AtomicUsize::new(0), mtu: AtomicUsize::new(0),
last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0), pending: AtomicUsize::new(0),
peers: RwLock::new(handshake::Device::new()), peers: RwLock::new(handshake::Device::new()),
runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)),

View File

@@ -14,7 +14,6 @@ use super::tun::Reader as TunReader;
use super::tun::Tun; use super::tun::Tun;
use super::udp::Reader as UDPReader; use super::udp::Reader as UDPReader;
use super::udp::Writer as UDPWriter;
use super::udp::UDP; use super::udp::UDP;
// constants // constants
@@ -195,20 +194,12 @@ pub fn handshake_worker<T: Tun, B: UDP>(
let mut resp_len: u64 = 0; let mut resp_len: u64 = 0;
if let Some(msg) = resp { if let Some(msg) = resp {
resp_len = msg.len() as u64; resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read(); let _ = wg.router.write(&msg[..], &mut src).map_err(|e| {
if let Some(writer) = send.as_ref() {
debug!( debug!(
"{} : handshake worker, send response ({} bytes)", "{} : handshake worker, failed to send response, error = {}",
wg, resp_len wg, e
); );
let _ = writer.write(&msg[..], &mut src).map_err(|e| { });
debug!(
"{} : handshake worker, failed to send response, error = {}",
wg,
e
)
});
}
} }
// update peer state // update peer state