Work on netlink IF event code for Linux
This commit is contained in:
@@ -261,7 +261,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
|
|||||||
|
|
||||||
// add readers
|
// add readers
|
||||||
while let Some(reader) = readers.pop() {
|
while let Some(reader) = readers.pop() {
|
||||||
cfg.wireguard.add_reader(reader);
|
cfg.wireguard.add_udp_reader(reader);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new UDP state
|
// create new UDP state
|
||||||
|
|||||||
21
src/main.rs
21
src/main.rs
@@ -26,7 +26,7 @@ fn main() {
|
|||||||
let mut foreground = false;
|
let mut foreground = false;
|
||||||
let mut args = env::args();
|
let mut args = env::args();
|
||||||
|
|
||||||
args.next(); // skip path
|
args.next(); // skip path (argv[0])
|
||||||
|
|
||||||
for arg in args {
|
for arg in args {
|
||||||
match arg.as_str() {
|
match arg.as_str() {
|
||||||
@@ -56,7 +56,7 @@ fn main() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// create TUN device
|
// create TUN device
|
||||||
let (readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| {
|
let (mut readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| {
|
||||||
eprintln!("Failed to create TUN device: {}", e);
|
eprintln!("Failed to create TUN device: {}", e);
|
||||||
exit(-3);
|
exit(-3);
|
||||||
});
|
});
|
||||||
@@ -82,7 +82,15 @@ fn main() {
|
|||||||
if drop_privileges {}
|
if drop_privileges {}
|
||||||
|
|
||||||
// create WireGuard device
|
// create WireGuard device
|
||||||
let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer);
|
let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(writer);
|
||||||
|
|
||||||
|
// add all Tun readers
|
||||||
|
while let Some(reader) = readers.pop() {
|
||||||
|
wg.add_tun_reader(reader);
|
||||||
|
}
|
||||||
|
|
||||||
|
// obtain handle for waiting
|
||||||
|
let wait = wg.wait();
|
||||||
|
|
||||||
// wrap in configuration interface
|
// wrap in configuration interface
|
||||||
let cfg = configuration::WireguardConfig::new(wg);
|
let cfg = configuration::WireguardConfig::new(wg);
|
||||||
@@ -124,7 +132,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// start UAPI server
|
// start UAPI server
|
||||||
loop {
|
thread::spawn(move || loop {
|
||||||
match uapi.connect() {
|
match uapi.connect() {
|
||||||
Ok(mut stream) => {
|
Ok(mut stream) => {
|
||||||
let cfg = cfg.clone();
|
let cfg = cfg.clone();
|
||||||
@@ -137,5 +145,8 @@ fn main() {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
|
||||||
|
// block until all tun readers closed
|
||||||
|
wait.wait();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use super::super::tun::*;
|
use super::super::tun::*;
|
||||||
|
|
||||||
use libc::*;
|
use libc;
|
||||||
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use std::mem;
|
||||||
use std::os::raw::c_short;
|
use std::os::raw::c_short;
|
||||||
use std::os::unix::io::RawFd;
|
use std::os::unix::io::RawFd;
|
||||||
use std::thread;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
const IFNAMSIZ: usize = 16;
|
const IFNAMSIZ: usize = 16;
|
||||||
const TUNSETIFF: u64 = 0x4004_54ca;
|
const TUNSETIFF: u64 = 0x4004_54ca;
|
||||||
@@ -30,6 +29,18 @@ struct Ifreq {
|
|||||||
_pad: [u8; 64],
|
_pad: [u8; 64],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// man 7 rtnetlink
|
||||||
|
// Layout from: https://elixir.bootlin.com/linux/latest/source/include/uapi/linux/rtnetlink.h#L516
|
||||||
|
#[repr(C)]
|
||||||
|
struct IfInfomsg {
|
||||||
|
ifi_family: libc::c_uchar,
|
||||||
|
__ifi_pad: libc::c_uchar,
|
||||||
|
ifi_type: libc::c_ushort,
|
||||||
|
ifi_index: libc::c_int,
|
||||||
|
ifi_flags: libc::c_uint,
|
||||||
|
ifi_change: libc::c_uint,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct LinuxTun {
|
pub struct LinuxTun {
|
||||||
events: Vec<TunEvent>,
|
events: Vec<TunEvent>,
|
||||||
}
|
}
|
||||||
@@ -42,12 +53,9 @@ pub struct LinuxTunWriter {
|
|||||||
fd: RawFd,
|
fd: RawFd,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Listens for netlink messages
|
|
||||||
* announcing an MTU update for the interface
|
|
||||||
*/
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct LinuxTunStatus {
|
pub struct LinuxTunStatus {
|
||||||
first: bool,
|
events: Vec<TunEvent>,
|
||||||
|
fd: RawFd,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -94,7 +102,7 @@ impl Reader for LinuxTunReader {
|
|||||||
);
|
);
|
||||||
*/
|
*/
|
||||||
let n: isize =
|
let n: isize =
|
||||||
unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
|
unsafe { libc::read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
Err(LinuxTunError::Closed)
|
Err(LinuxTunError::Closed)
|
||||||
} else {
|
} else {
|
||||||
@@ -108,7 +116,7 @@ impl Writer for LinuxTunWriter {
|
|||||||
type Error = LinuxTunError;
|
type Error = LinuxTunError;
|
||||||
|
|
||||||
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
|
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
|
||||||
match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } {
|
match unsafe { libc::write(self.fd, src.as_ptr() as _, src.len() as _) } {
|
||||||
-1 => Err(LinuxTunError::Closed),
|
-1 => Err(LinuxTunError::Closed),
|
||||||
_ => Ok(()),
|
_ => Ok(()),
|
||||||
}
|
}
|
||||||
@@ -119,13 +127,124 @@ impl Status for LinuxTunStatus {
|
|||||||
type Error = LinuxTunError;
|
type Error = LinuxTunError;
|
||||||
|
|
||||||
fn event(&mut self) -> Result<TunEvent, Self::Error> {
|
fn event(&mut self) -> Result<TunEvent, Self::Error> {
|
||||||
if self.first {
|
const DONE: u16 = libc::NLMSG_DONE as u16;
|
||||||
self.first = false;
|
const ERROR: u16 = libc::NLMSG_ERROR as u16;
|
||||||
return Ok(TunEvent::Up(1420));
|
const INFO_SIZE: usize = mem::size_of::<IfInfomsg>();
|
||||||
|
const HDR_SIZE: usize = mem::size_of::<libc::nlmsghdr>();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 1 << 12];
|
||||||
|
log::debug!("netlink, fetch event (fd = {})", self.fd);
|
||||||
|
loop {
|
||||||
|
// attempt to return a buffered event
|
||||||
|
if let Some(event) = self.events.pop() {
|
||||||
|
return Ok(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read message
|
||||||
|
let size: libc::ssize_t =
|
||||||
|
unsafe { libc::recv(self.fd, mem::transmute(&mut buf), buf.len(), 0) };
|
||||||
|
if size < 0 {
|
||||||
|
break Err(LinuxTunError::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// cut buffer to size
|
||||||
|
let size: usize = size as usize;
|
||||||
|
let mut remain = &buf[..size];
|
||||||
|
log::debug!("netlink, recieved message ({} bytes)", size);
|
||||||
|
|
||||||
|
// handle messages
|
||||||
|
while remain.len() >= HDR_SIZE {
|
||||||
|
// extract the header
|
||||||
|
assert!(remain.len() > HDR_SIZE);
|
||||||
|
let mut hdr = [0u8; HDR_SIZE];
|
||||||
|
hdr.copy_from_slice(&remain[..HDR_SIZE]);
|
||||||
|
let hdr: libc::nlmsghdr = unsafe { mem::transmute(hdr) };
|
||||||
|
|
||||||
|
// upcast length
|
||||||
|
let body: &[u8] = &remain[HDR_SIZE..];
|
||||||
|
let msg_len: usize = hdr.nlmsg_len as usize;
|
||||||
|
assert!(msg_len <= remain.len(), "malformed netlink message");
|
||||||
|
|
||||||
|
// handle message body
|
||||||
|
match hdr.nlmsg_type {
|
||||||
|
DONE => break,
|
||||||
|
ERROR => break,
|
||||||
|
libc::RTM_NEWLINK => {
|
||||||
|
// extract info struct
|
||||||
|
if body.len() < INFO_SIZE {
|
||||||
|
return Err(LinuxTunError::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut info = [0u8; INFO_SIZE];
|
||||||
|
info.copy_from_slice(&body[..INFO_SIZE]);
|
||||||
|
log::debug!("netlink, RTM_NEWLINK {:?}", &info[..]);
|
||||||
|
let info: IfInfomsg = unsafe { mem::transmute(info) };
|
||||||
|
|
||||||
|
// trace log
|
||||||
|
log::trace!(
|
||||||
|
"netlink, IfInfomsg{{ family = {}, type = {}, index = {}, flags = {}, change = {}}}",
|
||||||
|
info.ifi_family,
|
||||||
|
info.ifi_type,
|
||||||
|
info.ifi_index,
|
||||||
|
info.ifi_flags,
|
||||||
|
info.ifi_change,
|
||||||
|
);
|
||||||
|
debug_assert_eq!(info.__ifi_pad, 0);
|
||||||
|
|
||||||
|
// handle up / down
|
||||||
|
if info.ifi_flags & (libc::IFF_UP as u32) != 0 {
|
||||||
|
log::trace!("netlink, up event");
|
||||||
|
self.events.push(TunEvent::Up(1420));
|
||||||
|
} else {
|
||||||
|
log::trace!("netlink, down event");
|
||||||
|
self.events.push(TunEvent::Down);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
};
|
||||||
|
|
||||||
|
// go to next message
|
||||||
|
remain = &remain[msg_len..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LinuxTunStatus {
|
||||||
|
const RTNLGRP_LINK: libc::c_uint = 1;
|
||||||
|
const RTNLGRP_IPV4_IFADDR: libc::c_uint = 5;
|
||||||
|
const RTNLGRP_IPV6_IFADDR: libc::c_uint = 9;
|
||||||
|
|
||||||
|
fn new() -> Result<LinuxTunStatus, LinuxTunError> {
|
||||||
|
// create netlink socket
|
||||||
|
let fd = unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE) };
|
||||||
|
if fd < 0 {
|
||||||
|
return Err(LinuxTunError::Closed);
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
// prepare address (specify groups)
|
||||||
thread::sleep(Duration::from_secs(60 * 60));
|
let groups = (1 << (Self::RTNLGRP_LINK - 1))
|
||||||
|
| (1 << (Self::RTNLGRP_IPV4_IFADDR - 1))
|
||||||
|
| (1 << (Self::RTNLGRP_IPV6_IFADDR - 1));
|
||||||
|
|
||||||
|
let mut sockaddr: libc::sockaddr_nl = unsafe { mem::zeroed() };
|
||||||
|
sockaddr.nl_family = libc::AF_NETLINK as u16;
|
||||||
|
sockaddr.nl_groups = groups;
|
||||||
|
sockaddr.nl_pid = 0;
|
||||||
|
|
||||||
|
// attempt to bind
|
||||||
|
let res = unsafe {
|
||||||
|
libc::bind(
|
||||||
|
fd,
|
||||||
|
mem::transmute(&mut sockaddr),
|
||||||
|
mem::size_of::<libc::sockaddr_nl>() as u32,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if res != 0 {
|
||||||
|
Err(LinuxTunError::Closed)
|
||||||
|
} else {
|
||||||
|
Ok(LinuxTunStatus { events: vec![], fd })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -155,14 +274,14 @@ impl PlatformTun for LinuxTun {
|
|||||||
req.name[..bs.len()].copy_from_slice(bs);
|
req.name[..bs.len()].copy_from_slice(bs);
|
||||||
|
|
||||||
// open clone device
|
// open clone device
|
||||||
let fd: RawFd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } {
|
let fd: RawFd = match unsafe { libc::open(CLONE_DEVICE_PATH.as_ptr() as _, libc::O_RDWR) } {
|
||||||
-1 => return Err(LinuxTunError::FailedToOpenCloneDevice),
|
-1 => return Err(LinuxTunError::FailedToOpenCloneDevice),
|
||||||
fd => fd,
|
fd => fd,
|
||||||
};
|
};
|
||||||
assert!(fd >= 0);
|
assert!(fd >= 0);
|
||||||
|
|
||||||
// create TUN device
|
// create TUN device
|
||||||
if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 {
|
if unsafe { libc::ioctl(fd, TUNSETIFF as _, &req) } < 0 {
|
||||||
return Err(LinuxTunError::SetIFFIoctlFailed);
|
return Err(LinuxTunError::SetIFFIoctlFailed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,7 +289,7 @@ impl PlatformTun for LinuxTun {
|
|||||||
Ok((
|
Ok((
|
||||||
vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
|
vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
|
||||||
LinuxTunWriter { fd },
|
LinuxTunWriter { fd },
|
||||||
LinuxTunStatus { first: true },
|
LinuxTunStatus::new()?,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,15 +85,13 @@ fn test_pure_wireguard() {
|
|||||||
// create WG instances for dummy TUN devices
|
// create WG instances for dummy TUN devices
|
||||||
|
|
||||||
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
|
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
|
||||||
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
|
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer1);
|
||||||
Wireguard::new(vec![tun_reader1], tun_writer1);
|
wg1.add_tun_reader(tun_reader1);
|
||||||
|
|
||||||
wg1.up(1500);
|
wg1.up(1500);
|
||||||
|
|
||||||
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
|
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
|
||||||
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
|
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer2);
|
||||||
Wireguard::new(vec![tun_reader2], tun_writer2);
|
wg2.add_tun_reader(tun_reader2);
|
||||||
|
|
||||||
wg2.up(1500);
|
wg2.up(1500);
|
||||||
|
|
||||||
// create pair bind to connect the interfaces "over the internet"
|
// create pair bind to connect the interfaces "over the internet"
|
||||||
@@ -103,8 +101,8 @@ fn test_pure_wireguard() {
|
|||||||
wg1.set_writer(bind_writer1);
|
wg1.set_writer(bind_writer1);
|
||||||
wg2.set_writer(bind_writer2);
|
wg2.set_writer(bind_writer2);
|
||||||
|
|
||||||
wg1.add_reader(bind_reader1);
|
wg1.add_udp_reader(bind_reader1);
|
||||||
wg2.add_reader(bind_reader2);
|
wg2.add_udp_reader(bind_reader2);
|
||||||
|
|
||||||
// generate (public, pivate) key pairs
|
// generate (public, pivate) key pairs
|
||||||
|
|
||||||
|
|||||||
@@ -221,14 +221,14 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
|
|||||||
|
|
||||||
|
|
||||||
impl Timers {
|
impl Timers {
|
||||||
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
|
pub fn new<T, B>(runner: &Runner, running: bool, peer: Peer<T, B>) -> Timers
|
||||||
where
|
where
|
||||||
T: tun::Tun,
|
T: tun::Tun,
|
||||||
B: udp::UDP,
|
B: udp::UDP,
|
||||||
{
|
{
|
||||||
// create a timer instance for the provided peer
|
// create a timer instance for the provided peer
|
||||||
Timers {
|
Timers {
|
||||||
enabled: true,
|
enabled: running,
|
||||||
keepalive_interval: 0, // disabled
|
keepalive_interval: 0, // disabled
|
||||||
need_another_keepalive: AtomicBool::new(false),
|
need_another_keepalive: AtomicBool::new(false),
|
||||||
sent_lastminute_handshake: AtomicBool::new(false),
|
sent_lastminute_handshake: AtomicBool::new(false),
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ use std::sync::Arc;
|
|||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
// TODO: avoid
|
||||||
|
use std::sync::Condvar;
|
||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@@ -38,15 +42,51 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
|
|||||||
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
|
||||||
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WaitHandle(Arc<(StdMutex<usize>, Condvar)>);
|
||||||
|
|
||||||
|
impl WaitHandle {
|
||||||
|
pub fn wait(&self) {
|
||||||
|
let (lock, cvar) = &*self.0;
|
||||||
|
let mut nread = lock.lock().unwrap();
|
||||||
|
while *nread > 0 {
|
||||||
|
nread = cvar.wait(nread).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new() -> Self {
|
||||||
|
Self(Arc::new((StdMutex::new(0), Condvar::new())))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrease(&self) {
|
||||||
|
let (lock, cvar) = &*self.0;
|
||||||
|
let mut nread = lock.lock().unwrap();
|
||||||
|
assert!(*nread > 0);
|
||||||
|
*nread -= 1;
|
||||||
|
cvar.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increase(&self) {
|
||||||
|
let (lock, _) = &*self.0;
|
||||||
|
let mut nread = lock.lock().unwrap();
|
||||||
|
*nread += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct WireguardInner<T: tun::Tun, B: udp::UDP> {
|
pub struct WireguardInner<T: tun::Tun, B: udp::UDP> {
|
||||||
// identifier (for logging)
|
// identifier (for logging)
|
||||||
id: u32,
|
id: u32,
|
||||||
start: Instant,
|
|
||||||
|
// device enabled
|
||||||
|
enabled: RwLock<bool>,
|
||||||
|
|
||||||
|
// enables waiting for all readers to finish
|
||||||
|
tun_readers: WaitHandle,
|
||||||
|
|
||||||
// current MTU
|
// current MTU
|
||||||
mtu: AtomicUsize,
|
mtu: AtomicUsize,
|
||||||
|
|
||||||
// provides access to the MTU value of the tun device
|
// outbound writer
|
||||||
send: RwLock<Option<B::Writer>>,
|
send: RwLock<Option<B::Writer>>,
|
||||||
|
|
||||||
// identity and configuration map
|
// identity and configuration map
|
||||||
@@ -145,7 +185,12 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
/// on both ends of the device.
|
/// on both ends of the device.
|
||||||
pub fn down(&self) {
|
pub fn down(&self) {
|
||||||
// ensure exclusive access (to avoid race with "up" call)
|
// ensure exclusive access (to avoid race with "up" call)
|
||||||
let peers = self.peers.write();
|
let mut enabled = self.enabled.write();
|
||||||
|
|
||||||
|
// check if already down
|
||||||
|
if *enabled == false {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// set mtu
|
// set mtu
|
||||||
self.state.mtu.store(0, Ordering::Relaxed);
|
self.state.mtu.store(0, Ordering::Relaxed);
|
||||||
@@ -154,27 +199,36 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
self.router.down();
|
self.router.down();
|
||||||
|
|
||||||
// set all peers down (stops timers)
|
// set all peers down (stops timers)
|
||||||
for peer in peers.values() {
|
for peer in self.peers.write().values() {
|
||||||
peer.down();
|
peer.down();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*enabled = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Brings the WireGuard device up.
|
/// Brings the WireGuard device up.
|
||||||
/// Usually called when the associated interface is brought up.
|
/// Usually called when the associated interface is brought up.
|
||||||
pub fn up(&self, mtu: usize) {
|
pub fn up(&self, mtu: usize) {
|
||||||
// ensure exclusive access (to avoid race with "down" call)
|
// ensure exclusive access (to avoid race with "up" call)
|
||||||
let peers = self.peers.write();
|
let mut enabled = self.enabled.write();
|
||||||
|
|
||||||
// set mtu
|
// set mtu
|
||||||
self.state.mtu.store(mtu, Ordering::Relaxed);
|
self.state.mtu.store(mtu, Ordering::Relaxed);
|
||||||
|
|
||||||
|
// check if already up
|
||||||
|
if *enabled {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// enable tranmission from router
|
// enable tranmission from router
|
||||||
self.router.up();
|
self.router.up();
|
||||||
|
|
||||||
// set all peers up (restarts timers)
|
// set all peers up (restarts timers)
|
||||||
for peer in peers.values() {
|
for peer in self.peers.write().values() {
|
||||||
peer.up();
|
peer.up();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*enabled = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear_peers(&self) {
|
pub fn clear_peers(&self) {
|
||||||
@@ -232,7 +286,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
pk,
|
pk,
|
||||||
wg: self.state.clone(),
|
wg: self.state.clone(),
|
||||||
walltime_last_handshake: Mutex::new(None),
|
walltime_last_handshake: Mutex::new(None),
|
||||||
last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
|
last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON),
|
||||||
handshake_queued: AtomicBool::new(false),
|
handshake_queued: AtomicBool::new(false),
|
||||||
queue: Mutex::new(self.state.queue.lock().clone()),
|
queue: Mutex::new(self.state.queue.lock().clone()),
|
||||||
rx_bytes: AtomicU64::new(0),
|
rx_bytes: AtomicU64::new(0),
|
||||||
@@ -246,24 +300,31 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
// form WireGuard peer
|
// form WireGuard peer
|
||||||
let peer = Peer { router, state };
|
let peer = Peer { router, state };
|
||||||
|
|
||||||
/* The need for dummy timers arises from the chicken-egg
|
|
||||||
* problem of the timer callbacks being able to set timers themselves.
|
|
||||||
*
|
|
||||||
* This is in fact the only place where the write lock is ever taken.
|
|
||||||
* TODO: Consider the ease of using atomic pointers instead.
|
|
||||||
*/
|
|
||||||
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
|
|
||||||
|
|
||||||
// finally, add the peer to the wireguard device
|
// finally, add the peer to the wireguard device
|
||||||
let mut peers = self.state.peers.write();
|
let mut peers = self.state.peers.write();
|
||||||
match peers.entry(*pk.as_bytes()) {
|
match peers.entry(*pk.as_bytes()) {
|
||||||
Entry::Occupied(_) => false,
|
Entry::Occupied(_) => false,
|
||||||
Entry::Vacant(vacancy) => {
|
Entry::Vacant(vacancy) => {
|
||||||
|
// check that the public key does not cause conflict with the private key of the device
|
||||||
let ok_pk = self.state.handshake.write().add(pk).is_ok();
|
let ok_pk = self.state.handshake.write().add(pk).is_ok();
|
||||||
if ok_pk {
|
if !ok_pk {
|
||||||
vacancy.insert(peer);
|
return false;
|
||||||
}
|
}
|
||||||
ok_pk
|
|
||||||
|
// prevent up/down while inserting
|
||||||
|
let enabled = self.enabled.read();
|
||||||
|
|
||||||
|
/* The need for dummy timers arises from the chicken-egg
|
||||||
|
* problem of the timer callbacks being able to set timers themselves.
|
||||||
|
*
|
||||||
|
* This is in fact the only place where the write lock is ever taken.
|
||||||
|
* TODO: Consider the ease of using atomic pointers instead.
|
||||||
|
*/
|
||||||
|
*peer.timers.write() = Timers::new(&self.runner, *enabled, peer.clone());
|
||||||
|
|
||||||
|
// insert into peer map (takes ownership and ensures that the peer is not dropped)
|
||||||
|
vacancy.insert(peer);
|
||||||
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -273,7 +334,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
///
|
///
|
||||||
/// Any previous reader thread is stopped by closing the previous reader,
|
/// Any previous reader thread is stopped by closing the previous reader,
|
||||||
/// which unblocks the thread and causes an error on reader.read
|
/// which unblocks the thread and causes an error on reader.read
|
||||||
pub fn add_reader(&self, reader: B::Reader) {
|
pub fn add_udp_reader(&self, reader: B::Reader) {
|
||||||
let wg = self.state.clone();
|
let wg = self.state.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let mut last_under_load =
|
let mut last_under_load =
|
||||||
@@ -350,7 +411,72 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
self.state.router.set_outbound_writer(writer);
|
self.state.router.set_outbound_writer(writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> {
|
pub fn add_tun_reader(&self, reader: T::Reader) {
|
||||||
|
fn worker<T: tun::Tun, B: udp::UDP>(wg: &Arc<WireguardInner<T, B>>, reader: T::Reader) {
|
||||||
|
loop {
|
||||||
|
// create vector big enough for any transport message (based on MTU)
|
||||||
|
let mtu = wg.mtu.load(Ordering::Relaxed);
|
||||||
|
let size = mtu + router::SIZE_MESSAGE_PREFIX + 1;
|
||||||
|
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
||||||
|
msg.resize(size, 0);
|
||||||
|
|
||||||
|
// read a new IP packet
|
||||||
|
let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
|
||||||
|
Ok(payload) => payload,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("TUN worker, failed to read from tun device: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
|
||||||
|
|
||||||
|
// TODO: start device down
|
||||||
|
if mtu == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncate padding
|
||||||
|
let padded = padding(payload, mtu);
|
||||||
|
log::trace!(
|
||||||
|
"TUN worker, payload length = {}, padded length = {}",
|
||||||
|
payload,
|
||||||
|
padded
|
||||||
|
);
|
||||||
|
msg.truncate(router::SIZE_MESSAGE_PREFIX + padded);
|
||||||
|
debug_assert!(padded <= mtu);
|
||||||
|
debug_assert_eq!(
|
||||||
|
if padded < mtu {
|
||||||
|
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
},
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
|
// crypt-key route
|
||||||
|
let e = wg.router.send(msg);
|
||||||
|
debug!("TUN worker, router returned {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// start a thread for every reader
|
||||||
|
let wg = self.state.clone();
|
||||||
|
|
||||||
|
// increment reader count
|
||||||
|
wg.tun_readers.increase();
|
||||||
|
|
||||||
|
// start worker
|
||||||
|
thread::spawn(move || {
|
||||||
|
worker(&wg, reader);
|
||||||
|
wg.tun_readers.decrease();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait(&self) -> WaitHandle {
|
||||||
|
self.state.tun_readers.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(writer: T::Writer) -> Wireguard<T, B> {
|
||||||
// create device state
|
// create device state
|
||||||
let mut rng = OsRng::new().unwrap();
|
let mut rng = OsRng::new().unwrap();
|
||||||
|
|
||||||
@@ -358,7 +484,8 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
|
||||||
|
|
||||||
let wg = Arc::new(WireguardInner {
|
let wg = Arc::new(WireguardInner {
|
||||||
start: Instant::now(),
|
enabled: RwLock::new(false),
|
||||||
|
tun_readers: WaitHandle::new(),
|
||||||
id: rng.gen(),
|
id: rng.gen(),
|
||||||
mtu: AtomicUsize::new(0),
|
mtu: AtomicUsize::new(0),
|
||||||
peers: RwLock::new(HashMap::new()),
|
peers: RwLock::new(HashMap::new()),
|
||||||
@@ -486,59 +613,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// start TUN read IO threads (multiple threads to support multi-queue interfaces)
|
|
||||||
debug_assert!(
|
|
||||||
readers.len() > 0,
|
|
||||||
"attempted to create WG device without TUN readers"
|
|
||||||
);
|
|
||||||
while let Some(reader) = readers.pop() {
|
|
||||||
let wg = wg.clone();
|
|
||||||
thread::spawn(move || loop {
|
|
||||||
// create vector big enough for any transport message (based on MTU)
|
|
||||||
let mtu = wg.mtu.load(Ordering::Relaxed);
|
|
||||||
let size = mtu + router::SIZE_MESSAGE_PREFIX;
|
|
||||||
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
|
|
||||||
msg.resize(size, 0);
|
|
||||||
|
|
||||||
// read a new IP packet
|
|
||||||
let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
|
|
||||||
Ok(payload) => payload,
|
|
||||||
Err(e) => {
|
|
||||||
debug!("TUN worker, failed to read from tun device: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
|
|
||||||
|
|
||||||
// TODO: start device down
|
|
||||||
if mtu == 0 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncate padding
|
|
||||||
let padded = padding(payload, mtu);
|
|
||||||
log::trace!(
|
|
||||||
"TUN worker, payload length = {}, padded length = {}",
|
|
||||||
payload,
|
|
||||||
padded
|
|
||||||
);
|
|
||||||
msg.truncate(router::SIZE_MESSAGE_PREFIX + padded);
|
|
||||||
debug_assert!(padded <= mtu);
|
|
||||||
debug_assert_eq!(
|
|
||||||
if padded < mtu {
|
|
||||||
(msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
},
|
|
||||||
0
|
|
||||||
);
|
|
||||||
|
|
||||||
// crypt-key route
|
|
||||||
let e = wg.router.send(msg);
|
|
||||||
debug!("TUN worker, router returned {:?}", e);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Wireguard {
|
Wireguard {
|
||||||
state: wg,
|
state: wg,
|
||||||
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
|
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
|
||||||
|
|||||||
Reference in New Issue
Block a user