Enable up/down from configuration interface

This commit is contained in:
Mathias Hall-Andersen
2019-11-25 13:33:00 +01:00
parent 3bff078e3f
commit f228b6f98b
9 changed files with 185 additions and 90 deletions

View File

@@ -1,7 +1,8 @@
use spin::Mutex;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
use super::udp::Owner; use super::udp::Owner;
@@ -23,27 +24,57 @@ pub struct PeerState {
pub allowed_ips: Vec<(IpAddr, u32)>, pub allowed_ips: Vec<(IpAddr, u32)>,
pub endpoint: Option<SocketAddr>, pub endpoint: Option<SocketAddr>,
pub persistent_keepalive_interval: u64, pub persistent_keepalive_interval: u64,
pub preshared_key: [u8; 32], // 0^32 is the "default value" pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk)
} }
pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP> { pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>);
struct State<B: udp::PlatformUDP> {
port: u16,
bind: Option<B::Owner>,
fwmark: Option<u32>,
}
struct Inner<T: tun::Tun, B: udp::PlatformUDP> {
wireguard: Wireguard<T, B>, wireguard: Wireguard<T, B>,
fwmark: Mutex<Option<u32>>, port: u16,
network: Mutex<Option<B::Owner>>, bind: Option<B::Owner>,
fwmark: Option<u32>,
}
impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
fn lock(&self) -> MutexGuard<Inner<T, B>> {
self.0.lock().unwrap()
}
} }
impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> { impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig { WireguardConfig(Arc::new(Mutex::new(Inner {
wireguard: wg, wireguard: wg,
fwmark: Mutex::new(None), port: 0,
network: Mutex::new(None), bind: None,
} fwmark: None,
})))
}
}
impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireguardConfig<T, B> {
fn clone(&self) -> Self {
WireguardConfig(self.0.clone())
} }
} }
/// Exposed configuration interface /// Exposed configuration interface
pub trait Configuration { pub trait Configuration {
fn up(&self, mtu: usize);
fn down(&self);
fn start_listener(&self) -> Result<(), ConfigError>;
fn stop_listener(&self) -> Result<(), ConfigError>;
/// Updates the private key of the device /// Updates the private key of the device
/// ///
/// # Arguments /// # Arguments
@@ -65,7 +96,7 @@ pub trait Configuration {
/// An integer indicating the protocol version /// An integer indicating the protocol version
fn get_protocol_version(&self) -> usize; fn get_protocol_version(&self) -> usize;
fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError>; fn set_listen_port(&self, port: u16) -> Result<(), ConfigError>;
/// Set the firewall mark (or similar, depending on platform) /// Set the firewall mark (or similar, depending on platform)
/// ///
@@ -171,19 +202,24 @@ pub trait Configuration {
} }
impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
fn up(&self, mtu: usize) {
self.lock().wireguard.up(mtu);
}
fn down(&self) {
self.lock().wireguard.down();
}
fn get_fwmark(&self) -> Option<u32> { fn get_fwmark(&self) -> Option<u32> {
self.network self.lock().bind.as_ref().and_then(|own| own.get_fwmark())
.lock()
.as_ref()
.and_then(|bind| bind.get_fwmark())
} }
fn set_private_key(&self, sk: Option<StaticSecret>) { fn set_private_key(&self, sk: Option<StaticSecret>) {
self.wireguard.set_key(sk) self.lock().wireguard.set_key(sk)
} }
fn get_private_key(&self) -> Option<StaticSecret> { fn get_private_key(&self) -> Option<StaticSecret> {
self.wireguard.get_sk() self.lock().wireguard.get_sk()
} }
fn get_protocol_version(&self) -> usize { fn get_protocol_version(&self) -> usize {
@@ -191,49 +227,75 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
} }
fn get_listen_port(&self) -> Option<u16> { fn get_listen_port(&self) -> Option<u16> {
let bind = self.network.lock(); let st = self.lock();
log::trace!("Config, Get listen port, bound: {}", bind.is_some()); log::trace!("Config, Get listen port, bound: {}", st.bind.is_some());
bind.as_ref().map(|bind| bind.get_port()) st.bind.as_ref().map(|bind| bind.get_port())
} }
fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError> { fn stop_listener(&self) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port); self.lock().bind = None;
Ok(())
}
let mut bind = self.network.lock(); fn start_listener(&self) -> Result<(), ConfigError> {
let mut cfg = self.lock();
// close the current listener // check if already listening
*bind = None; if cfg.bind.is_some() {
return Ok(());
// bind to new port
if let Some(port) = port {
// create new listener
let (mut readers, writer, mut owner) = match B::bind(port) {
Ok(r) => r,
Err(_) => {
return Err(ConfigError::FailedToBind);
}
};
// set fwmark
let _ = owner.set_fwmark(*self.fwmark.lock()); // TODO: handle
// add readers/writer to wireguard
self.wireguard.set_writer(writer);
while let Some(reader) = readers.pop() {
self.wireguard.add_reader(reader);
}
// create new UDP state
*bind = Some(owner);
} }
// create new listener
let (mut readers, writer, mut owner) = match B::bind(cfg.port) {
Ok(r) => r,
Err(_) => {
return Err(ConfigError::FailedToBind);
}
};
// set fwmark
let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle
// set writer on wireguard
cfg.wireguard.set_writer(writer);
// add readers
while let Some(reader) = readers.pop() {
cfg.wireguard.add_reader(reader);
}
// create new UDP state
cfg.bind = Some(owner);
Ok(()) Ok(())
} }
fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port);
// update port
let listen: bool = {
let mut cfg = self.lock();
cfg.port = port;
if cfg.bind.is_some() {
cfg.bind = None;
true
} else {
false
}
};
// restart listener if bound
if listen {
self.start_listener()
} else {
Ok(())
}
}
fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> { fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {
log::trace!("Config, Set fwmark: {:?}", mark); log::trace!("Config, Set fwmark: {:?}", mark);
match self.network.lock().as_mut() { match self.lock().bind.as_mut() {
Some(bind) => { Some(bind) => {
bind.set_fwmark(mark).unwrap(); // TODO: handle bind.set_fwmark(mark).unwrap(); // TODO: handle
Ok(()) Ok(())
@@ -243,47 +305,48 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
} }
fn replace_peers(&self) { fn replace_peers(&self) {
self.wireguard.clear_peers(); self.lock().wireguard.clear_peers();
} }
fn remove_peer(&self, peer: &PublicKey) { fn remove_peer(&self, peer: &PublicKey) {
self.wireguard.remove_peer(peer); self.lock().wireguard.remove_peer(peer);
} }
fn add_peer(&self, peer: &PublicKey) -> bool { fn add_peer(&self, peer: &PublicKey) -> bool {
self.wireguard.add_peer(*peer) self.lock().wireguard.add_peer(*peer)
} }
fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) { fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) {
self.wireguard.set_psk(*peer, psk); self.lock().wireguard.set_psk(*peer, psk);
} }
fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) {
if let Some(peer) = self.wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.set_endpoint(B::Endpoint::from_address(addr)); peer.router.set_endpoint(B::Endpoint::from_address(addr));
} }
} }
fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) {
if let Some(peer) = self.wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.set_persistent_keepalive_interval(secs); peer.set_persistent_keepalive_interval(secs);
} }
} }
fn replace_allowed_ips(&self, peer: &PublicKey) { fn replace_allowed_ips(&self, peer: &PublicKey) {
if let Some(peer) = self.wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.remove_allowed_ips(); peer.router.remove_allowed_ips();
} }
} }
fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) {
if let Some(peer) = self.wireguard.lookup_peer(peer) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) {
peer.router.add_allowed_ip(ip, masklen); peer.router.add_allowed_ip(ip, masklen);
} }
} }
fn get_peers(&self) -> Vec<PeerState> { fn get_peers(&self) -> Vec<PeerState> {
let peers = self.wireguard.list_peers(); let cfg = self.lock();
let peers = cfg.wireguard.list_peers();
let mut state = Vec::with_capacity(peers.len()); let mut state = Vec::with_capacity(peers.len());
for p in peers { for p in peers {
@@ -295,7 +358,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
Some((duration.as_secs(), duration.subsec_nanos() as u64)) Some((duration.as_secs(), duration.subsec_nanos() as u64))
}); });
if let Some(psk) = self.wireguard.get_psk(&p.pk) { if let Some(psk) = cfg.wireguard.get_psk(&p.pk) {
// extract state into PeerState // extract state into PeerState
state.push(PeerState { state.push(PeerState {
preshared_key: psk, preshared_key: psk,

View File

@@ -116,7 +116,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set listen port // opt: set listen port
"listen_port" => match value.parse() { "listen_port" => match value.parse() {
Ok(port) => { Ok(port) => {
self.config.set_listen_port(Some(port))?; self.config.set_listen_port(port)?;
Ok(()) Ok(())
} }
Err(_) => Err(ConfigError::InvalidPortNumber), Err(_) => Err(ConfigError::InvalidPortNumber),

View File

@@ -4,6 +4,7 @@
use log; use log;
use daemonize::Daemonize; use daemonize::Daemonize;
use std::env; use std::env;
use std::process::exit; use std::process::exit;
use std::thread; use std::thread;
@@ -12,7 +13,9 @@ mod configuration;
mod platform; mod platform;
mod wireguard; mod wireguard;
use platform::tun::PlatformTun; use configuration::Configuration;
use platform::tun::{PlatformTun, Status};
use platform::uapi::{BindUAPI, PlatformUAPI}; use platform::uapi::{BindUAPI, PlatformUAPI};
use platform::*; use platform::*;
@@ -81,34 +84,56 @@ fn main() {
// create WireGuard device // create WireGuard device
let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer); let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer);
wg.set_mtu(1420); // wrap in configuration interface
let cfg = configuration::WireguardConfig::new(wg);
// start Tun event thread // start Tun event thread
/*
{ {
let wg = wg.clone(); let cfg = cfg.clone();
let mut status = status; let mut status = status;
thread::spawn(move || loop { thread::spawn(move || loop {
match status.event() { match status.event() {
Err(_) => break, Err(e) => {
Ok(tun::TunEvent::Up(mtu)) => { log::info!("Tun device error {}", e);
wg.mtu.store(mtu, Ordering::Relaxed); exit(0);
}
Ok(tun::TunEvent::Up(mtu)) => {
log::info!("Tun up (mtu = {})", mtu);
// bring the wireguard device up
cfg.up(mtu);
// start listening on UDP
let _ = cfg
.start_listener()
.map_err(|e| log::info!("Failed to start UDP listener: {}", e));
}
Ok(tun::TunEvent::Down) => {
log::info!("Tun down");
// set wireguard device down
cfg.down();
// close UDP listener
let _ = cfg
.stop_listener()
.map_err(|e| log::info!("Failed to stop UDP listener {}", e));
} }
Ok(tun::TunEvent::Down) => {}
} }
}); });
} }
*/
// handle TUN updates up/down // start UAPI server
// wrap in configuration interface and start UAPI server
let cfg = configuration::WireguardConfig::new(wg);
loop { loop {
match uapi.connect() { match uapi.connect() {
Ok(mut stream) => configuration::uapi::handle(&mut stream, &cfg), Ok(mut stream) => {
let cfg = cfg.clone();
thread::spawn(move || {
configuration::uapi::handle(&mut stream, &cfg);
});
}
Err(err) => { Err(err) => {
log::info!("UAPI error: {:}", err); log::info!("UAPI error: {}", err);
break; break;
} }
} }

View File

@@ -150,7 +150,6 @@ impl Status for TunStatus {
impl Tun for TunTest { impl Tun for TunTest {
type Writer = TunWriter; type Writer = TunWriter;
type Reader = TunReader; type Reader = TunReader;
type Status = TunStatus;
type Error = TunError; type Error = TunError;
} }
@@ -167,7 +166,7 @@ impl TunFakeIO {
} }
impl TunTest { impl TunTest {
pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) { pub fn create(store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) {
let (tx1, rx1) = if store { let (tx1, rx1) = if store {
sync_channel(32) sync_channel(32)
} else { } else {
@@ -200,6 +199,8 @@ impl TunTest {
} }
impl PlatformTun for TunTest { impl PlatformTun for TunTest {
type Status = TunStatus;
fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> { fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> {
Err(TunError::Disconnected) Err(TunError::Disconnected)
} }

View File

@@ -87,10 +87,12 @@ impl Reader for LinuxTunReader {
type Error = LinuxTunError; type Error = LinuxTunError;
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> { fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
/*
debug_assert!( debug_assert!(
offset < buf.len(), offset < buf.len(),
"There is no space for the body of the read" "There is no space for the body of the read"
); );
*/
let n: isize = let n: isize =
unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
if n < 0 { if n < 0 {
@@ -132,10 +134,11 @@ impl Tun for LinuxTun {
type Error = LinuxTunError; type Error = LinuxTunError;
type Reader = LinuxTunReader; type Reader = LinuxTunReader;
type Writer = LinuxTunWriter; type Writer = LinuxTunWriter;
type Status = LinuxTunStatus;
} }
impl PlatformTun for LinuxTun { impl PlatformTun for LinuxTun {
type Status = LinuxTunStatus;
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> { fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> {
// construct request struct // construct request struct
let mut req = Ifreq { let mut req = Ifreq {

View File

@@ -51,11 +51,12 @@ pub trait Reader: Send + 'static {
pub trait Tun: Send + Sync + 'static { pub trait Tun: Send + Sync + 'static {
type Writer: Writer; type Writer: Writer;
type Reader: Reader; type Reader: Reader;
type Status: Status;
type Error: Error; type Error: Error;
} }
/// On some platforms the application can create the TUN device itself. /// On some platforms the application can create the TUN device itself.
pub trait PlatformTun: Tun { pub trait PlatformTun: Tun {
type Status: Status;
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error>; fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error>;
} }

View File

@@ -139,7 +139,7 @@ mod tests {
} }
// create device // create device
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false);
let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer); Device::new(num_cpus::get(), tun_writer);
@@ -169,7 +169,7 @@ mod tests {
init(); init();
// create device // create device
let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false);
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new()); router.set_outbound_writer(dummy::VoidBind::new());
@@ -315,8 +315,8 @@ mod tests {
dummy::PairBind::pair(); dummy::PairBind::pair();
// create matching device // create matching device
let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false); let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false);
let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false); let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false);
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1); router1.set_outbound_writer(bind_writer1);

View File

@@ -84,17 +84,17 @@ fn test_pure_wireguard() {
// create WG instances for dummy TUN devices // create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true); let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader1], tun_writer1); Wireguard::new(vec![tun_reader1], tun_writer1);
wg1.set_mtu(1500); wg1.up(1500);
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true); let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader2], tun_writer2); Wireguard::new(vec![tun_reader2], tun_writer2);
wg2.set_mtu(1500); wg2.up(1500);
// create pair bind to connect the interfaces "over the internet" // create pair bind to connect the interfaces "over the internet"

View File

@@ -147,6 +147,9 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
// ensure exclusive access (to avoid race with "up" call) // ensure exclusive access (to avoid race with "up" call)
let peers = self.peers.write(); let peers = self.peers.write();
// set mtu
self.state.mtu.store(0, Ordering::Relaxed);
// avoid tranmission from router // avoid tranmission from router
self.router.down(); self.router.down();
@@ -158,10 +161,13 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
/// Brings the WireGuard device up. /// Brings the WireGuard device up.
/// Usually called when the associated interface is brought up. /// Usually called when the associated interface is brought up.
pub fn up(&self) { pub fn up(&self, mtu: usize) {
// ensure exclusive access (to avoid race with "down" call) // ensure exclusive access (to avoid race with "down" call)
let peers = self.peers.write(); let peers = self.peers.write();
// set mtu
self.state.mtu.store(mtu, Ordering::Relaxed);
// enable tranmission from router // enable tranmission from router
self.router.up(); self.router.up();
@@ -338,10 +344,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
}); });
} }
pub fn set_mtu(&self, mtu: usize) {
self.mtu.store(mtu, Ordering::Relaxed);
}
pub fn set_writer(&self, writer: B::Writer) { pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer // TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone()); *self.state.send.write() = Some(writer.clone());