add notification socket

This commit is contained in:
DSeeLP 2025-03-16 14:21:27 +01:00
parent 4c72fe6bab
commit 0e90a4e2e4
11 changed files with 424 additions and 15 deletions

151
Cargo.lock generated
View File

@ -17,6 +17,19 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "ahash"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"getrandom 0.2.15",
"once_cell",
"version_check",
"zerocopy 0.7.35",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "1.1.3" version = "1.1.3"
@ -41,6 +54,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "argon2" name = "argon2"
version = "0.5.3" version = "0.5.3"
@ -53,6 +72,18 @@ dependencies = [
"password-hash", "password-hash",
] ]
[[package]]
name = "async-broadcast"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "435a87a52755b8f27fcf321ac4f04b2802e337c8c4872923137471ec39c37532"
dependencies = [
"event-listener",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.86" version = "0.1.86"
@ -77,6 +108,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
@ -96,8 +128,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -143,8 +177,10 @@ dependencies = [
name = "bankserver" name = "bankserver"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-broadcast",
"axum", "axum",
"chrono", "chrono",
"concread",
"config", "config",
"dbmigrator", "dbmigrator",
"deadpool", "deadpool",
@ -274,6 +310,30 @@ dependencies = [
"static_assertions", "static_assertions",
] ]
[[package]]
name = "concread"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a06c26e76cd1d7a88a44324d0cf18b11589be552e97af09bee345f7e7334c6d"
dependencies = [
"ahash",
"arc-swap",
"crossbeam-utils",
"smallvec",
"sptr",
"tokio",
"tracing",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "config" name = "config"
version = "0.15.8" version = "0.15.8"
@ -300,6 +360,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]] [[package]]
name = "crypto-common" name = "crypto-common"
version = "0.1.6" version = "0.1.6"
@ -345,6 +411,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "data-encoding"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010"
[[package]] [[package]]
name = "dbmigrator" name = "dbmigrator"
version = "0.4.4-alpha" version = "0.4.4-alpha"
@ -452,6 +524,27 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "event-listener"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]] [[package]]
name = "fallible-iterator" name = "fallible-iterator"
version = "0.2.0" version = "0.2.0"
@ -929,6 +1022,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.3" version = "0.12.3"
@ -1366,6 +1465,17 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
@ -1435,6 +1545,12 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "sptr"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a"
[[package]] [[package]]
name = "static_assertions" name = "static_assertions"
version = "1.1.0" version = "1.1.0"
@ -1612,6 +1728,18 @@ dependencies = [
"whoami", "whoami",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.13" version = "0.7.13"
@ -1725,6 +1853,23 @@ dependencies = [
"tracing-log", "tracing-log",
] ]
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.18.0" version = "1.18.0"
@ -1764,6 +1909,12 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.15.1" version = "1.15.1"

View File

@ -12,8 +12,10 @@ name = "generate-schemas"
features = ["schemas"] features = ["schemas"]
[dependencies] [dependencies]
axum = "0.8" async-broadcast = "0.7.2"
axum = { version = "0.8", features = ["ws"] }
chrono = { version = "0.4.40", features = ["serde"] } chrono = { version = "0.4.40", features = ["serde"] }
concread = { version = "0.5.4", default-features = false, features = ["ahash", "asynch", "maps"] }
config = { version = "0.15.8", default-features = false } config = { version = "0.15.8", default-features = false }
dbmigrator = { git = "https://github.com/DSeeLP/dbmigrator.git", branch = "macros", version = "0.4.4-alpha", features = ["tokio-postgres"] } dbmigrator = { git = "https://github.com/DSeeLP/dbmigrator.git", branch = "macros", version = "0.4.4-alpha", features = ["tokio-postgres"] }
deadpool = "0.12" deadpool = "0.12"

View File

@ -405,6 +405,19 @@ paths:
$ref: '#/components/responses/InvalidBody' $ref: '#/components/responses/InvalidBody'
default: default:
$ref: '#/components/responses/Default' $ref: '#/components/responses/Default'
/api/socket:
get:
operationId: websocket-events
summary: Open websocket to receive events
security:
- bearer: []
responses:
101:
description: Switching protocols
401:
$ref: '#/components/responses/Unauthorized'
default:
$ref: '#/components/responses/Default'
components: components:
parameters: parameters:
Direction: Direction:

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{Router, extract::Path, http::StatusCode, routing::get}; use axum::{Router, extract::Path, routing::get};
use tracing::instrument; use tracing::instrument;
use uuid::Uuid; use uuid::Uuid;

View File

@ -12,7 +12,10 @@ use crate::{
}; };
use super::{ use super::{
AppState, EState, Error, Json, Pagination, RequestPagination, State, auth::Auth, make_schemas, AppState, EState, Error, Json, Pagination, RequestPagination, State,
auth::Auth,
make_schemas,
socket::{SocketEvent, SocketMessage},
transactions::NameOrUuid, transactions::NameOrUuid,
}; };
@ -108,5 +111,16 @@ pub async fn send_message(
let mut client = state.conn().await?; let mut client = state.conn().await?;
check_chat(&client, id, auth.user_id()).await?; check_chat(&client, id, auth.user_id()).await?;
let message = Chat::send(&mut client, id, auth.user_id(), body.text, body.extra).await?; let message = Chat::send(&mut client, id, auth.user_id(), body.text, body.extra).await?;
let notfication = SocketMessage::Event(SocketEvent::MessageReceived {
chat: id,
from: auth.user_id(),
message_id: message.id,
});
for id in Chat::member_ids(&client, id).await? {
if id == auth.user_id() {
continue;
}
state.sockets.send(id, notfication.clone()).await;
}
Ok(Json(message)) Ok(Json(message))
} }

View File

@ -14,10 +14,13 @@ use tracing_error::SpanTrace;
pub use axum::extract::State as EState; pub use axum::extract::State as EState;
pub use socket::Sockets;
mod account; mod account;
mod auth; mod auth;
mod chats; mod chats;
mod docs; mod docs;
mod socket;
mod transactions; mod transactions;
mod user; mod user;
@ -307,11 +310,11 @@ impl IntoResponse for Error {
} }
} }
#[derive(Clone)]
pub struct AppState { pub struct AppState {
pub pool: deadpool_postgres::Pool, pub pool: deadpool_postgres::Pool,
pub encoding_key: EncodingKey, pub encoding_key: EncodingKey,
pub decoding_key: DecodingKey, pub decoding_key: DecodingKey,
pub sockets: Arc<Sockets>,
} }
impl AppState { impl AppState {
@ -330,6 +333,7 @@ pub fn router() -> Router<Arc<AppState>> {
.nest("/accounts", account::router()) .nest("/accounts", account::router())
.nest("/transactions", transactions::router()) .nest("/transactions", transactions::router())
.nest("/chats", chats::router()) .nest("/chats", chats::router())
.nest("/socket", socket::router())
} }
make_schemas!((); (ApiError, _ValidationErrors), [crate::model::schemas, auth::schemas, user::schemas, transactions::schemas, account::schemas, chats::schemas]); make_schemas!((); (ApiError, _ValidationErrors), [crate::model::schemas, auth::schemas, user::schemas, transactions::schemas, account::schemas, chats::schemas]);
@ -382,6 +386,7 @@ macro_rules! make_schemas {
} }
pub(crate) use make_schemas; pub(crate) use make_schemas;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))]

194
src/api/socket.rs Normal file
View File

@ -0,0 +1,194 @@
use std::{
fmt::Debug,
sync::{Arc, Weak},
};
use async_broadcast::{Receiver, Sender};
use axum::{
Router,
extract::{
WebSocketUpgrade,
ws::{Message, Utf8Bytes, WebSocket},
},
response::Response,
routing::get,
};
use concread::hashmap::asynch::HashMap;
use serde::Serialize;
use tracing::{error, info, instrument};
use uuid::Uuid;
use super::{AppState, EState, Error, State, auth::Auth, make_schemas};
pub(super) fn router() -> Router<Arc<AppState>> {
Router::new().route("/", get(handle))
}
make_schemas!((); ());
#[instrument(skip(state))]
pub async fn handle(
EState(state): State,
auth: Auth,
ws: WebSocketUpgrade,
) -> Result<Response, Error> {
let sockets = state.sockets.clone();
let response = ws.on_upgrade(|socket| async move {
SocketHandler::new(auth.user_id(), socket, sockets)
.await
.await
});
Ok(response)
}
pub struct Sockets {
sockets: HashMap<Uuid, Sender<Utf8Bytes>>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
pub enum SocketMessage {
Event(SocketEvent),
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum SocketEvent {
MessageReceived {
chat: Uuid,
from: Uuid,
message_id: Uuid,
},
PaymentReceived {
from: Uuid,
to: Uuid,
amount: u64,
},
}
impl Sockets {
pub fn new() -> Self {
Self {
sockets: HashMap::new(),
}
}
pub async fn send(&self, id: Uuid, message: SocketMessage) {
let json = serde_json::to_string(&message).unwrap();
let tx = self.sockets.read();
let Some(sender) = tx.get(&id) else {
return;
};
let _ = sender.broadcast_direct(json.into()).await;
}
pub async fn send_multiple(&self, iter: impl IntoIterator<Item = (Uuid, SocketMessage)>) {
let tx = self.sockets.read();
for (id, message) in iter {
let json = serde_json::to_string(&message).unwrap();
let Some(sender) = tx.get(&id) else {
continue;
};
let _ = sender.broadcast_direct(json.into()).await;
}
}
}
struct SocketHandler {
id: Uuid,
state: Option<Weak<Sockets>>,
}
impl Debug for SocketHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SocketHandler")
.field("id", &self.id)
.finish()
}
}
impl SocketHandler {
pub async fn new(
id: Uuid,
socket: WebSocket,
sockets: Arc<Sockets>,
) -> impl Future<Output = ()> {
let handler = Self {
id,
state: Some(Arc::downgrade(&sockets)),
};
let mut tx = sockets.sockets.write().await;
let rx = if let Some(sender) = tx.get(&id) {
sender.new_receiver()
} else {
let (mut sender, receiver) = async_broadcast::broadcast(16);
sender.set_await_active(false);
tx.insert(id, sender);
tx.commit();
receiver
};
handler.run(socket, rx)
}
async fn shutdown(&self, state: Arc<Sockets>) {
let mut tx = state.sockets.write().await;
let Some(sender) = tx.get(&self.id) else {
return;
};
if sender.receiver_count() == 0 {
tx.remove(&self.id);
}
tx.commit();
}
#[instrument(skip(socket, rx))]
async fn run(mut self, mut socket: WebSocket, mut rx: Receiver<Utf8Bytes>) {
loop {
match Self::handle(&mut socket, &mut rx).await {
Ok(true) => {}
Ok(false) => break,
Err(err) => {
error!("{:?}", err);
break;
}
}
}
let Some(state) = self.state.take() else {
return;
};
let Some(state) = state.upgrade() else { return };
self.shutdown(state).await;
}
async fn handle(
socket: &mut WebSocket,
rx: &mut Receiver<Utf8Bytes>,
) -> Result<bool, axum::Error> {
tokio::select! {
msg = socket.recv() => {
let msg = match msg {
Some(v) => v?,
None => return Ok(false),
};
info!("Received message: {:?}", msg);
},
message = rx.recv_direct() => {
let message = message.map_err(|err| axum::Error::new(err))?;
socket.send(Message::Text(message)).await?;
}
}
Ok(true)
}
}
impl Drop for SocketHandler {
fn drop(&mut self) {
let Some(state) = self.state.take() else {
return;
};
let Some(state) = state.upgrade() else { return };
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(self.shutdown(state))
});
}
}

View File

@ -14,7 +14,9 @@ use crate::model::{
use super::{ use super::{
AppState, EState, Error, Json, Pagination, PaginationType, Query, RequestPagination, State, AppState, EState, Error, Json, Pagination, PaginationType, Query, RequestPagination, State,
auth::Auth, make_schemas, auth::Auth,
make_schemas,
socket::{SocketEvent, SocketMessage},
}; };
pub(super) fn router() -> Router<Arc<AppState>> { pub(super) fn router() -> Router<Arc<AppState>> {
@ -154,7 +156,7 @@ impl AccountSelector {
async fn account_id( async fn account_id(
&self, &self,
client: &impl GenericClient, client: &impl GenericClient,
) -> Result<Option<Uuid>, tokio_postgres::Error> { ) -> Result<Option<(Uuid, Uuid)>, tokio_postgres::Error> {
let user_id = match &self.user { let user_id = match &self.user {
NameOrUuid::Id(uuid) => *uuid, NameOrUuid::Id(uuid) => *uuid,
NameOrUuid::Name(name) => match User::info_by_name(client, &*name).await? { NameOrUuid::Name(name) => match User::info_by_name(client, &*name).await? {
@ -169,7 +171,7 @@ impl AccountSelector {
}, },
None => user_id, None => user_id,
}; };
Ok(Some(account_id)) Ok(Some((user_id, account_id)))
} }
} }
@ -384,7 +386,7 @@ pub async fn make_payment(
}) else { }) else {
todo!("from account doesn't exist") todo!("from account doesn't exist")
}; };
let Some(to) = to.account_id(&client).await? else { let Some((to_user, to)) = to.account_id(&client).await? else {
todo!("to account doesn't exist") todo!("to account doesn't exist")
}; };
if from.balance < amount { if from.balance < amount {
@ -401,6 +403,17 @@ pub async fn make_payment(
.await?; .await?;
let transaction = Transaction::create(&mut client, from.id, to, amount, None).await?; let transaction = Transaction::create(&mut client, from.id, to, amount, None).await?;
client.commit().await?; client.commit().await?;
state
.sockets
.send(
to_user,
SocketMessage::Event(SocketEvent::PaymentReceived {
from: from.id,
to,
amount,
}),
)
.await;
Ok(Json(transaction)) Ok(Json(transaction))
} }

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{Router, extract::Path, http::StatusCode, routing::get}; use axum::{Router, extract::Path, routing::get};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::instrument; use tracing::instrument;
use uuid::Uuid; use uuid::Uuid;

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::Router; use axum::Router;
use bankserver::{Config, setup_db}; use bankserver::{Config, api::Sockets, setup_db};
use jsonwebtoken::{DecodingKey, EncodingKey}; use jsonwebtoken::{DecodingKey, EncodingKey};
use tokio::{net::TcpListener, signal}; use tokio::{net::TcpListener, signal};
use tracing::{info, level_filters::LevelFilter}; use tracing::{info, level_filters::LevelFilter};
@ -39,6 +39,7 @@ async fn main() {
pool, pool,
encoding_key, encoding_key,
decoding_key, decoding_key,
sockets: Arc::new(Sockets::new()),
})); }));
let listener = TcpListener::bind(config.socket_addr()).await.unwrap(); let listener = TcpListener::bind(config.socket_addr()).await.unwrap();

View File

@ -26,11 +26,11 @@ pub struct ChatInfo {
#[derive(Serialize)] #[derive(Serialize)]
#[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))]
pub struct ChatMessage { pub struct ChatMessage {
id: Uuid, pub id: Uuid,
sender: Uuid, pub sender: Uuid,
time: DateTime<Utc>, pub time: DateTime<Utc>,
text: String, pub text: String,
extra: Option<serde_json::Value>, pub extra: Option<serde_json::Value>,
} }
impl PaginationType for Chat { impl PaginationType for Chat {
@ -134,6 +134,22 @@ impl Chat {
Ok(res.map(ChatInfo::from)) Ok(res.map(ChatInfo::from))
} }
pub async fn member_ids(
client: &impl GenericClient,
chat: Uuid,
) -> Result<Vec<Uuid>, tokio_postgres::Error> {
let stmt = client
.prepare_cached("select \"user\" from chat_members where chat = $1")
.await?;
let res = client
.query(&stmt, &[&chat])
.await?
.into_iter()
.map(|row| row.get(0))
.collect();
Ok(res)
}
pub async fn exists_between( pub async fn exists_between(
client: &impl GenericClient, client: &impl GenericClient,
peers: (Uuid, Uuid), peers: (Uuid, Uuid),