Update UAPI semantics for remove
This commit is contained in:
@@ -19,16 +19,16 @@ pub struct PeerState {
|
||||
pub last_handshake_time_nsec: u64,
|
||||
pub public_key: PublicKey,
|
||||
pub allowed_ips: Vec<(IpAddr, u32)>,
|
||||
pub preshared_key: Option<[u8; 32]>,
|
||||
pub preshared_key: [u8; 32], // 0^32 is the "default value"
|
||||
}
|
||||
|
||||
pub struct WireguardConfig<T: tun::Tun, B: bind::Platform> {
|
||||
pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> {
|
||||
wireguard: Wireguard<T, B>,
|
||||
network: Mutex<Option<B::Owner>>,
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: bind::Platform> WireguardConfig<T, B> {
|
||||
fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
|
||||
impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> {
|
||||
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
|
||||
WireguardConfig {
|
||||
wireguard: wg,
|
||||
network: Mutex::new(None),
|
||||
@@ -110,7 +110,7 @@ pub trait Configuration {
|
||||
/// # Returns
|
||||
///
|
||||
/// An error if no such peer exists
|
||||
fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
|
||||
fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError>;
|
||||
|
||||
/// Update the endpoint of the
|
||||
///
|
||||
@@ -170,7 +170,7 @@ pub trait Configuration {
|
||||
fn get_fwmark(&self) -> Option<u32>;
|
||||
}
|
||||
|
||||
impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
|
||||
impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B> {
|
||||
fn get_fwmark(&self) -> Option<u32> {
|
||||
self.network
|
||||
.lock()
|
||||
@@ -246,7 +246,7 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
|
||||
false
|
||||
}
|
||||
|
||||
fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError> {
|
||||
fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError> {
|
||||
if self.wireguard.set_psk(*peer, psk) {
|
||||
None
|
||||
} else {
|
||||
@@ -308,9 +308,10 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or(Duration::from_secs(0)); // any time before epoch is mapped to epoch
|
||||
|
||||
if let Some(psk) = self.wireguard.get_psk(&p.pk) {
|
||||
// extract state into PeerState
|
||||
state.push(PeerState {
|
||||
preshared_key: self.wireguard.get_psk(&p.pk),
|
||||
preshared_key: psk,
|
||||
rx_bytes: p.rx_bytes.load(Ordering::Relaxed),
|
||||
tx_bytes: p.tx_bytes.load(Ordering::Relaxed),
|
||||
allowed_ips: p.router.list_allowed_ips(),
|
||||
@@ -319,6 +320,7 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
|
||||
public_key: p.pk,
|
||||
})
|
||||
}
|
||||
}
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod config;
|
||||
mod error;
|
||||
mod uapi;
|
||||
pub mod uapi;
|
||||
|
||||
use super::platform::Endpoint;
|
||||
use super::platform::{bind, tun};
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use hex::FromHex;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use log;
|
||||
|
||||
use super::Configuration;
|
||||
use std::io;
|
||||
|
||||
@@ -8,9 +10,11 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
|
||||
let mut write = |key: &'static str, value: String| {
|
||||
debug_assert!(value.is_ascii());
|
||||
debug_assert!(key.is_ascii());
|
||||
log::trace!("UAPI: return : {} = {}", key, value);
|
||||
writer.write(key.as_ref())?;
|
||||
writer.write(b"=")?;
|
||||
writer.write(value.as_ref())
|
||||
writer.write(value.as_ref())?;
|
||||
writer.write(b"\n")
|
||||
};
|
||||
|
||||
// serialize interface
|
||||
@@ -40,9 +44,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
|
||||
p.last_handshake_time_nsec.to_string(),
|
||||
)?;
|
||||
write("public_key", hex::encode(p.public_key.as_bytes()))?;
|
||||
if let Some(psk) = p.preshared_key {
|
||||
write("preshared_key", hex::encode(psk))?;
|
||||
}
|
||||
write("preshared_key", hex::encode(p.preshared_key))?;
|
||||
for (ip, cidr) in p.allowed_ips {
|
||||
write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod get;
|
||||
mod set;
|
||||
|
||||
use log;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use super::{ConfigError, Configuration};
|
||||
@@ -10,10 +11,9 @@ use set::LineParser;
|
||||
|
||||
const MAX_LINE_LENGTH: usize = 256;
|
||||
|
||||
pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut W, config: &C) {
|
||||
fn operation<R: Read, W: Write, C: Configuration>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
pub fn handle<S: Read + Write, C: Configuration>(stream: &mut S, config: &C) {
|
||||
fn operation<S: Read + Write, C: Configuration>(
|
||||
stream: &mut S,
|
||||
config: &C,
|
||||
) -> Result<(), ConfigError> {
|
||||
// read string up to maximum length (why is this not in std?)
|
||||
@@ -23,6 +23,7 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
|
||||
while let Ok(_) = reader.read_exact(&mut m) {
|
||||
let c = m[0] as char;
|
||||
if c == '\n' {
|
||||
log::trace!("UAPI, line: {}", l);
|
||||
return Ok(l);
|
||||
};
|
||||
l.push(c);
|
||||
@@ -43,12 +44,16 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
|
||||
};
|
||||
|
||||
// read operation line
|
||||
match readline(reader)?.as_str() {
|
||||
"get=1" => serialize(writer, config).map_err(|_| ConfigError::IOError),
|
||||
match readline(stream)?.as_str() {
|
||||
"get=1" => {
|
||||
log::debug!("UAPI, Get operation");
|
||||
serialize(stream, config).map_err(|_| ConfigError::IOError)
|
||||
}
|
||||
"set=1" => {
|
||||
log::debug!("UAPI, Set operation");
|
||||
let mut parser = LineParser::new(config);
|
||||
loop {
|
||||
let ln = readline(reader)?;
|
||||
let ln = readline(stream)?;
|
||||
if ln == "" {
|
||||
break Ok(());
|
||||
};
|
||||
@@ -61,17 +66,17 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
|
||||
}
|
||||
|
||||
// process operation
|
||||
let res = operation(reader, writer, config);
|
||||
log::debug!("{:?}", res);
|
||||
let res = operation(stream, config);
|
||||
log::debug!("UAPI, Result of operation: {:?}", res);
|
||||
|
||||
// return errno
|
||||
let _ = writer.write("errno=".as_ref());
|
||||
let _ = writer.write(
|
||||
let _ = stream.write("errno=".as_ref());
|
||||
let _ = stream.write(
|
||||
match res {
|
||||
Err(e) => e.errno().to_string(),
|
||||
Ok(()) => "0".to_owned(),
|
||||
}
|
||||
.as_ref(),
|
||||
);
|
||||
let _ = writer.write("\n\n".as_ref());
|
||||
let _ = stream.write("\n\n".as_ref());
|
||||
}
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
use hex::FromHex;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use subtle::ConstantTimeEq;
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
use super::{ConfigError, Configuration};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum ParserState {
|
||||
Peer {
|
||||
Peer(ParsedPeer),
|
||||
Interface,
|
||||
}
|
||||
|
||||
struct ParsedPeer {
|
||||
public_key: PublicKey,
|
||||
update_only: bool,
|
||||
},
|
||||
Interface,
|
||||
allowed_ips: Vec<(IpAddr, u32)>,
|
||||
remove: bool,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
replace_allowed_ips: bool,
|
||||
persistent_keepalive_interval: Option<u64>,
|
||||
protocol_version: Option<usize>,
|
||||
endpoint: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
pub struct LineParser<'a, C: Configuration> {
|
||||
@@ -28,45 +37,71 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
|
||||
fn new_peer(value: &str) -> Result<ParserState, ConfigError> {
|
||||
match <[u8; 32]>::from_hex(value) {
|
||||
Ok(pk) => Ok(ParserState::Peer {
|
||||
Ok(pk) => Ok(ParserState::Peer(ParsedPeer {
|
||||
public_key: PublicKey::from(pk),
|
||||
remove: false,
|
||||
update_only: false,
|
||||
}),
|
||||
allowed_ips: vec![],
|
||||
preshared_key: None,
|
||||
replace_allowed_ips: false,
|
||||
persistent_keepalive_interval: None,
|
||||
protocol_version: None,
|
||||
endpoint: None,
|
||||
})),
|
||||
Err(_) => Err(ConfigError::InvalidHexValue),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> {
|
||||
// add the peer if not update_only
|
||||
let flush_peer = |st: ParserState| -> ParserState {
|
||||
match st {
|
||||
ParserState::Peer {
|
||||
public_key,
|
||||
update_only: false,
|
||||
} => {
|
||||
self.config.add_peer(&public_key);
|
||||
ParserState::Peer {
|
||||
public_key,
|
||||
update_only: true,
|
||||
// flush peer updates to configuration
|
||||
fn flush_peer<C: Configuration>(config: &C, peer: &ParsedPeer) -> Option<ConfigError> {
|
||||
if peer.remove {
|
||||
config.remove_peer(&peer.public_key);
|
||||
return None;
|
||||
}
|
||||
|
||||
if !peer.update_only {
|
||||
config.add_peer(&peer.public_key);
|
||||
}
|
||||
|
||||
for (ip, masklen) in &peer.allowed_ips {
|
||||
config.add_allowed_ip(&peer.public_key, *ip, *masklen);
|
||||
}
|
||||
|
||||
if let Some(psk) = peer.preshared_key {
|
||||
config.set_preshared_key(&peer.public_key, psk);
|
||||
}
|
||||
|
||||
if let Some(secs) = peer.persistent_keepalive_interval {
|
||||
config.set_persistent_keepalive_interval(&peer.public_key, secs);
|
||||
}
|
||||
|
||||
if let Some(version) = peer.protocol_version {
|
||||
if version == 0 || version > config.get_protocol_version() {
|
||||
return Some(ConfigError::UnsupportedProtocolVersion);
|
||||
}
|
||||
}
|
||||
_ => st,
|
||||
}
|
||||
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
config.set_endpoint(&peer.public_key, endpoint);
|
||||
};
|
||||
|
||||
None
|
||||
};
|
||||
|
||||
// parse line and update parser state
|
||||
self.state = match self.state {
|
||||
match self.state {
|
||||
// configure the interface
|
||||
ParserState::Interface => match key {
|
||||
// opt: set private key
|
||||
"private_key" => match <[u8; 32]>::from_hex(value) {
|
||||
Ok(sk) => {
|
||||
self.config.set_private_key(if sk == [0u8; 32] {
|
||||
self.config.set_private_key(if sk.ct_eq(&[0u8; 32]).into() {
|
||||
None
|
||||
} else {
|
||||
Some(StaticSecret::from(sk))
|
||||
});
|
||||
Ok(self.state)
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidHexValue),
|
||||
},
|
||||
@@ -75,7 +110,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
"listen_port" => match value.parse() {
|
||||
Ok(port) => {
|
||||
self.config.set_listen_port(Some(port));
|
||||
Ok(self.state)
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidPortNumber),
|
||||
},
|
||||
@@ -85,7 +120,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
Ok(fwmark) => {
|
||||
self.config
|
||||
.set_fwmark(if fwmark == 0 { None } else { Some(fwmark) });
|
||||
Ok(self.state)
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidFwmark),
|
||||
},
|
||||
@@ -96,51 +131,47 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
for p in self.config.get_peers() {
|
||||
self.config.remove_peer(&p.public_key)
|
||||
}
|
||||
Ok(self.state)
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ConfigError::UnsupportedValue),
|
||||
},
|
||||
|
||||
// opt: transition to peer configuration
|
||||
"public_key" => Self::new_peer(value),
|
||||
"public_key" => {
|
||||
self.state = Self::new_peer(value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// unknown key
|
||||
_ => Err(ConfigError::InvalidKey),
|
||||
},
|
||||
|
||||
// configure peers
|
||||
ParserState::Peer { public_key, .. } => match key {
|
||||
ParserState::Peer(ref mut peer) => match key {
|
||||
// opt: new peer
|
||||
"public_key" => {
|
||||
flush_peer(self.state);
|
||||
Self::new_peer(value)
|
||||
flush_peer(self.config, &peer);
|
||||
self.state = Self::new_peer(value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// opt: remove peer
|
||||
"remove" => {
|
||||
self.config.remove_peer(&public_key);
|
||||
Ok(self.state)
|
||||
peer.remove = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// opt: update only
|
||||
"update_only" => Ok(ParserState::Peer {
|
||||
public_key,
|
||||
update_only: true,
|
||||
}),
|
||||
"update_only" => {
|
||||
peer.update_only = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// opt: set preshared key
|
||||
"preshared_key" => match <[u8; 32]>::from_hex(value) {
|
||||
Ok(psk) => {
|
||||
let st = flush_peer(self.state);
|
||||
self.config.set_preshared_key(
|
||||
&public_key,
|
||||
if psk.ct_eq(&[0u8; 32]).into() {
|
||||
None
|
||||
} else {
|
||||
Some(psk)
|
||||
},
|
||||
);
|
||||
Ok(st)
|
||||
peer.preshared_key = Some(psk);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidHexValue),
|
||||
},
|
||||
@@ -148,9 +179,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
// opt: set endpoint
|
||||
"endpoint" => match value.parse() {
|
||||
Ok(endpoint) => {
|
||||
let st = flush_peer(self.state);
|
||||
self.config.set_endpoint(&public_key, endpoint);
|
||||
Ok(st)
|
||||
peer.endpoint = Some(endpoint);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidSocketAddr),
|
||||
},
|
||||
@@ -158,19 +188,17 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
// opt: set persistent keepalive interval
|
||||
"persistent_keepalive_interval" => match value.parse() {
|
||||
Ok(secs) => {
|
||||
let st = flush_peer(self.state);
|
||||
self.config
|
||||
.set_persistent_keepalive_interval(&public_key, secs);
|
||||
Ok(st)
|
||||
peer.persistent_keepalive_interval = Some(secs);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::InvalidKeepaliveInterval),
|
||||
},
|
||||
|
||||
// opt replace allowed ips
|
||||
"replace_allowed_ips" => {
|
||||
let st = flush_peer(self.state);
|
||||
self.config.replace_allowed_ips(&public_key);
|
||||
Ok(st)
|
||||
peer.replace_allowed_ips = true;
|
||||
peer.allowed_ips.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// opt add allowed ips
|
||||
@@ -180,9 +208,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
let cidr = split.next().and_then(|x| x.parse().ok());
|
||||
match (addr, cidr) {
|
||||
(Some(addr), Some(cidr)) => {
|
||||
let st = flush_peer(self.state);
|
||||
self.config.add_allowed_ip(&public_key, addr, cidr);
|
||||
Ok(st)
|
||||
peer.allowed_ips.push((addr, cidr));
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ConfigError::InvalidAllowedIp),
|
||||
}
|
||||
@@ -193,11 +220,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
let parse_res: Result<usize, _> = value.parse();
|
||||
match parse_res {
|
||||
Ok(version) => {
|
||||
if version == 0 || version > self.config.get_protocol_version() {
|
||||
Err(ConfigError::UnsupportedProtocolVersion)
|
||||
} else {
|
||||
Ok(self.state)
|
||||
}
|
||||
peer.protocol_version = Some(version);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(ConfigError::UnsupportedProtocolVersion),
|
||||
}
|
||||
@@ -206,8 +230,6 @@ impl<'a, C: Configuration> LineParser<'a, C> {
|
||||
// unknown key
|
||||
_ => Err(ConfigError::InvalidKey),
|
||||
},
|
||||
}?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
38
src/main.rs
38
src/main.rs
@@ -10,21 +10,35 @@ mod configuration;
|
||||
mod platform;
|
||||
mod wireguard;
|
||||
|
||||
use platform::tun;
|
||||
use platform::tun::PlatformTun;
|
||||
use platform::uapi::PlatformUAPI;
|
||||
use platform::*;
|
||||
|
||||
use configuration::WireguardConfig;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
fn main() {
|
||||
/*
|
||||
let (mut readers, writer, mtu) = platform::TunInstance::create("test").unwrap();
|
||||
let wg = wireguard::Wireguard::new(readers, writer, mtu);
|
||||
*/
|
||||
}
|
||||
let name = "wg0";
|
||||
|
||||
/*
|
||||
fn test_wg_configuration() {
|
||||
let (mut readers, writer, mtu) = platform::dummy::
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
|
||||
let wg = wireguard::Wireguard::new(readers, writer, mtu);
|
||||
// create UAPI socket
|
||||
let uapi = plt::UAPI::bind(name).unwrap();
|
||||
|
||||
// create TUN device
|
||||
let (readers, writer, mtu) = plt::Tun::create(name).unwrap();
|
||||
|
||||
// create WireGuard device
|
||||
let wg: wireguard::Wireguard<plt::Tun, plt::Bind> =
|
||||
wireguard::Wireguard::new(readers, writer, mtu);
|
||||
|
||||
// wrap in configuration interface and start UAPI server
|
||||
let cfg = configuration::WireguardConfig::new(wg);
|
||||
loop {
|
||||
let mut stream = uapi.accept().unwrap();
|
||||
configuration::uapi::handle(&mut stream.0, &cfg);
|
||||
}
|
||||
|
||||
thread::sleep(Duration::from_secs(600));
|
||||
}
|
||||
*/
|
||||
|
||||
@@ -37,7 +37,7 @@ pub trait Owner: Send {
|
||||
|
||||
/// On some platforms the application can itself bind to a socket.
|
||||
/// This enables configuration using the UAPI interface.
|
||||
pub trait Platform: Bind {
|
||||
pub trait PlatformBind: Bind {
|
||||
type Owner: Owner;
|
||||
|
||||
/// Bind to a new port, returning the reader/writer and
|
||||
|
||||
@@ -216,7 +216,7 @@ impl Owner for VoidOwner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Platform for PairBind {
|
||||
impl PlatformBind for PairBind {
|
||||
type Owner = VoidOwner;
|
||||
fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
|
||||
Err(BindError::Disconnected)
|
||||
|
||||
@@ -192,7 +192,7 @@ impl TunTest {
|
||||
}
|
||||
}
|
||||
|
||||
impl Platform for TunTest {
|
||||
impl PlatformTun for TunTest {
|
||||
fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
|
||||
Err(TunError::Disconnected)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
mod tun;
|
||||
mod uapi;
|
||||
mod udp;
|
||||
|
||||
pub use tun::LinuxTun;
|
||||
pub use udp::LinuxBind;
|
||||
pub use tun::LinuxTun as Tun;
|
||||
pub use uapi::LinuxUAPI as UAPI;
|
||||
pub use udp::LinuxBind as Bind;
|
||||
|
||||
@@ -125,7 +125,7 @@ impl Tun for LinuxTun {
|
||||
type MTU = LinuxTunMTU;
|
||||
}
|
||||
|
||||
impl Platform for LinuxTun {
|
||||
impl PlatformTun for LinuxTun {
|
||||
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
|
||||
// construct request struct
|
||||
let mut req = Ifreq {
|
||||
|
||||
31
src/platform/linux/uapi.rs
Normal file
31
src/platform/linux/uapi.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use super::super::uapi::*;
|
||||
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::os::unix::net::{UnixListener, UnixStream};
|
||||
|
||||
const SOCK_DIR: &str = "/var/run/wireguard/";
|
||||
|
||||
pub struct LinuxUAPI {}
|
||||
|
||||
impl PlatformUAPI for LinuxUAPI {
|
||||
type Error = io::Error;
|
||||
type Bind = UnixListener;
|
||||
|
||||
fn bind(name: &str) -> Result<UnixListener, io::Error> {
|
||||
let socket_path = format!("{}{}.sock", SOCK_DIR, name);
|
||||
let _ = fs::create_dir_all(SOCK_DIR);
|
||||
let _ = fs::remove_file(&socket_path);
|
||||
UnixListener::bind(socket_path)
|
||||
}
|
||||
}
|
||||
|
||||
impl BindUAPI for UnixListener {
|
||||
type Stream = UnixStream;
|
||||
type Error = io::Error;
|
||||
|
||||
fn accept(&self) -> Result<UnixStream, io::Error> {
|
||||
let (stream, _) = self.accept()?;
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,82 @@
|
||||
use super::super::bind::*;
|
||||
use super::super::Endpoint;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, UdpSocket};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct LinuxEndpoint {}
|
||||
#[derive(Clone)]
|
||||
pub struct LinuxBind(Arc<UdpSocket>);
|
||||
|
||||
pub struct LinuxBind {}
|
||||
pub struct LinuxOwner(Arc<UdpSocket>);
|
||||
|
||||
impl Endpoint for LinuxEndpoint {
|
||||
impl Endpoint for SocketAddr {
|
||||
fn clear_src(&mut self) {}
|
||||
|
||||
fn from_address(addr: SocketAddr) -> Self {
|
||||
LinuxEndpoint {}
|
||||
addr
|
||||
}
|
||||
|
||||
fn into_address(&self) -> SocketAddr {
|
||||
"127.0.0.1:6060".parse().unwrap()
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
impl Bind for PlatformBind {
|
||||
type Endpoint = PlatformEndpoint;
|
||||
impl Reader<SocketAddr> for LinuxBind {
|
||||
type Error = io::Error;
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
|
||||
self.0.recv_from(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer<SocketAddr> for LinuxBind {
|
||||
type Error = io::Error;
|
||||
|
||||
fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> {
|
||||
self.0.send_to(buf, dst)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Owner for LinuxOwner {
|
||||
type Error = io::Error;
|
||||
|
||||
fn get_port(&self) -> u16 {
|
||||
1337
|
||||
}
|
||||
|
||||
fn get_fwmark(&self) -> Option<u32> {
|
||||
None
|
||||
}
|
||||
|
||||
fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LinuxOwner {
|
||||
fn drop(&mut self) {}
|
||||
}
|
||||
|
||||
impl Bind for LinuxBind {
|
||||
type Error = io::Error;
|
||||
type Endpoint = SocketAddr;
|
||||
type Reader = LinuxBind;
|
||||
type Writer = LinuxBind;
|
||||
}
|
||||
|
||||
impl PlatformBind for LinuxBind {
|
||||
type Owner = LinuxOwner;
|
||||
|
||||
fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
|
||||
let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
|
||||
let socket = Arc::new(socket);
|
||||
|
||||
Ok((
|
||||
vec![LinuxBind(socket.clone())],
|
||||
LinuxBind(socket.clone()),
|
||||
LinuxOwner(socket),
|
||||
))
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
@@ -2,14 +2,15 @@ mod endpoint;
|
||||
|
||||
pub mod bind;
|
||||
pub mod tun;
|
||||
pub mod uapi;
|
||||
|
||||
pub use endpoint::Endpoint;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
mod linux;
|
||||
pub mod linux;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod dummy;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub use linux::LinuxTun as TunInstance;
|
||||
pub use linux as plt;
|
||||
|
||||
@@ -56,6 +56,6 @@ pub trait Tun: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
/// On some platforms the application can create the TUN device itself.
|
||||
pub trait Platform: Tun {
|
||||
pub trait PlatformTun: Tun {
|
||||
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>;
|
||||
}
|
||||
|
||||
16
src/platform/uapi.rs
Normal file
16
src/platform/uapi.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use std::error::Error;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
pub trait BindUAPI {
|
||||
type Stream: Read + Write;
|
||||
type Error: Error;
|
||||
|
||||
fn accept(&self) -> Result<Self::Stream, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait PlatformUAPI {
|
||||
type Error: Error;
|
||||
type Bind: BindUAPI;
|
||||
|
||||
fn bind(name: &str) -> Result<Self::Bind, Self::Error>;
|
||||
}
|
||||
@@ -178,13 +178,10 @@ impl Device {
|
||||
/// # Returns
|
||||
///
|
||||
/// The call might fail if the public key is not found
|
||||
pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
|
||||
pub fn set_psk(&mut self, pk: PublicKey, psk: Psk) -> Result<(), ConfigError> {
|
||||
match self.pk_map.get_mut(pk.as_bytes()) {
|
||||
Some(mut peer) => {
|
||||
peer.psk = match psk {
|
||||
Some(v) => v,
|
||||
None => [0u8; 32],
|
||||
};
|
||||
peer.psk = psk;
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ConfigError::new("No such public key")),
|
||||
@@ -466,8 +463,8 @@ mod tests {
|
||||
dev1.add(pk2).unwrap();
|
||||
dev2.add(pk1).unwrap();
|
||||
|
||||
dev1.set_psk(pk2, Some(psk)).unwrap();
|
||||
dev2.set_psk(pk1, Some(psk)).unwrap();
|
||||
dev1.set_psk(pk2, psk).unwrap();
|
||||
dev2.set_psk(pk1, psk).unwrap();
|
||||
|
||||
(pk1, dev1, pk2, dev2)
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
|
||||
.map(|sk| StaticSecret::from(sk.to_bytes()))
|
||||
}
|
||||
|
||||
pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool {
|
||||
pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool {
|
||||
self.state.handshake.write().set_psk(pk, psk).is_ok()
|
||||
}
|
||||
pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> {
|
||||
|
||||
Reference in New Issue
Block a user