Users feature #39

Open
alban wants to merge 8 commits from users into master

1026
client/Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -9,7 +9,7 @@ edition = "2021"
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
perseus = { version = "0.4.1", features = ["hydrate"] }
perseus = { version = "0.4.2", features = ["hydrate"] }
sycamore = { version = "^0.8.1", features = [
"ssr",
"serde",
@ -17,6 +17,7 @@ sycamore = { version = "^0.8.1", features = [
"hydrate",
] }
lazy_static = "1"
wasm-cookies = "0.2"
[target.'cfg(engine)'.dev-dependencies]
fantoccini = "^0.19.3"

@ -0,0 +1,184 @@
use crate::api::{
types::{
company::Company,
paginated_response::PaginatedResponse,
transaction::UserTransaction,
user::{FollowCompany, UserProfile},
},
FastInsidersApi,
};
#[cfg(client)]
use super::user::set_token_cookie;
#[cfg(client)]
use serde::Serialize;
#[cfg(engine)]
type Response = reqwest::Response;
#[cfg(client)]
type Response = reqwasm::http::Response;
#[cfg(client)]
fn update_token(resp: &Response) {
if let Some(token) = resp.headers().get("x-new-token") {
set_token_cookie(&token)
}
}
#[cfg(client)]
async fn get_auth_route(route: &str) -> Result<Response, ()> {
let token = wasm_cookies::get_raw("token").unwrap_or_default();
let resp = reqwasm::http::Request::get(route)
.header("Authorization", &format!("Bearer {}", token))
.send()
.await
.map_err(|_| ())?;
update_token(&resp);
Ok(resp)
}
#[cfg(client)]
async fn post_auth_route<T>(route: &str, body: T) -> Result<Response, ()>
where
T: Serialize,
{
let token = wasm_cookies::get_raw("token").unwrap_or_default();
let resp = reqwasm::http::Request::post(route)
.header("Authorization", &format!("Bearer {}", token))
.header("Content-type", "application/json")
.body(serde_json::to_string(&body).unwrap())
.send()
.await
.map_err(|_| ())?;
update_token(&resp);
Ok(resp)
}
impl FastInsidersApi {
/// This is only a route to verify that we are authenticated
pub async fn is_authenticated(&self) -> Result<Response, ()> {
let route = &format!("{}/auth/is_authenticated", self.url);
#[cfg(client)]
let resp = get_auth_route(route).await?;
#[cfg(engine)]
let resp = reqwest::Client::new()
.get(route)
.send()
.await
.map_err(|_| ())?;
Ok(resp)
}
pub async fn get_profile(&self) -> Result<UserProfile, ()> {
let route = &format!("{}/auth/profile", self.url);
#[cfg(client)]
let res = {
let resp = get_auth_route(route).await?;
resp.json::<UserProfile>().await.map_err(|_| ())?
};
#[cfg(engine)]
let res = reqwest::get(route)
.await
.map_err(|_| ())?
.json::<UserProfile>()
.await
.map_err(|_| ())?;
Ok(res)
}
pub async fn follow_company(&self, company_id: i32) -> Result<(), ()> {
let route = &format!("{}/auth/follow_company", self.url);
#[cfg(client)]
let res = {
let body = FollowCompany { company_id };
let resp = post_auth_route(route, body).await?;
};
#[cfg(engine)]
return Err(());
Ok(())
}
pub async fn unfollow_company(&self, company_id: i32) -> Result<(), ()> {
let route = &format!("{}/auth/unfollow_company", self.url);
#[cfg(client)]
let res = {
let body = FollowCompany { company_id };
let resp = post_auth_route(route, body).await?;
};
#[cfg(engine)]
return Err(());
Ok(())
}
pub async fn get_followed_companies(
&self,
page: i64,
size: i64,
) -> Result<PaginatedResponse<Company>, ()> {
let route = &format!(
"{}/auth/get_followed_companies?page={}&size={}",
self.url, page, size
);
#[cfg(client)]
let res = {
let resp = get_auth_route(route).await?;
resp.json::<PaginatedResponse<Company>>()
.await
.map_err(|_| ())?
};
#[cfg(engine)]
let res = reqwest::get(route)
.await
.map_err(|_| ())?
.json::<PaginatedResponse<Company>>()
.await
.map_err(|_| ())?;
Ok(res)
}
pub async fn get_user_transactions(
&self,
page: i64,
size: i64,
) -> Result<PaginatedResponse<UserTransaction>, ()> {
let route = &format!(
"{}/auth/user_transactions?page={}&size={}",
self.url, page, size
);
#[cfg(client)]
let res = {
let resp = get_auth_route(route).await?;
resp.json::<PaginatedResponse<UserTransaction>>()
.await
.map_err(|_| ())?
};
#[cfg(engine)]
let res = reqwest::get(route)
.await
.map_err(|_| ())?
.json::<PaginatedResponse<UserTransaction>>()
.await
.map_err(|_| ())?;
Ok(res)
}
}

@ -1,2 +1,4 @@
pub mod authenticated;
pub mod company;
pub mod transaction;
pub mod user;

@ -24,13 +24,16 @@ impl FastInsidersApi {
);
#[cfg(client)]
let res = reqwasm::http::Request::get(route)
.send()
.await
.map_err(|_| ())?
.json::<PaginatedResponse<TransactionCompany>>()
.await
.map_err(|_| ())?;
let res = {
use reqwasm::http::RequestCredentials;
reqwasm::http::Request::get(route)
.send()
.await
.map_err(|_| ())?
.json::<PaginatedResponse<TransactionCompany>>()
.await
.map_err(|_| ())?
};
#[cfg(engine)]
let res = reqwest::get(route)

@ -0,0 +1,163 @@
use std::{error::Error, fmt::Display, time::Duration};
use serde::{Deserialize, Serialize};
use crate::api::FastInsidersApi;
#[cfg(client)]
#[derive(Serialize)]
pub struct UserLoginBody {
pub name: String,
pub password: String,
}
#[cfg(client)]
#[derive(Serialize)]
pub struct UserRegisterBody {
pub name: String,
pub email: String,
pub password: String,
}
#[cfg(client)]
type Response = reqwasm::http::Response;
#[cfg(client)]
#[derive(Deserialize)]
struct LoginResponse {
pub token: String,
}
#[cfg(client)]
#[derive(Debug, Deserialize)]
pub enum LoginError {
Unknown,
InvalidCredentials,
InternalServer(String),
Server(String),
UserNotFound(String),
}
#[cfg(client)]
#[derive(Deserialize)]
struct ErrorBody {
error: String,
}
#[cfg(client)]
impl Display for LoginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unknown => {
write!(f, "There was an unknown error while trying to log you in.").to_owned()
}
Self::InternalServer(e) => write!(f, "Internal server error: {}", e),
Self::Server(e) => write!(f, "There was an error completing the request {}", e),
Self::InvalidCredentials => write!(f, "Invalid username or password"),
Self::UserNotFound(e) => write!(f, "{e}"),
}
}
}
#[cfg(client)]
impl Error for LoginError {}
#[cfg(client)]
#[derive(Debug, Deserialize)]
pub enum RegisterError {
Unknown,
InternalServer(String),
Server(String),
Conflict(String),
}
#[cfg(client)]
impl Display for RegisterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unknown => {
write!(f, "There was an unknown error while trying to log you in.").to_owned()
}
Self::InternalServer(e) => write!(f, "Internal server error: {}", e),
Self::Server(e) => write!(f, "There was an error completing the request {}", e),
Self::Conflict(e) => write!(f, "{e}"),
}
}
}
#[cfg(client)]
impl Error for RegisterError {}
#[cfg(client)]
pub fn set_token_cookie(token: &str) {
use wasm_cookies::CookieOptions;
let cookie_options = CookieOptions::default()
.secure()
.with_same_site(wasm_cookies::SameSite::Strict)
.expires_after(Duration::from_secs(60 * 60 * 24 * 5)); // 5 days
wasm_cookies::set("token", token, &cookie_options);
}
#[cfg(client)]
impl FastInsidersApi {
pub async fn login(&self, body: &UserLoginBody) -> Result<(), LoginError> {
let route = &format!("{}/user/login", self.url);
let resp = reqwasm::http::Request::post(route)
.header("Content-type", "application/json")
.body(serde_json::to_string(&body).unwrap())
.send()
.await
.map_err(|e| LoginError::Server(e.to_string()))?;
if resp.status() == 200 {
if let Ok(data) = resp.json::<LoginResponse>().await {
set_token_cookie(&data.token);
return Ok(());
} else {
panic!();
}
}
match resp.status() {
401 => return Err(LoginError::InvalidCredentials),
500 => {
let error = resp.json::<ErrorBody>().await.unwrap().error;
return Err(LoginError::InternalServer(error));
}
404 => {
let error = resp.json::<ErrorBody>().await.unwrap().error;
return Err(LoginError::UserNotFound(error));
}
_ => return Err(LoginError::Unknown),
}
}
pub async fn register(&self, body: &UserRegisterBody) -> Result<(), RegisterError> {
let route = &format!("{}/user/register", self.url);
use wasm_cookies::CookieOptions;
let resp = reqwasm::http::Request::post(route)
.header("Content-type", "application/json")
.body(serde_json::to_string(&body).unwrap())
.send()
.await
.map_err(|e| RegisterError::Server(e.to_string()))?;
if resp.status() == 200 {
return Ok(());
}
if resp.status() == 500 {
let error = resp.json::<ErrorBody>().await.unwrap().error;
return Err(RegisterError::InternalServer(error));
}
if resp.status() == 409 {
let error = resp.json::<ErrorBody>().await.unwrap().error;
return Err(RegisterError::Conflict(error));
}
Err(RegisterError::Unknown)
}
}

@ -1,3 +1,4 @@
pub mod company;
pub mod paginated_response;
pub mod transaction;
pub mod user;

@ -1,9 +1,10 @@
use perseus::{reactor::Reactor, web_log, prelude::navigate};
use serde::{Deserialize, Serialize};
use sycamore::prelude::*;
use sycamore::{prelude::*, futures::spawn_local_scoped, rt::Event};
use crate::components::base_table::TableContent;
use crate::{components::base_table::TableContent, global_state::AppStateRx};
use super::transaction::{TransactionCompany, TransactionsAggregated, LatestTransaction, MajorTransactions};
use super::{transaction::{TransactionCompany, TransactionsAggregated, LatestTransaction, MajorTransactions, UserTransaction}, company::Company};
pub trait IntoTableData<G>
where
@ -12,7 +13,7 @@ where
fn into_table_data(self, cx: Scope) -> TableContent<G>;
}
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Deserialize)]
pub struct PaginatedResponse<M> {
pub count: i64,
pub num_pages: i64,
@ -190,3 +191,126 @@ where
}
}
}
impl<G> IntoTableData<G> for PaginatedResponse<Company>
where
G: GenericNode + perseus::prelude::Html,
{
fn into_table_data(self, cx: Scope) -> TableContent<G> {
let headers_view = vec![
view! {cx, "Company" },
view! {cx, "Unfollow" },
];
let global_state = Reactor::<G>::from_cx(cx).get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
/// This function returns a function that knows wich company id to remove
let unfollow_company = move |company_id: i32| {
#[cfg(client)]
return {
move |e: Event| {
spawn_local_scoped(cx, async move {
match api_scope_ref.unfollow_company(company_id).await {
Ok(()) => {
// TODO find a way to only ask for the table to refresh instead of
// the complete page
navigate("/profile")
}
Err(e) => (),
};
})
}
};
#[cfg(engine)]
return {
move |_| {}
};
};
let data_view: Vec<Vec<View<G>>> = self
.list
.into_iter()
.map(|t| {
let mut res = vec![];
res.push(view! {cx,
a (href=format!("transactions/{}", t.slug),
class="text-indigo-800 dark:text-indigo-300 hover:text-indigo-500 hover:underline dark:hover:text-indigo-600",
) {
(t.name.to_owned())
}
});
res.push(view!{cx,
svg (class="m-auto cursor-pointer", on:click=unfollow_company(t.id), xmlns="http://www.w3.org/2000/svg", width="24", height="24", viewBox="0 0 24 24", fill="none", stroke="currentColor", stroke-width="2", stroke-linecap="round", stroke-linejoin="round") {
circle (cx="12", cy="12", r="10") {}
line (x1="8", y1="12", x2="16", y2="12") {}
}
});
res
})
.collect();
TableContent {
headers_view,
data_view,
}
}
}
impl<G> IntoTableData<G> for PaginatedResponse<UserTransaction>
where
G: GenericNode,
{
fn into_table_data(self, cx: Scope) -> TableContent<G> {
let headers_view = vec![
view! {cx, "Company" },
view! {cx, "Date published" },
view! {cx, "Date executed" },
view! {cx, "Person" },
view! {cx, "Nature" },
view! {cx, "ISIN" },
view! {cx, "Instrument" },
view! {cx, "Exchange" },
view! {cx, "Volume" },
view! {cx, "Unit price" },
view! {cx, "Total" },
];
let data_view: Vec<Vec<View<G>>> = self
.list
.into_iter()
.map(|t| {
let mut res = vec![];
res.push(view! {cx,
a (href=format!("transactions/{}", t.company_slug),
class="text-indigo-800 dark:text-indigo-300 hover:text-indigo-500 hover:underline dark:hover:text-indigo-600",
) {
(t.company_name.to_owned())
}
});
res.push(view! {cx, (t.date_published.to_owned()) });
res.push(view! {cx, (t.date_executed.to_owned()) });
res.push(view! {cx, (t.person.to_owned()) });
res.push(view! {cx, (t.nature.to_owned()) });
res.push(view! {cx, (t.isin.to_owned().unwrap_or_else(|| "-".to_string())) });
res.push(view! {cx, (t.instrument.to_owned()) });
res.push(view! {cx, (t.exchange.to_owned()) });
res.push(view! {cx, (t.volume.to_owned()) });
res.push(view! {cx, (t.unit_price.to_owned()) });
res.push(view! {cx, ((t.volume as f32 * t.unit_price).to_string()) });
res
})
.collect();
TableContent {
headers_view,
data_view,
}
}
}

@ -65,3 +65,19 @@ pub struct MajorTransactions {
pub unit_price: f32,
pub total: f32,
}
#[derive(Clone, Deserialize)]
pub struct UserTransaction {
pub company_name: String,
pub company_slug: String,
pub date_published: NaiveDate,
pub date_executed: NaiveDate,
pub person: String,
pub exchange: String,
pub nature: String,
pub isin: Option<String>,
pub instrument: String,
pub volume: i32,
pub unit_price: f32,
pub total: f32,
}

@ -0,0 +1,12 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize)]
pub struct UserProfile {
pub email: Option<String>,
pub name: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct FollowCompany {
pub company_id: i32,
}

@ -11,13 +11,30 @@ lazy_static! {
fn dark_mode_btn<G: Html>(cx: Scope, _props: ()) -> View<G> {
let global_state = Reactor::<G>::from_cx(cx).get_global_state::<AppStateRx>(cx);
let dark_mode = &global_state.dark_mode;
let toggle_dark_mode = move |_| {
global_state.dark_mode.set(!*global_state.dark_mode.get());
dark_mode.set(!*dark_mode.get());
};
view! { cx,
button (on:click=toggle_dark_mode, class="py-1 px-2 mx-1 rounded-full bg-slate-200 dark:bg-slate-800")
{ "Toggle dark mode" }
(if *dark_mode.get() {
view! {cx,
div(on:click=toggle_dark_mode, class="hover:cursor-pointer") {
svg(xmlns="http://www.w3.org/2000/svg", fill="none", viewBox="0 0 24 24", stroke-width="1.5", stroke="currentColor", class="w-6 h-6") {
path(stroke-linecap="round", stroke-linejoin="round", d="M21.752 15.002A9.718 9.718 0 0118 15.75c-5.385 0-9.75-4.365-9.75-9.75 0-1.33.266-2.597.748-3.752A9.753 9.753 0 003 11.25C3 16.635 7.365 21 12.75 21a9.753 9.753 0 009.002-5.998z") {}
}
}
}
} else {
view! {cx,
div(on:click=toggle_dark_mode, class="hover:cursor-pointer") {
svg(xmlns="http://www.w3.org/2000/svg", fill="none", viewBox="0 0 24 24", stroke-width="1.5", stroke="currentColor", class="w-6 h-6") {
path(stroke-linecap="round", stroke-linejoin="round", d="M12 3v2.25m6.364.386l-1.591 1.591M21 12h-2.25m-.386 6.364l-1.591-1.591M12 18.75V21m-4.773-4.227l-1.591 1.591M5.25 12H3m4.227-4.773L5.636 5.636M15.75 12a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0z") {}
}
}
}
})
}
}

@ -1 +1,3 @@
pub mod dark_mode_btn;
pub mod user_header;
pub mod user_icon;

@ -0,0 +1,40 @@
use lazy_static::lazy_static;
use perseus::prelude::*;
use sycamore::{prelude::*, rt::Event};
use crate::global_state::AppStateRx;
lazy_static! {
pub static ref USER_HEADER: Capsule<PerseusNodeType, ()> = get_capsule();
}
fn user_header<G: Html>(cx: Scope, _props: ()) -> View<G> {
let global_state = Reactor::<G>::from_cx(cx).get_global_state::<AppStateRx>(cx);
view! { cx,
(if *global_state.logged_in.get() {
view! { cx,
div (class="px-6 text-left border-l border-slate-700 dark:border-slate-300") {
a (id="header-followed-companies", href="/user_transactions", class="hover:underline") {
"Followed companies"
}
}
}
} else {
view!{cx, }
})
div (class="grow") {}
}
}
fn fallback<G: Html>(cx: Scope, _props: ()) -> View<G> {
view! { cx,
}
}
pub fn get_capsule<G: Html>() -> Capsule<G, ()> {
Capsule::build(Template::build("user_header"))
.fallback(fallback)
.view(user_header)
.build()
}

@ -0,0 +1,94 @@
use lazy_static::lazy_static;
use perseus::prelude::*;
use sycamore::{prelude::*, rt::Event};
use crate::global_state::AppStateRx;
lazy_static! {
pub static ref USER_ICON: Capsule<PerseusNodeType, ()> = get_capsule();
}
fn user_icon<G: Html>(cx: Scope, _props: ()) -> View<G> {
let global_state = Reactor::<G>::from_cx(cx).get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
#[cfg(client)]
spawn_local_scoped(cx, async move {
// Since logged in is set to false by default (on the first page load) we have to check
if *global_state.logged_in.get() {
return;
}
let status = api_scope_ref.is_authenticated().await.unwrap().status();
if status == 200 {
global_state.logged_in.set(true);
} else {
global_state.logged_in.set(false);
}
});
#[cfg(client)]
let logout = move |e: Event| {
wasm_cookies::delete("token");
global_state.logged_in.set(false);
navigate("/")
};
#[cfg(engine)]
let logout = move |_| {};
view! { cx,
(if *global_state.logged_in.get() {
view! { cx,
div(on:click=|_| navigate("/profile"), class="mx-1 hover:cursor-pointer", title="Login") {
// user profile icon
svg(xmlns="http://www.w3.org/2000/svg", fill="none", viewBox="0 0 24 24", stroke-width="1.5", stroke="currentColor", class="w-6 h-6") {
path(stroke-linecap="round", stroke-linejoin="round", d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z") {}
}
}
div(on:click=logout, class="mx-1 hover:cursor-pointer", title="Login") {
// logout icon
svg(xmlns="http://www.w3.org/2000/svg", width="24", height="24", viewBox="0 0 24 24", fill="none", stroke="currentColor", stroke-width="2", stroke-linecap="round", stroke-linejoin="round", class="feather feather-log-out") {
path(d="M9 21H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h4"){}
polyline(points="16 17 21 12 16 7") {}
line(x1="21", y1="12", x2="9", y2="12") {}
}
}
}
} else {
view! { cx,
div(on:click=|_| navigate("/login"), class="mx-1 hover:cursor-pointer", title="Login") {
// login icon
svg(xmlns="http://www.w3.org/2000/svg", width="24", height="24", viewBox="0 0 24 24", fill="none", stroke="currentColor", stroke-width="2", stroke-linecap="round", stroke-linejoin="round", class="feather feather-log-in") {
path(d="M15 3h4a2 2 0 0 1 2 2v14a2 2 0 0 1-2 2h-4") {}
polyline(points="10 17 15 12 10 7") {}
line(x1="15", y1="12", x2="3", y2="12"){}
}
}
}
})
}
}
fn fallback<G: Html>(cx: Scope, _props: ()) -> View<G> {
view! { cx,
div(on:click=|_| navigate("/login"), class="mx-1 hover:cursor-pointer", title="Login") {
// login icon
svg(xmlns="http://www.w3.org/2000/svg", width="24", height="24", viewBox="0 0 24 24", fill="none", stroke="currentColor", stroke-width="2", stroke-linecap="round", stroke-linejoin="round", class="feather feather-log-in") {
path(d="M15 3h4a2 2 0 0 1 2 2v14a2 2 0 0 1-2 2h-4") {}
polyline(points="10 17 15 12 10 7") {}
line(x1="15", y1="12", x2="3", y2="12"){}
}
}
}
}
pub fn get_capsule<G: Html>() -> Capsule<G, ()> {
Capsule::build(Template::build("dark_mode_btn"))
.fallback(fallback)
.view(user_icon)
.build()
}

@ -17,6 +17,7 @@ where
{
pub route: &'a R,
pub selected_item: &'a Signal<Option<T>>,
pub clear: &'a Signal<bool>,
}
#[component]
@ -62,6 +63,14 @@ where
});
});
create_effect(cx, move || {
if *props.clear.get() {
props.clear.set(false);
props.selected_item.set(None);
input.set("".to_string());
}
});
view! { cx,
input (bind:value=input, class="p-2 w-full rounded-md bg-slate-300 dark:bg-slate-800", on:blur=hide_dropdown) {}
div (class="relative") {

@ -16,7 +16,7 @@ where
{
pub headers_view: &'a Signal<Vec<View<G>>>,
pub data_view: &'a Signal<Vec<Vec<View<G>>>>,
pub table_class: &'a String,
pub table_class: &'a str,
}
#[derive(Debug, Clone)]
@ -51,7 +51,7 @@ where
for (idx, row) in v_table.iter().enumerate() {
let views = PEqView(idx, View::new_fragment(
row.iter().map(|cell| {
view!{cx, th (class="m-2 p-2 border-slate-500 border-x border-dashed") { (cell) } }
view!{cx, th (class="p-2 m-2 border-dashed border-slate-500 border-x") { (cell) } }
} ).collect()
));
@ -64,12 +64,12 @@ where
view! { cx,
table (class=format!("{} table-auto bg-slate-200 text-left dark:bg-slate-800 rounded-lg mx-auto my-2", props.table_class)) {
thead {
tr (class="border-b-2 border-slate-500 text-center") {
tr (class="text-center border-b-2 border-slate-500") {
Keyed(
iterable=headers,
view=|cx, v| {
view! {cx,
th (class="m-2 p-2") { (v.1) }
th (class="p-2 m-2") { (v.1) }
}
},
key=|v| v.0,
@ -81,7 +81,7 @@ where
iterable=data,
view=|cx, t| {
view! {cx,
tr (class="m-2 p-2 border-slate-500 border") {
tr (class="p-2 m-2 border border-slate-500") {
(t.1)
}
}

@ -12,8 +12,8 @@ pub fn MainContentContainer<'a, G: Html>(cx: Scope<'a>, props: MainProps<'a, G>)
let children = props.children.call(cx);
view! {cx,
div (id="main", class="flex flex-col items-center justify-center ") {
div (class="w-4/5 m-10 p-3 bg-slate-100 dark:bg-slate-600 rounded-lg items-center justify-center") {
div (id="main", class="flex flex-col justify-center items-center") {
div (class="justify-center items-center p-3 m-10 w-4/5 rounded-lg bg-slate-100 dark:bg-slate-600") {
(children)
}
}

@ -21,6 +21,7 @@ where
pub route: &'a C,
pub filter: Option<String>,
pub table_class: &'a String,
pub refresh: &'a Signal<bool>,
}
impl<'a, M, F, C> PaginatedTableStateRx<'a, M, F, C>
@ -81,21 +82,39 @@ where
page_size_string.track();
page.set(0);
});
let props_sig = create_signal(cx, props);
#[cfg(client)]
create_effect(cx, move || {
let page = *page.get();
let page_size_s = page_size_string.get();
let page_size = page_size_s.parse().unwrap_or(20);
let data_fetch = move || {
spawn_local_scoped(cx, async move {
let res = props_sig.get().get_data(page, page_size).await.unwrap();
let res = props_sig
.get()
.get_data(*page.get(), page_size_string.get().parse().unwrap_or(20))
.await
.unwrap();
paginated_data.set(Some(res.clone()));
n_rows.set(res.count);
let table_content = res.into_table_data(cx);
table_prop.data_view.set(table_content.data_view);
table_prop.headers_view.set(table_content.headers_view);
n_page.set((*paginated_data.get()).as_ref().map_or(0, |t| t.num_pages));
});
})
};
#[cfg(client)]
create_effect(cx, move || {
if *props_sig.get().refresh.get() {
props_sig.get().refresh.set(false);
data_fetch()
}
});
#[cfg(client)]
create_effect(cx, move || {
page.track();
page_size_string.track();
data_fetch();
});
view! { cx,

@ -1,25 +1,33 @@
use perseus::prelude::*;
use sycamore::prelude::*;
use crate::capsules::dark_mode_btn::DARK_MODE_BTN;
use crate::capsules::{
dark_mode_btn::DARK_MODE_BTN, user_header::USER_HEADER, user_icon::USER_ICON,
};
#[component]
pub fn TheHeader<G: Html>(cx: Scope) -> View<G> {
view! { cx,
header (class="p-2 w-full h-11 align-middle bg-gray-100 shadow-md backdrop-blur-lg dark:bg-slate-500/30") {
div (class="flex") {
div (class="flex-none mr-12") {
div (class="flex-none mx-6") {
a (href="/", class="hover:underline") {
"Fast Insiders"
}
}
div (class="text-left grow") {
div (class="px-6 text-left") {
a (id="header-all-transactions", href="/transactions", class="hover:underline") {
"All transactions"
}
}
div (class="flex-none") {
(DARK_MODE_BTN.widget(cx,"",()))
(USER_HEADER.widget(cx, "",()))
div (class="flex items-center pl-3 border-l border-slate-700 dark:border-slate-300") {
div(class="mx-1") {
(DARK_MODE_BTN.widget(cx,"",()))
}
div(class="flex mx-1") {
(USER_ICON.widget(cx,"",()))
}
}
}
}

@ -15,6 +15,7 @@ pub fn get_global_state_creator() -> GlobalStateCreator {
#[rx(alias = "AppStateRx")]
pub struct AppState {
pub dark_mode: bool,
pub logged_in: bool,
pub api: FastInsidersApi,
}
@ -22,6 +23,7 @@ pub struct AppState {
pub async fn get_build_state() -> AppState {
AppState {
dark_mode: true,
logged_in: false,
api: FastInsidersApi::new(""), // It's unfortunately not possible to have a different type
// for the build state and the request state, I would rather
// have left this out
@ -34,9 +36,12 @@ pub async fn get_build_state() -> AppState {
async fn get_request_state(_req: Request) -> AppState {
use crate::env::Config;
let config = Config::new();
let api = FastInsidersApi::new(&config.api_url);
AppState {
dark_mode: true,
api: FastInsidersApi::new(&config.api_url),
logged_in: false,
api,
}
}
@ -44,6 +49,7 @@ async fn get_request_state(_req: Request) -> AppState {
async fn amalgamate_states(build_state: AppState, request_state: AppState) -> AppState {
AppState {
dark_mode: build_state.dark_mode,
logged_in: build_state.logged_in,
api: request_state.api,
}
}

@ -14,7 +14,13 @@ pub fn main<G: Html>() -> PerseusApp<G> {
PerseusApp::new()
.template(crate::templates::index::get_template())
.template(crate::templates::transactions::get_template())
.template(crate::templates::login::get_template())
.template(crate::templates::profile::get_template())
.template(crate::templates::register::get_template())
.template(crate::templates::user_transactions::get_template())
.capsule_ref(&*crate::capsules::dark_mode_btn::DARK_MODE_BTN)
.capsule_ref(&*crate::capsules::user_icon::USER_ICON)
.capsule_ref(&*crate::capsules::user_header::USER_HEADER)
.global_state_creator(crate::global_state::get_global_state_creator())
.error_views(crate::error_pages::get_error_views())
.index_view(|cx| {

@ -29,6 +29,7 @@ fn index_page<G: Html>(cx: BoundedScope) -> View<G> {
route: route_ref,
filter: Some("72".to_string()),
table_class: table_classes,
refresh: create_signal(cx, true),
};
let route_ref = create_ref(cx, move |c, p, s| {
@ -40,6 +41,7 @@ fn index_page<G: Html>(cx: BoundedScope) -> View<G> {
route: route_ref,
filter: Some((24 * 30).to_string()),
table_class: table_classes,
refresh: create_signal(cx, true),
};
let route_ref = create_ref(cx, move |c, p, s| {
@ -51,6 +53,7 @@ fn index_page<G: Html>(cx: BoundedScope) -> View<G> {
route: route_ref,
filter: Some((24 * 30).to_string()),
table_class: table_classes,
refresh: create_signal(cx, true),
};
let dark_mode_class = create_memo(cx, || {

@ -0,0 +1,98 @@
use perseus::prelude::*;
use sycamore::{prelude::*, rt::Event};
use crate::{
components::{main_content_container::MainContentContainer, the_header::TheHeader},
global_state::AppStateRx,
};
fn login_page<G: Html>(cx: BoundedScope) -> View<G> {
let reactor = Reactor::<G>::from_cx(cx);
let global_state = reactor.get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
let dark_mode_class = create_memo(cx, || {
if *global_state.dark_mode.get() {
"dark"
} else {
""
}
});
let username = create_signal(cx, "".to_string());
let password = create_signal(cx, "".to_string());
let error_msg = create_signal(cx, "".to_string());
create_effect(cx, move || {
if *global_state.logged_in.get() {
navigate("/");
}
});
let submit_disabled = create_memo(cx, move || {
username.get().is_empty() || password.get().is_empty()
});
#[cfg(client)]
let submit_login = move |e: Event| {
use crate::api::routes::user::UserLoginBody;
e.prevent_default();
let user_info = UserLoginBody {
name: username.get().to_string(),
password: password.get().to_string(),
};
spawn_local_scoped(cx, async move {
match api_scope_ref.login(&user_info).await {
Ok(()) => {
global_state.logged_in.set(true);
}
Err(e) => error_msg.set(e.to_string()),
};
})
};
#[cfg(engine)]
let submit_login = move |_| {};
view! {cx,
main (class=format!("{} flex flex-1", dark_mode_class)) {
div (class="flex-1 font-sans bg-slate-200 text-slate-700 dark:bg-slate-700 dark:text-slate-100") {
TheHeader()
MainContentContainer(useless_prop=1) {
form(class="flex flex-col justify-center items-center") {
label(for="username") { "Username:" }
input(id="username", bind:value=username, type="text", class="p-2 m-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800") {}
label(for="password") { "Password:" }
input(id="password", bind:value=password, type="password", class="p-2 m-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800") {}
input(on:click=submit_login,
value="Login",
type="submit",
class="p-2 m-2 rounded-md hover:cursor-pointer disabled:cursor-not-allowed bg-slate-300 dark:hover:bg-slate-900 dark:bg-slate-800 hover:bg-slate-400",
disabled=*submit_disabled.get()
) {}
p (class="text-red-700 dark:text-rose-500") {
(error_msg.get())
}
}
a (class="hover:underline", href="/register") {
"Don't have an account? Register here."
}
}
}
}
}
}
pub fn get_template<G: Html>() -> Template<G> {
Template::build("login").head(head).view(login_page).build()
}
#[engine_only_fn]
fn head(cx: Scope) -> View<SsrNode> {
view! {cx,
title { "Fast Insiders" }
}
}

@ -1,2 +1,6 @@
pub mod index;
pub mod login;
pub mod profile;
pub mod register;
pub mod transactions;
pub mod user_transactions;

@ -0,0 +1,157 @@
use std::rc::Rc;
use perseus::prelude::*;
use sycamore::{prelude::*, rt::Event};
use crate::{
components::{
base_async_select::{AsyncSelectRx, BaseAsyncSelect},
base_button::{BaseButton, BaseButtonStateRx},
base_table::BaseTable,
loading::Loading,
main_content_container::MainContentContainer,
paginated_data_table::{PaginatedTable, PaginatedTableStateRx},
the_header::TheHeader,
},
global_state::AppStateRx,
};
fn profile_page<G: Html>(cx: BoundedScope) -> View<G> {
let reactor = Reactor::<G>::from_cx(cx);
let global_state = reactor.get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
let dark_mode_class = create_memo(cx, || {
if *global_state.dark_mode.get() {
"dark"
} else {
""
}
});
let username = create_signal(cx, "".to_string());
let email = create_signal(cx, "".to_string());
let loading: &Signal<bool> = create_signal(cx, true);
#[cfg(client)]
spawn_local_scoped(cx, async move {
let resp = api_scope_ref.get_profile().await.unwrap();
username.set(resp.name);
if let Some(em) = resp.email {
email.set(em);
}
loading.set(false);
});
let displayed_email = create_memo(cx, move || {
if email.get().is_empty() {
"No email set".to_string()
} else {
(*email.get()).clone()
}
});
let async_select_route_ref = create_ref(cx, |n, l| api_scope_ref.get_company_by_name(n, l));
let async_select_prop: AsyncSelectRx<_, _, _> = AsyncSelectRx {
route: async_select_route_ref,
selected_item: create_signal(cx, None),
clear: create_signal(cx, false),
};
let table_route_ref = create_ref(cx, move |_, p, s| {
api_scope_ref.get_followed_companies(p, s)
});
let paginated_table_state: PaginatedTableStateRx<_, _, _> = PaginatedTableStateRx {
record_label: "Companies".to_owned(),
route: table_route_ref,
filter: None,
table_class: create_ref(cx, "w-full".to_string()),
refresh: create_signal(cx, true),
};
let follow_button = BaseButtonStateRx {
label: create_signal(cx, "Follow".to_string()),
disabled: create_memo(cx, move || async_select_prop.selected_item.get().is_none()),
clicked: create_signal(cx, false),
};
create_effect(cx, move || {
if *follow_button.clicked.get() && async_select_prop.selected_item.get_untracked().is_some()
{
follow_button.clicked.set(false);
if let Some(company) = (*async_select_prop.selected_item.get_untracked()).clone() {
spawn_local_scoped(cx, async move {
api_scope_ref.follow_company(company.id).await.unwrap();
paginated_table_state.refresh.set(true);
});
}
async_select_prop.clear.set(true);
}
});
view! {cx,
main (class=format!("{} flex flex-1", dark_mode_class)) {
div (class="flex-1 font-sans bg-slate-200 text-slate-700 dark:bg-slate-700 dark:text-slate-100") {
TheHeader()
MainContentContainer(useless_prop=1) {
h1(class="text-lg text-center") {
"Profile page"
}
div(class="m-auto w-full max-w-2xl") {
(if !*loading.get() {
view! {cx,
p() {
b() {
"Username: "
}
input(
class="p-2 w-full rounded-md bg-slate-300 dark:bg-slate-800",
disabled=true,
value=username.get(),
)
}
p() {
b() {
"Email: "
}
input(
class="p-2 w-full rounded-md bg-slate-300 dark:bg-slate-800",
disabled=true,
value=displayed_email.get(),
)
}
}
} else {
view! {cx,
Loading()
}
})
h2(class="mt-3 font-bold") {
"Follow companies"
}
BaseAsyncSelect(async_select_prop)
BaseButton(follow_button)
PaginatedTable(paginated_table_state)
}
}
}
}
}
}
pub fn get_template<G: Html>() -> Template<G> {
Template::build("profile")
.head(head)
.view(profile_page)
.build()
}
#[engine_only_fn]
fn head(cx: Scope) -> View<SsrNode> {
view! {cx,
title { "Fast Insiders - User Profile" }
}
}

@ -0,0 +1,133 @@
use perseus::prelude::*;
use sycamore::{prelude::*, rt::Event};
use crate::{
components::{main_content_container::MainContentContainer, the_header::TheHeader},
global_state::AppStateRx,
};
fn register_page<G: Html>(cx: BoundedScope) -> View<G> {
let reactor = Reactor::<G>::from_cx(cx);
let global_state = reactor.get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
let dark_mode_class = create_memo(cx, || {
if *global_state.dark_mode.get() {
"dark"
} else {
""
}
});
let username = create_signal(cx, "".to_string());
let email = create_signal(cx, "".to_string());
let password = create_signal(cx, "".to_string());
let confirm_password = create_signal(cx, "".to_string());
let error_msg = create_signal(cx, "".to_string());
create_effect(cx, move || {
if *global_state.logged_in.get() {
navigate("/");
}
});
let passwords_match = move |e: Event| {
if confirm_password.get() != password.get() {
error_msg.set("Passwords do not match".to_string());
} else {
error_msg.set("".to_string());
}
};
let submit_disabled = create_memo(cx, move || {
username.get().is_empty()
|| password.get().is_empty()
|| password.get() != confirm_password.get()
});
#[cfg(client)]
let submit_register = move |e: Event| {
use crate::api::routes::user::UserRegisterBody;
e.prevent_default();
let user_info = UserRegisterBody {
name: username.get().to_string(),
email: email.get().to_string(),
password: password.get().to_string(),
};
spawn_local_scoped(cx, async move {
match api_scope_ref.register(&user_info).await {
Ok(()) => navigate("/login"),
Err(e) => error_msg.set(e.to_string()),
};
})
};
#[cfg(engine)]
let submit_register = move |_| {};
view! {cx,
main (class=format!("{} flex flex-1", dark_mode_class)) {
div (class="flex-1 font-sans bg-slate-200 text-slate-700 dark:bg-slate-700 dark:text-slate-100") {
TheHeader()
MainContentContainer(useless_prop=1) {
form(class="flex flex-col justify-center items-center") {
label(for="username") { "Username:" }
input(id="username",
bind:value=username,
type="text",
class="p-2 m-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800"
) {}
label(for="password") { "Password:" }
input(id="password",
bind:value=password,
type="password",
class="p-2 m-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800"
) {}
label(for="confirm-password") { "Confirm Password:" }
input(id="confirm-password",
bind:value=confirm_password,
on:blur=passwords_match,
type="password",
class="p-2 m-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800"
) {}
label(for="Email") { "Email:" }
input(id="email",
bind:value=email,
type="text",
class="p-2 mx-2 w-1/3 rounded-md bg-slate-300 dark:bg-slate-800"
) {}
p (class="mt-0 italic") {
"Set an email if you want to be able to reset your password"
}
input(on:click=submit_register,
value="Register",
type="submit",
class="p-2 m-2 rounded-md hover:cursor-pointer disabled:cursor-not-allowed bg-slate-300 dark:hover:bg-slate-900 dark:bg-slate-800 hover:bg-slate-400",
disabled=*submit_disabled.get()
) {}
p (class="text-red-700 dark:text-rose-500") {
(error_msg.get())
}
}
}
}
}
}
}
pub fn get_template<G: Html>() -> Template<G> {
Template::build("register")
.head(head)
.view(register_page)
.build()
}
#[engine_only_fn]
fn head(cx: Scope) -> View<SsrNode> {
view! {cx,
title { "Fast Insiders - Register" }
}
}

@ -45,12 +45,14 @@ fn transactions_page<'a, G: Html>(cx: Scope, state: &TransactionsPageStateRx) ->
route: route_ref,
filter: (*state.company_slug.get()).clone(),
table_class: create_ref(cx, "".to_string()),
refresh: create_signal(cx, true),
};
let route_ref = create_ref(cx, |n, l| api_scope_ref.get_company_by_name(n, l));
let async_select_prop: AsyncSelectRx<_, _, _> = AsyncSelectRx {
route: route_ref,
selected_item: create_signal(cx, None),
clear: create_signal(cx, false),
};
let search_button = BaseButtonStateRx {

@ -0,0 +1,89 @@
use perseus::prelude::*;
use serde::{Deserialize, Serialize};
use sycamore::prelude::*;
use crate::{
components::{
base_async_select::{AsyncSelectRx, BaseAsyncSelect},
base_button::{BaseButton, BaseButtonStateRx},
main_content_container::MainContentContainer,
paginated_data_table::{PaginatedTable, PaginatedTableStateRx},
the_header::TheHeader,
},
global_state::AppStateRx,
};
#[derive(Serialize, Deserialize, Clone, ReactiveState)]
#[rx(alias = "TransactionsPageStateRx")]
pub struct TransactionsPageState {
pub company_slug: Option<String>,
}
fn user_transactions_page<'a, G: Html>(cx: Scope) -> View<G> {
let global_state = Reactor::<G>::from_cx(cx).get_global_state::<AppStateRx>(cx);
let api = global_state.api.get();
let api_scope_ref = create_ref(cx, api);
let expand = create_signal(cx, false);
let filter_expand = BaseButtonStateRx {
label: create_signal(cx, "Filters".to_string()),
disabled: create_signal(cx, false),
clicked: create_signal(cx, false),
};
create_effect(cx, move || {
if *filter_expand.clicked.get() {
filter_expand.clicked.set(false);
expand.set(!*expand.get());
}
});
let route_ref = create_ref(cx, move |_, p, s| api_scope_ref.get_user_transactions(p, s));
let paginated_table_state: PaginatedTableStateRx<_, _, _> = PaginatedTableStateRx {
record_label: "transactions".to_owned(),
route: route_ref,
filter: None,
table_class: create_ref(cx, "".to_string()),
refresh: create_signal(cx, true),
};
let dark_mode_class = create_memo(cx, || {
if *global_state.dark_mode.get() {
"dark"
} else {
""
}
});
view! {cx,
main (class=format!("{} flex flex-1", dark_mode_class)) {
div (class="flex-1 font-sans bg-slate-200 text-slate-700 dark:bg-slate-700 dark:text-slate-100") {
TheHeader()
MainContentContainer(useless_prop=1) {
a (class="hover:underline", href="/user_transactions") {
h1 (
class="text-lg text-center"
) {
"Latest transactions from your followed companies"
}
}
PaginatedTable(paginated_table_state)
}
}
}
}
}
pub fn get_template<G: Html>() -> Template<G> {
Template::build("user_transactions")
.head(head)
.view(user_transactions_page)
.build()
}
#[engine_only_fn]
fn head(cx: Scope) -> View<SsrNode> {
view! {cx,
title { "Fast Insiders" }
}
}

File diff suppressed because one or more lines are too long

1
server/.gitignore vendored

@ -1 +1,2 @@
target/
keys/

2427
server/Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -6,18 +6,19 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.152", features = ["derive"] }
dotenvy = "0.15.6"
envy = "0.4.2"
serde_json = "1.0.91"
chrono = { version = "0.4", features = ["serde"] }
serde = { version = "1", features = ["derive"] }
dotenvy = "0.15"
envy = "0.4"
serde_json = "1"
migration = { version = "0.1.0", path = "./migration" }
tokio = { version = "^1.20.1", features = ["full"] }
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.11", features = ["json", "rustls-tls"] }
axum = "0.6.12"
hyper = { version = "0.14.25", features = ["full"] }
axum = { version = "0.6", features = ["headers", "macros"] }
hyper = { version = "0.14", features = ["full"] }
tower = "0.4"
sea-orm = { version = "0.11.0", features = [
tower-cookies = "0.9"
sea-orm = { version = "0.12", features = [
"runtime-tokio-rustls",
"macros",
"sqlx-mysql",
@ -32,4 +33,11 @@ thiserror = "1.0.38"
slug = "0.1.4"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tower-http = { version = "0.4", features = ["trace", "cors"] }
http = "0.2"
tracing = "0.1"
rust-argon2 = "1"
rand = "0.8"
jsonwebtoken = "8"
cookie = { version = "0.17", features = [ "secure" ] }
rsa = { version = "0.9.2", features = [ "pem" ] }
time = "0.3"

File diff suppressed because it is too large Load Diff

@ -12,5 +12,5 @@ path = "src/lib.rs"
async-std = { version = "^1", features = ["attributes", "tokio1"] }
[dependencies.sea-orm-migration]
version = "0.11.0"
version = "0.12.0"
features = ["sqlx-mysql", "runtime-tokio-rustls"]

@ -5,6 +5,8 @@ mod m20230112_115856_create_company_table;
mod m20230112_160440_create_transaction_table;
mod m20230119_112539_create_transactions_in_process_table;
mod m20230303_132528_transactions_created_at;
mod m20230604_113236_user_table;
mod m20231126_093416_create_user_company_table;
pub struct Migrator;
@ -16,6 +18,8 @@ impl MigratorTrait for Migrator {
Box::new(m20230112_160440_create_transaction_table::Migration),
Box::new(m20230119_112539_create_transactions_in_process_table::Migration),
Box::new(m20230303_132528_transactions_created_at::Migration),
Box::new(m20230604_113236_user_table::Migration),
Box::new(m20231126_093416_create_user_company_table::Migration),
]
}
}

@ -0,0 +1,43 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(User::Table)
.if_not_exists()
.col(
ColumnDef::new(User::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(ColumnDef::new(User::Email).string().unique_key())
.col(ColumnDef::new(User::Name).string().not_null().unique_key())
.col(ColumnDef::new(User::Password).string().not_null())
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(User::Table).to_owned())
.await
}
}
#[derive(Iden)]
pub enum User {
Table,
Id,
Email,
Name,
Password,
}

@ -0,0 +1,56 @@
use crate::m20230112_115856_create_company_table as company;
use crate::m20230604_113236_user_table as user;
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(UserCompany::Table)
.if_not_exists()
.col(ColumnDef::new(UserCompany::UserId).integer().not_null())
.col(ColumnDef::new(UserCompany::CompanyId).integer().not_null())
.primary_key(
Index::create()
.col(UserCompany::CompanyId)
.col(UserCompany::UserId),
)
.foreign_key(
ForeignKey::create()
.name("FK_user")
.from(UserCompany::Table, UserCompany::UserId)
.to(user::User::Table, user::User::Id)
.on_update(ForeignKeyAction::Cascade)
.on_delete(ForeignKeyAction::Cascade),
)
.foreign_key(
ForeignKey::create()
.name("FK_company")
.from(UserCompany::Table, UserCompany::CompanyId)
.to(company::Company::Table, company::Company::Id)
.on_update(ForeignKeyAction::Cascade)
.on_delete(ForeignKeyAction::Cascade),
)
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(UserCompany::Table).to_owned())
.await
}
}
#[derive(Iden)]
enum UserCompany {
Table,
UserId,
CompanyId,
}

@ -0,0 +1,21 @@
//! Custom serialization of OffsetDateTime to conform with the JWT spec (RFC 7519 section 2, "Numeric Date")
use serde::{self, Deserialize, Deserializer, Serializer};
use time::OffsetDateTime;
/// Serializes an OffsetDateTime to a Unix timestamp (milliseconds since 1970/1/1T00:00:00T)
pub fn serialize<S>(date: &OffsetDateTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let timestamp = date.unix_timestamp();
serializer.serialize_i64(timestamp)
}
/// Attempts to deserialize an i64 and use as a Unix timestamp
pub fn deserialize<'de, D>(deserializer: D) -> Result<OffsetDateTime, D::Error>
where
D: Deserializer<'de>,
{
OffsetDateTime::from_unix_timestamp(i64::deserialize(deserializer)?)
.map_err(|_| serde::de::Error::custom("invalid Unix timestamp value"))
}

@ -0,0 +1,322 @@
use futures::StreamExt;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use rsa::{
pkcs8::{der::zeroize::Zeroizing, EncodePrivateKey, EncodePublicKey, LineEnding},
RsaPrivateKey, RsaPublicKey,
};
use serde::{Deserialize, Serialize};
use std::path::Path;
use thiserror::Error;
use tokio::{fs, io};
pub mod jwt_numeric_date;
const RSA_BITS: usize = 4096;
const KEYS_DIR: &str = "./keys";
const CUR_PRIV: &str = "cur_priv.pem";
const CUR_PUB: &str = "cur_pub.pem";
const PRE_PRIV: &str = "pre_priv.pem";
const PRE_PUB: &str = "pre_pub.pem";
#[derive(Clone)]
pub struct JWTSecretManager {
current: RSAKeyPair,
previous: Option<RSAKeyPair>,
}
#[derive(Debug, Error)]
pub enum JWTSecretManagerError {
#[error("Failed to generatie new key pair: {0}")]
KeyGeneration(rsa::Error),
#[error("Failed to save pem files to file system: {0}")]
KeySave(io::Error),
#[error("Failed to create the directory to store keys: {0}")]
KeysDir(io::Error),
#[error("Failed to read existing keys from file system: {0}")]
KeysRead(io::Error),
#[error("Failed to encode a new JWT: {0}")]
EncodeFailed(jsonwebtoken::errors::Error),
}
#[derive(Debug, PartialEq)]
pub enum ValidationOutcome<T>
where
T: Serialize,
for<'de> T: Deserialize<'de>,
{
Ok(T),
Outdated(T, String),
Error(jsonwebtoken::errors::Error),
Unauthorized,
}
impl<T> From<jsonwebtoken::errors::Error> for ValidationOutcome<T>
where
T: Serialize,
for<'de> T: Deserialize<'de>,
{
fn from(value: jsonwebtoken::errors::Error) -> Self {
ValidationOutcome::Error(value)
}
}
impl JWTSecretManager {
pub async fn init() -> Result<Self, JWTSecretManagerError> {
// Check if we have any keys
let keys_dir = Path::new(KEYS_DIR);
if !keys_dir.exists() {
fs::create_dir(KEYS_DIR)
.await
.map_err(JWTSecretManagerError::KeysDir)?;
}
let current = if !keys_dir.join(CUR_PUB).exists() || !keys_dir.join(CUR_PRIV).exists() {
let keys = RSAKeyPair::generate_new()
.await
.map_err(JWTSecretManagerError::KeyGeneration)?;
keys.save_pem_files(keys_dir, CUR_PRIV, CUR_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
keys
} else {
let public = fs::read_to_string(keys_dir.join(CUR_PUB))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
let private = fs::read_to_string(keys_dir.join(CUR_PRIV))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
RSAKeyPair {
private: Zeroizing::new(private),
public,
}
};
let previous = if !keys_dir.join(PRE_PUB).exists() || !keys_dir.join(PRE_PRIV).exists() {
None
} else {
let public = fs::read_to_string(keys_dir.join(PRE_PUB))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
let private = fs::read_to_string(keys_dir.join(PRE_PRIV))
.await
.map_err(JWTSecretManagerError::KeysRead)?;
Some(RSAKeyPair {
private: Zeroizing::new(private),
public,
})
};
Ok(JWTSecretManager { current, previous })
}
pub async fn rotate(&mut self) -> Result<(), JWTSecretManagerError> {
self.previous = Some(self.current.clone());
self.current
.save_pem_files(KEYS_DIR, PRE_PRIV, PRE_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
self.current = RSAKeyPair::generate_new()
.await
.map_err(JWTSecretManagerError::KeyGeneration)?;
self.current
.save_pem_files(KEYS_DIR, CUR_PRIV, CUR_PUB)
.await
.map_err(JWTSecretManagerError::KeySave)?;
Ok(())
}
pub async fn decode<T>(&self, token: &str) -> ValidationOutcome<T>
where
T: Serialize + Clone,
for<'de> T: Deserialize<'de>,
{
if let Ok(claim) = self.current.verify_jwt(token).await {
return ValidationOutcome::Ok(claim);
}
match &self.previous {
Some(k) => {
if let Ok(claim) = k.verify_jwt::<T>(token).await {
let new_token = match self.current.encode(&claim).await {
Ok(t) => t,
Err(e) => return ValidationOutcome::Error(e),
};
return ValidationOutcome::Outdated(claim, new_token);
}
}
None => (),
}
ValidationOutcome::Unauthorized
}
pub async fn encode_new<T>(&self, claim: &T) -> Result<String, JWTSecretManagerError>
where
T: Serialize + Clone,
for<'de> T: Deserialize<'de>,
{
self.current
.encode(&claim)
.await
.map_err(JWTSecretManagerError::EncodeFailed)
}
}
#[derive(Clone)]
pub struct RSAKeyPair {
private: Zeroizing<String>,
public: String,
}
impl RSAKeyPair {
async fn generate_new() -> Result<RSAKeyPair, rsa::Error> {
let mut rng = rand::thread_rng();
let private_key = RsaPrivateKey::new(&mut rng, RSA_BITS).expect("failed to generate a key");
let public_key = RsaPublicKey::from(&private_key);
let private = private_key.to_pkcs8_pem(LineEnding::LF)?;
let public = public_key.to_public_key_pem(LineEnding::LF).unwrap(); // This is infaillible?
Ok(RSAKeyPair { private, public })
}
async fn save_pem_files(
&self,
path: impl AsRef<Path>,
private_name: &str,
public_name: &str,
) -> io::Result<()> {
// There's probably a better looking way to do this
let futures = vec![
fs::write(path.as_ref().join(public_name), &self.public),
fs::write(path.as_ref().join(private_name), &self.private),
];
let stream = futures::stream::iter(futures).buffer_unordered(2);
let results = stream.collect::<Vec<_>>().await;
for res in results {
res?;
}
Ok(())
}
async fn verify_jwt<T>(&self, token: &str) -> Result<T, jsonwebtoken::errors::Error>
where
for<'de> T: Deserialize<'de>,
{
match decode::<T>(
token,
&DecodingKey::from_rsa_pem(self.public.as_bytes())?,
&Validation::new(Algorithm::RS256),
) {
Ok(t) => Ok(t.claims),
Err(e) => Err(e),
}
}
async fn encode<T>(&self, claim: &T) -> Result<String, jsonwebtoken::errors::Error>
where
T: Serialize,
{
let secret = &EncodingKey::from_rsa_pem(self.private.as_bytes())?;
encode(&Header::new(Algorithm::RS256), claim, secret)
}
}
#[cfg(test)]
mod test {
use time::Duration;
use cookie::time::OffsetDateTime;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use crate::crypto::RSAKeyPair;
use crate::crypto::ValidationOutcome;
use super::jwt_numeric_date;
use super::JWTSecretManager;
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
struct Claim {
data: Vec<u8>,
#[serde(with = "jwt_numeric_date")]
exp: OffsetDateTime,
}
fn create_random_claim_data() -> Vec<u8> {
let mut bytes = [0u8; 16];
rand::thread_rng().fill_bytes(&mut bytes);
bytes.into()
}
#[tokio::test]
async fn generate_encode_and_decode_jwt_using_rsa() {
let rsa_keys = RSAKeyPair::generate_new()
.await
.expect("It should be possible to generate a new RSA key pair");
let claim = Claim {
data: create_random_claim_data(),
exp: OffsetDateTime::now_utc() + Duration::days(1),
};
let token = rsa_keys
.encode(&claim)
.await
.expect("It should be possible to encode a claim");
let decoded_claim = rsa_keys
.verify_jwt::<Claim>(&token)
.await
.expect("It should be possible to verify a token");
assert_eq!(decoded_claim.data, claim.data);
}
#[tokio::test]
async fn jwt_secret_manager_can_encode_and_decode() {
let jwt_secret_manager = JWTSecretManager::init()
.await
.expect("JWTSecretManager should be able to init");
let claim = Claim {
data: create_random_claim_data(),
exp: OffsetDateTime::now_utc() + Duration::days(1),
};
let token = jwt_secret_manager
.encode_new(&claim)
.await
.expect("It should be possible to encode a claim into a JWT");
match jwt_secret_manager.decode::<Claim>(&token).await {
ValidationOutcome::Ok(c) => assert_eq!(c.data, claim.data),
_ => panic!(),
}
}
#[tokio::test]
async fn jwt_secret_manager_can_rotate() {
let mut jwt_secret_manager = JWTSecretManager::init()
.await
.expect("JWTSecretManager should be able to init");
jwt_secret_manager
.rotate()
.await
.expect("JWTSecretManager should be able to rotate rsa keys");
}
}

@ -1,5 +1,6 @@
use sea_orm::{
error::DbErr, prelude::*, sea_query::SimpleExpr, FromQueryResult, Order, QueryOrder,
error::DbErr, prelude::*, sea_query::SimpleExpr, FromQueryResult, ItemsAndPagesNumber, Order,
QueryOrder,
};
use serde::{Deserialize, Serialize};
@ -31,9 +32,10 @@ where
}
let pages = selector.into_model().paginate(db, s);
let count = pages.num_items().await?;
let num_pages = pages.num_pages().await?;
let ItemsAndPagesNumber {
number_of_items: count,
number_of_pages: num_pages,
} = pages.num_items_and_pages().await?;
let p = page.unwrap_or(0).min(num_pages);
@ -48,6 +50,7 @@ where
Ok(res)
}
/// Use for 1-1 relationships, retrieves the entity and the single related record
pub async fn paginate_also_related<E, R, T, K, C>(
db: &DatabaseConnection,
page: Option<u64>,
@ -81,9 +84,10 @@ where
}
let pages = selector.into_model().paginate(db, s);
let count = pages.num_items().await?;
let num_pages = pages.num_pages().await?;
let ItemsAndPagesNumber {
number_of_items: count,
number_of_pages: num_pages,
} = pages.num_items_and_pages().await?;
let p = page.unwrap_or(0).min(num_pages);

@ -8,6 +8,9 @@ pub struct Env {
pub host: String,
#[serde(default = "port_default")]
pub port: String,
#[serde(default = "client_url_default")]
pub client_url: String,
pub api_url: String,
pub mysql_user: String,
pub mysql_password: String,
pub mysql_host: String,
@ -45,6 +48,10 @@ fn port_default() -> String {
"8000".to_string()
}
fn client_url_default() -> String {
"http://localhost:8080".to_string()
}
fn mysql_port_default() -> String {
"3306".to_string()
}
@ -107,6 +114,7 @@ impl Env {
#[derive(Debug)]
pub struct Config {
pub server_address: SocketAddr,
pub client_url: String,
pub database_url: String,
pub max_connections: u32,
pub min_connections: u32,
@ -134,6 +142,7 @@ impl Config {
let mut config = Config {
server_address,
client_url: env.client_url,
database_url,
max_connections: env.max_connections,
min_connections: env.min_connections,

@ -12,6 +12,9 @@ pub enum AppError {
DbErr(DbErr),
InProcessTransaction(InProcessTransactionError),
NotFound(String),
InternalServerError(String),
Unauthorized,
Conflict(String),
}
impl From<DbErr> for AppError {
@ -37,7 +40,10 @@ impl IntoResponse for AppError {
StatusCode::INTERNAL_SERVER_ERROR,
format!("Error in the in process transaction repo: {}", e),
),
AppError::NotFound(e) => (StatusCode::NOT_FOUND, format!("Not found error: {}", e)),
AppError::NotFound(e) => (StatusCode::NOT_FOUND, format!("{e}")),
AppError::InternalServerError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e),
AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
AppError::Conflict(e) => (StatusCode::CONFLICT, e),
};
let body = Json(json!({

@ -6,17 +6,26 @@ extern crate lazy_static;
#[macro_use]
extern crate log;
extern crate argon2;
// External crates
use axum::{
extract::MatchedPath,
extract::{MatchedPath, State},
headers::{authorization::Bearer, Authorization},
http::{HeaderValue, Method, Request, StatusCode},
middleware::{from_fn_with_state, Next},
response::Response,
routing::get,
Router,
routing::{get, post},
Router, TypedHeader,
};
use crypto::ValidationOutcome;
use http::header::{ACCESS_CONTROL_EXPOSE_HEADERS, AUTHORIZATION, CONTENT_TYPE};
use sea_orm::DatabaseConnection;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
use time::OffsetDateTime;
use tokio::signal;
use tower::ServiceBuilder;
use tower_http::{classify::ServerErrorsFailureClass, cors::CorsLayer, trace::TraceLayer};
use tracing::{info, info_span, Span};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -26,6 +35,7 @@ use migration::MigratorTrait;
use route::{company, in_process_transaction, transaction};
mod amf;
mod crypto;
mod db;
mod env;
mod error;
@ -33,7 +43,17 @@ mod model;
mod repo;
mod route;
mod task;
use crate::task::run_tasks;
use crate::{
crypto::{jwt_numeric_date, JWTSecretManager},
route::{
authenticated::{
follow_company_route, get_followed_companies, get_profile,
get_user_followed_companies_transactions, is_authenticated, unfollow_company_route,
},
user,
},
task::run_tasks,
};
// Module imports
use env::Config;
@ -50,6 +70,66 @@ async fn fallback() -> (StatusCode, &'static str) {
#[derive(Clone)]
pub struct AppState {
pub db: DatabaseConnection,
pub jwt_secret_manager: Arc<JWTSecretManager>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct UserJWTClaim {
pub user_id: i32,
pub username: String,
#[serde(with = "jwt_numeric_date")]
pub exp: OffsetDateTime,
}
async fn authenticator<B>(
TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
State(state): State<AppState>,
mut request: Request<B>,
next: Next<B>,
) -> Result<Response, StatusCode> {
let mut new_token_opt = None;
let authenticated = match state
.jwt_secret_manager
.decode::<UserJWTClaim>(auth.token())
.await
{
ValidationOutcome::Ok(token_data) => {
request.extensions_mut().insert(token_data);
true
}
ValidationOutcome::Unauthorized => false,
ValidationOutcome::Outdated(token_data, t) => {
request.extensions_mut().insert(token_data);
new_token_opt = Some(t);
true
}
ValidationOutcome::Error(e) => {
error!("Error in authentication layer: {}", e);
false
}
};
let mut response = next.run(request).await;
if authenticated {
if let Some(token) = new_token_opt {
// If for some reason the token cannot be put in a header, we just skip sending it
if let Ok(header_value) = HeaderValue::from_str(&token) {
let headers = response.headers_mut();
headers.insert("x-new-token", header_value);
headers.insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::from_str("x-new-token").unwrap(),
);
} else {
warn!("Failed to put an updated token in a response header");
}
}
Ok(response)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
#[tokio::main]
@ -62,12 +142,68 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with(tracing_subscriber::fmt::layer())
.init();
info!("Initializing JWT secret manager");
let jwt_secret_manager = match JWTSecretManager::init().await {
Ok(j) => {
info!("JWTSecretManager initialized");
j
}
Err(e) => {
error!("Error initializing the the JWT secret manager: {}", e);
Err(e)?
}
};
let shared_state = AppState {
db: db::init().await?,
jwt_secret_manager: Arc::new(jwt_secret_manager),
};
let _ = migration::Migrator::up(&shared_state.db, None).await;
let trace_layer = TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
info_span!(
"http_request",
method = ?request.method(),
full_path = ?request.uri(),
matched_path,
)
})
.on_request(|_request: &Request<_>, _span: &Span| {
info!("New request");
})
.on_response(|response: &Response, latency: Duration, _span: &Span| {
info!(
"Response, status {}, time {}ms",
response.status(),
latency.as_millis()
);
})
.on_failure(
|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
error!("There was an error answering this request, the server nonetheless answered in {}ms, error: {}", latency.as_millis(), error);
},
);
let authenticated_routes = Router::<AppState, _>::new()
.route("/is_authenticated", get(is_authenticated))
.route("/profile", get(get_profile))
.route("/get_followed_companies", get(get_followed_companies))
.route(
"/user_transactions",
get(get_user_followed_companies_transactions),
)
.route("/follow_company", post(follow_company_route))
.route("/unfollow_company", post(unfollow_company_route))
.layer(from_fn_with_state(shared_state.clone(), authenticator))
.with_state(shared_state.clone());
let app = Router::new()
.route("/company", get(company::get_all))
.route("/company/:name", get(company::get_by_name))
@ -96,44 +232,19 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
"/in_process_transaction/retry_all",
get(in_process_transaction::retry_all),
)
.with_state(shared_state.clone())
.route("/user/login", post(user::login))
.route("/user/register", post(user::register))
.nest("/auth", authenticated_routes)
.fallback(fallback)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
info_span!(
"http_request",
method = ?request.method(),
full_path = ?request.uri(),
matched_path,
)
})
.on_request(|_request: &Request<_>, _span: &Span| {
info!("New request");
})
.on_response(|response: &Response, latency: Duration, _span: &Span| {
info!(
"Response, status {}, time {}ms",
response.status(),
latency.as_millis()
);
})
.on_failure(
|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
error!("There was an error answering this request, the server nonetheless answered in {}ms, error: {}", latency.as_millis(), error);
},
),
ServiceBuilder::new().layer(trace_layer).layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().unwrap())
.allow_methods([Method::GET, Method::POST])
.allow_headers([CONTENT_TYPE, AUTHORIZATION]),
),
)
.layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().unwrap())
.allow_methods([Method::GET])
);
.with_state(shared_state.clone());
// Run tasks
tokio::task::spawn(async move { run_tasks(&shared_state.db).await });

@ -1,4 +1,4 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.0
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
@ -18,6 +18,8 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::transaction::Entity")]
Transaction,
#[sea_orm(has_many = "super::user_company::Entity")]
UserCompany,
}
impl Related<super::transaction::Entity> for Entity {
@ -26,4 +28,19 @@ impl Related<super::transaction::Entity> for Entity {
}
}
impl Related<super::user_company::Entity> for Entity {
fn to() -> RelationDef {
Relation::UserCompany.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
super::user_company::Relation::User.def()
}
fn via() -> Option<RelationDef> {
Some(super::user_company::Relation::Company.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

@ -1,4 +1,4 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.0
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};

@ -1,7 +1,9 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.0
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
pub mod prelude;
pub mod company;
pub mod in_process_transaction;
pub mod transaction;
pub mod user;
pub mod user_company;

@ -1,5 +1,7 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.0
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
pub use super::company::Entity as Company;
pub use super::in_process_transaction::Entity as InProcessTransaction;
pub use super::transaction::Entity as Transaction;
pub use super::user::Entity as User;
pub use super::user_company::Entity as UserCompany;

@ -1,4 +1,4 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.0
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};

@ -0,0 +1,39 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(unique)]
pub email: Option<String>,
#[sea_orm(unique)]
pub name: String,
pub password: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::user_company::Entity")]
UserCompany,
}
impl Related<super::user_company::Entity> for Entity {
fn to() -> RelationDef {
Relation::UserCompany.def()
}
}
impl Related<super::company::Entity> for Entity {
fn to() -> RelationDef {
super::user_company::Relation::Company.def()
}
fn via() -> Option<RelationDef> {
Some(super::user_company::Relation::User.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

@ -0,0 +1,47 @@
//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.6
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user_company")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: i32,
#[sea_orm(primary_key, auto_increment = false)]
pub company_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::company::Entity",
from = "Column::CompanyId",
to = "super::company::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
Company,
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
)]
User,
}
impl Related<super::company::Entity> for Entity {
fn to() -> RelationDef {
Relation::Company.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

@ -1,3 +1,4 @@
pub mod company;
pub mod in_process_transaction;
pub mod transaction;
pub mod user;

@ -0,0 +1,51 @@
use crate::model::user::ActiveModel;
use sea_orm::{
ActiveModelTrait, ActiveValue, ConnectionTrait, DbErr, DeriveIntoActiveModel, EntityTrait,
IntoActiveModel,
};
use serde::{Deserialize, Serialize};
use crate::model;
#[derive(Debug, PartialEq, Clone, DeriveIntoActiveModel, Serialize, Deserialize)]
pub struct NewUser {
pub email: Option<String>,
pub name: String,
pub password: String,
}
impl NewUser {
pub async fn create<C>(&self, db: &C) -> Result<model::user::Model, DbErr>
where
C: ConnectionTrait,
{
let res = self.clone().into_active_model().insert(db).await?;
Ok(res)
}
}
pub async fn follow_company<C>(db: &C, user_id: i32, company_id: i32) -> Result<(), DbErr>
where
C: ConnectionTrait,
{
let relation = model::user_company::ActiveModel {
user_id: ActiveValue::Set(user_id),
company_id: ActiveValue::Set(company_id),
};
relation.insert(db).await?;
Ok(())
}
pub async fn unfollow_company<C>(db: &C, user_id: i32, company_id: i32) -> Result<(), DbErr>
where
C: ConnectionTrait,
{
model::user_company::Entity::delete_by_id((user_id, company_id))
.exec(db)
.await?;
Ok(())
}

@ -0,0 +1,200 @@
use axum::{
extract::{Query, State},
Extension, Json,
};
use chrono::NaiveDate;
use sea_orm::{
sea_query::Expr, ColumnTrait, EntityTrait, FromQueryResult, ItemsAndPagesNumber, JoinType,
ModelTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, RelationTrait,
};
use serde::{Deserialize, Serialize};
use crate::{
db::paginate::PaginatedResponse,
error::AppError,
model::{
self,
prelude::{Company, Transaction, User},
},
repo::user::{follow_company, unfollow_company},
AppState, UserJWTClaim,
};
use super::Pagination;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserProfile {
email: Option<String>,
name: String,
}
impl From<model::user::Model> for UserProfile {
fn from(value: model::user::Model) -> Self {
UserProfile {
email: value.email,
name: value.name,
}
}
}
pub async fn is_authenticated() -> Result<(), AppError> {
Ok(())
}
pub async fn get_profile(
Extension(token_data): Extension<UserJWTClaim>,
State(state): State<AppState>,
) -> Result<Json<UserProfile>, AppError> {
let db = &state.db;
let user_opt = model::user::Entity::find()
.filter(model::user::Column::Id.eq(token_data.user_id))
.one(db)
.await?;
if let Some(user) = user_opt {
return Ok(Json(UserProfile {
email: user.email,
name: user.name,
}));
}
Err(AppError::NotFound(
"Authenticated user does not exist".to_string(),
))
}
#[derive(Deserialize)]
pub struct FollowCompany {
company_id: i32,
}
pub async fn follow_company_route(
Extension(token_data): Extension<UserJWTClaim>,
State(state): State<AppState>,
Json(payload): Json<FollowCompany>,
) -> Result<(), AppError> {
let db = &state.db;
follow_company(db, token_data.user_id, payload.company_id).await?;
Ok(())
}
pub async fn unfollow_company_route(
Extension(token_data): Extension<UserJWTClaim>,
State(state): State<AppState>,
Json(payload): Json<FollowCompany>,
) -> Result<(), AppError> {
let db = &state.db;
unfollow_company(db, token_data.user_id, payload.company_id).await?;
Ok(())
}
pub async fn get_followed_companies(
Extension(token_data): Extension<UserJWTClaim>,
State(state): State<AppState>,
Query(Pagination { page, size }): Query<Pagination>,
) -> Result<Json<PaginatedResponse<model::company::Model>>, AppError> {
let db = &state.db;
let s = size.unwrap_or(10);
let p = page.unwrap_or_default();
let user = User::find_by_id(token_data.user_id)
.one(db)
.await?
.ok_or(AppError::NotFound("User not found".to_string()))?;
// Eventually switch these 2 by using cloumn_as to make a count column
let (count_res, companies) = tokio::join!(user.find_related(Company).count(db), async {
user.find_related(Company)
.order_by(model::company::Column::Name, sea_orm::Order::Asc)
.offset(p * s)
.limit(s)
.all(db)
.await
});
let count = count_res?;
let res = PaginatedResponse {
count,
num_pages: count / s + 1,
list: companies?,
};
Ok(Json(res))
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromQueryResult)]
pub struct UserTransaction {
pub company_name: String,
pub company_slug: String,
pub date_published: NaiveDate,
pub date_executed: NaiveDate,
pub person: String,
pub exchange: String,
pub nature: String,
pub isin: Option<String>,
pub instrument: String,
pub volume: i32,
pub unit_price: f32,
pub total: f32,
}
pub async fn get_user_followed_companies_transactions(
Extension(token_data): Extension<UserJWTClaim>,
State(state): State<AppState>,
Query(Pagination { page, size }): Query<Pagination>,
) -> Result<Json<PaginatedResponse<UserTransaction>>, AppError> {
let db = &state.db;
let s = size.unwrap_or(20).min(50);
let query = Transaction::find()
.select_only()
.join(
JoinType::InnerJoin,
model::transaction::Relation::Company.def(),
)
.join(
JoinType::InnerJoin,
model::company::Relation::UserCompany.def(),
)
.column_as(model::company::Column::Name, "company_name")
.column_as(model::company::Column::Slug, "company_slug")
.column(model::transaction::Column::DatePublished)
.column(model::transaction::Column::DateExecuted)
.column(model::transaction::Column::Person)
.column(model::transaction::Column::Exchange)
.column(model::transaction::Column::Nature)
.column(model::transaction::Column::Isin)
.column(model::transaction::Column::Instrument)
.column(model::transaction::Column::Volume)
.column(model::transaction::Column::UnitPrice)
.column_as(
Expr::col(model::transaction::Column::UnitPrice)
.mul(Expr::col(model::transaction::Column::Volume)),
"total",
)
.filter(model::user_company::Column::UserId.eq(token_data.user_id))
.order_by_desc(model::transaction::Column::DatePublished)
.into_model::<UserTransaction>()
.paginate(db, s);
let ItemsAndPagesNumber {
number_of_pages: num_pages,
number_of_items: count,
} = query.num_items_and_pages().await?;
let p = page.unwrap_or(0).min(num_pages);
let list = query.fetch_page(p).await?;
let res = PaginatedResponse {
count,
num_pages,
list,
};
Ok(Json(res))
}

@ -2,9 +2,11 @@ use std::{fmt, str::FromStr};
use serde::{de, Deserialize, Deserializer};
pub mod authenticated;
pub mod company;
pub mod in_process_transaction;
pub mod transaction;
pub mod user;
/// Struct to deserialize paginated routes query parameters
#[derive(Deserialize)]

@ -0,0 +1,127 @@
use axum::{extract::State, Json};
use rand::RngCore;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::UserJWTClaim;
use crate::{error::AppError, model, repo::user::NewUser, AppState};
#[derive(Deserialize)]
pub struct UserLoginBody {
pub name: String,
pub password: String,
}
#[derive(Serialize)]
pub struct LoginResponse {
pub token: String,
}
pub async fn login(
State(state): State<AppState>,
Json(payload): Json<UserLoginBody>,
) -> Result<Json<LoginResponse>, AppError> {
let db = &state.db;
let user_opt = model::user::Entity::find()
.filter(model::user::Column::Name.eq(payload.name))
.one(db)
.await?;
if user_opt.is_none() {
// To prevent timing attacks, we use the same verify function on a known password
argon2::verify_encoded("$argon2i$v=19$m=4096,t=3,p=1$CXr/AgSDawghR+GmOhM0wQ$4k2TCyoqkh/YaK9mh6uEa0eRZ/CIx3bfzJs5UnCcKjw", b"1234").unwrap();
return Err(AppError::NotFound(
"User does not exist. Consider registering.".to_string(),
));
}
let user = user_opt.unwrap();
let valid =
argon2::verify_encoded(&user.password, payload.password.as_bytes()).map_err(|e| {
error!("Error verifying the password for user {}: {}", user.name, e);
AppError::InternalServerError(
"There was an error verifying authentication.".to_string(),
)
})?;
if !valid {
return Err(AppError::Unauthorized);
}
// Generate a JWT and store it as a same site cookie
let claim = UserJWTClaim {
user_id: user.id,
username: user.name.clone(),
exp: OffsetDateTime::now_utc() + time::Duration::days(5),
};
let token = state
.jwt_secret_manager
.encode_new(&claim)
.await
.map_err(|e| {
error!("Failed to encode a new JWT for user {}: {}", user.name, e);
AppError::InternalServerError(
"There was an error while encoding the authorization token".to_string(),
)
})?;
Ok(Json(LoginResponse { token }))
}
#[derive(Deserialize)]
pub struct UserRegisterBody {
pub name: String,
pub email: String,
pub password: String,
}
pub async fn register(
State(state): State<AppState>,
Json(payload): Json<UserRegisterBody>,
) -> Result<(), AppError> {
let db = &state.db;
let mut filter = model::user::Column::Name.eq(&payload.name);
let mut email = Some(payload.email.to_string());
if !payload.email.is_empty() {
filter = filter.or(model::user::Column::Email.eq(&payload.email));
} else {
email = None;
}
let user_opt = model::user::Entity::find().filter(filter).one(db).await?;
if user_opt.is_some() {
return Err(AppError::Conflict(
"The username or email is already in use.".to_string(),
));
}
let salt = generate_salt();
let pass_hash =
argon2::hash_encoded(payload.password.as_ref(), &salt, &argon2::Config::default())
.map_err(|e| {
error!("Failed to hash a password: {}", e);
AppError::InternalServerError(
"There was an error in the registration process".to_string(),
)
})?;
let new_user = NewUser {
email,
name: payload.name,
password: pass_hash,
};
new_user.create(db).await?;
Ok(())
}
fn generate_salt() -> Vec<u8> {
let mut salt = [0u8; 16]; // 16 bytes salt length (adjust as needed)
rand::thread_rng().fill_bytes(&mut salt);
salt.into()
}
Loading…
Cancel
Save