Update UAPI semantics for remove

This commit is contained in:
Mathias Hall-Andersen
2019-11-15 15:32:36 +01:00
parent a85725eede
commit 05710c455f
18 changed files with 289 additions and 141 deletions

View File

@@ -19,16 +19,16 @@ pub struct PeerState {
pub last_handshake_time_nsec: u64, pub last_handshake_time_nsec: u64,
pub public_key: PublicKey, pub public_key: PublicKey,
pub allowed_ips: Vec<(IpAddr, u32)>, pub allowed_ips: Vec<(IpAddr, u32)>,
pub preshared_key: Option<[u8; 32]>, pub preshared_key: [u8; 32], // 0^32 is the "default value"
} }
pub struct WireguardConfig<T: tun::Tun, B: bind::Platform> { pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> {
wireguard: Wireguard<T, B>, wireguard: Wireguard<T, B>,
network: Mutex<Option<B::Owner>>, network: Mutex<Option<B::Owner>>,
} }
impl<T: tun::Tun, B: bind::Platform> WireguardConfig<T, B> { impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> {
fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig { WireguardConfig {
wireguard: wg, wireguard: wg,
network: Mutex::new(None), network: Mutex::new(None),
@@ -110,7 +110,7 @@ pub trait Configuration {
/// # Returns /// # Returns
/// ///
/// An error if no such peer exists /// An error if no such peer exists
fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>; fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError>;
/// Update the endpoint of the /// Update the endpoint of the
/// ///
@@ -170,7 +170,7 @@ pub trait Configuration {
fn get_fwmark(&self) -> Option<u32>; fn get_fwmark(&self) -> Option<u32>;
} }
impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> { impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B> {
fn get_fwmark(&self) -> Option<u32> { fn get_fwmark(&self) -> Option<u32> {
self.network self.network
.lock() .lock()
@@ -246,7 +246,7 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
false false
} }
fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError> { fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError> {
if self.wireguard.set_psk(*peer, psk) { if self.wireguard.set_psk(*peer, psk) {
None None
} else { } else {
@@ -308,16 +308,18 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0)); // any time before epoch is mapped to epoch .unwrap_or(Duration::from_secs(0)); // any time before epoch is mapped to epoch
// extract state into PeerState if let Some(psk) = self.wireguard.get_psk(&p.pk) {
state.push(PeerState { // extract state into PeerState
preshared_key: self.wireguard.get_psk(&p.pk), state.push(PeerState {
rx_bytes: p.rx_bytes.load(Ordering::Relaxed), preshared_key: psk,
tx_bytes: p.tx_bytes.load(Ordering::Relaxed), rx_bytes: p.rx_bytes.load(Ordering::Relaxed),
allowed_ips: p.router.list_allowed_ips(), tx_bytes: p.tx_bytes.load(Ordering::Relaxed),
last_handshake_time_nsec: last_handshake.subsec_nanos() as u64, allowed_ips: p.router.list_allowed_ips(),
last_handshake_time_sec: last_handshake.as_secs(), last_handshake_time_nsec: last_handshake.subsec_nanos() as u64,
public_key: p.pk, last_handshake_time_sec: last_handshake.as_secs(),
}) public_key: p.pk,
})
}
} }
state state
} }

View File

@@ -1,6 +1,6 @@
mod config; mod config;
mod error; mod error;
mod uapi; pub mod uapi;
use super::platform::Endpoint; use super::platform::Endpoint;
use super::platform::{bind, tun}; use super::platform::{bind, tun};

View File

@@ -1,6 +1,8 @@
use hex::FromHex; use hex::FromHex;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use log;
use super::Configuration; use super::Configuration;
use std::io; use std::io;
@@ -8,9 +10,11 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
let mut write = |key: &'static str, value: String| { let mut write = |key: &'static str, value: String| {
debug_assert!(value.is_ascii()); debug_assert!(value.is_ascii());
debug_assert!(key.is_ascii()); debug_assert!(key.is_ascii());
log::trace!("UAPI: return : {} = {}", key, value);
writer.write(key.as_ref())?; writer.write(key.as_ref())?;
writer.write(b"=")?; writer.write(b"=")?;
writer.write(value.as_ref()) writer.write(value.as_ref())?;
writer.write(b"\n")
}; };
// serialize interface // serialize interface
@@ -40,9 +44,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
p.last_handshake_time_nsec.to_string(), p.last_handshake_time_nsec.to_string(),
)?; )?;
write("public_key", hex::encode(p.public_key.as_bytes()))?; write("public_key", hex::encode(p.public_key.as_bytes()))?;
if let Some(psk) = p.preshared_key { write("preshared_key", hex::encode(p.preshared_key))?;
write("preshared_key", hex::encode(psk))?;
}
for (ip, cidr) in p.allowed_ips { for (ip, cidr) in p.allowed_ips {
write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?; write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?;
} }

View File

@@ -1,6 +1,7 @@
mod get; mod get;
mod set; mod set;
use log;
use std::io::{Read, Write}; use std::io::{Read, Write};
use super::{ConfigError, Configuration}; use super::{ConfigError, Configuration};
@@ -10,10 +11,9 @@ use set::LineParser;
const MAX_LINE_LENGTH: usize = 256; const MAX_LINE_LENGTH: usize = 256;
pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut W, config: &C) { pub fn handle<S: Read + Write, C: Configuration>(stream: &mut S, config: &C) {
fn operation<R: Read, W: Write, C: Configuration>( fn operation<S: Read + Write, C: Configuration>(
reader: &mut R, stream: &mut S,
writer: &mut W,
config: &C, config: &C,
) -> Result<(), ConfigError> { ) -> Result<(), ConfigError> {
// read string up to maximum length (why is this not in std?) // read string up to maximum length (why is this not in std?)
@@ -23,6 +23,7 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
while let Ok(_) = reader.read_exact(&mut m) { while let Ok(_) = reader.read_exact(&mut m) {
let c = m[0] as char; let c = m[0] as char;
if c == '\n' { if c == '\n' {
log::trace!("UAPI, line: {}", l);
return Ok(l); return Ok(l);
}; };
l.push(c); l.push(c);
@@ -43,12 +44,16 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
}; };
// read operation line // read operation line
match readline(reader)?.as_str() { match readline(stream)?.as_str() {
"get=1" => serialize(writer, config).map_err(|_| ConfigError::IOError), "get=1" => {
log::debug!("UAPI, Get operation");
serialize(stream, config).map_err(|_| ConfigError::IOError)
}
"set=1" => { "set=1" => {
log::debug!("UAPI, Set operation");
let mut parser = LineParser::new(config); let mut parser = LineParser::new(config);
loop { loop {
let ln = readline(reader)?; let ln = readline(stream)?;
if ln == "" { if ln == "" {
break Ok(()); break Ok(());
}; };
@@ -61,17 +66,17 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
} }
// process operation // process operation
let res = operation(reader, writer, config); let res = operation(stream, config);
log::debug!("{:?}", res); log::debug!("UAPI, Result of operation: {:?}", res);
// return errno // return errno
let _ = writer.write("errno=".as_ref()); let _ = stream.write("errno=".as_ref());
let _ = writer.write( let _ = stream.write(
match res { match res {
Err(e) => e.errno().to_string(), Err(e) => e.errno().to_string(),
Ok(()) => "0".to_owned(), Ok(()) => "0".to_owned(),
} }
.as_ref(), .as_ref(),
); );
let _ = writer.write("\n\n".as_ref()); let _ = stream.write("\n\n".as_ref());
} }

View File

@@ -1,18 +1,27 @@
use hex::FromHex; use hex::FromHex;
use std::net::{IpAddr, SocketAddr};
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use x25519_dalek::{PublicKey, StaticSecret}; use x25519_dalek::{PublicKey, StaticSecret};
use super::{ConfigError, Configuration}; use super::{ConfigError, Configuration};
#[derive(Copy, Clone)]
enum ParserState { enum ParserState {
Peer { Peer(ParsedPeer),
public_key: PublicKey,
update_only: bool,
},
Interface, Interface,
} }
struct ParsedPeer {
public_key: PublicKey,
update_only: bool,
allowed_ips: Vec<(IpAddr, u32)>,
remove: bool,
preshared_key: Option<[u8; 32]>,
replace_allowed_ips: bool,
persistent_keepalive_interval: Option<u64>,
protocol_version: Option<usize>,
endpoint: Option<SocketAddr>,
}
pub struct LineParser<'a, C: Configuration> { pub struct LineParser<'a, C: Configuration> {
config: &'a C, config: &'a C,
state: ParserState, state: ParserState,
@@ -28,45 +37,71 @@ impl<'a, C: Configuration> LineParser<'a, C> {
fn new_peer(value: &str) -> Result<ParserState, ConfigError> { fn new_peer(value: &str) -> Result<ParserState, ConfigError> {
match <[u8; 32]>::from_hex(value) { match <[u8; 32]>::from_hex(value) {
Ok(pk) => Ok(ParserState::Peer { Ok(pk) => Ok(ParserState::Peer(ParsedPeer {
public_key: PublicKey::from(pk), public_key: PublicKey::from(pk),
remove: false,
update_only: false, update_only: false,
}), allowed_ips: vec![],
preshared_key: None,
replace_allowed_ips: false,
persistent_keepalive_interval: None,
protocol_version: None,
endpoint: None,
})),
Err(_) => Err(ConfigError::InvalidHexValue), Err(_) => Err(ConfigError::InvalidHexValue),
} }
} }
pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> { pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> {
// add the peer if not update_only // flush peer updates to configuration
let flush_peer = |st: ParserState| -> ParserState { fn flush_peer<C: Configuration>(config: &C, peer: &ParsedPeer) -> Option<ConfigError> {
match st { if peer.remove {
ParserState::Peer { config.remove_peer(&peer.public_key);
public_key, return None;
update_only: false,
} => {
self.config.add_peer(&public_key);
ParserState::Peer {
public_key,
update_only: true,
}
}
_ => st,
} }
if !peer.update_only {
config.add_peer(&peer.public_key);
}
for (ip, masklen) in &peer.allowed_ips {
config.add_allowed_ip(&peer.public_key, *ip, *masklen);
}
if let Some(psk) = peer.preshared_key {
config.set_preshared_key(&peer.public_key, psk);
}
if let Some(secs) = peer.persistent_keepalive_interval {
config.set_persistent_keepalive_interval(&peer.public_key, secs);
}
if let Some(version) = peer.protocol_version {
if version == 0 || version > config.get_protocol_version() {
return Some(ConfigError::UnsupportedProtocolVersion);
}
}
if let Some(endpoint) = peer.endpoint {
config.set_endpoint(&peer.public_key, endpoint);
};
None
}; };
// parse line and update parser state // parse line and update parser state
self.state = match self.state { match self.state {
// configure the interface // configure the interface
ParserState::Interface => match key { ParserState::Interface => match key {
// opt: set private key // opt: set private key
"private_key" => match <[u8; 32]>::from_hex(value) { "private_key" => match <[u8; 32]>::from_hex(value) {
Ok(sk) => { Ok(sk) => {
self.config.set_private_key(if sk == [0u8; 32] { self.config.set_private_key(if sk.ct_eq(&[0u8; 32]).into() {
None None
} else { } else {
Some(StaticSecret::from(sk)) Some(StaticSecret::from(sk))
}); });
Ok(self.state) Ok(())
} }
Err(_) => Err(ConfigError::InvalidHexValue), Err(_) => Err(ConfigError::InvalidHexValue),
}, },
@@ -75,7 +110,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
"listen_port" => match value.parse() { "listen_port" => match value.parse() {
Ok(port) => { Ok(port) => {
self.config.set_listen_port(Some(port)); self.config.set_listen_port(Some(port));
Ok(self.state) Ok(())
} }
Err(_) => Err(ConfigError::InvalidPortNumber), Err(_) => Err(ConfigError::InvalidPortNumber),
}, },
@@ -85,7 +120,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
Ok(fwmark) => { Ok(fwmark) => {
self.config self.config
.set_fwmark(if fwmark == 0 { None } else { Some(fwmark) }); .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) });
Ok(self.state) Ok(())
} }
Err(_) => Err(ConfigError::InvalidFwmark), Err(_) => Err(ConfigError::InvalidFwmark),
}, },
@@ -96,51 +131,47 @@ impl<'a, C: Configuration> LineParser<'a, C> {
for p in self.config.get_peers() { for p in self.config.get_peers() {
self.config.remove_peer(&p.public_key) self.config.remove_peer(&p.public_key)
} }
Ok(self.state) Ok(())
} }
_ => Err(ConfigError::UnsupportedValue), _ => Err(ConfigError::UnsupportedValue),
}, },
// opt: transition to peer configuration // opt: transition to peer configuration
"public_key" => Self::new_peer(value), "public_key" => {
self.state = Self::new_peer(value)?;
Ok(())
}
// unknown key // unknown key
_ => Err(ConfigError::InvalidKey), _ => Err(ConfigError::InvalidKey),
}, },
// configure peers // configure peers
ParserState::Peer { public_key, .. } => match key { ParserState::Peer(ref mut peer) => match key {
// opt: new peer // opt: new peer
"public_key" => { "public_key" => {
flush_peer(self.state); flush_peer(self.config, &peer);
Self::new_peer(value) self.state = Self::new_peer(value)?;
Ok(())
} }
// opt: remove peer // opt: remove peer
"remove" => { "remove" => {
self.config.remove_peer(&public_key); peer.remove = true;
Ok(self.state) Ok(())
} }
// opt: update only // opt: update only
"update_only" => Ok(ParserState::Peer { "update_only" => {
public_key, peer.update_only = true;
update_only: true, Ok(())
}), }
// opt: set preshared key // opt: set preshared key
"preshared_key" => match <[u8; 32]>::from_hex(value) { "preshared_key" => match <[u8; 32]>::from_hex(value) {
Ok(psk) => { Ok(psk) => {
let st = flush_peer(self.state); peer.preshared_key = Some(psk);
self.config.set_preshared_key( Ok(())
&public_key,
if psk.ct_eq(&[0u8; 32]).into() {
None
} else {
Some(psk)
},
);
Ok(st)
} }
Err(_) => Err(ConfigError::InvalidHexValue), Err(_) => Err(ConfigError::InvalidHexValue),
}, },
@@ -148,9 +179,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set endpoint // opt: set endpoint
"endpoint" => match value.parse() { "endpoint" => match value.parse() {
Ok(endpoint) => { Ok(endpoint) => {
let st = flush_peer(self.state); peer.endpoint = Some(endpoint);
self.config.set_endpoint(&public_key, endpoint); Ok(())
Ok(st)
} }
Err(_) => Err(ConfigError::InvalidSocketAddr), Err(_) => Err(ConfigError::InvalidSocketAddr),
}, },
@@ -158,19 +188,17 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set persistent keepalive interval // opt: set persistent keepalive interval
"persistent_keepalive_interval" => match value.parse() { "persistent_keepalive_interval" => match value.parse() {
Ok(secs) => { Ok(secs) => {
let st = flush_peer(self.state); peer.persistent_keepalive_interval = Some(secs);
self.config Ok(())
.set_persistent_keepalive_interval(&public_key, secs);
Ok(st)
} }
Err(_) => Err(ConfigError::InvalidKeepaliveInterval), Err(_) => Err(ConfigError::InvalidKeepaliveInterval),
}, },
// opt replace allowed ips // opt replace allowed ips
"replace_allowed_ips" => { "replace_allowed_ips" => {
let st = flush_peer(self.state); peer.replace_allowed_ips = true;
self.config.replace_allowed_ips(&public_key); peer.allowed_ips.clear();
Ok(st) Ok(())
} }
// opt add allowed ips // opt add allowed ips
@@ -180,9 +208,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
let cidr = split.next().and_then(|x| x.parse().ok()); let cidr = split.next().and_then(|x| x.parse().ok());
match (addr, cidr) { match (addr, cidr) {
(Some(addr), Some(cidr)) => { (Some(addr), Some(cidr)) => {
let st = flush_peer(self.state); peer.allowed_ips.push((addr, cidr));
self.config.add_allowed_ip(&public_key, addr, cidr); Ok(())
Ok(st)
} }
_ => Err(ConfigError::InvalidAllowedIp), _ => Err(ConfigError::InvalidAllowedIp),
} }
@@ -193,11 +220,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
let parse_res: Result<usize, _> = value.parse(); let parse_res: Result<usize, _> = value.parse();
match parse_res { match parse_res {
Ok(version) => { Ok(version) => {
if version == 0 || version > self.config.get_protocol_version() { peer.protocol_version = Some(version);
Err(ConfigError::UnsupportedProtocolVersion) Ok(())
} else {
Ok(self.state)
}
} }
Err(_) => Err(ConfigError::UnsupportedProtocolVersion), Err(_) => Err(ConfigError::UnsupportedProtocolVersion),
} }
@@ -206,8 +230,6 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// unknown key // unknown key
_ => Err(ConfigError::InvalidKey), _ => Err(ConfigError::InvalidKey),
}, },
}?; }
Ok(())
} }
} }

View File

@@ -10,21 +10,35 @@ mod configuration;
mod platform; mod platform;
mod wireguard; mod wireguard;
use platform::tun; use platform::tun::PlatformTun;
use platform::uapi::PlatformUAPI;
use platform::*;
use configuration::WireguardConfig; use std::sync::Arc;
use std::thread;
use std::time::Duration;
fn main() { fn main() {
/* let name = "wg0";
let (mut readers, writer, mtu) = platform::TunInstance::create("test").unwrap();
let wg = wireguard::Wireguard::new(readers, writer, mtu);
*/
}
/* let _ = env_logger::builder().is_test(true).try_init();
fn test_wg_configuration() {
let (mut readers, writer, mtu) = platform::dummy::
let wg = wireguard::Wireguard::new(readers, writer, mtu); // create UAPI socket
let uapi = plt::UAPI::bind(name).unwrap();
// create TUN device
let (readers, writer, mtu) = plt::Tun::create(name).unwrap();
// create WireGuard device
let wg: wireguard::Wireguard<plt::Tun, plt::Bind> =
wireguard::Wireguard::new(readers, writer, mtu);
// wrap in configuration interface and start UAPI server
let cfg = configuration::WireguardConfig::new(wg);
loop {
let mut stream = uapi.accept().unwrap();
configuration::uapi::handle(&mut stream.0, &cfg);
}
thread::sleep(Duration::from_secs(600));
} }
*/

View File

@@ -37,7 +37,7 @@ pub trait Owner: Send {
/// On some platforms the application can itself bind to a socket. /// On some platforms the application can itself bind to a socket.
/// This enables configuration using the UAPI interface. /// This enables configuration using the UAPI interface.
pub trait Platform: Bind { pub trait PlatformBind: Bind {
type Owner: Owner; type Owner: Owner;
/// Bind to a new port, returning the reader/writer and /// Bind to a new port, returning the reader/writer and

View File

@@ -216,7 +216,7 @@ impl Owner for VoidOwner {
} }
} }
impl Platform for PairBind { impl PlatformBind for PairBind {
type Owner = VoidOwner; type Owner = VoidOwner;
fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
Err(BindError::Disconnected) Err(BindError::Disconnected)

View File

@@ -192,7 +192,7 @@ impl TunTest {
} }
} }
impl Platform for TunTest { impl PlatformTun for TunTest {
fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
Err(TunError::Disconnected) Err(TunError::Disconnected)
} }

View File

@@ -1,5 +1,7 @@
mod tun; mod tun;
mod uapi;
mod udp; mod udp;
pub use tun::LinuxTun; pub use tun::LinuxTun as Tun;
pub use udp::LinuxBind; pub use uapi::LinuxUAPI as UAPI;
pub use udp::LinuxBind as Bind;

View File

@@ -125,7 +125,7 @@ impl Tun for LinuxTun {
type MTU = LinuxTunMTU; type MTU = LinuxTunMTU;
} }
impl Platform for LinuxTun { impl PlatformTun for LinuxTun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
// construct request struct // construct request struct
let mut req = Ifreq { let mut req = Ifreq {

View File

@@ -0,0 +1,31 @@
use super::super::uapi::*;
use std::fs;
use std::io;
use std::os::unix::net::{UnixListener, UnixStream};
const SOCK_DIR: &str = "/var/run/wireguard/";
pub struct LinuxUAPI {}
impl PlatformUAPI for LinuxUAPI {
type Error = io::Error;
type Bind = UnixListener;
fn bind(name: &str) -> Result<UnixListener, io::Error> {
let socket_path = format!("{}{}.sock", SOCK_DIR, name);
let _ = fs::create_dir_all(SOCK_DIR);
let _ = fs::remove_file(&socket_path);
UnixListener::bind(socket_path)
}
}
impl BindUAPI for UnixListener {
type Stream = UnixStream;
type Error = io::Error;
fn accept(&self) -> Result<UnixStream, io::Error> {
let (stream, _) = self.accept()?;
Ok(stream)
}
}

View File

@@ -1,26 +1,82 @@
use super::super::bind::*; use super::super::bind::*;
use super::super::Endpoint; use super::super::Endpoint;
use std::net::SocketAddr; use std::io;
use std::net::{SocketAddr, UdpSocket};
use std::sync::Arc;
pub struct LinuxEndpoint {} #[derive(Clone)]
pub struct LinuxBind(Arc<UdpSocket>);
pub struct LinuxBind {} pub struct LinuxOwner(Arc<UdpSocket>);
impl Endpoint for LinuxEndpoint { impl Endpoint for SocketAddr {
fn clear_src(&mut self) {} fn clear_src(&mut self) {}
fn from_address(addr: SocketAddr) -> Self { fn from_address(addr: SocketAddr) -> Self {
LinuxEndpoint {} addr
} }
fn into_address(&self) -> SocketAddr { fn into_address(&self) -> SocketAddr {
"127.0.0.1:6060".parse().unwrap() *self
} }
} }
/* impl Reader<SocketAddr> for LinuxBind {
impl Bind for PlatformBind { type Error = io::Error;
type Endpoint = PlatformEndpoint;
fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
self.0.recv_from(buf)
}
}
impl Writer<SocketAddr> for LinuxBind {
type Error = io::Error;
fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> {
self.0.send_to(buf, dst)?;
Ok(())
}
}
impl Owner for LinuxOwner {
type Error = io::Error;
fn get_port(&self) -> u16 {
1337
}
fn get_fwmark(&self) -> Option<u32> {
None
}
fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error> {
None
}
}
impl Drop for LinuxOwner {
fn drop(&mut self) {}
}
impl Bind for LinuxBind {
type Error = io::Error;
type Endpoint = SocketAddr;
type Reader = LinuxBind;
type Writer = LinuxBind;
}
impl PlatformBind for LinuxBind {
type Owner = LinuxOwner;
fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
let socket = Arc::new(socket);
Ok((
vec![LinuxBind(socket.clone())],
LinuxBind(socket.clone()),
LinuxOwner(socket),
))
}
} }
*/

View File

@@ -2,14 +2,15 @@ mod endpoint;
pub mod bind; pub mod bind;
pub mod tun; pub mod tun;
pub mod uapi;
pub use endpoint::Endpoint; pub use endpoint::Endpoint;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
mod linux; pub mod linux;
#[cfg(test)] #[cfg(test)]
pub mod dummy; pub mod dummy;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub use linux::LinuxTun as TunInstance; pub use linux as plt;

View File

@@ -56,6 +56,6 @@ pub trait Tun: Send + Sync + 'static {
} }
/// On some platforms the application can create the TUN device itself. /// On some platforms the application can create the TUN device itself.
pub trait Platform: Tun { pub trait PlatformTun: Tun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>; fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>;
} }

16
src/platform/uapi.rs Normal file
View File

@@ -0,0 +1,16 @@
use std::error::Error;
use std::io::{Read, Write};
pub trait BindUAPI {
type Stream: Read + Write;
type Error: Error;
fn accept(&self) -> Result<Self::Stream, Self::Error>;
}
pub trait PlatformUAPI {
type Error: Error;
type Bind: BindUAPI;
fn bind(name: &str) -> Result<Self::Bind, Self::Error>;
}

View File

@@ -178,13 +178,10 @@ impl Device {
/// # Returns /// # Returns
/// ///
/// The call might fail if the public key is not found /// The call might fail if the public key is not found
pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> { pub fn set_psk(&mut self, pk: PublicKey, psk: Psk) -> Result<(), ConfigError> {
match self.pk_map.get_mut(pk.as_bytes()) { match self.pk_map.get_mut(pk.as_bytes()) {
Some(mut peer) => { Some(mut peer) => {
peer.psk = match psk { peer.psk = psk;
Some(v) => v,
None => [0u8; 32],
};
Ok(()) Ok(())
} }
_ => Err(ConfigError::new("No such public key")), _ => Err(ConfigError::new("No such public key")),
@@ -466,8 +463,8 @@ mod tests {
dev1.add(pk2).unwrap(); dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap(); dev2.add(pk1).unwrap();
dev1.set_psk(pk2, Some(psk)).unwrap(); dev1.set_psk(pk2, psk).unwrap();
dev2.set_psk(pk1, Some(psk)).unwrap(); dev2.set_psk(pk1, psk).unwrap();
(pk1, dev1, pk2, dev2) (pk1, dev1, pk2, dev2)
} }

View File

@@ -201,7 +201,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
.map(|sk| StaticSecret::from(sk.to_bytes())) .map(|sk| StaticSecret::from(sk.to_bytes()))
} }
pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool { pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool {
self.state.handshake.write().set_psk(pk, psk).is_ok() self.state.handshake.write().set_psk(pk, psk).is_ok()
} }
pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> {