use std::{ fmt::Debug, sync::{Arc, Weak}, }; use async_broadcast::{Receiver, Sender}; use axum::{ Router, extract::{ WebSocketUpgrade, ws::{Message, Utf8Bytes, WebSocket}, }, response::Response, routing::get, }; use bank_core::NameOrUuid; use concread::hashmap::asynch::HashMap; use serde::Serialize; use tracing::{error, info, instrument}; use uuid::Uuid; use super::{AppState, EState, Error, State, auth::Auth}; pub(super) fn router() -> Router> { Router::new().route("/", get(handle)) } #[instrument(skip(state))] pub async fn handle( EState(state): State, auth: Auth, ws: WebSocketUpgrade, ) -> Result { let sockets = state.sockets.clone(); let response = ws.on_upgrade(|socket| async move { SocketHandler::new(auth.user_id(), socket, sockets) .await .await }); Ok(response) } pub struct Sockets { sockets: HashMap>, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "snake_case", tag = "type", content = "data")] pub enum SocketMessage { Event(SocketEvent), } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "snake_case", tag = "type")] pub enum SocketEvent { MessageReceived { chat: Uuid, from: Uuid, message_id: Uuid, }, PaymentReceived { from: Option, to: Uuid, amount: u64, }, } impl Sockets { pub fn new() -> Self { Self { sockets: HashMap::new(), } } pub async fn send(&self, id: Uuid, message: SocketMessage) { let json = serde_json::to_string(&message).unwrap(); let tx = self.sockets.read(); let Some(sender) = tx.get(&id) else { return; }; let _ = sender.broadcast_direct(json.into()).await; } pub async fn send_multiple(&self, iter: impl IntoIterator) { let tx = self.sockets.read(); for (id, message) in iter { let json = serde_json::to_string(&message).unwrap(); let Some(sender) = tx.get(&id) else { continue; }; let _ = sender.broadcast_direct(json.into()).await; } } } struct SocketHandler { id: Uuid, state: Option>, } impl Debug for SocketHandler { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SocketHandler") .field("id", &self.id) .finish() } } impl SocketHandler { pub async fn new( id: Uuid, socket: WebSocket, sockets: Arc, ) -> impl Future { let handler = Self { id, state: Some(Arc::downgrade(&sockets)), }; let mut tx = sockets.sockets.write().await; let rx = if let Some(sender) = tx.get(&id) { sender.new_receiver() } else { let (mut sender, receiver) = async_broadcast::broadcast(16); sender.set_await_active(false); tx.insert(id, sender); tx.commit(); receiver }; handler.run(socket, rx) } async fn shutdown(&self, state: Arc) { let mut tx = state.sockets.write().await; let Some(sender) = tx.get(&self.id) else { return; }; if sender.receiver_count() == 0 { tx.remove(&self.id); } tx.commit(); } #[instrument(skip(socket, rx))] async fn run(mut self, mut socket: WebSocket, mut rx: Receiver) { loop { match Self::handle(&mut socket, &mut rx).await { Ok(true) => {} Ok(false) => break, Err(err) => { error!("{:?}", err); break; } } } let Some(state) = self.state.take() else { return; }; let Some(state) = state.upgrade() else { return }; self.shutdown(state).await; } async fn handle( socket: &mut WebSocket, rx: &mut Receiver, ) -> Result { tokio::select! { msg = socket.recv() => { let msg = match msg { Some(v) => v?, None => return Ok(false), }; info!("Received message: {:?}", msg); }, message = rx.recv_direct() => { let message = message.map_err(|err| axum::Error::new(err))?; socket.send(Message::Text(message)).await?; } } Ok(true) } } impl Drop for SocketHandler { fn drop(&mut self) { let Some(state) = self.state.take() else { return; }; let Some(state) = state.upgrade() else { return }; tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(self.shutdown(state)) }); } }