refactor: cleaned up unwrap_file_keys

This commit is contained in:
2024-02-26 10:31:25 +01:00
parent 825f837e0e
commit 412d9ac43a

View File

@@ -1,15 +1,15 @@
use age_core::format::{FileKey, Stanza}; use age_core::format::{FileKey, Stanza};
use age_core::primitives::{aead_encrypt, aead_decrypt}; use age_core::primitives::{aead_decrypt, aead_encrypt};
use age_core::secrecy::ExposeSecret; use age_core::secrecy::ExposeSecret;
use age_plugin::{ use age_plugin::{
identity::{self, IdentityPluginV1}, identity::{self, IdentityPluginV1},
recipient::{self, RecipientPluginV1}, recipient::{self, RecipientPluginV1},
Callbacks, run_state_machine, run_state_machine, Callbacks,
}; };
use base64::prelude::*;
use bech32::ToBase32; use bech32::ToBase32;
use clap::Parser; use clap::Parser;
use xwing_kem::{XwingPublicKey, XwingSecretKey, XwingCiphertext}; use xwing_kem::{XwingCiphertext, XwingPublicKey, XwingSecretKey};
use base64::prelude::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
@@ -20,7 +20,7 @@ const STANZA_TAG: &str = "xwing";
#[derive(Default)] #[derive(Default)]
struct RecipientPlugin { struct RecipientPlugin {
recipients: Vec<XwingPublicKey> recipients: Vec<XwingPublicKey>,
} }
impl RecipientPluginV1 for RecipientPlugin { impl RecipientPluginV1 for RecipientPlugin {
@@ -32,7 +32,12 @@ impl RecipientPluginV1 for RecipientPlugin {
) -> Result<(), recipient::Error> { ) -> Result<(), recipient::Error> {
let bytes = match bytes.try_into() { let bytes = match bytes.try_into() {
Ok(x) => x, Ok(x) => x,
_ => return Err(recipient::Error::Recipient { index, message: "Invalid recipient".to_owned() }) _ => {
return Err(recipient::Error::Recipient {
index,
message: "Invalid recipient".to_owned(),
})
}
}; };
self.recipients.push(XwingPublicKey::from(bytes)); self.recipients.push(XwingPublicKey::from(bytes));
Ok(()) Ok(())
@@ -42,7 +47,7 @@ impl RecipientPluginV1 for RecipientPlugin {
&mut self, &mut self,
_index: usize, _index: usize,
_plugin_name: &str, _plugin_name: &str,
_bytes: &[u8] _bytes: &[u8],
) -> Result<(), recipient::Error> { ) -> Result<(), recipient::Error> {
unimplemented!() unimplemented!()
} }
@@ -60,7 +65,11 @@ impl RecipientPluginV1 for RecipientPlugin {
.map(|recipient| { .map(|recipient| {
let (ss, ct) = recipient.encapsulate(); let (ss, ct) = recipient.encapsulate();
let wrapped_key = aead_encrypt(&ss.to_bytes(), file_key.expose_secret()); let wrapped_key = aead_encrypt(&ss.to_bytes(), file_key.expose_secret());
Stanza { tag: STANZA_TAG.to_string(), args: vec![BASE64_STANDARD.encode(ct.to_bytes())], body: wrapped_key } Stanza {
tag: STANZA_TAG.to_string(),
args: vec![BASE64_STANDARD.encode(ct.to_bytes())],
body: wrapped_key,
}
}) })
.collect() .collect()
}) })
@@ -70,7 +79,7 @@ impl RecipientPluginV1 for RecipientPlugin {
#[derive(Default)] #[derive(Default)]
struct IdentityPlugin { struct IdentityPlugin {
identities: Vec<XwingSecretKey> identities: Vec<XwingSecretKey>,
} }
impl IdentityPluginV1 for IdentityPlugin { impl IdentityPluginV1 for IdentityPlugin {
@@ -78,11 +87,16 @@ impl IdentityPluginV1 for IdentityPlugin {
&mut self, &mut self,
index: usize, index: usize,
_plugin_name: &str, _plugin_name: &str,
bytes: &[u8] bytes: &[u8],
) -> Result<(), identity::Error> { ) -> Result<(), identity::Error> {
let bytes = match bytes.try_into() { let bytes = match bytes.try_into() {
Ok(x) => x, Ok(x) => x,
_ => return Err(identity::Error::Identity { index, message: "Invalid identity".to_owned() }) _ => {
return Err(identity::Error::Identity {
index,
message: "Invalid identity".to_owned(),
})
}
}; };
self.identities.push(XwingSecretKey::from(bytes)); self.identities.push(XwingSecretKey::from(bytes));
Ok(()) Ok(())
@@ -93,42 +107,51 @@ impl IdentityPluginV1 for IdentityPlugin {
files: Vec<Vec<Stanza>>, files: Vec<Vec<Stanza>>,
mut _callbacks: impl Callbacks<identity::Error>, mut _callbacks: impl Callbacks<identity::Error>,
) -> io::Result<HashMap<usize, Result<FileKey, Vec<identity::Error>>>> { ) -> io::Result<HashMap<usize, Result<FileKey, Vec<identity::Error>>>> {
Ok(files.iter().enumerate().map(|(file_index, stanzas)| { Ok(files
let result: Vec<Result<[u8; 16], identity::Error>> = stanzas.into_iter().enumerate().map(|(stanza_index, stanza)| { .into_iter()
let decryptions: Vec<Option<[u8; 16]>> = self.identities.iter().map(|identity| { .map(|file| try_decrypt_file(&self.identities, file))
let ct = match BASE64_STANDARD.decode(&stanza.args.get(0).unwrap_or(&"".to_string())).unwrap_or_default().try_into() { .enumerate()
Ok(x) => x, .map(|(file_index, file_key)| {
_ => return None file_key.ok_or(vec![identity::Error::Stanza {
}; file_index,
let ss = identity.decapsulate(XwingCiphertext::from(ct)); stanza_index: 1,
message: "Invalid stanzas".to_string(),
match aead_decrypt(&ss.to_bytes(), 16, &stanza.body) { }])
Ok(file_key) => { })
let file_key: [u8; 16] = file_key.try_into().expect("This should never fail"); .enumerate()
Some(file_key) .collect())
},
_ => None
}
}).collect();
let file_key = decryptions.into_iter().filter(|file_key| file_key.is_some()).map(|file_key| file_key.expect("This should never fail")).next();
if let Some(file_key) = file_key {
return Ok(file_key)
}
Err(identity::Error::Stanza { file_index, stanza_index, message: "Invalid stanza".to_owned() })
}).collect();
let file_key = result.iter().filter(|file_key| file_key.is_ok()).map(|file_key| file_key.as_ref().ok().expect("This should never fail")).next();
if let Some(file_key) = file_key {
return (file_index, Ok(FileKey::from(file_key.to_owned())));
}
(file_index, Err(result.into_iter().map(|identity_error| identity_error.err().expect("This should never fail")).collect()))
}).collect())
} }
} }
fn try_decrypt_file(keys: &Vec<XwingSecretKey>, stanzas: Vec<Stanza>) -> Option<FileKey> {
stanzas
.iter()
.map(|stanza| try_decrypt(keys, stanza))
.filter(|file_key| file_key.is_some())
.map(|file_key| file_key.expect("This should never fail"))
.next()
}
fn try_decrypt(keys: &Vec<XwingSecretKey>, stanza: &Stanza) -> Option<FileKey> {
let ct = BASE64_STANDARD
.decode(&stanza.args.get(0).unwrap_or(&"".to_string()))
.ok()?;
let ct: [u8; 1120] = ct.try_into().ok()?;
let ct = XwingCiphertext::from(ct);
for key in keys {
let ss = key.decapsulate(ct);
let file_key = aead_decrypt(&ss.to_bytes(), 16, &stanza.body);
if let Ok(file_key) = file_key {
let file_key: [u8; 16] = file_key.try_into().expect("This should never fail");
return Some(FileKey::from(file_key));
}
}
None
}
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
struct PluginOptions { struct PluginOptions {
#[arg(help = "run the given age plugin state machine", long)] #[arg(help = "run the given age plugin state machine", long)]
@@ -157,11 +180,26 @@ fn main() -> io::Result<()> {
); );
println!( println!(
"# public key: {}", "# public key: {}",
bech32::encode(RECIPIENT_PREFIX, pk.to_vec().to_base32(), bech32::Variant::Bech32).unwrap().to_lowercase().as_str() bech32::encode(
RECIPIENT_PREFIX,
pk.to_vec().to_base32(),
bech32::Variant::Bech32
)
.unwrap()
.to_lowercase()
.as_str()
);
println!(
"{}",
bech32::encode(
IDENTITY_PREFIX,
sk.to_vec().to_base32(),
bech32::Variant::Bech32
)
.unwrap()
.to_ascii_uppercase()
.as_str()
); );
println!("{}", bech32::encode(IDENTITY_PREFIX, sk.to_vec().to_base32(), bech32::Variant::Bech32).unwrap().to_ascii_uppercase().as_str());
Ok(()) Ok(())
} }