From 0e90a4e2e42574c3a54ecf3056b8ca8ce40c0ec8 Mon Sep 17 00:00:00 2001 From: DSeeLP <46624152+DSeeLP@users.noreply.github.com> Date: Sun, 16 Mar 2025 14:21:27 +0100 Subject: [PATCH] add notification socket --- Cargo.lock | 151 +++++++++++++++++++++++++++++++ Cargo.toml | 4 +- openapi-def.yaml | 13 +++ src/api/account.rs | 2 +- src/api/chats.rs | 16 +++- src/api/mod.rs | 7 +- src/api/socket.rs | 194 ++++++++++++++++++++++++++++++++++++++++ src/api/transactions.rs | 21 ++++- src/api/user.rs | 2 +- src/main.rs | 3 +- src/model/chats.rs | 26 ++++-- 11 files changed, 424 insertions(+), 15 deletions(-) create mode 100644 src/api/socket.rs diff --git a/Cargo.lock b/Cargo.lock index ab5c5e8..c1bc11c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,19 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom 0.2.15", + "once_cell", + "version_check", + "zerocopy 0.7.35", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -41,6 +54,12 @@ dependencies = [ "libc", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "argon2" version = "0.5.3" @@ -53,6 +72,18 @@ dependencies = [ "password-hash", ] +[[package]] +name = "async-broadcast" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435a87a52755b8f27fcf321ac4f04b2802e337c8c4872923137471ec39c37532" +dependencies = [ + "event-listener", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-trait" version = "0.1.86" @@ -77,6 +108,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -96,8 +128,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -143,8 +177,10 @@ dependencies = [ name = "bankserver" version = "0.1.0" dependencies = [ + "async-broadcast", "axum", "chrono", + "concread", "config", "dbmigrator", "deadpool", @@ -274,6 +310,30 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "concread" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a06c26e76cd1d7a88a44324d0cf18b11589be552e97af09bee345f7e7334c6d" +dependencies = [ + "ahash", + "arc-swap", + "crossbeam-utils", + "smallvec", + "sptr", + "tokio", + "tracing", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "config" version = "0.15.8" @@ -300,6 +360,12 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.6" @@ -345,6 +411,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" + [[package]] name = "dbmigrator" version = "0.4.4-alpha" @@ -452,6 +524,27 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "event-listener" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -929,6 +1022,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1366,6 +1465,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1435,6 +1545,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "sptr" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a" + [[package]] name = "static_assertions" version = "1.1.0" @@ -1612,6 +1728,18 @@ dependencies = [ "whoami", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.13" @@ -1725,6 +1853,23 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -1764,6 +1909,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.15.1" diff --git a/Cargo.toml b/Cargo.toml index e2161c6..e909369 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,10 @@ name = "generate-schemas" features = ["schemas"] [dependencies] -axum = "0.8" +async-broadcast = "0.7.2" +axum = { version = "0.8", features = ["ws"] } chrono = { version = "0.4.40", features = ["serde"] } +concread = { version = "0.5.4", default-features = false, features = ["ahash", "asynch", "maps"] } config = { version = "0.15.8", default-features = false } dbmigrator = { git = "https://github.com/DSeeLP/dbmigrator.git", branch = "macros", version = "0.4.4-alpha", features = ["tokio-postgres"] } deadpool = "0.12" diff --git a/openapi-def.yaml b/openapi-def.yaml index 12b52ff..d276a1f 100644 --- a/openapi-def.yaml +++ b/openapi-def.yaml @@ -405,6 +405,19 @@ paths: $ref: '#/components/responses/InvalidBody' default: $ref: '#/components/responses/Default' + /api/socket: + get: + operationId: websocket-events + summary: Open websocket to receive events + security: + - bearer: [] + responses: + 101: + description: Switching protocols + 401: + $ref: '#/components/responses/Unauthorized' + default: + $ref: '#/components/responses/Default' components: parameters: Direction: diff --git a/src/api/account.rs b/src/api/account.rs index c959ac8..4e5df65 100644 --- a/src/api/account.rs +++ b/src/api/account.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::{Router, extract::Path, http::StatusCode, routing::get}; +use axum::{Router, extract::Path, routing::get}; use tracing::instrument; use uuid::Uuid; diff --git a/src/api/chats.rs b/src/api/chats.rs index 8346bdf..11e2ccd 100644 --- a/src/api/chats.rs +++ b/src/api/chats.rs @@ -12,7 +12,10 @@ use crate::{ }; use super::{ - AppState, EState, Error, Json, Pagination, RequestPagination, State, auth::Auth, make_schemas, + AppState, EState, Error, Json, Pagination, RequestPagination, State, + auth::Auth, + make_schemas, + socket::{SocketEvent, SocketMessage}, transactions::NameOrUuid, }; @@ -108,5 +111,16 @@ pub async fn send_message( let mut client = state.conn().await?; check_chat(&client, id, auth.user_id()).await?; let message = Chat::send(&mut client, id, auth.user_id(), body.text, body.extra).await?; + let notfication = SocketMessage::Event(SocketEvent::MessageReceived { + chat: id, + from: auth.user_id(), + message_id: message.id, + }); + for id in Chat::member_ids(&client, id).await? { + if id == auth.user_id() { + continue; + } + state.sockets.send(id, notfication.clone()).await; + } Ok(Json(message)) } diff --git a/src/api/mod.rs b/src/api/mod.rs index c6b8199..8f930e4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -14,10 +14,13 @@ use tracing_error::SpanTrace; pub use axum::extract::State as EState; +pub use socket::Sockets; + mod account; mod auth; mod chats; mod docs; +mod socket; mod transactions; mod user; @@ -307,11 +310,11 @@ impl IntoResponse for Error { } } -#[derive(Clone)] pub struct AppState { pub pool: deadpool_postgres::Pool, pub encoding_key: EncodingKey, pub decoding_key: DecodingKey, + pub sockets: Arc, } impl AppState { @@ -330,6 +333,7 @@ pub fn router() -> Router> { .nest("/accounts", account::router()) .nest("/transactions", transactions::router()) .nest("/chats", chats::router()) + .nest("/socket", socket::router()) } make_schemas!((); (ApiError, _ValidationErrors), [crate::model::schemas, auth::schemas, user::schemas, transactions::schemas, account::schemas, chats::schemas]); @@ -382,6 +386,7 @@ macro_rules! make_schemas { } pub(crate) use make_schemas; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] diff --git a/src/api/socket.rs b/src/api/socket.rs new file mode 100644 index 0000000..84cb69a --- /dev/null +++ b/src/api/socket.rs @@ -0,0 +1,194 @@ +use std::{ + fmt::Debug, + sync::{Arc, Weak}, +}; + +use async_broadcast::{Receiver, Sender}; +use axum::{ + Router, + extract::{ + WebSocketUpgrade, + ws::{Message, Utf8Bytes, WebSocket}, + }, + response::Response, + routing::get, +}; +use concread::hashmap::asynch::HashMap; +use serde::Serialize; +use tracing::{error, info, instrument}; +use uuid::Uuid; + +use super::{AppState, EState, Error, State, auth::Auth, make_schemas}; + +pub(super) fn router() -> Router> { + Router::new().route("/", get(handle)) +} + +make_schemas!((); ()); + +#[instrument(skip(state))] +pub async fn handle( + EState(state): State, + auth: Auth, + ws: WebSocketUpgrade, +) -> Result { + let sockets = state.sockets.clone(); + let response = ws.on_upgrade(|socket| async move { + SocketHandler::new(auth.user_id(), socket, sockets) + .await + .await + }); + Ok(response) +} + +pub struct Sockets { + sockets: HashMap>, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case", tag = "type", content = "data")] +pub enum SocketMessage { + Event(SocketEvent), +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum SocketEvent { + MessageReceived { + chat: Uuid, + from: Uuid, + message_id: Uuid, + }, + PaymentReceived { + from: Uuid, + to: Uuid, + amount: u64, + }, +} + +impl Sockets { + pub fn new() -> Self { + Self { + sockets: HashMap::new(), + } + } + + pub async fn send(&self, id: Uuid, message: SocketMessage) { + let json = serde_json::to_string(&message).unwrap(); + let tx = self.sockets.read(); + let Some(sender) = tx.get(&id) else { + return; + }; + let _ = sender.broadcast_direct(json.into()).await; + } + + pub async fn send_multiple(&self, iter: impl IntoIterator) { + let tx = self.sockets.read(); + for (id, message) in iter { + let json = serde_json::to_string(&message).unwrap(); + let Some(sender) = tx.get(&id) else { + continue; + }; + let _ = sender.broadcast_direct(json.into()).await; + } + } +} + +struct SocketHandler { + id: Uuid, + state: Option>, +} + +impl Debug for SocketHandler { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SocketHandler") + .field("id", &self.id) + .finish() + } +} + +impl SocketHandler { + pub async fn new( + id: Uuid, + socket: WebSocket, + sockets: Arc, + ) -> impl Future { + let handler = Self { + id, + state: Some(Arc::downgrade(&sockets)), + }; + let mut tx = sockets.sockets.write().await; + let rx = if let Some(sender) = tx.get(&id) { + sender.new_receiver() + } else { + let (mut sender, receiver) = async_broadcast::broadcast(16); + sender.set_await_active(false); + tx.insert(id, sender); + tx.commit(); + receiver + }; + handler.run(socket, rx) + } + + async fn shutdown(&self, state: Arc) { + let mut tx = state.sockets.write().await; + let Some(sender) = tx.get(&self.id) else { + return; + }; + if sender.receiver_count() == 0 { + tx.remove(&self.id); + } + tx.commit(); + } + + #[instrument(skip(socket, rx))] + async fn run(mut self, mut socket: WebSocket, mut rx: Receiver) { + loop { + match Self::handle(&mut socket, &mut rx).await { + Ok(true) => {} + Ok(false) => break, + Err(err) => { + error!("{:?}", err); + break; + } + } + } + let Some(state) = self.state.take() else { + return; + }; + let Some(state) = state.upgrade() else { return }; + self.shutdown(state).await; + } + + async fn handle( + socket: &mut WebSocket, + rx: &mut Receiver, + ) -> Result { + tokio::select! { + msg = socket.recv() => { + let msg = match msg { + Some(v) => v?, + None => return Ok(false), + }; + info!("Received message: {:?}", msg); + }, + message = rx.recv_direct() => { + let message = message.map_err(|err| axum::Error::new(err))?; + socket.send(Message::Text(message)).await?; + } + } + Ok(true) + } +} + +impl Drop for SocketHandler { + fn drop(&mut self) { + let Some(state) = self.state.take() else { + return; + }; + let Some(state) = state.upgrade() else { return }; + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.shutdown(state)) + }); + } +} diff --git a/src/api/transactions.rs b/src/api/transactions.rs index c70be22..ac8e1be 100644 --- a/src/api/transactions.rs +++ b/src/api/transactions.rs @@ -14,7 +14,9 @@ use crate::model::{ use super::{ AppState, EState, Error, Json, Pagination, PaginationType, Query, RequestPagination, State, - auth::Auth, make_schemas, + auth::Auth, + make_schemas, + socket::{SocketEvent, SocketMessage}, }; pub(super) fn router() -> Router> { @@ -154,7 +156,7 @@ impl AccountSelector { async fn account_id( &self, client: &impl GenericClient, - ) -> Result, tokio_postgres::Error> { + ) -> Result, tokio_postgres::Error> { let user_id = match &self.user { NameOrUuid::Id(uuid) => *uuid, NameOrUuid::Name(name) => match User::info_by_name(client, &*name).await? { @@ -169,7 +171,7 @@ impl AccountSelector { }, None => user_id, }; - Ok(Some(account_id)) + Ok(Some((user_id, account_id))) } } @@ -384,7 +386,7 @@ pub async fn make_payment( }) else { todo!("from account doesn't exist") }; - let Some(to) = to.account_id(&client).await? else { + let Some((to_user, to)) = to.account_id(&client).await? else { todo!("to account doesn't exist") }; if from.balance < amount { @@ -401,6 +403,17 @@ pub async fn make_payment( .await?; let transaction = Transaction::create(&mut client, from.id, to, amount, None).await?; client.commit().await?; + state + .sockets + .send( + to_user, + SocketMessage::Event(SocketEvent::PaymentReceived { + from: from.id, + to, + amount, + }), + ) + .await; Ok(Json(transaction)) } diff --git a/src/api/user.rs b/src/api/user.rs index 380c645..0dbdd95 100644 --- a/src/api/user.rs +++ b/src/api/user.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::{Router, extract::Path, http::StatusCode, routing::get}; +use axum::{Router, extract::Path, routing::get}; use serde::{Deserialize, Serialize}; use tracing::instrument; use uuid::Uuid; diff --git a/src/main.rs b/src/main.rs index bc08831..1ab88a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::Router; -use bankserver::{Config, setup_db}; +use bankserver::{Config, api::Sockets, setup_db}; use jsonwebtoken::{DecodingKey, EncodingKey}; use tokio::{net::TcpListener, signal}; use tracing::{info, level_filters::LevelFilter}; @@ -39,6 +39,7 @@ async fn main() { pool, encoding_key, decoding_key, + sockets: Arc::new(Sockets::new()), })); let listener = TcpListener::bind(config.socket_addr()).await.unwrap(); diff --git a/src/model/chats.rs b/src/model/chats.rs index 747c886..bcd44bc 100644 --- a/src/model/chats.rs +++ b/src/model/chats.rs @@ -26,11 +26,11 @@ pub struct ChatInfo { #[derive(Serialize)] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] pub struct ChatMessage { - id: Uuid, - sender: Uuid, - time: DateTime, - text: String, - extra: Option, + pub id: Uuid, + pub sender: Uuid, + pub time: DateTime, + pub text: String, + pub extra: Option, } impl PaginationType for Chat { @@ -134,6 +134,22 @@ impl Chat { Ok(res.map(ChatInfo::from)) } + pub async fn member_ids( + client: &impl GenericClient, + chat: Uuid, + ) -> Result, tokio_postgres::Error> { + let stmt = client + .prepare_cached("select \"user\" from chat_members where chat = $1") + .await?; + let res = client + .query(&stmt, &[&chat]) + .await? + .into_iter() + .map(|row| row.get(0)) + .collect(); + Ok(res) + } + pub async fn exists_between( client: &impl GenericClient, peers: (Uuid, Uuid),