You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
323 lines
9.3 KiB
323 lines
9.3 KiB
use futures::StreamExt;
|
|
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
|
use rsa::{
|
|
pkcs8::{der::zeroize::Zeroizing, EncodePrivateKey, EncodePublicKey, LineEnding},
|
|
RsaPrivateKey, RsaPublicKey,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::path::Path;
|
|
use thiserror::Error;
|
|
use tokio::{fs, io};
|
|
|
|
pub mod jwt_numeric_date;
|
|
|
|
const RSA_BITS: usize = 4096;
|
|
const KEYS_DIR: &str = "./keys";
|
|
const CUR_PRIV: &str = "cur_priv.pem";
|
|
const CUR_PUB: &str = "cur_pub.pem";
|
|
const PRE_PRIV: &str = "pre_priv.pem";
|
|
const PRE_PUB: &str = "pre_pub.pem";
|
|
|
|
#[derive(Clone)]
|
|
pub struct JWTSecretManager {
|
|
current: RSAKeyPair,
|
|
previous: Option<RSAKeyPair>,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum JWTSecretManagerError {
|
|
#[error("Failed to generatie new key pair: {0}")]
|
|
KeyGeneration(rsa::Error),
|
|
#[error("Failed to save pem files to file system: {0}")]
|
|
KeySave(io::Error),
|
|
#[error("Failed to create the directory to store keys: {0}")]
|
|
KeysDir(io::Error),
|
|
#[error("Failed to read existing keys from file system: {0}")]
|
|
KeysRead(io::Error),
|
|
#[error("Failed to encode a new JWT: {0}")]
|
|
EncodeFailed(jsonwebtoken::errors::Error),
|
|
}
|
|
|
|
#[derive(Debug, PartialEq)]
|
|
pub enum ValidationOutcome<T>
|
|
where
|
|
T: Serialize,
|
|
for<'de> T: Deserialize<'de>,
|
|
{
|
|
Ok(T),
|
|
Outdated(T, String),
|
|
Error(jsonwebtoken::errors::Error),
|
|
Unauthorized,
|
|
}
|
|
|
|
impl<T> From<jsonwebtoken::errors::Error> for ValidationOutcome<T>
|
|
where
|
|
T: Serialize,
|
|
for<'de> T: Deserialize<'de>,
|
|
{
|
|
fn from(value: jsonwebtoken::errors::Error) -> Self {
|
|
ValidationOutcome::Error(value)
|
|
}
|
|
}
|
|
|
|
impl JWTSecretManager {
|
|
pub async fn init() -> Result<Self, JWTSecretManagerError> {
|
|
// Check if we have any keys
|
|
let keys_dir = Path::new(KEYS_DIR);
|
|
if !keys_dir.exists() {
|
|
fs::create_dir(KEYS_DIR)
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeysDir)?;
|
|
}
|
|
|
|
let current = if !keys_dir.join(CUR_PUB).exists() || !keys_dir.join(CUR_PRIV).exists() {
|
|
let keys = RSAKeyPair::generate_new()
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeyGeneration)?;
|
|
|
|
keys.save_pem_files(keys_dir, CUR_PRIV, CUR_PUB)
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeySave)?;
|
|
|
|
keys
|
|
} else {
|
|
let public = fs::read_to_string(keys_dir.join(CUR_PUB))
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeysRead)?;
|
|
let private = fs::read_to_string(keys_dir.join(CUR_PRIV))
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeysRead)?;
|
|
|
|
RSAKeyPair {
|
|
private: Zeroizing::new(private),
|
|
public,
|
|
}
|
|
};
|
|
|
|
let previous = if !keys_dir.join(PRE_PUB).exists() || !keys_dir.join(PRE_PRIV).exists() {
|
|
None
|
|
} else {
|
|
let public = fs::read_to_string(keys_dir.join(PRE_PUB))
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeysRead)?;
|
|
let private = fs::read_to_string(keys_dir.join(PRE_PRIV))
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeysRead)?;
|
|
|
|
Some(RSAKeyPair {
|
|
private: Zeroizing::new(private),
|
|
public,
|
|
})
|
|
};
|
|
|
|
Ok(JWTSecretManager { current, previous })
|
|
}
|
|
|
|
pub async fn rotate(&mut self) -> Result<(), JWTSecretManagerError> {
|
|
self.previous = Some(self.current.clone());
|
|
|
|
self.current
|
|
.save_pem_files(KEYS_DIR, PRE_PRIV, PRE_PUB)
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeySave)?;
|
|
|
|
self.current = RSAKeyPair::generate_new()
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeyGeneration)?;
|
|
|
|
self.current
|
|
.save_pem_files(KEYS_DIR, CUR_PRIV, CUR_PUB)
|
|
.await
|
|
.map_err(JWTSecretManagerError::KeySave)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn decode<T>(&self, token: &str) -> ValidationOutcome<T>
|
|
where
|
|
T: Serialize + Clone,
|
|
for<'de> T: Deserialize<'de>,
|
|
{
|
|
if let Ok(claim) = self.current.verify_jwt(token).await {
|
|
return ValidationOutcome::Ok(claim);
|
|
}
|
|
|
|
match &self.previous {
|
|
Some(k) => {
|
|
if let Ok(claim) = k.verify_jwt::<T>(token).await {
|
|
let new_token = match self.current.encode(&claim).await {
|
|
Ok(t) => t,
|
|
Err(e) => return ValidationOutcome::Error(e),
|
|
};
|
|
return ValidationOutcome::Outdated(claim, new_token);
|
|
}
|
|
}
|
|
None => (),
|
|
}
|
|
|
|
ValidationOutcome::Unauthorized
|
|
}
|
|
|
|
pub async fn encode_new<T>(&self, claim: &T) -> Result<String, JWTSecretManagerError>
|
|
where
|
|
T: Serialize + Clone,
|
|
for<'de> T: Deserialize<'de>,
|
|
{
|
|
self.current
|
|
.encode(&claim)
|
|
.await
|
|
.map_err(JWTSecretManagerError::EncodeFailed)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct RSAKeyPair {
|
|
private: Zeroizing<String>,
|
|
public: String,
|
|
}
|
|
|
|
impl RSAKeyPair {
|
|
async fn generate_new() -> Result<RSAKeyPair, rsa::Error> {
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let private_key = RsaPrivateKey::new(&mut rng, RSA_BITS).expect("failed to generate a key");
|
|
let public_key = RsaPublicKey::from(&private_key);
|
|
|
|
let private = private_key.to_pkcs8_pem(LineEnding::LF)?;
|
|
let public = public_key.to_public_key_pem(LineEnding::LF).unwrap(); // This is infaillible?
|
|
|
|
Ok(RSAKeyPair { private, public })
|
|
}
|
|
|
|
async fn save_pem_files(
|
|
&self,
|
|
path: impl AsRef<Path>,
|
|
private_name: &str,
|
|
public_name: &str,
|
|
) -> io::Result<()> {
|
|
// There's probably a better looking way to do this
|
|
let futures = vec![
|
|
fs::write(path.as_ref().join(public_name), &self.public),
|
|
fs::write(path.as_ref().join(private_name), &self.private),
|
|
];
|
|
|
|
let stream = futures::stream::iter(futures).buffer_unordered(2);
|
|
let results = stream.collect::<Vec<_>>().await;
|
|
|
|
for res in results {
|
|
res?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn verify_jwt<T>(&self, token: &str) -> Result<T, jsonwebtoken::errors::Error>
|
|
where
|
|
for<'de> T: Deserialize<'de>,
|
|
{
|
|
match decode::<T>(
|
|
token,
|
|
&DecodingKey::from_rsa_pem(self.public.as_bytes())?,
|
|
&Validation::new(Algorithm::RS256),
|
|
) {
|
|
Ok(t) => Ok(t.claims),
|
|
Err(e) => Err(e),
|
|
}
|
|
}
|
|
|
|
async fn encode<T>(&self, claim: &T) -> Result<String, jsonwebtoken::errors::Error>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
let secret = &EncodingKey::from_rsa_pem(self.private.as_bytes())?;
|
|
encode(&Header::new(Algorithm::RS256), claim, secret)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
|
|
use time::Duration;
|
|
|
|
use cookie::time::OffsetDateTime;
|
|
use rand::RngCore;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::crypto::RSAKeyPair;
|
|
use crate::crypto::ValidationOutcome;
|
|
|
|
use super::jwt_numeric_date;
|
|
use super::JWTSecretManager;
|
|
|
|
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
|
struct Claim {
|
|
data: Vec<u8>,
|
|
#[serde(with = "jwt_numeric_date")]
|
|
exp: OffsetDateTime,
|
|
}
|
|
|
|
fn create_random_claim_data() -> Vec<u8> {
|
|
let mut bytes = [0u8; 16];
|
|
rand::thread_rng().fill_bytes(&mut bytes);
|
|
bytes.into()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn generate_encode_and_decode_jwt_using_rsa() {
|
|
let rsa_keys = RSAKeyPair::generate_new()
|
|
.await
|
|
.expect("It should be possible to generate a new RSA key pair");
|
|
|
|
let claim = Claim {
|
|
data: create_random_claim_data(),
|
|
exp: OffsetDateTime::now_utc() + Duration::days(1),
|
|
};
|
|
|
|
let token = rsa_keys
|
|
.encode(&claim)
|
|
.await
|
|
.expect("It should be possible to encode a claim");
|
|
|
|
let decoded_claim = rsa_keys
|
|
.verify_jwt::<Claim>(&token)
|
|
.await
|
|
.expect("It should be possible to verify a token");
|
|
|
|
assert_eq!(decoded_claim.data, claim.data);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn jwt_secret_manager_can_encode_and_decode() {
|
|
let jwt_secret_manager = JWTSecretManager::init()
|
|
.await
|
|
.expect("JWTSecretManager should be able to init");
|
|
|
|
let claim = Claim {
|
|
data: create_random_claim_data(),
|
|
exp: OffsetDateTime::now_utc() + Duration::days(1),
|
|
};
|
|
|
|
let token = jwt_secret_manager
|
|
.encode_new(&claim)
|
|
.await
|
|
.expect("It should be possible to encode a claim into a JWT");
|
|
|
|
match jwt_secret_manager.decode::<Claim>(&token).await {
|
|
ValidationOutcome::Ok(c) => assert_eq!(c.data, claim.data),
|
|
_ => panic!(),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn jwt_secret_manager_can_rotate() {
|
|
let mut jwt_secret_manager = JWTSecretManager::init()
|
|
.await
|
|
.expect("JWTSecretManager should be able to init");
|
|
|
|
jwt_secret_manager
|
|
.rotate()
|
|
.await
|
|
.expect("JWTSecretManager should be able to rotate rsa keys");
|
|
}
|
|
}
|