Move parser code to zerocopy

This commit is contained in:
Mathias Hall-Andersen
2019-07-26 15:46:24 +02:00
parent 43b56dfb58
commit 5efb318171
5 changed files with 138 additions and 151 deletions

View File

@@ -202,7 +202,6 @@ mod tests {
use super::*;
use hex;
use messages::*;
use std::convert::TryFrom;
#[test]
fn handshake() {
@@ -234,7 +233,7 @@ mod tests {
let msg1 = dev1.begin(&pk2).unwrap();
println!("msg1 = {}", hex::encode(&msg1[..]));
println!("msg1 = {:?}", Initiation::try_from(&msg1[..]).unwrap());
println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
// process initiation and create response
@@ -244,7 +243,7 @@ mod tests {
let msg2 = msg2.unwrap();
println!("msg2 = {}", hex::encode(&msg2[..]));
println!("msg2 = {:?}", Response::try_from(&msg2[..]).unwrap());
println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
assert!(!ks_r.confirmed, "Responders key-pair is confirmed");

View File

@@ -1,8 +1,11 @@
use crate::types::*;
use hex;
use std::convert::TryFrom;
use std::fmt;
use std::mem;
use byteorder::LittleEndian;
use zerocopy::byteorder::U32;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
use crate::types::*;
const SIZE_TAG: usize = 16;
const SIZE_X25519_POINT: usize = 32;
@@ -11,17 +14,11 @@ const SIZE_TIMESTAMP: usize = 12;
pub const TYPE_INITIATION: u8 = 1;
pub const TYPE_RESPONSE: u8 = 2;
/* Functions related to the packing / unpacking of
* the fixed-sized noise handshake messages.
*
* The unpacked types are unexposed implementation details.
*/
#[repr(C)]
#[derive(Copy, Clone)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct Initiation {
f_type: u32,
pub f_sender: u32,
f_type: U32<LittleEndian>,
pub f_sender: U32<LittleEndian>,
pub f_ephemeral: [u8; SIZE_X25519_POINT],
pub f_static: [u8; SIZE_X25519_POINT],
pub f_static_tag: [u8; SIZE_TAG],
@@ -29,63 +26,11 @@ pub struct Initiation {
pub f_timestamp_tag: [u8; SIZE_TAG],
}
impl TryFrom<&[u8]> for Initiation {
type Error = HandshakeError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
// check length of slice matches message
if value.len() != mem::size_of::<Self>() {
return Err(HandshakeError::InvalidMessageFormat);
}
// create owned copy
let mut owned = [0u8; mem::size_of::<Self>()];
let mut msg: Self;
owned.copy_from_slice(value);
// cast to Initiation
unsafe {
msg = mem::transmute::<[u8; mem::size_of::<Self>()], Self>(owned);
};
// correct endianness
msg.f_type = msg.f_type.to_le();
msg.f_sender = msg.f_sender.to_le();
// check type and reserved fields
if msg.f_type != (TYPE_INITIATION as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
Ok(msg)
}
}
impl Into<Vec<u8>> for Initiation {
fn into(self) -> Vec<u8> {
// correct endianness
let mut msg = self;
msg.f_type = msg.f_type.to_le();
msg.f_sender = msg.f_sender.to_le();
// cast to array
let array: [u8; mem::size_of::<Self>()];
unsafe { array = mem::transmute::<Self, [u8; mem::size_of::<Self>()]>(msg) };
array.to_vec()
}
}
impl Default for Initiation {
fn default() -> Self {
Self {
f_type: TYPE_INITIATION as u32,
f_sender: 0,
f_type: <U32<LittleEndian>>::new(TYPE_INITIATION as u32),
f_sender: <U32<LittleEndian>>::new(0),
f_ephemeral: [0u8; SIZE_X25519_POINT],
f_static: [0u8; SIZE_X25519_POINT],
f_static_tag: [0u8; SIZE_TAG],
@@ -95,12 +40,25 @@ impl Default for Initiation {
}
}
impl Initiation {
pub fn parse<B : ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
let msg: LayoutVerified<B, Self> =
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
if msg.f_type.get() != (TYPE_INITIATION as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
Ok(msg)
}
}
impl fmt::Debug for Initiation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f,
"MessageInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}|{}, timestamp = {}|{} }}",
self.f_type,
self.f_sender,
self.f_type.get(),
self.f_sender.get(),
hex::encode(self.f_ephemeral),
hex::encode(self.f_static),
hex::encode(self.f_static_tag),
@@ -113,8 +71,8 @@ impl fmt::Debug for Initiation {
#[cfg(test)]
impl PartialEq for Initiation {
fn eq(&self, other: &Self) -> bool {
self.f_type == other.f_type
&& self.f_sender == other.f_sender
self.f_type.get() == other.f_type.get()
&& self.f_sender.get() == other.f_sender.get()
&& self.f_ephemeral[..] == other.f_ephemeral[..]
&& self.f_static[..] == other.f_static[..]
&& self.f_static_tag[..] == other.f_static_tag[..]
@@ -127,46 +85,21 @@ impl PartialEq for Initiation {
impl Eq for Initiation {}
#[repr(C)]
#[derive(Copy, Clone)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
pub struct Response {
f_type: u32,
pub f_sender: u32,
pub f_receiver: u32,
f_type: U32<LittleEndian>,
pub f_sender: U32<LittleEndian>,
pub f_receiver: U32<LittleEndian>,
pub f_ephemeral: [u8; SIZE_X25519_POINT],
pub f_empty_tag: [u8; SIZE_TAG],
}
impl TryFrom<&[u8]> for Response {
type Error = HandshakeError;
impl Response {
pub fn parse<B : ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
let msg: LayoutVerified<B, Self> =
LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
// check length of slice matches message
if value.len() != mem::size_of::<Self>() {
return Err(HandshakeError::InvalidMessageFormat);
}
// create owned copy
let mut owned = [0u8; mem::size_of::<Self>()];
let mut msg: Self;
owned.copy_from_slice(value);
// cast to MessageResponse
unsafe {
msg = mem::transmute::<[u8; mem::size_of::<Self>()], Self>(owned);
};
// correct endianness
msg.f_type = msg.f_type.to_le();
msg.f_sender = msg.f_sender.to_le();
msg.f_receiver = msg.f_receiver.to_le();
// check type and reserved fields
if msg.f_type != (TYPE_RESPONSE as u32) {
if msg.f_type.get() != (TYPE_RESPONSE as u32) {
return Err(HandshakeError::InvalidMessageFormat);
}
@@ -174,28 +107,12 @@ impl TryFrom<&[u8]> for Response {
}
}
impl Into<Vec<u8>> for Response {
fn into(self) -> Vec<u8> {
// correct endianness
let mut msg = self;
msg.f_type = msg.f_type.to_le();
msg.f_sender = msg.f_sender.to_le();
msg.f_receiver = msg.f_receiver.to_le();
// cast to array
let array: [u8; mem::size_of::<Self>()];
unsafe { array = mem::transmute::<Self, [u8; mem::size_of::<Self>()]>(msg) };
array.to_vec()
}
}
impl Default for Response {
fn default() -> Self {
Self {
f_type: TYPE_RESPONSE as u32,
f_sender: 0,
f_receiver: 0,
f_type: <U32<LittleEndian>>::new(TYPE_RESPONSE as u32),
f_sender: <U32<LittleEndian>>::ZERO,
f_receiver: <U32<LittleEndian>>::ZERO,
f_ephemeral: [0u8; SIZE_X25519_POINT],
f_empty_tag: [0u8; SIZE_TAG],
}
@@ -234,8 +151,8 @@ mod tests {
fn message_response_identity() {
let mut msg: Response = Default::default();
msg.f_sender = 146252;
msg.f_receiver = 554442;
msg.f_sender.set(146252);
msg.f_receiver.set(554442);
msg.f_ephemeral = [
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
@@ -246,16 +163,16 @@ mod tests {
0x2f, 0xde,
];
let buf: Vec<u8> = msg.into();
let msg_p: Response = Response::try_from(&buf[..]).unwrap();
assert_eq!(msg, msg_p);
let buf: Vec<u8> = msg.as_bytes().to_vec();
let msg_p = Response::parse(&buf[..]).unwrap();
assert_eq!(msg, *msg_p.into_ref());
}
#[test]
fn message_initiate_identity() {
let mut msg: Initiation = Default::default();
msg.f_sender = 575757;
msg.f_sender.set(575757);
msg.f_ephemeral = [
0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
@@ -278,7 +195,8 @@ mod tests {
0x2f, 0xde,
];
let buf: Vec<u8> = msg.into();
assert_eq!(msg, Initiation::try_from(&buf[..]).unwrap());
let buf: Vec<u8> = msg.as_bytes().to_vec();
let msg_p = Initiation::parse(&buf[..]).unwrap();
assert_eq!(msg, *msg_p.into_ref());
}
}

View File

@@ -1,5 +1,3 @@
use std::convert::TryFrom;
// DH
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
@@ -17,6 +15,8 @@ use rand::rngs::OsRng;
use generic_array::typenum::*;
use generic_array::GenericArray;
use zerocopy::AsBytes;
use crate::device::Device;
use crate::messages::{Initiation, Response};
use crate::peer::{Peer, State};
@@ -148,7 +148,7 @@ pub fn create_initiation<T: Copy>(
let hs = INITIAL_HS;
let hs = HASH!(&hs, peer.pk.as_bytes());
msg.f_sender = sender;
msg.f_sender.set(sender);
// (E_priv, E_pub) := DH-Generate()
@@ -214,7 +214,7 @@ pub fn create_initiation<T: Copy>(
// return message as vector
Ok(Initiation::into(msg))
Ok(msg.as_bytes().to_vec())
}
pub fn consume_initiation<'a, T: Copy>(
@@ -223,7 +223,7 @@ pub fn consume_initiation<'a, T: Copy>(
) -> Result<(&'a Peer<T>, TemporaryState), HandshakeError> {
// parse message
let msg = Initiation::try_from(msg)?;
let msg = Initiation::parse(msg)?;
// initialize state
@@ -288,7 +288,7 @@ pub fn consume_initiation<'a, T: Copy>(
// return state (to create response)
Ok((peer, (msg.f_sender, eph_r_pk, hs, ck)))
Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
}
pub fn create_response<T: Copy>(
@@ -301,8 +301,8 @@ pub fn create_response<T: Copy>(
let (receiver, eph_r_pk, hs, ck) = state;
msg.f_sender = sender;
msg.f_receiver = receiver;
msg.f_sender.set(sender);
msg.f_receiver.set(receiver);
// (E_priv, E_pub) := DH-Generate()
@@ -361,7 +361,7 @@ pub fn create_response<T: Copy>(
Ok((
peer.identifier,
Some(Response::into(msg)),
Some(msg.as_bytes().to_vec()),
Some(KeyPair {
confirmed: false,
send: Key {
@@ -376,17 +376,15 @@ pub fn create_response<T: Copy>(
))
}
pub fn consume_response<T: Copy>(
device: &Device<T>,
msg: &[u8],
) -> Result<Output<T>, HandshakeError> {
pub fn consume_response<T: Copy>(device: &Device<T>, msg: &[u8]) -> Result<Output<T>, HandshakeError> {
// parse message
let msg = Response::try_from(msg)?;
let msg = Response::parse(msg)?;
// retrieve peer and associated state
let peer = device.lookup_id(msg.f_receiver)?;
let peer = device.lookup_id(msg.f_receiver.get())?;
let (hs, ck, sender, eph_sk) = match peer.get_state() {
State::Reset => Err(HandshakeError::InvalidState),
State::InitiationSent {
@@ -448,7 +446,7 @@ pub fn consume_response<T: Copy>(
key: key_send.into(),
},
recv: Key {
id: msg.f_sender,
id: msg.f_sender.get(),
key: key_recv.into(),
},
}),