use futures::StreamExt; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use rsa::{ pkcs8::{der::zeroize::Zeroizing, EncodePrivateKey, EncodePublicKey, LineEnding}, RsaPrivateKey, RsaPublicKey, }; use serde::{Deserialize, Serialize}; use std::path::Path; use thiserror::Error; use tokio::{fs, io}; pub mod jwt_numeric_date; const RSA_BITS: usize = 4096; const KEYS_DIR: &str = "./keys"; const CUR_PRIV: &str = "cur_priv.pem"; const CUR_PUB: &str = "cur_pub.pem"; const PRE_PRIV: &str = "pre_priv.pem"; const PRE_PUB: &str = "pre_pub.pem"; #[derive(Clone)] pub struct JWTSecretManager { current: RSAKeyPair, previous: Option, } #[derive(Debug, Error)] pub enum JWTSecretManagerError { #[error("Failed to generatie new key pair: {0}")] KeyGeneration(rsa::Error), #[error("Failed to save pem files to file system: {0}")] KeySave(io::Error), #[error("Failed to create the directory to store keys: {0}")] KeysDir(io::Error), #[error("Failed to read existing keys from file system: {0}")] KeysRead(io::Error), #[error("Failed to encode a new JWT: {0}")] EncodeFailed(jsonwebtoken::errors::Error), } #[derive(Debug, PartialEq)] pub enum ValidationOutcome where T: Serialize, for<'de> T: Deserialize<'de>, { Ok(T), Outdated(T, String), Error(jsonwebtoken::errors::Error), Unauthorized, } impl From for ValidationOutcome where T: Serialize, for<'de> T: Deserialize<'de>, { fn from(value: jsonwebtoken::errors::Error) -> Self { ValidationOutcome::Error(value) } } impl JWTSecretManager { pub async fn init() -> Result { // Check if we have any keys let keys_dir = Path::new(KEYS_DIR); if !keys_dir.exists() { fs::create_dir(KEYS_DIR) .await .map_err(JWTSecretManagerError::KeysDir)?; } let current = if !keys_dir.join(CUR_PUB).exists() || !keys_dir.join(CUR_PRIV).exists() { let keys = RSAKeyPair::generate_new() .await .map_err(JWTSecretManagerError::KeyGeneration)?; keys.save_pem_files(keys_dir, CUR_PRIV, CUR_PUB) .await .map_err(JWTSecretManagerError::KeySave)?; keys } else { let public = fs::read_to_string(keys_dir.join(CUR_PUB)) .await .map_err(JWTSecretManagerError::KeysRead)?; let private = fs::read_to_string(keys_dir.join(CUR_PRIV)) .await .map_err(JWTSecretManagerError::KeysRead)?; RSAKeyPair { private: Zeroizing::new(private), public, } }; let previous = if !keys_dir.join(PRE_PUB).exists() || !keys_dir.join(PRE_PRIV).exists() { None } else { let public = fs::read_to_string(keys_dir.join(PRE_PUB)) .await .map_err(JWTSecretManagerError::KeysRead)?; let private = fs::read_to_string(keys_dir.join(PRE_PRIV)) .await .map_err(JWTSecretManagerError::KeysRead)?; Some(RSAKeyPair { private: Zeroizing::new(private), public, }) }; Ok(JWTSecretManager { current, previous }) } pub async fn rotate(&mut self) -> Result<(), JWTSecretManagerError> { self.previous = Some(self.current.clone()); self.current .save_pem_files(KEYS_DIR, PRE_PRIV, PRE_PUB) .await .map_err(JWTSecretManagerError::KeySave)?; self.current = RSAKeyPair::generate_new() .await .map_err(JWTSecretManagerError::KeyGeneration)?; self.current .save_pem_files(KEYS_DIR, CUR_PRIV, CUR_PUB) .await .map_err(JWTSecretManagerError::KeySave)?; Ok(()) } pub async fn decode(&self, token: &str) -> ValidationOutcome where T: Serialize + Clone, for<'de> T: Deserialize<'de>, { if let Ok(claim) = self.current.verify_jwt(token).await { return ValidationOutcome::Ok(claim); } match &self.previous { Some(k) => { if let Ok(claim) = k.verify_jwt::(token).await { let new_token = match self.current.encode(&claim).await { Ok(t) => t, Err(e) => return ValidationOutcome::Error(e), }; return ValidationOutcome::Outdated(claim, new_token); } } None => (), } ValidationOutcome::Unauthorized } pub async fn encode_new(&self, claim: &T) -> Result where T: Serialize + Clone, for<'de> T: Deserialize<'de>, { self.current .encode(&claim) .await .map_err(JWTSecretManagerError::EncodeFailed) } } #[derive(Clone)] pub struct RSAKeyPair { private: Zeroizing, public: String, } impl RSAKeyPair { async fn generate_new() -> Result { let mut rng = rand::thread_rng(); let private_key = RsaPrivateKey::new(&mut rng, RSA_BITS).expect("failed to generate a key"); let public_key = RsaPublicKey::from(&private_key); let private = private_key.to_pkcs8_pem(LineEnding::LF)?; let public = public_key.to_public_key_pem(LineEnding::LF).unwrap(); // This is infaillible? Ok(RSAKeyPair { private, public }) } async fn save_pem_files( &self, path: impl AsRef, private_name: &str, public_name: &str, ) -> io::Result<()> { // There's probably a better looking way to do this let futures = vec![ fs::write(path.as_ref().join(public_name), &self.public), fs::write(path.as_ref().join(private_name), &self.private), ]; let stream = futures::stream::iter(futures).buffer_unordered(2); let results = stream.collect::>().await; for res in results { res?; } Ok(()) } async fn verify_jwt(&self, token: &str) -> Result where for<'de> T: Deserialize<'de>, { match decode::( token, &DecodingKey::from_rsa_pem(self.public.as_bytes())?, &Validation::new(Algorithm::RS256), ) { Ok(t) => Ok(t.claims), Err(e) => Err(e), } } async fn encode(&self, claim: &T) -> Result where T: Serialize, { let secret = &EncodingKey::from_rsa_pem(self.private.as_bytes())?; encode(&Header::new(Algorithm::RS256), claim, secret) } } #[cfg(test)] mod test { use time::Duration; use cookie::time::OffsetDateTime; use rand::RngCore; use serde::{Deserialize, Serialize}; use crate::crypto::RSAKeyPair; use crate::crypto::ValidationOutcome; use super::jwt_numeric_date; use super::JWTSecretManager; #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] struct Claim { data: Vec, #[serde(with = "jwt_numeric_date")] exp: OffsetDateTime, } fn create_random_claim_data() -> Vec { let mut bytes = [0u8; 16]; rand::thread_rng().fill_bytes(&mut bytes); bytes.into() } #[tokio::test] async fn generate_encode_and_decode_jwt_using_rsa() { let rsa_keys = RSAKeyPair::generate_new() .await .expect("It should be possible to generate a new RSA key pair"); let claim = Claim { data: create_random_claim_data(), exp: OffsetDateTime::now_utc() + Duration::days(1), }; let token = rsa_keys .encode(&claim) .await .expect("It should be possible to encode a claim"); let decoded_claim = rsa_keys .verify_jwt::(&token) .await .expect("It should be possible to verify a token"); assert_eq!(decoded_claim.data, claim.data); } #[tokio::test] async fn jwt_secret_manager_can_encode_and_decode() { let jwt_secret_manager = JWTSecretManager::init() .await .expect("JWTSecretManager should be able to init"); let claim = Claim { data: create_random_claim_data(), exp: OffsetDateTime::now_utc() + Duration::days(1), }; let token = jwt_secret_manager .encode_new(&claim) .await .expect("It should be possible to encode a claim into a JWT"); match jwt_secret_manager.decode::(&token).await { ValidationOutcome::Ok(c) => assert_eq!(c.data, claim.data), _ => panic!(), } } #[tokio::test] async fn jwt_secret_manager_can_rotate() { let mut jwt_secret_manager = JWTSecretManager::init() .await .expect("JWTSecretManager should be able to init"); jwt_secret_manager .rotate() .await .expect("JWTSecretManager should be able to rotate rsa keys"); } }