diff --git a/src/main.rs b/src/main.rs index ec27f94..bc3a8b3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,15 @@ use age_core::format::{FileKey, Stanza}; -use age_core::primitives::{aead_encrypt, aead_decrypt}; +use age_core::primitives::{aead_decrypt, aead_encrypt}; use age_core::secrecy::ExposeSecret; use age_plugin::{ identity::{self, IdentityPluginV1}, recipient::{self, RecipientPluginV1}, - Callbacks, run_state_machine, + run_state_machine, Callbacks, }; +use base64::prelude::*; use bech32::ToBase32; use clap::Parser; -use xwing_kem::{XwingPublicKey, XwingSecretKey, XwingCiphertext}; -use base64::prelude::*; +use xwing_kem::{XwingCiphertext, XwingPublicKey, XwingSecretKey}; use std::collections::HashMap; use std::io; @@ -20,7 +20,7 @@ const STANZA_TAG: &str = "xwing"; #[derive(Default)] struct RecipientPlugin { - recipients: Vec + recipients: Vec, } impl RecipientPluginV1 for RecipientPlugin { @@ -32,7 +32,12 @@ impl RecipientPluginV1 for RecipientPlugin { ) -> Result<(), recipient::Error> { let bytes = match bytes.try_into() { Ok(x) => x, - _ => return Err(recipient::Error::Recipient { index, message: "Invalid recipient".to_owned() }) + _ => { + return Err(recipient::Error::Recipient { + index, + message: "Invalid recipient".to_owned(), + }) + } }; self.recipients.push(XwingPublicKey::from(bytes)); Ok(()) @@ -42,7 +47,7 @@ impl RecipientPluginV1 for RecipientPlugin { &mut self, _index: usize, _plugin_name: &str, - _bytes: &[u8] + _bytes: &[u8], ) -> Result<(), recipient::Error> { unimplemented!() } @@ -60,7 +65,11 @@ impl RecipientPluginV1 for RecipientPlugin { .map(|recipient| { let (ss, ct) = recipient.encapsulate(); let wrapped_key = aead_encrypt(&ss.to_bytes(), file_key.expose_secret()); - Stanza { tag: STANZA_TAG.to_string(), args: vec![BASE64_STANDARD.encode(ct.to_bytes())], body: wrapped_key } + Stanza { + tag: STANZA_TAG.to_string(), + args: vec![BASE64_STANDARD.encode(ct.to_bytes())], + body: wrapped_key, + } }) .collect() }) @@ -70,7 +79,7 @@ impl RecipientPluginV1 for RecipientPlugin { #[derive(Default)] struct IdentityPlugin { - identities: Vec + identities: Vec, } impl IdentityPluginV1 for IdentityPlugin { @@ -78,11 +87,16 @@ impl IdentityPluginV1 for IdentityPlugin { &mut self, index: usize, _plugin_name: &str, - bytes: &[u8] + bytes: &[u8], ) -> Result<(), identity::Error> { let bytes = match bytes.try_into() { Ok(x) => x, - _ => return Err(identity::Error::Identity { index, message: "Invalid identity".to_owned() }) + _ => { + return Err(identity::Error::Identity { + index, + message: "Invalid identity".to_owned(), + }) + } }; self.identities.push(XwingSecretKey::from(bytes)); Ok(()) @@ -93,42 +107,51 @@ impl IdentityPluginV1 for IdentityPlugin { files: Vec>, mut _callbacks: impl Callbacks, ) -> io::Result>>> { - Ok(files.iter().enumerate().map(|(file_index, stanzas)| { - let result: Vec> = stanzas.into_iter().enumerate().map(|(stanza_index, stanza)| { - let decryptions: Vec> = self.identities.iter().map(|identity| { - let ct = match BASE64_STANDARD.decode(&stanza.args.get(0).unwrap_or(&"".to_string())).unwrap_or_default().try_into() { - Ok(x) => x, - _ => return None - }; - let ss = identity.decapsulate(XwingCiphertext::from(ct)); - - match aead_decrypt(&ss.to_bytes(), 16, &stanza.body) { - Ok(file_key) => { - let file_key: [u8; 16] = file_key.try_into().expect("This should never fail"); - Some(file_key) - }, - _ => None - } - }).collect(); - - let file_key = decryptions.into_iter().filter(|file_key| file_key.is_some()).map(|file_key| file_key.expect("This should never fail")).next(); - if let Some(file_key) = file_key { - return Ok(file_key) - } - - Err(identity::Error::Stanza { file_index, stanza_index, message: "Invalid stanza".to_owned() }) - }).collect(); - - let file_key = result.iter().filter(|file_key| file_key.is_ok()).map(|file_key| file_key.as_ref().ok().expect("This should never fail")).next(); - if let Some(file_key) = file_key { - return (file_index, Ok(FileKey::from(file_key.to_owned()))); - } - - (file_index, Err(result.into_iter().map(|identity_error| identity_error.err().expect("This should never fail")).collect())) - }).collect()) + Ok(files + .into_iter() + .map(|file| try_decrypt_file(&self.identities, file)) + .enumerate() + .map(|(file_index, file_key)| { + file_key.ok_or(vec![identity::Error::Stanza { + file_index, + stanza_index: 1, + message: "Invalid stanzas".to_string(), + }]) + }) + .enumerate() + .collect()) } } +fn try_decrypt_file(keys: &Vec, stanzas: Vec) -> Option { + stanzas + .iter() + .map(|stanza| try_decrypt(keys, stanza)) + .filter(|file_key| file_key.is_some()) + .map(|file_key| file_key.expect("This should never fail")) + .next() +} + +fn try_decrypt(keys: &Vec, stanza: &Stanza) -> Option { + let ct = BASE64_STANDARD + .decode(&stanza.args.get(0).unwrap_or(&"".to_string())) + .ok()?; + let ct: [u8; 1120] = ct.try_into().ok()?; + let ct = XwingCiphertext::from(ct); + + for key in keys { + let ss = key.decapsulate(ct); + let file_key = aead_decrypt(&ss.to_bytes(), 16, &stanza.body); + + if let Ok(file_key) = file_key { + let file_key: [u8; 16] = file_key.try_into().expect("This should never fail"); + return Some(FileKey::from(file_key)); + } + } + + None +} + #[derive(Debug, Parser)] struct PluginOptions { #[arg(help = "run the given age plugin state machine", long)] @@ -157,11 +180,26 @@ fn main() -> io::Result<()> { ); println!( "# public key: {}", - bech32::encode(RECIPIENT_PREFIX, pk.to_vec().to_base32(), bech32::Variant::Bech32).unwrap().to_lowercase().as_str() - + bech32::encode( + RECIPIENT_PREFIX, + pk.to_vec().to_base32(), + bech32::Variant::Bech32 + ) + .unwrap() + .to_lowercase() + .as_str() + ); + println!( + "{}", + bech32::encode( + IDENTITY_PREFIX, + sk.to_vec().to_base32(), + bech32::Variant::Bech32 + ) + .unwrap() + .to_ascii_uppercase() + .as_str() ); - println!("{}", bech32::encode(IDENTITY_PREFIX, sk.to_vec().to_base32(), bech32::Variant::Bech32).unwrap().to_ascii_uppercase().as_str()); - Ok(()) }