Distributed rate limit, fix search panic, add migration task (#3419)

* Distributed rate limit, fix search panic, add migration task

* Add binary info to root endpoint
This commit is contained in:
Jai Agrawal 2025-03-25 01:10:43 -07:00 committed by GitHub
parent 5fbf5b22c0
commit b5a9a93323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 317 additions and 277 deletions

74
Cargo.lock generated
View File

@ -2469,6 +2469,12 @@ dependencies = [
"const-random",
]
[[package]]
name = "dotenv-build"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4547f16c17f6051a12cdb8c62b803f94bee6807c74aa7c530b30b737df981fc"
[[package]]
name = "dotenvy"
version = "0.15.7"
@ -3345,26 +3351,6 @@ dependencies = [
"system-deps",
]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if 1.0.0",
"dashmap 5.5.3",
"futures 0.3.30",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot 0.12.3",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec 1.13.2",
"spinning_top",
]
[[package]]
name = "group"
version = "0.13.0"
@ -4529,17 +4515,18 @@ dependencies = [
"dashmap 5.5.3",
"deadpool-redis",
"derive-new",
"dotenv-build",
"dotenvy",
"either",
"flate2",
"futures 0.3.30",
"futures-timer",
"futures-util",
"governor",
"hex",
"hmac 0.11.0",
"hyper 0.14.31",
"hyper-tls 0.5.0",
"iana-time-zone",
"image 0.24.9",
"itertools 0.12.1",
"jemalloc_pprof",
@ -5253,12 +5240,6 @@ dependencies = [
"memoffset 0.9.1",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nodrop"
version = "0.1.14"
@ -5275,12 +5256,6 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "normpath"
version = "1.3.0"
@ -6564,21 +6539,6 @@ dependencies = [
"bytemuck",
]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils 0.8.20",
"libc",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi 0.3.9",
]
[[package]]
name = "quick-error"
version = "2.0.1"
@ -6792,15 +6752,6 @@ dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "raw-cpuid"
version = "11.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
dependencies = [
"bitflags 2.6.0",
]
[[package]]
name = "raw-window-handle"
version = "0.5.2"
@ -8284,15 +8235,6 @@ dependencies = [
"lock_api 0.4.12",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api 0.4.12",
]
[[package]]
name = "spki"
version = "0.7.3"

View File

@ -1,2 +1,3 @@
BASE_URL=https://api.modrinth.com/v2/
BROWSER_BASE_URL=https://api.modrinth.com/v2/
PYRO_BASE_URL=https://archon.modrinth.com/

View File

@ -19,14 +19,13 @@ actix-ws = "0.3.0"
actix-files = "0.6.5"
prometheus = "0.13.4"
actix-web-prom = { version = "0.9.0", features = ["process"] }
governor = "0.6.3"
tracing = "0.1.41"
tracing-subscriber = "0.3.19"
tracing-actix-web = "0.7.16"
console-subscriber = "0.4.1"
tokio = { version = "1.35.1", features = ["sync"] }
tokio = { version = "1.35.1", features = ["sync", "rt-multi-thread"] }
tokio-stream = "0.1.14"
futures = "0.3.30"
@ -132,6 +131,7 @@ json-patch = "*"
ariadne = { path = "../../packages/ariadne" }
clap = { version = "4.5", features = ["derive"] }
iana-time-zone = "0.1.61"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = { version = "0.6.0", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] }
@ -140,3 +140,8 @@ jemalloc_pprof = { version = "0.7.0", features = ["flamegraph"] }
[dev-dependencies]
actix-http = "3.4.0"
[build-dependencies]
dotenv-build = "0.1.1"
chrono = "0.4.38"
iana-time-zone = "0.1.60"

View File

@ -1,3 +1,47 @@
use std::path::Path;
use std::process::Command;
use chrono::Local;
use dotenv_build::Config;
fn main() {
let output = Command::new("git")
.arg("rev-parse")
.arg("HEAD")
.output()
.expect("`git` invocation to succeed");
let git_hash = String::from_utf8(output.stdout)
.expect("valid UTF-8 output from `git` invocation");
println!("cargo::rerun-if-changed=.git/HEAD");
println!("cargo::rustc-env=GIT_HASH={}", git_hash.trim());
let timedate_fmt = Local::now().format("%F @ %I:%M %p");
let timezone_fmt = iana_time_zone::get_timezone()
.map(|tz| format!(" ({tz})"))
.unwrap_or_default();
let comptime = timedate_fmt.to_string() + timezone_fmt.as_str();
println!("cargo::rustc-env=COMPILATION_DATE={comptime}");
// trick to get compilation profile
let profile = std::env::var("OUT_DIR")
.expect("OUT_DIR to be set")
.split(std::path::MAIN_SEPARATOR)
.nth_back(3)
.unwrap_or("unknown")
.to_string();
println!("cargo::rustc-env=COMPILATION_PROFILE={profile}");
dotenv_build::output(Config {
filename: Path::new(".build.env"),
recursive_search: true,
fail_if_missing_dotenv: false,
})
.unwrap();
println!("cargo:rerun-if-changed=migrations");
}

View File

@ -1,7 +1,7 @@
use crate::database::redis::RedisPool;
use crate::queue::payouts::process_payout;
use crate::search;
use crate::search::indexing::index_projects;
use crate::{database, search};
use clap::ValueEnum;
use sqlx::Postgres;
use tracing::{info, warn};
@ -15,6 +15,7 @@ pub enum BackgroundTask {
Payouts,
IndexBilling,
IndexSubscriptions,
Migrations,
}
impl BackgroundTask {
@ -28,6 +29,7 @@ impl BackgroundTask {
) {
use BackgroundTask::*;
match self {
Migrations => run_migrations().await,
IndexSearch => index_search(pool, redis_pool, search_config).await,
ReleaseScheduled => release_scheduled(pool).await,
UpdateVersions => update_versions(pool, redis_pool).await,
@ -50,6 +52,12 @@ impl BackgroundTask {
}
}
pub async fn run_migrations() {
database::check_for_migrations()
.await
.expect("An error occurred while running migrations.");
}
pub async fn index_search(
pool: sqlx::Pool<Postgres>,
redis_pool: RedisPool,

View File

@ -1,4 +1,3 @@
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
@ -13,14 +12,12 @@ use tracing::{info, warn};
extern crate clickhouse as clickhouse_crate;
use clickhouse_crate::Client;
use governor::middleware::StateInformationMiddleware;
use governor::{Quota, RateLimiter};
use util::cors::default_cors;
use crate::background_task::update_versions;
use crate::queue::moderation::AutomatedModerationQueue;
use crate::util::env::{parse_strings_from_var, parse_var};
use crate::util::ratelimit::KeyedRateLimiter;
use crate::util::ratelimit::{AsyncRateLimiter, GCRAParameters};
use sync::friends::handle_pubsub;
pub mod auth;
@ -57,7 +54,7 @@ pub struct LabrinthConfig {
pub analytics_queue: Arc<AnalyticsQueue>,
pub active_sockets: web::Data<ActiveSockets>,
pub automated_moderation_queue: web::Data<AutomatedModerationQueue>,
pub rate_limiter: KeyedRateLimiter,
pub rate_limiter: web::Data<AsyncRateLimiter>,
pub stripe_client: stripe::Client,
}
@ -93,24 +90,10 @@ pub fn app_setup(
let mut scheduler = scheduler::Scheduler::new();
let limiter: KeyedRateLimiter = Arc::new(
RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap()))
.with_middleware::<StateInformationMiddleware>(),
);
let limiter_clone = Arc::clone(&limiter);
scheduler.run(Duration::from_secs(60), move || {
info!(
"Clearing ratelimiter, storage size: {}",
limiter_clone.len()
);
limiter_clone.retain_recent();
info!(
"Done clearing ratelimiter, storage size: {}",
limiter_clone.len()
);
async move {}
});
let limiter = web::Data::new(AsyncRateLimiter::new(
redis_pool.clone(),
GCRAParameters::new(300, 300),
));
if enable_background_tasks {
// The interval in seconds at which the local database is indexed
@ -329,6 +312,7 @@ pub fn app_config(
.app_data(labrinth_config.active_sockets.clone())
.app_data(labrinth_config.automated_moderation_queue.clone())
.app_data(web::Data::new(labrinth_config.stripe_client.clone()))
.app_data(labrinth_config.rate_limiter.clone())
.configure(
#[allow(unused_variables)]
|cfg| {

View File

@ -1,3 +1,4 @@
use actix_web::middleware::from_fn;
use actix_web::{App, HttpServer};
use actix_web_prom::PrometheusMetricsBuilder;
use clap::Parser;
@ -5,7 +6,7 @@ use labrinth::background_task::BackgroundTask;
use labrinth::database::redis::RedisPool;
use labrinth::file_hosting::S3Host;
use labrinth::search;
use labrinth::util::ratelimit::RateLimit;
use labrinth::util::ratelimit::rate_limit_middleware;
use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue};
use std::sync::Arc;
use tracing::{error, info};
@ -33,6 +34,10 @@ struct Args {
#[arg(long)]
no_background_tasks: bool,
/// Don't automatically run migrations. This means the migrations should be run via --run-background-task.
#[arg(long)]
no_migrations: bool,
/// Run a single background task and then exit. Perfect for cron jobs.
#[arg(long, value_enum, id = "task")]
run_background_task: Option<BackgroundTask>,
@ -67,10 +72,12 @@ async fn main() -> std::io::Result<()> {
dotenvy::var("BIND_ADDR").unwrap()
);
if !args.no_migrations {
database::check_for_migrations()
.await
.expect("An error occurred while running migrations.");
}
}
// Database Connector
let pool = database::connect()
@ -164,7 +171,7 @@ async fn main() -> std::io::Result<()> {
App::new()
.wrap(TracingLogger::default())
.wrap(prometheus.clone())
.wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter)))
.wrap(from_fn(rate_limit_middleware))
.wrap(actix_web::middleware::Compress::default())
.wrap(sentry_actix::Sentry::new())
.configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone()))

View File

@ -7,7 +7,13 @@ pub async fn index_get() -> HttpResponse {
"name": "modrinth-labrinth",
"version": env!("CARGO_PKG_VERSION"),
"documentation": "https://docs.modrinth.com",
"about": "Welcome traveler!"
"about": "Welcome traveler!",
"build_info": {
"comp_date": env!("COMPILATION_DATE"),
"git_hash": env!("GIT_HASH", "unknown"),
"profile": env!("COMPILATION_PROFILE"),
}
});
HttpResponse::Ok().json(data)

View File

@ -91,11 +91,11 @@ pub async fn ws_init(
let friend_statuses = if !friends.is_empty() {
let db = db.clone();
let redis = redis.clone();
tokio_stream::iter(friends.iter())
let statuses = tokio_stream::iter(friends.iter())
.map(|x| {
let db = db.clone();
let redis = redis.clone();
async move {
async move {
get_user_status(
if x.user_id == user_id.into() {
@ -109,12 +109,12 @@ pub async fn ws_init(
)
.await
}
}
})
.buffer_unordered(16)
.filter_map(|x| x)
.collect::<Vec<_>>()
.await
.await;
statuses.into_iter().flatten().collect()
} else {
Vec::new()
};

View File

@ -209,8 +209,9 @@ pub async fn search_for_project(
let mut filter_string = String::new();
// Convert offset and limit to page and hits_per_page
let hits_per_page = limit;
let page = offset / limit + 1;
let hits_per_page = if limit == 0 { 1 } else { limit };
let page = offset / hits_per_page + 1;
let results = {
let mut query = meilisearch_index.search();

View File

@ -1,76 +1,145 @@
use governor::clock::{Clock, DefaultClock};
use governor::{middleware, state, RateLimiter};
use std::str::FromStr;
use std::sync::Arc;
use crate::database::redis::RedisPool;
use crate::routes::ApiError;
use crate::util::env::parse_var;
use actix_web::{
body::EitherBody,
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, ResponseError,
body::{EitherBody, MessageBody},
dev::{ServiceRequest, ServiceResponse},
middleware::Next,
web, Error, ResponseError,
};
use futures_util::future::LocalBoxFuture;
use futures_util::future::{ready, Ready};
use chrono::Utc;
use std::str::FromStr;
use std::sync::Arc;
pub type KeyedRateLimiter<
K = String,
MW = middleware::StateInformationMiddleware,
> = Arc<
RateLimiter<K, state::keyed::DefaultKeyedStateStore<K>, DefaultClock, MW>,
>;
const RATE_LIMIT_NAMESPACE: &str = "rate_limit";
const RATE_LIMIT_EXPIRY: i64 = 300; // 5 minutes
const MINUTE_IN_NANOS: i64 = 60_000_000_000;
pub struct RateLimit(pub KeyedRateLimiter);
pub struct GCRAParameters {
emission_interval: i64,
burst_size: u32,
}
impl<S, B> Transform<S, ServiceRequest> for RateLimit
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = RateLimitService<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
impl GCRAParameters {
pub(crate) fn new(requests_per_minute: u32, burst_size: u32) -> Self {
// Calculate emission interval in nanoseconds
let emission_interval = MINUTE_IN_NANOS / requests_per_minute as i64;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RateLimitService {
service,
rate_limiter: Arc::clone(&self.0),
}))
Self {
emission_interval,
burst_size,
}
}
}
#[doc(hidden)]
pub struct RateLimitService<S> {
service: S,
rate_limiter: KeyedRateLimiter,
pub struct RateLimitDecision {
pub allowed: bool,
pub limit: u32,
pub remaining: u32,
pub reset_after_ms: i64,
pub retry_after_ms: Option<i64>,
}
impl<S, B> Service<ServiceRequest> for RateLimitService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
#[derive(Clone)]
pub struct AsyncRateLimiter {
redis_pool: RedisPool,
params: Arc<GCRAParameters>,
}
forward_ready!(service);
impl AsyncRateLimiter {
pub fn new(redis_pool: RedisPool, params: GCRAParameters) -> Self {
Self {
redis_pool,
params: Arc::new(params),
}
}
pub async fn check_rate_limit(&self, key: &str) -> RateLimitDecision {
let mut conn = match self.redis_pool.connect().await {
Ok(conn) => conn,
Err(_) => {
// If Redis is unavailable, allow the request but with reduced limit
return RateLimitDecision {
allowed: true,
limit: self.params.burst_size,
remaining: 1,
reset_after_ms: 60_000, // 1 minute
retry_after_ms: None,
};
}
};
// Get current time in nanoseconds since UNIX epoch
let now = Utc::now().timestamp_nanos_opt().unwrap_or(0);
// Get the current TAT from Redis (if it exists)
let tat_str = conn.get(RATE_LIMIT_NAMESPACE, key).await.ok().flatten();
// Parse the TAT or use current time if not found
let current_tat = match tat_str {
Some(tat_str) => tat_str.parse::<i64>().unwrap_or(now),
None => now,
};
// Calculate the new TAT using GCRA
let increment = self.params.emission_interval;
let max_tat_delta = increment * self.params.burst_size as i64;
// Calculate allowance: how much time has passed since the TAT
let allowance = now - current_tat;
if allowance < -max_tat_delta {
// Too many requests, rate limit exceeded
// Calculate when the client can retry
let retry_after_ms = (-allowance - max_tat_delta) / 1_000_000;
return RateLimitDecision {
allowed: false,
limit: self.params.burst_size,
remaining: 0,
reset_after_ms: -allowance / 1_000_000,
retry_after_ms: Some(retry_after_ms.max(0)),
};
}
let new_tat = std::cmp::max(current_tat + increment, now);
let _ = conn
.set(
RATE_LIMIT_NAMESPACE,
key,
&new_tat.to_string(),
Some(RATE_LIMIT_EXPIRY),
)
.await;
let remaining_capacity =
((max_tat_delta - (new_tat - now)) / increment).max(0) as u32;
RateLimitDecision {
allowed: true,
limit: self.params.burst_size,
remaining: remaining_capacity,
reset_after_ms: (new_tat - now) / 1_000_000,
retry_after_ms: None,
}
}
}
pub async fn rate_limit_middleware(
req: ServiceRequest,
next: Next<impl MessageBody>,
) -> Result<ServiceResponse<EitherBody<impl MessageBody>>, Error> {
let rate_limiter = req
.app_data::<web::Data<AsyncRateLimiter>>()
.expect("Rate limiter not configured properly")
.clone();
fn call(&self, req: ServiceRequest) -> Self::Future {
if let Some(key) = req.headers().get("x-ratelimit-key") {
if key.to_str().ok()
== dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref()
{
let res = self.service.call(req);
return Box::pin(async move {
let service_response = res.await?;
Ok(service_response.map_into_left_body())
});
return Ok(next.call(req).await?.map_into_left_body());
}
}
@ -86,72 +155,51 @@ where
};
if let Some(ip) = ip {
let ip = ip.to_string();
let decision = rate_limiter.check_rate_limit(ip).await;
match self.rate_limiter.check_key(&ip) {
Ok(snapshot) => {
let fut = self.service.call(req);
if decision.allowed {
let mut service_response = next.call(req).await?;
Box::pin(async move {
match fut.await {
Ok(mut service_response) => {
// Now you have a mutable reference to the ServiceResponse, so you can modify its headers.
// Add rate limit headers
let headers = service_response.headers_mut();
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
snapshot.quota().burst_size().get().into(),
decision.limit.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-remaining",
)
.unwrap(),
snapshot.remaining_burst_capacity().into(),
decision.remaining.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-reset",
)
.unwrap(),
snapshot
.quota()
.burst_size_replenished_in()
.as_secs()
.into(),
(decision.reset_after_ms / 1000).into(),
);
// Return the modified response as Ok.
Ok(service_response.map_into_left_body())
}
Err(e) => {
// Handle error case
Err(e)
}
}
})
}
Err(negative) => {
let wait_time =
negative.wait_time_from(DefaultClock::default().now());
} else {
let mut response = ApiError::RateLimitError(
wait_time.as_millis(),
negative.quota().burst_size().get(),
decision.retry_after_ms.unwrap_or(0) as u128,
decision.limit,
)
.error_response();
// Add rate limit headers
let headers = response.headers_mut();
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
negative.quota().burst_size().get().into(),
decision.limit.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
@ -165,10 +213,10 @@ where
"x-ratelimit-reset",
)
.unwrap(),
wait_time.as_secs().into(),
(decision.reset_after_ms / 1000).into(),
);
// TODO: Sentralize CORS in the CORS util.
// TODO: Centralize CORS in the CORS util.
headers.insert(
actix_web::http::header::HeaderName::from_str(
"Access-Control-Allow-Origin",
@ -177,10 +225,7 @@ where
"*".parse().unwrap(),
);
Box::pin(async {
Ok(req.into_response(response.map_into_right_body()))
})
}
}
} else {
let response = ApiError::CustomAuthentication(
@ -188,9 +233,6 @@ where
)
.error_response();
Box::pin(async {
Ok(req.into_response(response.map_into_right_body()))
})
}
}
}