Update configuration API

This commit is contained in:
Mathias Hall-Andersen
2019-11-17 19:52:40 +01:00
parent 05710c455f
commit 64707b0471
15 changed files with 124 additions and 107 deletions

View File

@@ -10,6 +10,9 @@ use bind::Owner;
/// The goal of the configuration interface is, among others, /// The goal of the configuration interface is, among others,
/// to hide the IO implementations (over which the WG device is generic), /// to hide the IO implementations (over which the WG device is generic),
/// from the configuration and UAPI code. /// from the configuration and UAPI code.
///
/// Furthermore it forms the simpler interface for embedding WireGuard in other applications,
/// and hides the complex types of the implementation from the host application.
/// Describes a snapshot of the state of a peer /// Describes a snapshot of the state of a peer
pub struct PeerState { pub struct PeerState {
@@ -24,6 +27,7 @@ pub struct PeerState {
pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> { pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> {
wireguard: Wireguard<T, B>, wireguard: Wireguard<T, B>,
fwmark: Mutex<Option<u32>>,
network: Mutex<Option<B::Owner>>, network: Mutex<Option<B::Owner>>,
} }
@@ -31,6 +35,7 @@ impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> {
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig { WireguardConfig {
wireguard: wg, wireguard: wg,
fwmark: Mutex::new(None),
network: Mutex::new(None), network: Mutex::new(None),
} }
} }
@@ -59,7 +64,7 @@ pub trait Configuration {
/// An integer indicating the protocol version /// An integer indicating the protocol version
fn get_protocol_version(&self) -> usize; fn get_protocol_version(&self) -> usize;
fn set_listen_port(&self, port: Option<u16>) -> Option<ConfigError>; fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError>;
/// Set the firewall mark (or similar, depending on platform) /// Set the firewall mark (or similar, depending on platform)
/// ///
@@ -71,7 +76,7 @@ pub trait Configuration {
/// ///
/// An error if this operation is not supported by the underlying /// An error if this operation is not supported by the underlying
/// "bind" implementation. /// "bind" implementation.
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>; fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError>;
/// Removes all peers from the device /// Removes all peers from the device
fn replace_peers(&self); fn replace_peers(&self);
@@ -110,7 +115,7 @@ pub trait Configuration {
/// # Returns /// # Returns
/// ///
/// An error if no such peer exists /// An error if no such peer exists
fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError>; fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]);
/// Update the endpoint of the /// Update the endpoint of the
/// ///
@@ -118,7 +123,7 @@ pub trait Configuration {
/// ///
/// - `peer': The public key of the peer /// - `peer': The public key of the peer
/// - `psk` /// - `psk`
fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) -> Option<ConfigError>; fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr);
/// Update the endpoint of the /// Update the endpoint of the
/// ///
@@ -126,8 +131,7 @@ pub trait Configuration {
/// ///
/// - `peer': The public key of the peer /// - `peer': The public key of the peer
/// - `psk` /// - `psk`
fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64);
-> Option<ConfigError>;
/// Remove all allowed IPs from the peer /// Remove all allowed IPs from the peer
/// ///
@@ -138,7 +142,7 @@ pub trait Configuration {
/// # Returns /// # Returns
/// ///
/// An error if no such peer exists /// An error if no such peer exists
fn replace_allowed_ips(&self, peer: &PublicKey) -> Option<ConfigError>; fn replace_allowed_ips(&self, peer: &PublicKey);
/// Add a new allowed subnet to the peer /// Add a new allowed subnet to the peer
/// ///
@@ -151,12 +155,7 @@ pub trait Configuration {
/// # Returns /// # Returns
/// ///
/// An error if the peer does not exist /// An error if the peer does not exist
/// fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32);
/// # Note:
///
/// The API must itself sanitize the (ip, masklen) set:
/// The ip should be masked to remove any set bits right of the first "masklen" bits.
fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
fn get_listen_port(&self) -> Option<u16>; fn get_listen_port(&self) -> Option<u16>;
@@ -191,10 +190,14 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
} }
fn get_listen_port(&self) -> Option<u16> { fn get_listen_port(&self) -> Option<u16> {
self.network.lock().as_ref().map(|bind| bind.get_port()) let bind = self.network.lock();
log::trace!("Config, Get listen port, bound: {}", bind.is_some());
bind.as_ref().map(|bind| bind.get_port())
} }
fn set_listen_port(&self, port: Option<u16>) -> Option<ConfigError> { fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port);
let mut bind = self.network.lock(); let mut bind = self.network.lock();
// close the current listener // close the current listener
@@ -203,13 +206,16 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
// bind to new port // bind to new port
if let Some(port) = port { if let Some(port) = port {
// create new listener // create new listener
let (mut readers, writer, owner) = match B::bind(port) { let (mut readers, writer, mut owner) = match B::bind(port) {
Ok(r) => r, Ok(r) => r,
Err(_) => { Err(_) => {
return Some(ConfigError::FailedToBind); return Err(ConfigError::FailedToBind);
} }
}; };
// set fwmark
let _ = owner.set_fwmark(*self.fwmark.lock()); // TODO: handle
// add readers/writer to wireguard // add readers/writer to wireguard
self.wireguard.set_writer(writer); self.wireguard.set_writer(writer);
while let Some(reader) = readers.pop() { while let Some(reader) = readers.pop() {
@@ -220,16 +226,18 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
*bind = Some(owner); *bind = Some(owner);
} }
None Ok(())
} }
fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> { fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {
log::trace!("Config, Set fwmark: {:?}", mark);
match self.network.lock().as_mut() { match self.network.lock().as_mut() {
Some(bind) => { Some(bind) => {
bind.set_fwmark(mark).unwrap(); // TODO: handle bind.set_fwmark(mark).unwrap(); // TODO: handle
None Ok(())
} }
None => Some(ConfigError::NotListening), None => Err(ConfigError::NotListening),
} }
} }
@@ -242,59 +250,34 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
} }
fn add_peer(&self, peer: &PublicKey) -> bool { fn add_peer(&self, peer: &PublicKey) -> bool {
self.wireguard.add_peer(*peer); self.wireguard.add_peer(*peer)
false
} }
fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError> { fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) {
if self.wireguard.set_psk(*peer, psk) { self.wireguard.set_psk(*peer, psk);
None }
} else {
Some(ConfigError::NoSuchPeer) fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) {
if let Some(peer) = self.wireguard.lookup_peer(peer) {
peer.router.set_endpoint(B::Endpoint::from_address(addr));
} }
} }
fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) -> Option<ConfigError> { fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) {
match self.wireguard.lookup_peer(peer) { if let Some(peer) = self.wireguard.lookup_peer(peer) {
Some(peer) => { peer.set_persistent_keepalive_interval(secs);
peer.router.set_endpoint(B::Endpoint::from_address(addr));
None
}
None => Some(ConfigError::NoSuchPeer),
} }
} }
fn set_persistent_keepalive_interval( fn replace_allowed_ips(&self, peer: &PublicKey) {
&self, if let Some(peer) = self.wireguard.lookup_peer(peer) {
peer: &PublicKey, peer.router.remove_allowed_ips();
secs: u64,
) -> Option<ConfigError> {
match self.wireguard.lookup_peer(peer) {
Some(peer) => {
peer.set_persistent_keepalive_interval(secs);
None
}
None => Some(ConfigError::NoSuchPeer),
} }
} }
fn replace_allowed_ips(&self, peer: &PublicKey) -> Option<ConfigError> { fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) {
match self.wireguard.lookup_peer(peer) { if let Some(peer) = self.wireguard.lookup_peer(peer) {
Some(peer) => { peer.router.add_allowed_ip(ip, masklen);
peer.router.remove_allowed_ips();
None
}
None => Some(ConfigError::NoSuchPeer),
}
}
fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError> {
match self.wireguard.lookup_peer(peer) {
Some(peer) => {
peer.router.add_allowed_ip(ip, masklen);
None
}
None => Some(ConfigError::NoSuchPeer),
} }
} }

View File

@@ -1,10 +1,7 @@
use hex::FromHex;
use subtle::ConstantTimeEq;
use log; use log;
use std::io;
use super::Configuration; use super::Configuration;
use std::io;
pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> { pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> {
let mut write = |key: &'static str, value: String| { let mut write = |key: &'static str, value: String| {

View File

@@ -55,10 +55,13 @@ pub fn handle<S: Read + Write, C: Configuration>(stream: &mut S, config: &C) {
loop { loop {
let ln = readline(stream)?; let ln = readline(stream)?;
if ln == "" { if ln == "" {
// end of transcript
parser.parse_line("", "")?; // flush final peer
break Ok(()); break Ok(());
} else {
let (k, v) = keypair(ln.as_str())?;
parser.parse_line(k, v)?;
}; };
let (k, v) = keypair(ln.as_str())?;
parser.parse_line(k, v)?;
} }
} }
_ => Err(ConfigError::InvalidOperation), _ => Err(ConfigError::InvalidOperation),

View File

@@ -109,7 +109,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set listen port // opt: set listen port
"listen_port" => match value.parse() { "listen_port" => match value.parse() {
Ok(port) => { Ok(port) => {
self.config.set_listen_port(Some(port)); self.config.set_listen_port(Some(port))?;
Ok(()) Ok(())
} }
Err(_) => Err(ConfigError::InvalidPortNumber), Err(_) => Err(ConfigError::InvalidPortNumber),
@@ -119,7 +119,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
"fwmark" => match value.parse() { "fwmark" => match value.parse() {
Ok(fwmark) => { Ok(fwmark) => {
self.config self.config
.set_fwmark(if fwmark == 0 { None } else { Some(fwmark) }); .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) })?;
Ok(()) Ok(())
} }
Err(_) => Err(ConfigError::InvalidFwmark), Err(_) => Err(ConfigError::InvalidFwmark),
@@ -142,6 +142,9 @@ impl<'a, C: Configuration> LineParser<'a, C> {
Ok(()) Ok(())
} }
// ignore (end of transcript)
"" => Ok(()),
// unknown key // unknown key
_ => Err(ConfigError::InvalidKey), _ => Err(ConfigError::InvalidKey),
}, },
@@ -227,6 +230,12 @@ impl<'a, C: Configuration> LineParser<'a, C> {
} }
} }
// flush (used at end of transcipt)
"" => {
flush_peer(self.config, &peer);
Ok(())
}
// unknown key // unknown key
_ => Err(ConfigError::InvalidKey), _ => Err(ConfigError::InvalidKey),
}, },

View File

@@ -10,24 +10,37 @@ mod configuration;
mod platform; mod platform;
mod wireguard; mod wireguard;
use log;
use std::env;
use platform::tun::PlatformTun; use platform::tun::PlatformTun;
use platform::uapi::PlatformUAPI; use platform::uapi::{BindUAPI, PlatformUAPI};
use platform::*; use platform::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
fn main() { fn main() {
let name = "wg0"; let mut name = String::new();
let mut foreground = false;
for arg in env::args() {
if arg == "--foreground" || arg == "-f" {
foreground = true;
} else {
name = arg;
}
}
if name == "" {
return;
}
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
// create UAPI socket // create UAPI socket
let uapi = plt::UAPI::bind(name).unwrap(); let uapi = plt::UAPI::bind(name.as_str()).unwrap();
// create TUN device // create TUN device
let (readers, writer, mtu) = plt::Tun::create(name).unwrap(); let (readers, writer, mtu) = plt::Tun::create(name.as_str()).unwrap();
// create WireGuard device // create WireGuard device
let wg: wireguard::Wireguard<plt::Tun, plt::Bind> = let wg: wireguard::Wireguard<plt::Tun, plt::Bind> =
@@ -36,9 +49,12 @@ fn main() {
// wrap in configuration interface and start UAPI server // wrap in configuration interface and start UAPI server
let cfg = configuration::WireguardConfig::new(wg); let cfg = configuration::WireguardConfig::new(wg);
loop { loop {
let mut stream = uapi.accept().unwrap(); match uapi.connect() {
configuration::uapi::handle(&mut stream.0, &cfg); Ok(mut stream) => configuration::uapi::handle(&mut stream, &cfg),
Err(err) => {
log::info!("UAPI error: {:}", err);
break;
}
}
} }
thread::sleep(Duration::from_secs(600));
} }

View File

@@ -32,7 +32,7 @@ pub trait Owner: Send {
fn get_fwmark(&self) -> Option<u32>; fn get_fwmark(&self) -> Option<u32>;
fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error>; fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>;
} }
/// On some platforms the application can itself bind to a socket. /// On some platforms the application can itself bind to a socket.

View File

@@ -203,8 +203,8 @@ impl Bind for PairBind {
impl Owner for VoidOwner { impl Owner for VoidOwner {
type Error = BindError; type Error = BindError;
fn set_fwmark(&mut self, _value: Option<u32>) -> Option<Self::Error> { fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> {
None Ok(())
} }
fn get_port(&self) -> u16 { fn get_port(&self) -> u16 {

View File

@@ -24,7 +24,7 @@ impl BindUAPI for UnixListener {
type Stream = UnixStream; type Stream = UnixStream;
type Error = io::Error; type Error = io::Error;
fn accept(&self) -> Result<UnixStream, io::Error> { fn connect(&self) -> Result<UnixStream, io::Error> {
let (stream, _) = self.accept()?; let (stream, _) = self.accept()?;
Ok(stream) Ok(stream)
} }

View File

@@ -43,15 +43,15 @@ impl Owner for LinuxOwner {
type Error = io::Error; type Error = io::Error;
fn get_port(&self) -> u16 { fn get_port(&self) -> u16 {
1337 self.0.local_addr().unwrap().port() // todo handle
} }
fn get_fwmark(&self) -> Option<u32> { fn get_fwmark(&self) -> Option<u32> {
None None
} }
fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error> { fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> {
None Ok(())
} }
} }

View File

@@ -5,7 +5,7 @@ pub trait BindUAPI {
type Stream: Read + Write; type Stream: Read + Write;
type Error: Error; type Error: Error;
fn accept(&self) -> Result<Self::Stream, Self::Error>; fn connect(&self) -> Result<Self::Stream, Self::Error>;
} }
pub trait PlatformUAPI { pub trait PlatformUAPI {

View File

@@ -469,6 +469,10 @@ mod tests {
(pk1, dev1, pk2, dev2) (pk1, dev1, pk2, dev2)
} }
fn wait() {
thread::sleep(Duration::from_millis(20));
}
/* Test longest possible handshake interaction (7 messages): /* Test longest possible handshake interaction (7 messages):
* *
* 1. I -> R (initation) * 1. I -> R (initation)
@@ -502,8 +506,8 @@ mod tests {
_ => panic!("unexpected response"), _ => panic!("unexpected response"),
} }
// avoid initation flood // avoid initation flood detection
thread::sleep(Duration::from_millis(20)); wait();
// 3. device-1 : create second initation // 3. device-1 : create second initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
@@ -529,8 +533,8 @@ mod tests {
_ => panic!("unexpected response"), _ => panic!("unexpected response"),
} }
// avoid initation flood // avoid initation flood detection
thread::sleep(Duration::from_millis(20)); wait();
// 6. device-1 : create third initation // 6. device-1 : create third initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
@@ -600,8 +604,8 @@ mod tests {
dev1.release(ks_i.send.id); dev1.release(ks_i.send.id);
dev2.release(ks_r.send.id); dev2.release(ks_r.send.id);
// to avoid flood detection // avoid initation flood detection
thread::sleep(Duration::from_millis(20)); wait();
} }
dev1.remove(pk2).unwrap(); dev1.remove(pk2).unwrap();

View File

@@ -7,7 +7,6 @@ use generic_array::typenum::U32;
use generic_array::GenericArray; use generic_array::GenericArray;
use x25519_dalek::PublicKey; use x25519_dalek::PublicKey;
use x25519_dalek::SharedSecret;
use x25519_dalek::StaticSecret; use x25519_dalek::StaticSecret;
use clear_on_drop::clear::Clear; use clear_on_drop::clear::Clear;

View File

@@ -1,4 +1,3 @@
use super::constants::*;
use super::router; use super::router;
use super::timers::{Events, Timers}; use super::timers::{Events, Timers};
use super::HandshakeJob; use super::HandshakeJob;
@@ -9,7 +8,7 @@ use super::wireguard::WireguardInner;
use std::fmt; use std::fmt;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Instant, SystemTime}; use std::time::{Instant, SystemTime};

View File

@@ -63,7 +63,7 @@ impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> {
// take a write lock preventing simultaneous "stop_timers" call // take a write lock preventing simultaneous "stop_timers" call
let mut timers = self.timers_mut(); let mut timers = self.timers_mut();
// set flag to renable timer events // set flag to reenable timer events
if timers.enabled { if timers.enabled {
return; return;
} }

View File

@@ -18,6 +18,7 @@ use std::sync::Arc;
use std::thread; use std::thread;
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use std::collections::hash_map::Entry;
use std::collections::HashMap; use std::collections::HashMap;
use log::debug; use log::debug;
@@ -208,9 +209,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
self.state.handshake.read().get_psk(pk).ok() self.state.handshake.read().get_psk(pk).ok()
} }
pub fn add_peer(&self, pk: PublicKey) { pub fn add_peer(&self, pk: PublicKey) -> bool {
if self.state.peers.read().contains_key(pk.as_bytes()) { if self.state.peers.read().contains_key(pk.as_bytes()) {
return; return false;
} }
let mut rng = OsRng::new().unwrap(); let mut rng = OsRng::new().unwrap();
@@ -243,10 +244,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// finally, add the peer to the wireguard device // finally, add the peer to the wireguard device
let mut peers = self.state.peers.write(); let mut peers = self.state.peers.write();
peers.entry(*pk.as_bytes()).or_insert(peer); match peers.entry(*pk.as_bytes()) {
Entry::Occupied(_) => false,
// add to the handshake device Entry::Vacant(vacancy) => {
self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface let ok_pk = self.state.handshake.write().add(pk).is_ok();
if ok_pk {
vacancy.insert(peer);
}
ok_pk
}
}
} }
/// Begin consuming messages from the reader. /// Begin consuming messages from the reader.