Commonized networking (#3310)

* Fix not being able to connect to local friends socket

* Start basic work on tunneling protocol and move some code into a common crate

* Commonize message serialization logic

* Serialize Base62Ids as u64 when human-readability is not required

* Move ActiveSockets tuple into struct

* Make CI run when rust-common is updated

CI is currently broken for labrinth, however

* Fix theseus-release.yml to reference itself correctly

* Implement Labrinth side of tunneling

* Implement non-friend part of theseus tunneling

* Implement client-side except for socket loop

* Implement the socket loop

Doesn't work though. Debugging time!

* Fix config.rs

* Fix deadlock in labrinth socket handling

* Update dockerfile

* switch to workspace prepare at root level

* Wait for connection before tunneling in playground

* Move rust-common into labrinth

* Remove rust-common references from Actions

* Revert "Update dockerfile"

This reverts commit 3caad59bb474ce425d0b8928d7cee7ae1a5011bd.

* Fix Docker build

* Rebuild Theseus if common code changes

* Allow multiple connections from the same user

* Fix test building

* Move FriendSocketListening and FriendSocketStoppedListening to non-panicking TODO for now

* Make message_serialization macro take varargs for binary messages

* Improve syntax of message_serialization macro

* Remove the ability to connect to a virtual socket, and disable the ability to listen on one

* Allow the app to compile without running labrinth

* Clippy fix

* Update Rust and Clippy fix again

---------

Co-authored-by: Jai A <jaiagr+gpg@pm.me>
This commit is contained in:
Josiah Glosson 2025-02-28 12:52:47 -06:00 committed by GitHub
parent 90def724c2
commit 650ab71a83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 1132 additions and 584 deletions

View File

@ -6,9 +6,11 @@ on:
tags:
- 'v*'
paths:
- .github/workflows/app-release.yml
- .github/workflows/theseus-release.yml
- 'apps/app/**'
- 'apps/app-frontend/**'
- 'apps/labrinth/src/common/**'
- 'apps/labrinth/Cargo.toml'
- 'packages/app-lib/**'
- 'packages/app-macros/**'
- 'packages/assets/**'

1
.idea/code.iml generated
View File

@ -10,6 +10,7 @@
<sourceFolder url="file://$MODULE_DIR$/apps/labrinth/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/apps/labrinth/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/packages/app-lib/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/packages/rust-common/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />

64
Cargo.lock generated
View File

@ -1204,7 +1204,7 @@ checksum = "d38f2da7a0a2c4ccf0065be06397cc26a81f4e528be095826eee9d4adbb8c60f"
dependencies = [
"byteorder",
"fnv",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -1291,7 +1291,7 @@ dependencies = [
"time",
"tokio 1.42.0",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -1995,7 +1995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d"
dependencies = [
"serde",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -2511,7 +2511,7 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4"
dependencies = [
"bit_field",
"flume",
"half",
"half 2.4.1",
"lebe",
"miniz_oxide 0.7.4",
"rayon-core",
@ -3242,6 +3242,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
version = "2.4.1"
@ -4245,6 +4251,7 @@ dependencies = [
"deadpool-redis",
"derive-new",
"dotenvy",
"either",
"env_logger",
"flate2",
"futures 0.3.30",
@ -4277,6 +4284,8 @@ dependencies = [
"sentry",
"sentry-actix",
"serde",
"serde_bytes",
"serde_cbor",
"serde_json",
"serde_with",
"sha1 0.6.1",
@ -4290,7 +4299,7 @@ dependencies = [
"totp-rs",
"url",
"urlencoding",
"uuid 1.10.0",
"uuid 1.12.0",
"validator",
"webp",
"woothee",
@ -4657,7 +4666,7 @@ dependencies = [
"serde_json",
"thiserror 1.0.64",
"time",
"uuid 1.10.0",
"uuid 1.12.0",
"wasm-bindgen-futures",
"web-sys",
"yaup",
@ -6724,7 +6733,7 @@ dependencies = [
"rkyv_derive",
"seahash",
"tinyvec",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -7108,7 +7117,7 @@ dependencies = [
"serde",
"serde_json",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -7361,7 +7370,7 @@ dependencies = [
"thiserror 1.0.64",
"time",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -7396,6 +7405,25 @@ dependencies = [
"xml-rs",
]
[[package]]
name = "serde_bytes"
version = "0.11.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a"
dependencies = [
"serde",
]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
@ -8536,7 +8564,7 @@ dependencies = [
"thiserror 2.0.7",
"time",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
"walkdir",
]
@ -8629,7 +8657,7 @@ dependencies = [
"thiserror 2.0.7",
"toml 0.8.19",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
]
[[package]]
@ -8810,7 +8838,7 @@ dependencies = [
"toml 0.8.19",
"url",
"urlpattern",
"uuid 1.10.0",
"uuid 1.12.0",
"walkdir",
]
@ -8884,9 +8912,11 @@ dependencies = [
"dirs 5.0.1",
"discord-rich-presence",
"dunce",
"either",
"flate2",
"futures 0.3.30",
"indicatif",
"labrinth",
"lazy_static",
"notify",
"notify-debouncer-mini",
@ -8913,7 +8943,7 @@ dependencies = [
"tracing-subscriber",
"url",
"urlencoding",
"uuid 1.10.0",
"uuid 1.12.0",
"whoami",
"winreg 0.52.0",
"zip 0.6.6",
@ -8955,7 +8985,7 @@ dependencies = [
"tracing",
"tracing-error",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
"window-shadows",
]
@ -8974,7 +9004,7 @@ dependencies = [
"tracing-error",
"tracing-subscriber",
"url",
"uuid 1.10.0",
"uuid 1.12.0",
"webbrowser",
]
@ -9819,9 +9849,9 @@ dependencies = [
[[package]]
name = "uuid"
version = "1.10.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314"
checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4"
dependencies = [
"getrandom 0.2.15",
"rand 0.8.5",

View File

@ -21,4 +21,4 @@ strip = true # Remove debug symbols
opt-level = 3
[patch.crates-io]
wry = { git = "https://github.com/modrinth/wry", rev = "51907c6" }
wry = { git = "https://github.com/modrinth/wry", rev = "51907c6" }

View File

@ -0,0 +1,2 @@
[env]
SQLX_OFFLINE = "true"

View File

@ -3,9 +3,9 @@
windows_subsystem = "windows"
)]
use std::time::Duration;
use theseus::prelude::*;
use theseus::profile::create::profile_create;
use tokio::signal::ctrl_c;
// A simple Rust implementation of the authentication run
// 1) call the authenticate_begin_flow() function to get the URL to open (like you would in the frontend)
@ -41,54 +41,21 @@ async fn main() -> theseus::Result<()> {
// Initialize state
State::init().await?;
if minecraft_auth::users().await?.is_empty() {
println!("No users found, authenticating.");
authenticate_run().await?; // could take credentials from here direct, but also deposited in state users
}
//
// st.settings
// .write()
// .await
// .java_globals
// .insert(JAVA_8_KEY.to_string(), check_jre(path).await?.unwrap());
// Clear profiles
println!("Clearing profiles.");
{
let h = profile::list().await?;
for profile in h.into_iter() {
profile::remove(&profile.path).await?;
loop {
if State::get().await?.friends_socket.is_connected().await {
break;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
println!("Creating/adding profile.");
tracing::info!("Starting host");
let name = "Example".to_string();
let game_version = "1.16.1".to_string();
let modloader = ModLoader::Forge;
let loader_version = "stable".to_string();
let socket = State::get().await?.friends_socket.open_port(25565).await?;
tracing::info!("Running host on socket {}", socket.socket_id());
let profile_path = profile_create(
name,
game_version,
modloader,
Some(loader_version),
None,
None,
None,
)
.await?;
println!("running");
// Run a profile, running minecraft and store the RwLock to the process
let process = profile::run(&profile_path).await?;
println!("Minecraft UUID: {}", process.uuid);
println!("All running process UUID {:?}", process::get_all().await?);
// hold the lock to the process until it ends
println!("Waiting for process to end...");
process::wait_for(process.uuid).await?;
ctrl_c().await?;
tracing::info!("Stopping host");
socket.shutdown().await?;
Ok(())
}

View File

@ -0,0 +1,2 @@
[env]
SQLX_OFFLINE = "true"

View File

@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE team_members\n SET\n is_owner = TRUE,\n accepted = TRUE,\n permissions = $2,\n organization_permissions = NULL,\n role = 'Inherited Owner'\n WHERE (id = $1)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
},
"hash": "11344e920ea606504c2fdc3c5a3cb1b1e990def66cf260cb5d648cab72cc34f1"
}

View File

@ -1,22 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT u.id \n FROM team_members\n INNER JOIN users u ON u.id = team_members.user_id\n WHERE team_id = $1 AND is_owner = TRUE\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
},
"hash": "2b097a9a1b24b9648d3558e348c7d8cd467e589504c6e754f1f6836203946590"
}

View File

@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM version_fields\n WHERE version_id = $1\n AND field_id = ANY($2)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int4Array"
]
},
"nullable": []
},
"hash": "527291243eb3684e956d7d49c579857ce857ff462c830dd0cb74574f415d4105"
}

View File

@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT u.id\n FROM team_members\n INNER JOIN users u ON u.id = team_members.user_id\n WHERE team_id = $1 AND is_owner = TRUE\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
},
"hash": "96ebe21d1430779e88dcaf8872a8c939b3889f91df9a0e404d4c63d466869fe5"
}

View File

@ -1,15 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM version_fields \n WHERE version_id = $1\n AND field_id = ANY($2)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int4Array"
]
},
"nullable": []
},
"hash": "acd2e72610008d4fe240cdfadc1c70c997443f7319a5c535df967d56d24bd54a"
}

View File

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO mods (\n id, team_id, name, summary, description,\n published, downloads, icon_url, raw_icon_url, status, requested_status,\n license_url, license,\n slug, color, monetization_status, organization_id\n )\n VALUES (\n $1, $2, $3, $4, $5, $6, \n $7, $8, $9, $10, $11,\n $12, $13,\n LOWER($14), $15, $16, $17\n )\n ",
"query": "\n INSERT INTO mods (\n id, team_id, name, summary, description,\n published, downloads, icon_url, raw_icon_url, status, requested_status,\n license_url, license,\n slug, color, monetization_status, organization_id\n )\n VALUES (\n $1, $2, $3, $4, $5, $6,\n $7, $8, $9, $10, $11,\n $12, $13,\n LOWER($14), $15, $16, $17\n )\n ",
"describe": {
"columns": [],
"parameters": {
@ -26,5 +26,5 @@
},
"nullable": []
},
"hash": "f899b378fad8fcfa1ebf527146b565b7c4466205e0bfd84f299123329926fe3f"
"hash": "bcbcac3c0b2b2b0327577d3095fa744ab42f7f1dcd2b7f3c3dace12b899b3f38"
}

View File

@ -1,15 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE team_members\n SET \n is_owner = TRUE,\n accepted = TRUE,\n permissions = $2,\n organization_permissions = NULL,\n role = 'Inherited Owner'\n WHERE (id = $1)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
},
"hash": "dc64653d72645b76e42a1834124ce3f9225c5b6b8b941812167b3b7002bfdb2a"
}

View File

@ -36,8 +36,10 @@ reqwest = { version = "0.11.18", features = ["json", "multipart"] }
hyper = { version = "0.14", features = ["full"] }
hyper-tls = "0.5.0"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
serde_json = "1.0"
serde_cbor = "0.11"
serde_with = "3.0.0"
chrono = { version = "0.4.26", features = ["serde"] }
yaserde = "0.12.0"
@ -74,6 +76,7 @@ dotenvy = "0.15.7"
log = "0.4.20"
env_logger = "0.10.1"
thiserror = "1.0.56"
either = "1.13"
sqlx = { version = "0.8.2", features = [
"runtime-tokio-rustls",

View File

@ -34,7 +34,7 @@ pub enum AuthenticationError {
#[error("Error uploading user profile picture")]
FileHosting(#[from] FileHostingError),
#[error("Error while decoding PAT: {0}")]
Decoding(#[from] crate::models::ids::DecodingError),
Decoding(#[from] crate::common::ids::DecodingError),
#[error("{0}")]
Mail(#[from] email::MailError),
#[error("Invalid Authentication Credentials")]

View File

@ -1,7 +1,7 @@
use super::ValidatedRedirectUri;
use crate::auth::AuthenticationError;
use crate::common::ids::DecodingError;
use crate::models::error::ApiError;
use crate::models::ids::DecodingError;
use actix_web::http::{header::LOCATION, StatusCode};
use actix_web::HttpResponse;

View File

@ -0,0 +1,218 @@
pub use super::users::UserId;
use thiserror::Error;
/// Generates a random 64 bit integer that is exactly `n` characters
/// long when encoded as base62.
///
/// Uses `rand`'s thread rng on every call.
///
/// # Panics
///
/// This method panics if `n` is 0 or greater than 11, since a `u64`
/// can only represent up to 11 character base62 strings
#[inline]
pub fn random_base62(n: usize) -> u64 {
random_base62_rng(&mut rand::thread_rng(), n)
}
/// Generates a random 64 bit integer that is exactly `n` characters
/// long when encoded as base62, using the given rng.
///
/// # Panics
///
/// This method panics if `n` is 0 or greater than 11, since a `u64`
/// can only represent up to 11 character base62 strings
pub fn random_base62_rng<R: rand::RngCore>(rng: &mut R, n: usize) -> u64 {
random_base62_rng_range(rng, n, n)
}
pub fn random_base62_rng_range<R: rand::RngCore>(
rng: &mut R,
n_min: usize,
n_max: usize,
) -> u64 {
use rand::Rng;
assert!(n_min > 0 && n_max <= 11 && n_min <= n_max);
// gen_range is [low, high): max value is `MULTIPLES[n] - 1`,
// which is n characters long when encoded
rng.gen_range(MULTIPLES[n_min - 1]..MULTIPLES[n_max])
}
const MULTIPLES: [u64; 12] = [
1,
62,
62 * 62,
62 * 62 * 62,
62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
u64::MAX,
];
/// An ID encoded as base62 for use in the API.
///
/// All ids should be random and encode to 8-10 character base62 strings,
/// to avoid enumeration and other attacks.
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct Base62Id(pub u64);
/// An error decoding a number from base62.
#[derive(Error, Debug)]
pub enum DecodingError {
/// Encountered a non-base62 character in a base62 string
#[error("Invalid character {0:?} in base62 encoding")]
InvalidBase62(char),
/// Encountered integer overflow when decoding a base62 id.
#[error("Base62 decoding overflowed")]
Overflow,
}
#[macro_export]
macro_rules! from_base62id {
($($struct:ty, $con:expr;)+) => {
$(
impl From<Base62Id> for $struct {
fn from(id: Base62Id) -> $struct {
$con(id.0)
}
}
impl From<$struct> for Base62Id {
fn from(id: $struct) -> Base62Id {
Base62Id(id.0)
}
}
)+
};
}
#[macro_export]
macro_rules! impl_base62_display {
($struct:ty) => {
impl std::fmt::Display for $struct {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&$crate::common::ids::base62_impl::to_base62(
self.0,
))
}
}
};
}
impl_base62_display!(Base62Id);
#[macro_export]
macro_rules! base62_id_impl {
($struct:ty, $cons:expr) => {
$crate::common::ids::from_base62id!($struct, $cons;);
$crate::common::ids::impl_base62_display!($struct);
}
}
base62_id_impl!(UserId, UserId);
pub use {base62_id_impl, from_base62id, impl_base62_display};
pub mod base62_impl {
use serde::de::{self, Deserializer, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use super::{Base62Id, DecodingError};
impl<'de> Deserialize<'de> for Base62Id {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Base62Visitor;
impl Visitor<'_> for Base62Visitor {
type Value = Base62Id;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
formatter.write_str("a base62 string id")
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Base62Id(v))
}
fn visit_str<E>(self, string: &str) -> Result<Base62Id, E>
where
E: de::Error,
{
parse_base62(string).map(Base62Id).map_err(E::custom)
}
}
if deserializer.is_human_readable() {
deserializer.deserialize_str(Base62Visitor)
} else {
deserializer.deserialize_u64(Base62Visitor)
}
}
}
impl Serialize for Base62Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&to_base62(self.0))
} else {
serializer.serialize_u64(self.0)
}
}
}
const BASE62_CHARS: [u8; 62] =
*b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
pub fn to_base62(mut num: u64) -> String {
let length = (num as f64).log(62.0).ceil() as usize;
let mut output = String::with_capacity(length);
while num > 0 {
// Could be done more efficiently, but requires byte
// manipulation of strings & Vec<u8> -> String conversion
output.insert(0, BASE62_CHARS[(num % 62) as usize] as char);
num /= 62;
}
output
}
pub fn parse_base62(string: &str) -> Result<u64, DecodingError> {
let mut num: u64 = 0;
for c in string.chars() {
let next_digit;
if c.is_ascii_digit() {
next_digit = (c as u8 - b'0') as u64;
} else if c.is_ascii_uppercase() {
next_digit = 10 + (c as u8 - b'A') as u64;
} else if c.is_ascii_lowercase() {
next_digit = 36 + (c as u8 - b'a') as u64;
} else {
return Err(DecodingError::InvalidBase62(c));
}
// We don't want this panicking or wrapping on integer overflow
if let Some(n) =
num.checked_mul(62).and_then(|n| n.checked_add(next_digit))
{
num = n;
} else {
return Err(DecodingError::Overflow);
}
}
Ok(num)
}
}

View File

@ -0,0 +1,3 @@
pub mod ids;
pub mod networking;
pub mod users;

View File

@ -0,0 +1,65 @@
use crate::common::ids::UserId;
use crate::common::users::UserStatus;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientToServerMessage {
StatusUpdate {
profile_name: Option<String>,
},
SocketListen {
socket: Uuid,
},
SocketClose {
socket: Uuid,
},
SocketSend {
socket: Uuid,
#[serde(with = "serde_bytes")]
data: Vec<u8>,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerToClientMessage {
StatusUpdate {
status: UserStatus,
},
UserOffline {
id: UserId,
},
FriendStatuses {
statuses: Vec<UserStatus>,
},
FriendRequest {
from: UserId,
},
FriendRequestRejected {
from: UserId,
},
FriendSocketListening {
user: UserId,
socket: Uuid,
},
FriendSocketStoppedListening {
user: UserId,
},
SocketConnected {
to_socket: Uuid,
new_socket: Uuid,
},
SocketClosed {
socket: Uuid,
},
SocketData {
socket: Uuid,
#[serde(with = "serde_bytes")]
data: Vec<u8>,
},
}

View File

@ -0,0 +1,2 @@
pub mod message;
pub mod serialization;

View File

@ -0,0 +1,56 @@
use super::message::{ClientToServerMessage, ServerToClientMessage};
use either::Either;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SerializationError {
#[error("Failed to (de)serialize message: {0}")]
SerializationFailed(#[from] serde_json::Error),
#[error("Failed to (de)serialize binary message: {0}")]
BinarySerializationFailed(#[from] serde_cbor::Error),
}
macro_rules! message_serialization {
($message_enum:ty $(,$binary_pattern:pat_param)* $(,)?) => {
impl $message_enum {
pub fn is_binary(&self) -> bool {
match self {
$(
$binary_pattern => true,
)*
_ => false,
}
}
pub fn serialize(
&self,
) -> Result<Either<String, Vec<u8>>, SerializationError> {
Ok(match self {
$(
$binary_pattern => Either::Right(serde_cbor::to_vec(self)?),
)*
_ => Either::Left(serde_json::to_string(self)?),
})
}
pub fn deserialize(
msg: Either<&str, &[u8]>,
) -> Result<Self, SerializationError> {
Ok(match msg {
Either::Left(text) => serde_json::from_str(&text)?,
Either::Right(bytes) => serde_cbor::from_slice(&bytes)?,
})
}
}
};
}
message_serialization!(
ClientToServerMessage,
ClientToServerMessage::SocketSend { .. },
);
message_serialization!(
ServerToClientMessage,
ServerToClientMessage::SocketData { .. },
);

View File

@ -0,0 +1,15 @@
use super::ids::Base62Id;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
#[serde(from = "Base62Id")]
#[serde(into = "Base62Id")]
pub struct UserId(pub u64);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UserStatus {
pub user_id: UserId,
pub profile_name: Option<String>,
pub last_update: DateTime<Utc>,
}

View File

@ -1,6 +1,6 @@
use super::DatabaseError;
use crate::models::ids::base62_impl::to_base62;
use crate::models::ids::{random_base62_rng, random_base62_rng_range};
use crate::common::ids::base62_impl::to_base62;
use crate::common::ids::{random_base62_rng, random_base62_rng_range};
use censor::Censor;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;

View File

@ -1,6 +1,5 @@
use crate::{
database::redis::RedisPool, models::ids::base62_impl::parse_base62,
};
use crate::common::ids::base62_impl::parse_base62;
use crate::database::redis::RedisPool;
use dashmap::DashMap;
use futures::TryStreamExt;
use std::fmt::{Debug, Display};

View File

@ -1,7 +1,7 @@
use super::ids::*;
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models::DatabaseError;
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::pats::Scopes;
use chrono::{DateTime, Utc};
use dashmap::DashMap;

View File

@ -3,10 +3,10 @@ use super::loader_fields::{
VersionField,
};
use super::{ids::*, User};
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models;
use crate::database::models::DatabaseError;
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::projects::{MonetizationStatus, ProjectStatus};
use chrono::{DateTime, Utc};
use dashmap::{DashMap, DashSet};
@ -300,7 +300,7 @@ impl Project {
slug, color, monetization_status, organization_id
)
VALUES (
$1, $2, $3, $4, $5, $6,
$1, $2, $3, $4, $5, $6,
$7, $8, $9, $10, $11,
$12, $13,
LOWER($14), $15, $16, $17

View File

@ -1,7 +1,7 @@
use super::ids::*;
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models::DatabaseError;
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::parse_base62;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};

View File

@ -1,9 +1,9 @@
use super::ids::{ProjectId, UserId};
use super::{CollectionId, ReportId, ThreadId};
use crate::common::ids::base62_impl::{parse_base62, to_base62};
use crate::database::models;
use crate::database::models::{DatabaseError, OrganizationId};
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::{parse_base62, to_base62};
use crate::models::users::Badges;
use chrono::{DateTime, Utc};
use dashmap::DashMap;

View File

@ -1,5 +1,5 @@
use super::models::DatabaseError;
use crate::models::ids::base62_impl::{parse_base62, to_base62};
use crate::common::ids::base62_impl::{parse_base62, to_base62};
use chrono::{TimeZone, Utc};
use dashmap::DashMap;
use deadpool_redis::{Config, Runtime};

View File

@ -25,6 +25,8 @@ use crate::{
util::env::{parse_strings_from_var, parse_var},
};
pub mod common;
pub mod auth;
pub mod clickhouse;
pub mod database;
@ -297,8 +299,10 @@ pub fn app_setup(
}
let ip_salt = Pepper {
pepper: models::ids::Base62Id(models::ids::random_base62(11))
.to_string(),
pepper: crate::common::ids::Base62Id(
crate::common::ids::random_base62(11),
)
.to_string(),
};
let payouts_queue = web::Data::new(PayoutsQueue::new());

View File

@ -13,117 +13,13 @@ pub use super::teams::TeamId;
pub use super::threads::ThreadId;
pub use super::threads::ThreadMessageId;
pub use super::users::UserId;
use crate::common::ids::base62_id_impl;
pub use crate::common::ids::Base62Id;
pub use crate::models::billing::{
ChargeId, ProductId, ProductPriceId, UserSubscriptionId,
};
use thiserror::Error;
/// Generates a random 64 bit integer that is exactly `n` characters
/// long when encoded as base62.
///
/// Uses `rand`'s thread rng on every call.
///
/// # Panics
///
/// This method panics if `n` is 0 or greater than 11, since a `u64`
/// can only represent up to 11 character base62 strings
#[inline]
pub fn random_base62(n: usize) -> u64 {
random_base62_rng(&mut rand::thread_rng(), n)
}
/// Generates a random 64 bit integer that is exactly `n` characters
/// long when encoded as base62, using the given rng.
///
/// # Panics
///
/// This method panics if `n` is 0 or greater than 11, since a `u64`
/// can only represent up to 11 character base62 strings
pub fn random_base62_rng<R: rand::RngCore>(rng: &mut R, n: usize) -> u64 {
random_base62_rng_range(rng, n, n)
}
pub fn random_base62_rng_range<R: rand::RngCore>(
rng: &mut R,
n_min: usize,
n_max: usize,
) -> u64 {
use rand::Rng;
assert!(n_min > 0 && n_max <= 11 && n_min <= n_max);
// gen_range is [low, high): max value is `MULTIPLES[n] - 1`,
// which is n characters long when encoded
rng.gen_range(MULTIPLES[n_min - 1]..MULTIPLES[n_max])
}
const MULTIPLES: [u64; 12] = [
1,
62,
62 * 62,
62 * 62 * 62,
62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62 * 62,
u64::MAX,
];
/// An ID encoded as base62 for use in the API.
///
/// All ids should be random and encode to 8-10 character base62 strings,
/// to avoid enumeration and other attacks.
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct Base62Id(pub u64);
/// An error decoding a number from base62.
#[derive(Error, Debug)]
pub enum DecodingError {
/// Encountered a non-base62 character in a base62 string
#[error("Invalid character {0:?} in base62 encoding")]
InvalidBase62(char),
/// Encountered integer overflow when decoding a base62 id.
#[error("Base62 decoding overflowed")]
Overflow,
}
macro_rules! from_base62id {
($($struct:ty, $con:expr;)+) => {
$(
impl From<Base62Id> for $struct {
fn from(id: Base62Id) -> $struct {
$con(id.0)
}
}
impl From<$struct> for Base62Id {
fn from(id: $struct) -> Base62Id {
Base62Id(id.0)
}
}
)+
};
}
macro_rules! impl_base62_display {
($struct:ty) => {
impl std::fmt::Display for $struct {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&base62_impl::to_base62(self.0))
}
}
};
}
impl_base62_display!(Base62Id);
macro_rules! base62_id_impl {
($struct:ty, $cons:expr) => {
from_base62id!($struct, $cons;);
impl_base62_display!($struct);
}
}
base62_id_impl!(ProjectId, ProjectId);
base62_id_impl!(UserId, UserId);
base62_id_impl!(VersionId, VersionId);
base62_id_impl!(CollectionId, CollectionId);
base62_id_impl!(TeamId, TeamId);
@ -143,91 +39,3 @@ base62_id_impl!(ProductId, ProductId);
base62_id_impl!(ProductPriceId, ProductPriceId);
base62_id_impl!(UserSubscriptionId, UserSubscriptionId);
base62_id_impl!(ChargeId, ChargeId);
pub mod base62_impl {
use serde::de::{self, Deserializer, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use super::{Base62Id, DecodingError};
impl<'de> Deserialize<'de> for Base62Id {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Base62Visitor;
impl Visitor<'_> for Base62Visitor {
type Value = Base62Id;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
formatter.write_str("a base62 string id")
}
fn visit_str<E>(self, string: &str) -> Result<Base62Id, E>
where
E: de::Error,
{
parse_base62(string).map(Base62Id).map_err(E::custom)
}
}
deserializer.deserialize_str(Base62Visitor)
}
}
impl Serialize for Base62Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&to_base62(self.0))
}
}
const BASE62_CHARS: [u8; 62] =
*b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
pub fn to_base62(mut num: u64) -> String {
let length = (num as f64).log(62.0).ceil() as usize;
let mut output = String::with_capacity(length);
while num > 0 {
// Could be done more efficiently, but requires byte
// manipulation of strings & Vec<u8> -> String conversion
output.insert(0, BASE62_CHARS[(num % 62) as usize] as char);
num /= 62;
}
output
}
pub fn parse_base62(string: &str) -> Result<u64, DecodingError> {
let mut num: u64 = 0;
for c in string.chars() {
let next_digit;
if c.is_ascii_digit() {
next_digit = (c as u8 - b'0') as u64;
} else if c.is_ascii_uppercase() {
next_digit = 10 + (c as u8 - b'A') as u64;
} else if c.is_ascii_lowercase() {
next_digit = 36 + (c as u8 - b'a') as u64;
} else {
return Err(DecodingError::InvalidBase62(c));
}
// We don't want this panicking or wrapping on integer overflow
if let Some(n) =
num.checked_mul(62).and_then(|n| n.checked_add(next_digit))
{
num = n;
} else {
return Err(DecodingError::Overflow);
}
}
Ok(num)
}
}

View File

@ -178,7 +178,7 @@ impl From<DBNotification> for Notification {
name.clone(),
text.clone(),
link.clone(),
actions.clone().into_iter().map(Into::into).collect(),
actions.clone().into_iter().collect(),
),
NotificationBody::Unknown => {
("".to_string(), "".to_string(), "#".to_string(), vec![])

View File

@ -1,14 +1,9 @@
use super::ids::Base62Id;
pub use crate::common::users::{UserId, UserStatus};
use crate::{auth::AuthProvider, bitflags_serde_impl};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
#[serde(from = "Base62Id")]
#[serde(into = "Base62Id")]
pub struct UserId(pub u64);
pub const DELETED_USER: UserId = UserId(127155982985829);
bitflags::bitflags! {
@ -211,10 +206,3 @@ impl UserFriend {
}
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct UserStatus {
pub user_id: UserId,
pub profile_name: Option<String>,
pub last_update: DateTime<Utc>,
}

View File

@ -1,16 +1,68 @@
//! "Database" for Hydra
use crate::models::users::{UserId, UserStatus};
use actix_ws::Session;
use dashmap::DashMap;
use dashmap::{DashMap, DashSet};
use std::sync::atomic::AtomicU32;
use uuid::Uuid;
pub type SocketId = u32;
pub struct ActiveSockets {
pub auth_sockets: DashMap<UserId, (UserStatus, Session)>,
pub sockets: DashMap<SocketId, ActiveSocket>,
pub sockets_by_user_id: DashMap<UserId, DashSet<SocketId>>,
pub next_socket_id: AtomicU32,
pub tunnel_sockets: DashMap<Uuid, TunnelSocket>,
}
impl Default for ActiveSockets {
fn default() -> Self {
Self {
auth_sockets: DashMap::new(),
sockets: DashMap::new(),
sockets_by_user_id: DashMap::new(),
next_socket_id: AtomicU32::new(0),
tunnel_sockets: DashMap::new(),
}
}
}
impl ActiveSockets {
pub fn get_status(&self, user: UserId) -> Option<UserStatus> {
self.sockets_by_user_id
.get(&user)
.and_then(|x| x.iter().next().and_then(|x| self.sockets.get(&*x)))
.map(|x| x.status.clone())
}
}
pub struct ActiveSocket {
pub status: UserStatus,
pub socket: Session,
pub owned_tunnel_sockets: DashSet<Uuid>,
}
impl ActiveSocket {
pub fn new(status: UserStatus, session: Session) -> Self {
Self {
status,
socket: session,
owned_tunnel_sockets: DashSet::new(),
}
}
}
pub struct TunnelSocket {
pub owner: SocketId,
pub socket_type: TunnelSocketType,
}
impl TunnelSocket {
pub fn new(owner: SocketId, socket_type: TunnelSocketType) -> Self {
Self { owner, socket_type }
}
}
pub enum TunnelSocketType {
Listening,
Connected { connected_to: Uuid },
}

View File

@ -74,7 +74,7 @@ pub async fn count_download(
let project_id: crate::database::models::ids::ProjectId =
download_body.project_id.into();
let id_option = crate::models::ids::base62_impl::parse_base62(
let id_option = crate::common::ids::base62_impl::parse_base62(
&download_body.version_name,
)
.ok()

View File

@ -1,4 +1,5 @@
use crate::auth::{get_user_from_headers, send_email};
use crate::common::ids::base62_impl::{parse_base62, to_base62};
use crate::database::models::charge_item::ChargeItem;
use crate::database::models::{
generate_charge_id, generate_user_subscription_id, product_item,
@ -10,7 +11,6 @@ use crate::models::billing::{
Product, ProductMetadata, ProductPrice, SubscriptionMetadata,
SubscriptionStatus, UserSubscription,
};
use crate::models::ids::base62_impl::{parse_base62, to_base62};
use crate::models::pats::Scopes;
use crate::models::users::Badges;
use crate::queue::session::AuthQueue;

View File

@ -1,11 +1,11 @@
use crate::auth::email::send_email;
use crate::auth::validate::get_user_record_from_bearer_token;
use crate::auth::{get_user_from_headers, AuthProvider, AuthenticationError};
use crate::common::ids::base62_impl::{parse_base62, to_base62};
use crate::common::ids::random_base62_rng;
use crate::database::models::flow_item::Flow;
use crate::database::redis::RedisPool;
use crate::file_hosting::FileHost;
use crate::models::ids::base62_impl::{parse_base62, to_base62};
use crate::models::ids::random_base62_rng;
use crate::models::pats::Scopes;
use crate::models::users::{Badges, Role};
use crate::queue::session::AuthQueue;

View File

@ -1,7 +1,7 @@
use super::ApiError;
use crate::common::ids::random_base62;
use crate::database;
use crate::database::redis::RedisPool;
use crate::models::ids::random_base62;
use crate::models::projects::ProjectStatus;
use crate::queue::moderation::{ApprovalType, IdentifiedFile, MissingMetadata};
use crate::queue::session::AuthQueue;

View File

@ -1,41 +1,33 @@
use crate::auth::validate::get_user_record_from_bearer_token;
use crate::auth::AuthenticationError;
use crate::common::ids::UserId;
use crate::common::networking::message::{
ClientToServerMessage, ServerToClientMessage,
};
use crate::common::users::UserStatus;
use crate::database::models::friend_item::FriendItem;
use crate::database::redis::RedisPool;
use crate::models::ids::UserId;
use crate::models::pats::Scopes;
use crate::models::users::{User, UserStatus};
use crate::models::users::User;
use crate::queue::session::AuthQueue;
use crate::queue::socket::ActiveSockets;
use crate::queue::socket::{
ActiveSocket, ActiveSockets, SocketId, TunnelSocketType,
};
use crate::routes::ApiError;
use actix_web::web::{Data, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse};
use actix_ws::Message;
use chrono::Utc;
use either::Either;
use futures_util::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use sqlx::PgPool;
use std::sync::atomic::Ordering;
pub fn config(cfg: &mut web::ServiceConfig) {
cfg.service(ws_init);
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientToServerMessage {
StatusUpdate { profile_name: Option<String> },
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerToClientMessage {
StatusUpdate { status: UserStatus },
UserOffline { id: UserId },
FriendStatuses { statuses: Vec<UserStatus> },
FriendRequest { from: UserId },
FriendRequestRejected { from: UserId },
}
#[derive(Deserialize)]
struct LauncherHeartbeatInit {
code: String,
@ -71,10 +63,6 @@ pub async fn ws_init(
let user = User::from_full(db_user);
if let Some((_, (_, session))) = db.auth_sockets.remove(&user.id) {
let _ = session.close(None).await;
}
let (res, mut session, msg_stream) = match actix_ws::handle(&req, body) {
Ok(x) => x,
Err(e) => return Ok(e.error_response()),
@ -94,8 +82,8 @@ pub async fn ws_init(
friends
.iter()
.filter_map(|x| {
db.auth_sockets.get(
&if x.user_id == user.id.into() {
db.get_status(
if x.user_id == user.id.into() {
x.friend_id
} else {
x.user_id
@ -103,7 +91,6 @@ pub async fn ws_init(
.into(),
)
})
.map(|x| x.value().0.clone())
.collect::<Vec<_>>()
} else {
Vec::new()
@ -117,7 +104,17 @@ pub async fn ws_init(
)?)
.await;
db.auth_sockets.insert(user.id, (status.clone(), session));
let db = db.clone();
let socket_id = db.next_socket_id.fetch_add(1, Ordering::Relaxed);
db.sockets
.insert(socket_id, ActiveSocket::new(status.clone(), session));
db.sockets_by_user_id
.entry(user.id)
.or_default()
.insert(socket_id);
#[cfg(debug_assertions)]
log::info!("Connection {socket_id} opened by {}", user.id);
broadcast_friends(
user.id,
@ -133,68 +130,182 @@ pub async fn ws_init(
actix_web::rt::spawn(async move {
// receive messages from websocket
while let Some(msg) = stream.next().await {
match msg {
let message = match msg {
Ok(Message::Text(text)) => {
if let Ok(message) =
serde_json::from_str::<ClientToServerMessage>(&text)
{
match message {
ClientToServerMessage::StatusUpdate {
profile_name,
} => {
if let Some(mut pair) =
db.auth_sockets.get_mut(&user.id)
{
let (status, _) = pair.value_mut();
ClientToServerMessage::deserialize(Either::Left(&text))
}
if status
.profile_name
.as_ref()
.map(|x| x.len() > 64)
.unwrap_or(false)
{
continue;
}
status.profile_name = profile_name;
status.last_update = Utc::now();
let user_status = status.clone();
// We drop the pair to avoid holding the lock for too long
drop(pair);
let _ = broadcast_friends(
user.id,
ServerToClientMessage::StatusUpdate {
status: user_status,
},
&pool,
&db,
None,
)
.await;
}
}
}
}
Ok(Message::Binary(bytes)) => {
ClientToServerMessage::deserialize(Either::Right(&bytes))
}
Ok(Message::Close(_)) => {
let _ = close_socket(user.id, &pool, &db).await;
let _ = close_socket(socket_id, &pool, &db).await;
continue;
}
Ok(Message::Ping(msg)) => {
if let Some(socket) = db.auth_sockets.get(&user.id) {
let (_, socket) = socket.value();
let _ = socket.clone().pong(&msg).await;
if let Some(socket) = db.sockets.get(&socket_id) {
let _ = socket.socket.clone().pong(&msg).await;
}
continue;
}
_ => continue,
};
if message.is_err() {
continue;
}
let message = message.unwrap();
#[cfg(debug_assertions)]
if !message.is_binary() {
log::info!("Received message from {socket_id}: {:?}", message);
}
match message {
ClientToServerMessage::StatusUpdate { profile_name } => {
if let Some(mut pair) = db.sockets.get_mut(&socket_id) {
let ActiveSocket { status, .. } = pair.value_mut();
if status
.profile_name
.as_ref()
.map(|x| x.len() > 64)
.unwrap_or(false)
{
return;
}
status.profile_name = profile_name;
status.last_update = Utc::now();
let user_status = status.clone();
// We drop the pair to avoid holding the lock for too long
drop(pair);
let _ = broadcast_friends(
user.id,
ServerToClientMessage::StatusUpdate {
status: user_status,
},
&pool,
&db,
None,
)
.await;
}
}
_ => {}
ClientToServerMessage::SocketListen { .. } => {
// TODO: Listen to socket
// The code below probably won't need changes, but there's no way to connect to
// a tunnel socket yet, so we shouldn't be storing them
// let Some(active_socket) = db.sockets.get(&socket_id) else {
// return;
// };
// let Vacant(entry) = db.tunnel_sockets.entry(socket) else {
// continue;
// };
// entry.insert(TunnelSocket::new(
// socket_id,
// TunnelSocketType::Listening,
// ));
// active_socket.owned_tunnel_sockets.insert(socket);
// let _ = broadcast_friends(
// user.id,
// ServerToClientMessage::FriendSocketListening {
// user: user.id,
// socket,
// },
// &pool,
// &db,
// None,
// )
// .await;
}
ClientToServerMessage::SocketClose { socket } => {
let Some(active_socket) = db.sockets.get(&socket_id) else {
return;
};
if active_socket
.owned_tunnel_sockets
.remove(&socket)
.is_none()
{
continue;
}
let Some((_, tunnel_socket)) =
db.tunnel_sockets.remove(&socket)
else {
continue;
};
match tunnel_socket.socket_type {
TunnelSocketType::Listening => {
let _ = broadcast_friends(
user.id,
ServerToClientMessage::FriendSocketStoppedListening { user: user.id },
&pool,
&db,
None,
)
.await;
}
TunnelSocketType::Connected { connected_to } => {
let Some((_, other)) =
db.tunnel_sockets.remove(&connected_to)
else {
continue;
};
let Some(other_user) = db.sockets.get(&other.owner)
else {
continue;
};
let _ = send_message(
&other_user,
&ServerToClientMessage::SocketClosed { socket },
)
.await;
}
}
}
ClientToServerMessage::SocketSend { socket, data } => {
let Some(tunnel_socket) = db.tunnel_sockets.get(&socket)
else {
continue;
};
if tunnel_socket.owner != socket_id {
continue;
}
let TunnelSocketType::Connected { connected_to } =
tunnel_socket.socket_type
else {
continue;
};
let Some(other_tunnel) =
db.tunnel_sockets.get(&connected_to)
else {
continue;
};
let Some(other_user) = db.sockets.get(&other_tunnel.owner)
else {
continue;
};
let _ = send_message(
&other_user,
&ServerToClientMessage::SocketData {
socket: connected_to,
data,
},
)
.await;
}
}
}
let _ = close_socket(user.id, &pool, &db).await;
let _ = close_socket(socket_id, &pool, &db).await;
});
Ok(res)
@ -207,6 +318,7 @@ pub async fn broadcast_friends(
sockets: &ActiveSockets,
friends: Option<Vec<FriendItem>>,
) -> Result<(), crate::database::models::DatabaseError> {
// FIXME Probably shouldn't be using database errors for this. Maybe ApiError?
let friends = if let Some(friends) = friends {
friends
} else {
@ -221,11 +333,46 @@ pub async fn broadcast_friends(
};
if friend.accepted {
if let Some(socket) = sockets.auth_sockets.get(&friend_id.into()) {
let (_, socket) = socket.value();
if let Some(socket_ids) =
sockets.sockets_by_user_id.get(&friend_id.into())
{
for socket_id in socket_ids.iter() {
if let Some(socket) = sockets.sockets.get(&socket_id) {
let _ = send_message(socket.value(), &message).await;
}
}
}
}
}
let _ =
socket.clone().text(serde_json::to_string(&message)?).await;
Ok(())
}
pub async fn send_message(
socket: &ActiveSocket,
message: &ServerToClientMessage,
) -> Result<(), crate::database::models::DatabaseError> {
let mut socket = socket.socket.clone();
// FIXME Probably shouldn't swallow sending errors
let _ = match message.serialize() {
Ok(Either::Left(text)) => socket.text(text).await,
Ok(Either::Right(bytes)) => socket.binary(bytes).await,
Err(_) => Ok(()), // TODO: Maybe should log these? Though it is the backend
};
Ok(())
}
pub async fn send_message_to_user(
db: &ActiveSockets,
user: UserId,
message: &ServerToClientMessage,
) -> Result<(), crate::database::models::DatabaseError> {
if let Some(socket_ids) = db.sockets_by_user_id.get(&user) {
for socket_id in socket_ids.iter() {
if let Some(socket) = db.sockets.get(&socket_id) {
send_message(&socket, message).await?;
}
}
}
@ -234,21 +381,66 @@ pub async fn broadcast_friends(
}
pub async fn close_socket(
id: UserId,
id: SocketId,
pool: &PgPool,
sockets: &ActiveSockets,
db: &ActiveSockets,
) -> Result<(), crate::database::models::DatabaseError> {
if let Some((_, (_, socket))) = sockets.auth_sockets.remove(&id) {
let _ = socket.close(None).await;
if let Some((_, socket)) = db.sockets.remove(&id) {
let user_id = socket.status.user_id;
db.sockets_by_user_id.remove_if(&user_id, |_, sockets| {
sockets.remove(&id);
sockets.is_empty()
});
let _ = socket.socket.close(None).await;
broadcast_friends(
id,
ServerToClientMessage::UserOffline { id },
user_id,
ServerToClientMessage::UserOffline { id: user_id },
pool,
sockets,
db,
None,
)
.await?;
for owned_socket in socket.owned_tunnel_sockets {
let Some((_, tunnel_socket)) =
db.tunnel_sockets.remove(&owned_socket)
else {
continue;
};
match tunnel_socket.socket_type {
TunnelSocketType::Listening => {
let _ = broadcast_friends(
user_id,
ServerToClientMessage::SocketClosed {
socket: owned_socket,
},
pool,
db,
None,
)
.await;
}
TunnelSocketType::Connected { connected_to } => {
let Some((_, other)) =
db.tunnel_sockets.remove(&connected_to)
else {
continue;
};
let Some(other_user) = db.sockets.get(&other.owner) else {
continue;
};
let _ = send_message(
&other_user,
&ServerToClientMessage::SocketClosed {
socket: connected_to,
},
)
.await;
}
}
}
}
Ok(())

View File

@ -164,7 +164,7 @@ async fn find_version(
pool: &PgPool,
redis: &RedisPool,
) -> Result<Option<QueryVersion>, ApiError> {
let id_option = crate::models::ids::base62_impl::parse_base62(vcoords)
let id_option = crate::common::ids::base62_impl::parse_base62(vcoords)
.ok()
.map(|x| x as i64);

View File

@ -117,7 +117,7 @@ pub enum ApiError {
#[error("Captcha Error. Try resubmitting the form.")]
Turnstile,
#[error("Error while decoding Base62: {0}")]
Decoding(#[from] crate::models::ids::DecodingError),
Decoding(#[from] crate::common::ids::DecodingError),
#[error("Image Parsing Error: {0}")]
ImageParse(#[from] image::ImageError),
#[error("Password Hashing Error: {0}")]

View File

@ -1,4 +1,5 @@
use super::ApiError;
use crate::common::ids::base62_impl::to_base62;
use crate::database;
use crate::database::redis::RedisPool;
use crate::models::teams::ProjectPermissions;
@ -6,7 +7,7 @@ use crate::{
auth::get_user_from_headers,
database::models::user_item,
models::{
ids::{base62_impl::to_base62, ProjectId, VersionId},
ids::{ProjectId, VersionId},
pats::Scopes,
},
queue::session::AuthQueue,

View File

@ -1,12 +1,12 @@
use crate::auth::checks::is_visible_collection;
use crate::auth::{filter_visible_collections, get_user_from_headers};
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models::{
collection_item, generate_collection_id, project_item,
};
use crate::database::redis::RedisPool;
use crate::file_hosting::FileHost;
use crate::models::collections::{Collection, CollectionStatus};
use crate::models::ids::base62_impl::parse_base62;
use crate::models::ids::{CollectionId, ProjectId};
use crate::models::pats::Scopes;
use crate::queue::session::AuthQueue;

View File

@ -1,11 +1,12 @@
use crate::auth::get_user_from_headers;
use crate::common::networking::message::ServerToClientMessage;
use crate::database::models::UserId;
use crate::database::redis::RedisPool;
use crate::models::pats::Scopes;
use crate::models::users::UserFriend;
use crate::queue::session::AuthQueue;
use crate::queue::socket::ActiveSockets;
use crate::routes::internal::statuses::{close_socket, ServerToClientMessage};
use crate::routes::internal::statuses::send_message_to_user;
use crate::routes::ApiError;
use actix_web::{delete, get, post, web, HttpRequest, HttpResponse};
use chrono::Utc;
@ -76,22 +77,16 @@ pub async fn add_friend(
friend_id: UserId,
sockets: &ActiveSockets,
) -> Result<(), ApiError> {
if let Some(pair) = sockets.auth_sockets.get(&user_id.into()) {
let (friend_status, _) = pair.value();
if let Some(socket) =
sockets.auth_sockets.get(&friend_id.into())
{
let (_, socket) = socket.value();
let _ = socket
.clone()
.text(serde_json::to_string(
&ServerToClientMessage::StatusUpdate {
status: friend_status.clone(),
},
)?)
.await;
}
if let Some(friend_status) = sockets.get_status(user_id.into())
{
send_message_to_user(
sockets,
friend_id.into(),
&ServerToClientMessage::StatusUpdate {
status: friend_status.clone(),
},
)
.await?;
}
Ok(())
@ -121,20 +116,12 @@ pub async fn add_friend(
.insert(&mut transaction)
.await?;
if let Some(socket) = db.auth_sockets.get(&friend.id.into()) {
let (_, socket) = socket.value();
if socket
.clone()
.text(serde_json::to_string(
&ServerToClientMessage::FriendRequest { from: user.id },
)?)
.await
.is_err()
{
close_socket(user.id, &pool, &db).await?;
}
}
send_message_to_user(
&db,
friend.id.into(),
&ServerToClientMessage::FriendRequest { from: user.id },
)
.await?;
}
transaction.commit().await?;
@ -178,18 +165,12 @@ pub async fn remove_friend(
)
.await?;
if let Some(socket) = db.auth_sockets.get(&friend.id.into()) {
let (_, socket) = socket.value();
let _ = socket
.clone()
.text(serde_json::to_string(
&ServerToClientMessage::FriendRequestRejected {
from: user.id,
},
)?)
.await;
}
send_message_to_user(
&db,
friend.id.into(),
&ServerToClientMessage::FriendRequestRejected { from: user.id },
)
.await?;
transaction.commit().await?;

View File

@ -1,19 +1,7 @@
use std::{collections::HashSet, fmt::Display, sync::Arc};
use actix_web::{
delete, get, patch, post,
web::{self, scope},
HttpRequest, HttpResponse,
};
use chrono::Utc;
use itertools::Itertools;
use rand::{distributions::Alphanumeric, Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use validator::Validate;
use super::ApiError;
use crate::common::ids::base62_impl::parse_base62;
use crate::{
auth::{checks::ValidateAuthorized, get_user_from_headers},
database::{
@ -35,13 +23,21 @@ use crate::{
util::validate::validation_errors_to_string,
};
use crate::{
file_hosting::FileHost,
models::{
ids::base62_impl::parse_base62,
oauth_clients::DeleteOAuthClientQueryParam,
},
file_hosting::FileHost, models::oauth_clients::DeleteOAuthClientQueryParam,
util::routes::read_from_payload,
};
use actix_web::{
delete, get, patch, post,
web::{self, scope},
HttpRequest, HttpResponse,
};
use chrono::Utc;
use itertools::Itertools;
use rand::{distributions::Alphanumeric, Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use validator::Validate;
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
use crate::models::ids::OAuthClientId as ApiOAuthClientId;

View File

@ -3,13 +3,13 @@ use std::sync::Arc;
use super::ApiError;
use crate::auth::{filter_visible_projects, get_user_from_headers};
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models::team_item::TeamMember;
use crate::database::models::{
generate_organization_id, team_item, Organization,
};
use crate::database::redis::RedisPool;
use crate::file_hosting::FileHost;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::ids::UserId;
use crate::models::organizations::OrganizationId;
use crate::models::pats::Scopes;
@ -786,7 +786,7 @@ pub async fn organization_projects_add(
let organization_owner_user_id = sqlx::query!(
"
SELECT u.id
SELECT u.id
FROM team_members
INNER JOIN users u ON u.id = team_members.user_id
WHERE team_id = $1 AND is_owner = TRUE
@ -969,7 +969,7 @@ pub async fn organization_projects_remove(
sqlx::query!(
"
UPDATE team_members
SET
SET
is_owner = TRUE,
accepted = TRUE,
permissions = $2,

View File

@ -1,5 +1,6 @@
use super::version_creation::{try_create_version_fields, InitialVersionData};
use crate::auth::{get_user_from_headers, AuthenticationError};
use crate::common::ids::base62_impl::to_base62;
use crate::database::models::loader_fields::{
Loader, LoaderField, LoaderFieldEnumValue,
};
@ -8,7 +9,6 @@ use crate::database::models::{self, image_item, User};
use crate::database::redis::RedisPool;
use crate::file_hosting::{FileHost, FileHostingError};
use crate::models::error::ApiError;
use crate::models::ids::base62_impl::to_base62;
use crate::models::ids::{ImageId, OrganizationId};
use crate::models::images::{Image, ImageContext};
use crate::models::pats::Scopes;

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::auth::checks::{filter_visible_versions, is_visible_project};
use crate::auth::{filter_visible_projects, get_user_from_headers};
use crate::common::ids::base62_impl::parse_base62;
use crate::database::models::notification_item::NotificationBuilder;
use crate::database::models::project_item::{GalleryItem, ModCategory};
use crate::database::models::thread_item::ThreadMessageBuilder;
@ -11,7 +12,6 @@ use crate::database::redis::RedisPool;
use crate::database::{self, models as db_models};
use crate::file_hosting::FileHost;
use crate::models;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::images::ImageContext;
use crate::models::notifications::NotificationBody;
use crate::models::pats::Scopes;

View File

@ -1,4 +1,5 @@
use crate::auth::{check_is_moderator_from_headers, get_user_from_headers};
use crate::common::ids::base62_impl::parse_base62;
use crate::database;
use crate::database::models::image_item;
use crate::database::models::thread_item::{
@ -6,9 +7,7 @@ use crate::database::models::thread_item::{
};
use crate::database::redis::RedisPool;
use crate::models::ids::ImageId;
use crate::models::ids::{
base62_impl::parse_base62, ProjectId, UserId, VersionId,
};
use crate::models::ids::{ProjectId, UserId, VersionId};
use crate::models::images::{Image, ImageContext};
use crate::models::pats::Scopes;
use crate::models::reports::{ItemType, Report};

View File

@ -5,6 +5,7 @@ use crate::auth::checks::{
filter_visible_versions, is_visible_project, is_visible_version,
};
use crate::auth::get_user_from_headers;
use crate::common::ids::base62_impl::parse_base62;
use crate::database;
use crate::database::models::loader_fields::{
self, LoaderField, LoaderFieldEnumValue, VersionField,
@ -13,7 +14,6 @@ use crate::database::models::version_item::{DependencyBuilder, LoaderVersion};
use crate::database::models::{image_item, Organization};
use crate::database::redis::RedisPool;
use crate::models;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::ids::VersionId;
use crate::models::images::ImageContext;
use crate::models::pats::Scopes;
@ -444,7 +444,7 @@ pub async fn version_edit_helper(
.collect::<Vec<i32>>();
sqlx::query!(
"
DELETE FROM version_fields
DELETE FROM version_fields
WHERE version_id = $1
AND field_id = ANY($2)
",

View File

@ -1,8 +1,8 @@
/// This module is used for the indexing from any source.
pub mod local_import;
use crate::common::ids::base62_impl::to_base62;
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::to_base62;
use crate::search::{SearchConfig, UploadSearchProject};
use local_import::index_local;
use log::info;

View File

@ -25,7 +25,7 @@ pub fn get_color_from_img(data: &[u8]) -> Result<Option<u32>, ImageError> {
)
.ok()
.and_then(|x| x.first().copied())
.map(|x| (x.r as u32) << 16 | (x.g as u32) << 8 | (x.b as u32));
.map(|x| ((x.r as u32) << 16) | ((x.g as u32) << 8) | (x.b as u32));
Ok(color)
}

View File

@ -1,6 +1,6 @@
use crate::common::ids::base62_impl::to_base62;
use crate::database::models::legacy_loader_fields::MinecraftGameVersion;
use crate::database::redis::RedisPool;
use crate::models::ids::base62_impl::to_base62;
use crate::models::projects::ProjectId;
use crate::routes::ApiError;
use chrono::{DateTime, Utc};

View File

@ -7,7 +7,7 @@ use common::{
environment::{with_test_environment, TestEnvironment},
};
use itertools::Itertools;
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::models::teams::ProjectPermissions;
use labrinth::queue::payouts;
use rust_decimal::{prelude::ToPrimitive, Decimal};

View File

@ -9,10 +9,10 @@ use common::environment::{
};
use common::permissions::{PermissionsTest, PermissionsTestContext};
use futures::StreamExt;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::database::models::project_item::{
PROJECTS_NAMESPACE, PROJECTS_SLUGS_NAMESPACE,
};
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::models::projects::ProjectId;
use labrinth::models::teams::ProjectPermissions;
use labrinth::util::actix::{MultipartSegment, MultipartSegmentData};

View File

@ -18,7 +18,7 @@ use common::environment::{
with_test_environment, with_test_environment_all, TestEnvironment,
};
use common::{database::*, scopes::ScopeTest};
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::models::pats::Scopes;
use labrinth::models::projects::ProjectId;
use labrinth::models::users::UserId;

View File

@ -8,7 +8,7 @@ use common::environment::with_test_environment;
use common::environment::TestEnvironment;
use common::search::setup_search_projects;
use futures::stream::StreamExt;
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::common::ids::base62_impl::parse_base62;
use serde_json::json;
use crate::common::api_common::Api;

View File

@ -18,12 +18,10 @@ use actix_http::StatusCode;
use actix_web::test;
use futures::StreamExt;
use itertools::Itertools;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::{
database::models::project_item::PROJECTS_SLUGS_NAMESPACE,
models::{
ids::base62_impl::parse_base62, projects::ProjectId,
teams::ProjectPermissions,
},
models::{projects::ProjectId, teams::ProjectPermissions},
util::actix::{AppendsMultipart, MultipartSegment, MultipartSegmentData},
};
use serde_json::json;

View File

@ -6,7 +6,7 @@ use crate::common::dummy_data::TestFile;
use crate::common::environment::with_test_environment;
use crate::common::environment::TestEnvironment;
use crate::common::scopes::ScopeTest;
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::models::pats::Scopes;
use labrinth::models::projects::ProjectId;

View File

@ -11,7 +11,7 @@ use crate::common::environment::with_test_environment;
use crate::common::environment::TestEnvironment;
use actix_http::StatusCode;
use futures::stream::StreamExt;
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::common::ids::base62_impl::parse_base62;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;

View File

@ -13,8 +13,8 @@ use common::asserts::assert_common_version_ids;
use common::database::USER_USER_PAT;
use common::environment::{with_test_environment, with_test_environment_all};
use futures::StreamExt;
use labrinth::common::ids::base62_impl::parse_base62;
use labrinth::database::models::version_item::VERSIONS_NAMESPACE;
use labrinth::models::ids::base62_impl::parse_base62;
use labrinth::models::projects::{
Dependency, DependencyType, VersionId, VersionStatus, VersionType,
};

View File

@ -0,0 +1,2 @@
[env]
SQLX_OFFLINE = "true"

View File

@ -29,6 +29,7 @@ regex = "1.5"
sys-info = "0.9.0"
sysinfo = "0.30.8"
thiserror = "1.0"
either = "1.13"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.18", features = ["chrono", "env-filter"] }
@ -62,6 +63,8 @@ base64 = "0.22.0"
sqlx = { version = "0.8.2", features = [ "runtime-tokio", "sqlite", "macros" ] }
labrinth = { path = "../../apps/labrinth" }
[target.'cfg(windows)'.dependencies]
winreg = "0.52.0"

View File

@ -1,4 +1,5 @@
use crate::state::{FriendsSocket, UserFriend, UserStatus};
use crate::state::{FriendsSocket, UserFriend};
use labrinth::common::users::UserStatus;
#[tracing::instrument]
pub async fn friends() -> crate::Result<Vec<UserFriend>> {

View File

@ -19,8 +19,9 @@ pub mod data {
Hooks, JavaVersion, LinkedData, MemorySettings, ModLoader,
ModrinthCredentials, Organization, ProcessMetadata, ProfileFile,
Project, ProjectType, SearchResult, SearchResults, Settings,
TeamMember, Theme, User, UserFriend, UserStatus, Version, WindowSize,
TeamMember, Theme, User, UserFriend, Version, WindowSize,
};
pub use labrinth::common::users::UserStatus;
}
pub mod prelude {

View File

@ -13,6 +13,11 @@ pub enum ErrorKind {
#[error("Serialization error (JSON): {0}")]
JSONError(#[from] serde_json::Error),
#[error("Serialization error (websocket): {0}")]
WebsocketSerializationError(
#[from] labrinth::common::networking::serialization::SerializationError,
),
#[error("Error parsing UUID: {0}")]
UUIDError(#[from] uuid::Error),

View File

@ -1,6 +1,6 @@
//! Theseus state management system
use crate::state::UserStatus;
use dashmap::DashMap;
use labrinth::common::users::{UserId, UserStatus};
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
#[cfg(feature = "tauri")]
@ -262,8 +262,8 @@ pub enum EventError {
#[serde(rename_all = "snake_case")]
#[serde(tag = "event")]
pub enum FriendPayload {
FriendRequest { from: String },
UserOffline { id: String },
FriendRequest { from: UserId },
UserOffline { id: UserId },
StatusUpdate { user_status: UserStatus },
StatusSync,
}

View File

@ -2,7 +2,8 @@ use crate::config::{MODRINTH_API_URL_V3, MODRINTH_SOCKET_URL};
use crate::data::ModrinthCredentials;
use crate::event::emit::emit_friend;
use crate::event::FriendPayload;
use crate::state::{ProcessManager, Profile};
use crate::state::tunnel::InternalTunnelSocket;
use crate::state::{ProcessManager, Profile, TunnelSocket};
use crate::util::fetch::{fetch_advanced, fetch_json, FetchSemaphore};
use async_tungstenite::tokio::{connect_async, ConnectStream};
use async_tungstenite::tungstenite::client::IntoClientRequest;
@ -10,20 +11,33 @@ use async_tungstenite::tungstenite::Message;
use async_tungstenite::WebSocketStream;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use either::Either;
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use labrinth::common::networking::message::{
ClientToServerMessage, ServerToClientMessage,
};
use labrinth::common::users::{UserId, UserStatus};
use reqwest::header::HeaderValue;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::ops::Deref;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::OwnedReadHalf;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, RwLock};
use uuid::Uuid;
type WriteSocket =
pub(super) type WriteSocket =
Arc<RwLock<Option<SplitSink<WebSocketStream<ConnectStream>, Message>>>>;
pub(super) type TunnelSockets = Arc<DashMap<Uuid, Arc<InternalTunnelSocket>>>;
pub struct FriendsSocket {
write: WriteSocket,
user_statuses: Arc<DashMap<String, UserStatus>>,
user_statuses: Arc<DashMap<UserId, UserStatus>>,
tunnel_sockets: TunnelSockets,
}
#[derive(Deserialize, Serialize)]
@ -34,28 +48,6 @@ pub struct UserFriend {
pub created: DateTime<Utc>,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientToServerMessage {
StatusUpdate { profile_name: Option<String> },
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerToClientMessage {
StatusUpdate { status: UserStatus },
UserOffline { id: String },
FriendStatuses { statuses: Vec<UserStatus> },
FriendRequest { from: String },
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct UserStatus {
pub user_id: String,
pub profile_name: Option<String>,
pub last_update: DateTime<Utc>,
}
impl Default for FriendsSocket {
fn default() -> Self {
Self::new()
@ -67,6 +59,7 @@ impl FriendsSocket {
Self {
write: Arc::new(RwLock::new(None)),
user_statuses: Arc::new(DashMap::new()),
tunnel_sockets: Arc::new(DashMap::new()),
}
}
@ -120,6 +113,7 @@ impl FriendsSocket {
let write_handle = self.write.clone();
let statuses = self.user_statuses.clone();
let sockets = self.tunnel_sockets.clone();
tokio::spawn(async move {
let mut read_stream = read;
@ -128,18 +122,14 @@ impl FriendsSocket {
Ok(msg) => {
let server_message = match msg {
Message::Text(text) => {
serde_json::from_str::<
ServerToClientMessage,
>(
&text
ServerToClientMessage::deserialize(
Either::Left(&text),
)
.ok()
}
Message::Binary(bytes) => {
serde_json::from_slice::<
ServerToClientMessage,
>(
&bytes
ServerToClientMessage::deserialize(
Either::Right(&bytes),
)
.ok()
}
@ -165,7 +155,7 @@ impl FriendsSocket {
{
match server_message {
ServerToClientMessage::StatusUpdate { status } => {
statuses.insert(status.user_id.clone(), status.clone());
statuses.insert(status.user_id, status.clone());
let _ = emit_friend(FriendPayload::StatusUpdate { user_status: status }).await;
},
ServerToClientMessage::UserOffline { id } => {
@ -175,13 +165,41 @@ impl FriendsSocket {
ServerToClientMessage::FriendStatuses { statuses: new_statuses } => {
statuses.clear();
new_statuses.into_iter().for_each(|status| {
statuses.insert(status.user_id.clone(), status);
statuses.insert(status.user_id, status);
});
let _ = emit_friend(FriendPayload::StatusSync).await;
}
ServerToClientMessage::FriendRequest { from } => {
let _ = emit_friend(FriendPayload::FriendRequest { from }).await;
}
ServerToClientMessage::FriendRequestRejected { .. } => todo!(),
ServerToClientMessage::FriendSocketListening { .. } => {}, // TODO
ServerToClientMessage::FriendSocketStoppedListening { .. } => {}, // TODO
ServerToClientMessage::SocketConnected { to_socket, new_socket } => {
if let Some(connected_to) = sockets.get(&to_socket) {
if let InternalTunnelSocket::Listening(local_addr) = *connected_to.value().clone() {
if let Ok(new_stream) = TcpStream::connect(local_addr).await {
let (read, write) = new_stream.into_split();
sockets.insert(new_socket, Arc::new(InternalTunnelSocket::Connected(Mutex::new(write))));
Self::socket_read_loop(write_handle.clone(), read, new_socket);
continue;
}
}
}
let _ = Self::send_message(&write_handle, ClientToServerMessage::SocketClose { socket: new_socket }).await;
},
ServerToClientMessage::SocketClosed { socket } => {
sockets.remove_if(&socket, |_, x| matches!(*x.clone(), InternalTunnelSocket::Connected(_)));
},
ServerToClientMessage::SocketData { socket, data } => {
if let Some(mut socket) = sockets.get_mut(&socket) {
if let InternalTunnelSocket::Connected(ref stream) = *socket.value_mut().clone() {
let _ = stream.lock().await.write_all(&data).await;
}
}
},
}
}
}
@ -217,10 +235,7 @@ impl FriendsSocket {
let mut last_ping = Utc::now();
loop {
let connected = {
let read = state.friends_socket.write.read().await;
read.is_some()
};
let connected = state.friends_socket.is_connected().await;
if !connected
&& Utc::now().signed_duration_since(last_connection)
@ -269,16 +284,11 @@ impl FriendsSocket {
&self,
profile_name: Option<String>,
) -> crate::Result<()> {
let mut write_lock = self.write.write().await;
if let Some(ref mut write_half) = *write_lock {
write_half
.send(Message::Text(serde_json::to_string(
&ClientToServerMessage::StatusUpdate { profile_name },
)?))
.await?;
}
Ok(())
Self::send_message(
&self.write,
ClientToServerMessage::StatusUpdate { profile_name },
)
.await
}
#[tracing::instrument(skip_all)]
@ -346,4 +356,81 @@ impl FriendsSocket {
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn open_port(&self, port: u16) -> crate::Result<TunnelSocket> {
let socket_id = Uuid::new_v4();
let socket = self.tunnel_sockets.entry(socket_id).insert(Arc::new(
InternalTunnelSocket::Listening(SocketAddr::new(
"127.0.0.1".parse().unwrap(),
port,
)),
));
Self::send_message(
&self.write,
ClientToServerMessage::SocketListen { socket: socket_id },
)
.await?;
self.create_tunnel_socket(socket_id, socket)
}
pub async fn is_connected(&self) -> bool {
self.write.read().await.is_some()
}
fn create_tunnel_socket(
&self,
socket_id: Uuid,
socket: impl Deref<Target = Arc<InternalTunnelSocket>>,
) -> crate::Result<TunnelSocket> {
Ok(TunnelSocket {
socket_id,
write: self.write.clone(),
sockets: self.tunnel_sockets.clone(),
internal: socket.clone(),
})
}
fn socket_read_loop(
write: WriteSocket,
mut read_half: OwnedReadHalf,
socket_id: Uuid,
) {
tokio::spawn(async move {
let mut read_buffer = [0u8; 8192];
loop {
match read_half.read(&mut read_buffer).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let _ = Self::send_message(
&write,
ClientToServerMessage::SocketSend {
socket: socket_id,
data: read_buffer[..n].to_vec(),
},
)
.await;
}
};
}
});
}
#[tracing::instrument(skip(write))]
pub(super) async fn send_message(
write: &WriteSocket,
message: ClientToServerMessage,
) -> crate::Result<()> {
let serialized = match message.serialize()? {
Either::Left(text) => Message::text(text),
Either::Right(bytes) => Message::binary(bytes),
};
let mut write_lock = write.write().await;
if let Some(ref mut write_half) = *write_lock {
write_half.send(serialized).await?;
}
Ok(())
}
}

View File

@ -34,6 +34,9 @@ pub use self::cache::*;
mod friends;
pub use self::friends::*;
mod tunnel;
pub use self::tunnel::*;
pub mod db;
pub mod fs_watcher;
mod mr_auth;

View File

@ -0,0 +1,61 @@
use crate::state::friends::{TunnelSockets, WriteSocket};
use crate::state::FriendsSocket;
use labrinth::common::networking::message::ClientToServerMessage;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::sync::Mutex;
use uuid::Uuid;
pub(super) enum InternalTunnelSocket {
Listening(SocketAddr),
Connected(Mutex<OwnedWriteHalf>),
}
pub struct TunnelSocket {
pub(super) socket_id: Uuid,
pub(super) write: WriteSocket,
pub(super) sockets: TunnelSockets,
pub(super) internal: Arc<InternalTunnelSocket>,
}
impl TunnelSocket {
pub fn socket_id(&self) -> Uuid {
self.socket_id
}
pub async fn shutdown(self) -> crate::Result<()> {
if self.sockets.remove(&self.socket_id).is_some() {
FriendsSocket::send_message(
&self.write,
ClientToServerMessage::SocketClose {
socket: self.socket_id,
},
)
.await?;
if let InternalTunnelSocket::Connected(ref stream) =
*self.internal.clone()
{
stream.lock().await.shutdown().await?
}
}
Ok(())
}
}
impl Drop for TunnelSocket {
fn drop(&mut self) {
if self.sockets.remove(&self.socket_id).is_some() {
let write = self.write.clone();
let socket_id = self.socket_id;
tokio::spawn(async move {
let _ = FriendsSocket::send_message(
&write,
ClientToServerMessage::SocketClose { socket: socket_id },
)
.await;
});
}
}
}

View File

@ -1,4 +1,6 @@
//! Functions for fetching infromation from the Internet
use super::io::{self, IOError};
use crate::config::{MODRINTH_API_URL, MODRINTH_API_URL_V3};
use crate::event::emit::emit_loading;
use crate::event::LoadingBarId;
use bytes::Bytes;
@ -11,8 +13,6 @@ use std::time::{self};
use tokio::sync::Semaphore;
use tokio::{fs::File, io::AsyncWriteExt};
use super::io::{self, IOError};
#[derive(Debug)]
pub struct IoSemaphore(pub Semaphore);
#[derive(Debug)]
@ -87,7 +87,8 @@ pub async fn fetch_advanced(
.map(|x| &*x.0.to_lowercase() == "authorization")
.unwrap_or(false)
&& (url.starts_with("https://cdn.modrinth.com")
|| url.starts_with("https://api.modrinth.com"))
|| url.starts_with(MODRINTH_API_URL)
|| url.starts_with(MODRINTH_API_URL_V3))
{
crate::state::ModrinthCredentials::get_active(exec).await?
} else {