use std::{borrow::Cow, sync::Arc}; use axum::{Router, http::StatusCode, routing::post}; use bank_core::{ApiError, Name, NameOrUuid, make_schemas, transaction::Transaction}; use deadpool_postgres::GenericClient; use garde::Validate; use schemars::JsonSchema; use serde::Deserialize; use tokio_postgres::Statement; use tracing::{error, instrument}; use uuid::Uuid; use crate::model::{Accounts, Transactions, Users}; use super::{ AppState, EState, Error, InteropState, Json, State, auth::Auth, socket::{SocketEvent, SocketMessage}, }; pub(super) fn router() -> Router> { Router::new().route("/", post(make_payment)) } make_schemas!((MakePayment); ()); #[derive(Debug, Validate, PartialEq)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] pub struct AccountSelector { #[garde(dive)] pub user: NameOrUuid, #[garde(dive)] pub account: Option, } impl AccountSelector { #[instrument(skip(client))] pub async fn account_id( &self, client: &impl GenericClient, ) -> Result, tokio_postgres::Error> { let user_id = match &self.user { NameOrUuid::Id(uuid) => *uuid, NameOrUuid::Name(name) => match Users::info_by_name(client, &*name).await? { Some(info) => info.id, None => return Ok(None), }, }; let account_id = match self.account.as_ref() { Some(name) => match Accounts::get_for_user(client, user_id, &*name).await? { Some(info) => info.id, None => return Ok(None), }, None => user_id, }; Ok(Some((user_id, account_id))) } } #[doc(hidden)] #[derive(Debug, Deserialize, Validate, PartialEq)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] #[serde(untagged)] pub enum UnvalidatedAccountSelector { Username(#[garde(dive)] Name), Object { #[garde(dive)] user: NameOrUuid, #[garde(dive)] account: Option, }, } impl<'de> Deserialize<'de> for AccountSelector { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Ok( match UnvalidatedAccountSelector::deserialize(deserializer)? { UnvalidatedAccountSelector::Username(name) => Self { user: NameOrUuid::Name(name), account: None, }, UnvalidatedAccountSelector::Object { user, account } => Self { user, account }, }, ) } } pub trait ValidateTransform: Sized { type Src: Validate; fn schema_name() -> Cow<'static, str>; fn schema_id() -> Cow<'static, str> { Cow::Borrowed(std::any::type_name::()) } fn validate_transform_into( src: Self::Src, ctx: &::Context, parent: &mut dyn FnMut() -> garde::Path, report: &mut garde::Report, ) -> Option; } #[derive(Debug, PartialEq)] pub struct UnvalidatedTransform { src: Dest::Src, } impl<'de, Dest> Deserialize<'de> for UnvalidatedTransform where Dest: ValidateTransform, Dest::Src: Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Ok(UnvalidatedTransform { src: >::deserialize(deserializer)?, }) } } #[derive(Debug, PartialEq)] pub enum AccountTarget { Selector(AccountSelector), Interop(String), } #[derive(Debug, PartialEq, Validate)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemas", serde(untagged))] pub enum UnvalidatedAccountTarget { /// Interop user Interop(#[garde(pattern("^.{2,}@[a-z0-9]{2,4}$"))] String), Selector(#[garde(dive)] UnvalidatedAccountSelector), } impl<'de> Deserialize<'de> for UnvalidatedAccountTarget { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { #[derive(Deserialize)] #[serde(untagged)] enum AccountTargetHelper { Text(String), Selector(UnvalidatedAccountSelector), } Ok(match AccountTargetHelper::deserialize(deserializer)? { AccountTargetHelper::Text(text) => { // TODO: don't hardcode prefix if let Some(text) = text.strip_suffix("@thc") { Self::Interop(text.into()) } else { Self::Selector(UnvalidatedAccountSelector::Username(Name(text))) } } AccountTargetHelper::Selector(selector) => Self::Selector(selector), }) } } impl ValidateTransform for AccountTarget { type Src = UnvalidatedAccountTarget; fn schema_name() -> Cow<'static, str> { Cow::Borrowed("AccountTarget") } fn validate_transform_into( src: Self::Src, ctx: &::Context, parent: &mut dyn FnMut() -> garde::Path, report: &mut garde::Report, ) -> Option { match src { UnvalidatedAccountTarget::Interop(user) => Some(AccountTarget::Interop(user)), UnvalidatedAccountTarget::Selector(selector) => { AccountSelector::validate_transform_into(selector, ctx, parent, report) .map(Self::Selector) } } } } impl schemars::JsonSchema for UnvalidatedTransform where Dest: ValidateTransform, Dest::Src: JsonSchema, { fn schema_name() -> Cow<'static, str> { Dest::schema_name() } fn schema_id() -> Cow<'static, str> { Dest::schema_id() } fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema { ::json_schema(generator) } fn always_inline_schema() -> bool { ::always_inline_schema() } } impl UnvalidatedTransform { pub fn validate_transform_into( self, ctx: &::Context, parent: &mut dyn FnMut() -> garde::Path, report: &mut garde::Report, ) -> Option { Dest::validate_transform_into(self.src, ctx, parent, report) } pub fn validate_with( self, ctx: &::Context, ) -> Result { let mut report = garde::Report::new(); let result = Dest::validate_transform_into(self.src, ctx, &mut garde::Path::empty, &mut report); match report.is_empty() { true => Ok(result.unwrap()), false => Err(report), } } } impl ValidateTransform for AccountSelector { type Src = UnvalidatedAccountSelector; fn schema_name() -> Cow<'static, str> { Cow::Borrowed("AccountSelector") } fn validate_transform_into( src: Self::Src, ctx: &::Context, mut parent: &mut dyn FnMut() -> garde::Path, report: &mut garde::Report, ) -> Option { let count = report.iter().count(); match src { UnvalidatedAccountSelector::Username(name) => { name.validate_into(ctx, parent, report); if count != report.iter().count() { return None; } Some(AccountSelector { user: NameOrUuid::Name(name), account: None, }) } UnvalidatedAccountSelector::Object { user, account } => { { let mut path = garde::util::nested_path!(parent, "user"); user.validate_into(ctx, &mut path, report); } let mut path = garde::util::nested_path!(parent, "account"); account.validate_into(ctx, &mut path, report); if count != report.iter().count() { return None; } Some(AccountSelector { user, account }) } } } } #[derive(Debug, Deserialize, PartialEq)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] pub struct MakePayment { from: NameOrUuid, to: UnvalidatedTransform, amount: u64, } impl MakePayment { pub fn validate(self) -> Result<(NameOrUuid, AccountTarget, u64), garde::Report> { let mut report = garde::Report::new(); self.from .validate_into(&(), &mut || garde::Path::new("from"), &mut report); let to = self .to .validate_transform_into(&(), &mut || garde::Path::new("to"), &mut report); if let Err(error) = garde::rules::range::apply(&self.amount, (Some(1), None)) { report.append(garde::Path::new("amount"), error); } if !report.is_empty() { return Err(report); } Ok((self.from, to.unwrap(), self.amount)) } } const TARGET_NOT_FOUND: ApiError<'static> = ApiError::const_new( StatusCode::NOT_FOUND, "transaction.target.not_found", "Not Found", ); pub async fn make_payment( EState(state): State, auth: Auth, Json(body): Json, ) -> Result, Error> { let (from, to, amount) = body.validate()?; let user_id = auth.user_id(); let mut client = state.conn().await?; let mut client = client.transaction().await?; let Some(from) = (match from { NameOrUuid::Id(uuid) => Accounts::by_id(&client, uuid).await?, NameOrUuid::Name(name) => Accounts::get_for_user(&client, user_id, &*name).await?, }) else { return Err(ApiError::const_new( StatusCode::NOT_FOUND, "transaction.from.not_found", "Not Found", ) .into()); }; let transaction = match to { AccountTarget::Interop(user) => { let Some(interop) = &state.interop else { return Err(TARGET_NOT_FOUND.into()); }; let builder = TransactionBuilder::new(&mut client, amount).await?; let transaction = builder.interop_pay(from.id, user.clone(), None).await?; if amount % 100 != 0 { todo!() } if let Err(err) = send_interop_payment(&interop, &from.name, &user, (amount / 100) as u32).await { return Err(match err { InteropError::NotFound => TARGET_NOT_FOUND.into(), InteropError::Http(err) => { error!("{err}"); ApiError::INTERNAL_SERVER_ERROR.into() } InteropError::Other(response) => { error!("{response:?}"); ApiError::INTERNAL_SERVER_ERROR.into() } }); } transaction } AccountTarget::Selector(selector) => { let Some((to_user, to)) = selector.account_id(&client).await? else { return Err(TARGET_NOT_FOUND.into()); }; if from.balance < amount { return Err(ApiError::const_new( StatusCode::BAD_REQUEST, "transaction.insufficient_funds", "Insufficient funds", ) .into()); } let builder = TransactionBuilder::new(&mut client, amount).await?; let (transaction, notification) = builder.normal(from.id, to, None).await?; client.commit().await?; state.sockets.send(to_user, notification).await; transaction } }; Ok(Json(transaction)) } pub struct TransactionBuilder<'a, T> { client: &'a mut T, update_balance: Statement, amount: u64, } impl<'a, T: GenericClient> TransactionBuilder<'a, T> { pub async fn new(client: &'a mut T, amount: u64) -> Result { let update_balance = client .prepare_cached("update accounts set balance = balance + $2 where id = $1") .await?; Ok(Self { client, update_balance, amount, }) } async fn update_balance( &mut self, account: Uuid, is_to: bool, ) -> Result<(), tokio_postgres::Error> { let amount = if is_to { self.amount as i64 } else { -(self.amount as i64) }; self.client .execute(&self.update_balance, &[&account, &amount]) .await?; Ok(()) } pub async fn system( mut self, to: Uuid, message: Option, ) -> Result<(Transaction, SocketMessage), tokio_postgres::Error> { self.update_balance(to, true).await?; let transaction = Transactions::create(self.client, None, Some(to), None, self.amount, message).await?; Ok(( transaction, SocketMessage::Event(SocketEvent::PaymentReceived { from: None, to, amount: self.amount, }), )) } pub async fn normal( mut self, from: Uuid, to: Uuid, message: Option, ) -> Result<(Transaction, SocketMessage), tokio_postgres::Error> { self.update_balance(from, false).await?; self.update_balance(to, true).await?; let transaction = Transactions::create( self.client, Some(from), Some(to), None, self.amount, message, ) .await?; Ok(( transaction, SocketMessage::Event(SocketEvent::PaymentReceived { from: Some(NameOrUuid::Id(from)), to, amount: self.amount, }), )) } pub async fn interop_pay( mut self, from: Uuid, name: String, message: Option, ) -> Result { self.update_balance(from, false).await?; let transaction = Transactions::create( self.client, Some(from), None, Some(name), self.amount, message, ) .await?; Ok(transaction) } pub async fn interop_receive( mut self, to: Uuid, name: String, message: Option, ) -> Result<(Transaction, SocketMessage), tokio_postgres::Error> { self.update_balance(to, true).await?; let transaction = Transactions::create( self.client, None, Some(to), Some(name.clone()), self.amount, message, ) .await?; Ok(( transaction, SocketMessage::Event(SocketEvent::PaymentReceived { from: Some(NameOrUuid::Name(Name(name))), to, amount: self.amount, }), )) } } enum InteropError { NotFound, Http(reqwest::Error), Other(reqwest::Response), } async fn send_interop_payment( interop: &InteropState, from: &str, to: &str, amount: u32, ) -> Result<(), InteropError> { let response = interop .client .post(interop.pay_url.clone()) .json(&serde_json::json!({ "from": from, "to": to, "amount": amount })) .send() .await .map_err(InteropError::Http)?; if response.status() == StatusCode::OK { return Ok(()); } if response.status() == StatusCode::NOT_FOUND { return Err(InteropError::NotFound); } Err(InteropError::Other(response)) } #[cfg(test)] mod tests { use bank_core::Name; use crate::api::transactions::{ AccountSelector, AccountTarget, NameOrUuid, UnvalidatedAccountSelector, UnvalidatedAccountTarget, UnvalidatedTransform, }; use super::MakePayment; #[test] fn payment_body() { let uuid = uuid::uuid!("6fd8b7ab-7278-45b5-9916-2b09b4224a38"); let payment = serde_json::from_str::( r#"{"from":"personal", "to": { "user": "6fd8b7ab-7278-45b5-9916-2b09b4224a38" }, "amount": 100}"#, ) .unwrap(); assert_eq!( MakePayment { from: NameOrUuid::Name(Name("personal".into())), to: UnvalidatedTransform { src: UnvalidatedAccountTarget::Selector(UnvalidatedAccountSelector::Object { user: NameOrUuid::Id(uuid), account: None }) }, amount: 100 }, payment ); assert_eq!( payment.validate().unwrap(), ( NameOrUuid::Name(Name("personal".into())), AccountTarget::Selector(AccountSelector { user: NameOrUuid::Id(uuid), account: None }), 100 ) ); let payment = serde_json::from_str::( r#"{"from":"personal", "to": { "user": "test", "account": "abc" }, "amount": 100}"#, ) .unwrap(); assert_eq!( MakePayment { from: NameOrUuid::Name(Name("personal".into())), to: UnvalidatedTransform { src: UnvalidatedAccountTarget::Selector(UnvalidatedAccountSelector::Object { user: NameOrUuid::Name(Name("test".into())), account: Some(Name("abc".into())) }) }, amount: 100 }, payment ); assert_eq!( payment.validate().unwrap(), ( NameOrUuid::Name(Name("personal".into())), AccountTarget::Selector(AccountSelector { user: NameOrUuid::Name(Name("test".into())), account: Some(Name("abc".into())) }), 100 ) ); } }