You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

323 lines
9.3 KiB

use futures::StreamExt;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use rsa::{
pkcs8::{der::zeroize::Zeroizing, EncodePrivateKey, EncodePublicKey, LineEnding},
RsaPrivateKey, RsaPublicKey,
};
use serde::{Deserialize, Serialize};
use std::path::Path;
use thiserror::Error;
use tokio::{fs, io};
pub mod jwt_numeric_date;
const RSA_BITS: usize = 4096;
const KEYS_DIR: &str = "./keys";
const CUR_PRIV: &str = "cur_priv.pem";
const CUR_PUB: &str = "cur_pub.pem";
const PRE_PRIV: &str = "pre_priv.pem";
const PRE_PUB: &str = "pre_pub.pem";
#[derive(Clone)]
pub struct JWTSecretManager {
current: RSAKeyPair,
previous: Option<RSAKeyPair>,
}
#[derive(Debug, Error)]
pub enum JWTSecretManagerError {
#[error("Failed to generatie new key pair: {0}")]
KeyGeneration(rsa::Error),
#[error("Failed to save pem files to file system: {0}")]
KeySave(io::Error),
#[error("Failed to create the directory to store keys: {0}")]
KeysDir(io::Error),
#[error("Failed to read existing keys from file system: {0}")]
KeysRead(io::Error),
#[error("Failed to encode a new JWT: {0}")]
EncodeFailed(jsonwebtoken::errors::Error),
}
#[derive(Debug, PartialEq)]
pub enum ValidationOutcome<T>
where
T: Serialize,
for<'de> T: Deserialize<'de>,
{
Ok(T),
Outdated(T, String),
Error(jsonwebtoken::errors::Error),
Unauthorized,
}
impl<T> From<jsonwebtoken::errors::Error> for ValidationOutcome<T>
where
T: Serialize,
for<'de> T: Deserialize<'de>,
{
fn from(value: jsonwebtoken::errors::Error) -> Self {
ValidationOutcome::Error(value)
}
}
impl JWTSecretManager {
pub async fn init() -> Result<Self, JWTSecretManagerError> {
// Check if we have any keys
let keys_dir = Path::new(KEYS_DIR);
if !keys_dir.exists() {
fs::create_dir(KEYS_DIR)
.await
.map_err(JWTSecretManagerError::KeysDir)?;
}
let current = if !keys_dir.join(CUR_PUB).exists() || !keys_dir.join(CUR_PRIV).exists() {
let keys = RSAKeyPair::generate_new()
.await
.map_err(JWTSecretManagerError::KeyGeneration)?;
keys.save_pem_files(keys_dir, CUR_PRIV, CUR_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
keys
} else {
let public = fs::read_to_string(keys_dir.join(CUR_PUB))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
let private = fs::read_to_string(keys_dir.join(CUR_PRIV))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
RSAKeyPair {
private: Zeroizing::new(private),
public,
}
};
let previous = if !keys_dir.join(PRE_PUB).exists() || !keys_dir.join(PRE_PRIV).exists() {
None
} else {
let public = fs::read_to_string(keys_dir.join(PRE_PUB))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
let private = fs::read_to_string(keys_dir.join(PRE_PRIV))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
Some(RSAKeyPair {
private: Zeroizing::new(private),
public,
})
};
Ok(JWTSecretManager { current, previous })
}
pub async fn rotate(&mut self) -> Result<(), JWTSecretManagerError> {
self.previous = Some(self.current.clone());
self.current
.save_pem_files(KEYS_DIR, PRE_PRIV, PRE_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
self.current = RSAKeyPair::generate_new()
.await
.map_err(JWTSecretManagerError::KeyGeneration)?;
self.current
.save_pem_files(KEYS_DIR, CUR_PRIV, CUR_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
Ok(())
}
pub async fn decode<T>(&self, token: &str) -> ValidationOutcome<T>
where
T: Serialize + Clone,
for<'de> T: Deserialize<'de>,
{
if let Ok(claim) = self.current.verify_jwt(token).await {
return ValidationOutcome::Ok(claim);
}
match &self.previous {
Some(k) => {
if let Ok(claim) = k.verify_jwt::<T>(token).await {
let new_token = match self.current.encode(&claim).await {
Ok(t) => t,
Err(e) => return ValidationOutcome::Error(e),
};
return ValidationOutcome::Outdated(claim, new_token);
}
}
None => (),
}
ValidationOutcome::Unauthorized
}
pub async fn encode_new<T>(&self, claim: &T) -> Result<String, JWTSecretManagerError>
where
T: Serialize + Clone,
for<'de> T: Deserialize<'de>,
{
self.current
.encode(&claim)
.await
.map_err(JWTSecretManagerError::EncodeFailed)
}
}
#[derive(Clone)]
pub struct RSAKeyPair {
private: Zeroizing<String>,
public: String,
}
impl RSAKeyPair {
async fn generate_new() -> Result<RSAKeyPair, rsa::Error> {
let mut rng = rand::thread_rng();
let private_key = RsaPrivateKey::new(&mut rng, RSA_BITS).expect("failed to generate a key");
let public_key = RsaPublicKey::from(&private_key);
let private = private_key.to_pkcs8_pem(LineEnding::LF)?;
let public = public_key.to_public_key_pem(LineEnding::LF).unwrap(); // This is infaillible?
Ok(RSAKeyPair { private, public })
}
async fn save_pem_files(
&self,
path: impl AsRef<Path>,
private_name: &str,
public_name: &str,
) -> io::Result<()> {
// There's probably a better looking way to do this
let futures = vec![
fs::write(path.as_ref().join(public_name), &self.public),
fs::write(path.as_ref().join(private_name), &self.private),
];
let stream = futures::stream::iter(futures).buffer_unordered(2);
let results = stream.collect::<Vec<_>>().await;
for res in results {
res?;
}
Ok(())
}
async fn verify_jwt<T>(&self, token: &str) -> Result<T, jsonwebtoken::errors::Error>
where
for<'de> T: Deserialize<'de>,
{
match decode::<T>(
token,
&DecodingKey::from_rsa_pem(self.public.as_bytes())?,
&Validation::new(Algorithm::RS256),
) {
Ok(t) => Ok(t.claims),
Err(e) => Err(e),
}
}
async fn encode<T>(&self, claim: &T) -> Result<String, jsonwebtoken::errors::Error>
where
T: Serialize,
{
let secret = &EncodingKey::from_rsa_pem(self.private.as_bytes())?;
encode(&Header::new(Algorithm::RS256), claim, secret)
}
}
#[cfg(test)]
mod test {
use time::Duration;
use cookie::time::OffsetDateTime;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use crate::crypto::RSAKeyPair;
use crate::crypto::ValidationOutcome;
use super::jwt_numeric_date;
use super::JWTSecretManager;
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
struct Claim {
data: Vec<u8>,
#[serde(with = "jwt_numeric_date")]
exp: OffsetDateTime,
}
fn create_random_claim_data() -> Vec<u8> {
let mut bytes = [0u8; 16];
rand::thread_rng().fill_bytes(&mut bytes);
bytes.into()
}
#[tokio::test]
async fn generate_encode_and_decode_jwt_using_rsa() {
let rsa_keys = RSAKeyPair::generate_new()
.await
.expect("It should be possible to generate a new RSA key pair");
let claim = Claim {
data: create_random_claim_data(),
exp: OffsetDateTime::now_utc() + Duration::days(1),
};
let token = rsa_keys
.encode(&claim)
.await
.expect("It should be possible to encode a claim");
let decoded_claim = rsa_keys
.verify_jwt::<Claim>(&token)
.await
.expect("It should be possible to verify a token");
assert_eq!(decoded_claim.data, claim.data);
}
#[tokio::test]
async fn jwt_secret_manager_can_encode_and_decode() {
let jwt_secret_manager = JWTSecretManager::init()
.await
.expect("JWTSecretManager should be able to init");
let claim = Claim {
data: create_random_claim_data(),
exp: OffsetDateTime::now_utc() + Duration::days(1),
};
let token = jwt_secret_manager
.encode_new(&claim)
.await
.expect("It should be possible to encode a claim into a JWT");
match jwt_secret_manager.decode::<Claim>(&token).await {
ValidationOutcome::Ok(c) => assert_eq!(c.data, claim.data),
_ => panic!(),
}
}
#[tokio::test]
async fn jwt_secret_manager_can_rotate() {
let mut jwt_secret_manager = JWTSecretManager::init()
.await
.expect("JWTSecretManager should be able to init");
jwt_secret_manager
.rotate()
.await
.expect("JWTSecretManager should be able to rotate rsa keys");
}
}