From b5a9a93323144193dc2e81db3c58e4b8677d0389 Mon Sep 17 00:00:00 2001 From: Jai Agrawal <18202329+Geometrically@users.noreply.github.com> Date: Tue, 25 Mar 2025 01:10:43 -0700 Subject: [PATCH] Distributed rate limit, fix search panic, add migration task (#3419) * Distributed rate limit, fix search panic, add migration task * Add binary info to root endpoint --- Cargo.lock | 74 +--- apps/frontend/.env.example | 1 + apps/labrinth/Cargo.toml | 9 +- apps/labrinth/build.rs | 44 +++ apps/labrinth/src/background_task.rs | 10 +- apps/labrinth/src/lib.rs | 30 +- apps/labrinth/src/main.rs | 17 +- apps/labrinth/src/routes/index.rs | 8 +- apps/labrinth/src/routes/internal/statuses.rs | 32 +- apps/labrinth/src/search/mod.rs | 5 +- apps/labrinth/src/util/ratelimit.rs | 364 ++++++++++-------- 11 files changed, 317 insertions(+), 277 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3396996df..1b29561b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2469,6 +2469,12 @@ dependencies = [ "const-random", ] +[[package]] +name = "dotenv-build" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4547f16c17f6051a12cdb8c62b803f94bee6807c74aa7c530b30b737df981fc" + [[package]] name = "dotenvy" version = "0.15.7" @@ -3345,26 +3351,6 @@ dependencies = [ "system-deps", ] -[[package]] -name = "governor" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" -dependencies = [ - "cfg-if 1.0.0", - "dashmap 5.5.3", - "futures 0.3.30", - "futures-timer", - "no-std-compat", - "nonzero_ext", - "parking_lot 0.12.3", - "portable-atomic", - "quanta", - "rand 0.8.5", - "smallvec 1.13.2", - "spinning_top", -] - [[package]] name = "group" version = "0.13.0" @@ -4529,17 +4515,18 @@ dependencies = [ "dashmap 5.5.3", "deadpool-redis", "derive-new", + "dotenv-build", "dotenvy", "either", "flate2", "futures 0.3.30", "futures-timer", "futures-util", - "governor", "hex", "hmac 0.11.0", "hyper 0.14.31", "hyper-tls 0.5.0", + "iana-time-zone", "image 0.24.9", "itertools 0.12.1", "jemalloc_pprof", @@ -5253,12 +5240,6 @@ dependencies = [ "memoffset 0.9.1", ] -[[package]] -name = "no-std-compat" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" - [[package]] name = "nodrop" version = "0.1.14" @@ -5275,12 +5256,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "nonzero_ext" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" - [[package]] name = "normpath" version = "1.3.0" @@ -6564,21 +6539,6 @@ dependencies = [ "bytemuck", ] -[[package]] -name = "quanta" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" -dependencies = [ - "crossbeam-utils 0.8.20", - "libc", - "once_cell", - "raw-cpuid", - "wasi 0.11.0+wasi-snapshot-preview1", - "web-sys", - "winapi 0.3.9", -] - [[package]] name = "quick-error" version = "2.0.1" @@ -6792,15 +6752,6 @@ dependencies = [ "rand_core 0.5.1", ] -[[package]] -name = "raw-cpuid" -version = "11.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" -dependencies = [ - "bitflags 2.6.0", -] - [[package]] name = "raw-window-handle" version = "0.5.2" @@ -8284,15 +8235,6 @@ dependencies = [ "lock_api 0.4.12", ] -[[package]] -name = "spinning_top" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" -dependencies = [ - "lock_api 0.4.12", -] - [[package]] name = "spki" version = "0.7.3" diff --git a/apps/frontend/.env.example b/apps/frontend/.env.example index bd54ceb10..43ceb1d53 100644 --- a/apps/frontend/.env.example +++ b/apps/frontend/.env.example @@ -1,2 +1,3 @@ BASE_URL=https://api.modrinth.com/v2/ BROWSER_BASE_URL=https://api.modrinth.com/v2/ +PYRO_BASE_URL=https://archon.modrinth.com/ diff --git a/apps/labrinth/Cargo.toml b/apps/labrinth/Cargo.toml index 07bb1fbfe..bed7b380f 100644 --- a/apps/labrinth/Cargo.toml +++ b/apps/labrinth/Cargo.toml @@ -19,14 +19,13 @@ actix-ws = "0.3.0" actix-files = "0.6.5" prometheus = "0.13.4" actix-web-prom = { version = "0.9.0", features = ["process"] } -governor = "0.6.3" tracing = "0.1.41" tracing-subscriber = "0.3.19" tracing-actix-web = "0.7.16" console-subscriber = "0.4.1" -tokio = { version = "1.35.1", features = ["sync"] } +tokio = { version = "1.35.1", features = ["sync", "rt-multi-thread"] } tokio-stream = "0.1.14" futures = "0.3.30" @@ -132,6 +131,7 @@ json-patch = "*" ariadne = { path = "../../packages/ariadne" } clap = { version = "4.5", features = ["derive"] } +iana-time-zone = "0.1.61" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { version = "0.6.0", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] } @@ -140,3 +140,8 @@ jemalloc_pprof = { version = "0.7.0", features = ["flamegraph"] } [dev-dependencies] actix-http = "3.4.0" + +[build-dependencies] +dotenv-build = "0.1.1" +chrono = "0.4.38" +iana-time-zone = "0.1.60" \ No newline at end of file diff --git a/apps/labrinth/build.rs b/apps/labrinth/build.rs index 3a8149ef0..0f5fdf2b5 100644 --- a/apps/labrinth/build.rs +++ b/apps/labrinth/build.rs @@ -1,3 +1,47 @@ +use std::path::Path; +use std::process::Command; + +use chrono::Local; +use dotenv_build::Config; + fn main() { + let output = Command::new("git") + .arg("rev-parse") + .arg("HEAD") + .output() + .expect("`git` invocation to succeed"); + + let git_hash = String::from_utf8(output.stdout) + .expect("valid UTF-8 output from `git` invocation"); + + println!("cargo::rerun-if-changed=.git/HEAD"); + println!("cargo::rustc-env=GIT_HASH={}", git_hash.trim()); + + let timedate_fmt = Local::now().format("%F @ %I:%M %p"); + let timezone_fmt = iana_time_zone::get_timezone() + .map(|tz| format!(" ({tz})")) + .unwrap_or_default(); + + let comptime = timedate_fmt.to_string() + timezone_fmt.as_str(); + + println!("cargo::rustc-env=COMPILATION_DATE={comptime}"); + + // trick to get compilation profile + let profile = std::env::var("OUT_DIR") + .expect("OUT_DIR to be set") + .split(std::path::MAIN_SEPARATOR) + .nth_back(3) + .unwrap_or("unknown") + .to_string(); + + println!("cargo::rustc-env=COMPILATION_PROFILE={profile}"); + + dotenv_build::output(Config { + filename: Path::new(".build.env"), + recursive_search: true, + fail_if_missing_dotenv: false, + }) + .unwrap(); + println!("cargo:rerun-if-changed=migrations"); } diff --git a/apps/labrinth/src/background_task.rs b/apps/labrinth/src/background_task.rs index c534a83a4..be3ebd673 100644 --- a/apps/labrinth/src/background_task.rs +++ b/apps/labrinth/src/background_task.rs @@ -1,7 +1,7 @@ use crate::database::redis::RedisPool; use crate::queue::payouts::process_payout; -use crate::search; use crate::search::indexing::index_projects; +use crate::{database, search}; use clap::ValueEnum; use sqlx::Postgres; use tracing::{info, warn}; @@ -15,6 +15,7 @@ pub enum BackgroundTask { Payouts, IndexBilling, IndexSubscriptions, + Migrations, } impl BackgroundTask { @@ -28,6 +29,7 @@ impl BackgroundTask { ) { use BackgroundTask::*; match self { + Migrations => run_migrations().await, IndexSearch => index_search(pool, redis_pool, search_config).await, ReleaseScheduled => release_scheduled(pool).await, UpdateVersions => update_versions(pool, redis_pool).await, @@ -50,6 +52,12 @@ impl BackgroundTask { } } +pub async fn run_migrations() { + database::check_for_migrations() + .await + .expect("An error occurred while running migrations."); +} + pub async fn index_search( pool: sqlx::Pool, redis_pool: RedisPool, diff --git a/apps/labrinth/src/lib.rs b/apps/labrinth/src/lib.rs index 97052fb78..18b94a724 100644 --- a/apps/labrinth/src/lib.rs +++ b/apps/labrinth/src/lib.rs @@ -1,4 +1,3 @@ -use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; @@ -13,14 +12,12 @@ use tracing::{info, warn}; extern crate clickhouse as clickhouse_crate; use clickhouse_crate::Client; -use governor::middleware::StateInformationMiddleware; -use governor::{Quota, RateLimiter}; use util::cors::default_cors; use crate::background_task::update_versions; use crate::queue::moderation::AutomatedModerationQueue; use crate::util::env::{parse_strings_from_var, parse_var}; -use crate::util::ratelimit::KeyedRateLimiter; +use crate::util::ratelimit::{AsyncRateLimiter, GCRAParameters}; use sync::friends::handle_pubsub; pub mod auth; @@ -57,7 +54,7 @@ pub struct LabrinthConfig { pub analytics_queue: Arc, pub active_sockets: web::Data, pub automated_moderation_queue: web::Data, - pub rate_limiter: KeyedRateLimiter, + pub rate_limiter: web::Data, pub stripe_client: stripe::Client, } @@ -93,24 +90,10 @@ pub fn app_setup( let mut scheduler = scheduler::Scheduler::new(); - let limiter: KeyedRateLimiter = Arc::new( - RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap())) - .with_middleware::(), - ); - let limiter_clone = Arc::clone(&limiter); - scheduler.run(Duration::from_secs(60), move || { - info!( - "Clearing ratelimiter, storage size: {}", - limiter_clone.len() - ); - limiter_clone.retain_recent(); - info!( - "Done clearing ratelimiter, storage size: {}", - limiter_clone.len() - ); - - async move {} - }); + let limiter = web::Data::new(AsyncRateLimiter::new( + redis_pool.clone(), + GCRAParameters::new(300, 300), + )); if enable_background_tasks { // The interval in seconds at which the local database is indexed @@ -329,6 +312,7 @@ pub fn app_config( .app_data(labrinth_config.active_sockets.clone()) .app_data(labrinth_config.automated_moderation_queue.clone()) .app_data(web::Data::new(labrinth_config.stripe_client.clone())) + .app_data(labrinth_config.rate_limiter.clone()) .configure( #[allow(unused_variables)] |cfg| { diff --git a/apps/labrinth/src/main.rs b/apps/labrinth/src/main.rs index d114002ba..1c4f9c87c 100644 --- a/apps/labrinth/src/main.rs +++ b/apps/labrinth/src/main.rs @@ -1,3 +1,4 @@ +use actix_web::middleware::from_fn; use actix_web::{App, HttpServer}; use actix_web_prom::PrometheusMetricsBuilder; use clap::Parser; @@ -5,7 +6,7 @@ use labrinth::background_task::BackgroundTask; use labrinth::database::redis::RedisPool; use labrinth::file_hosting::S3Host; use labrinth::search; -use labrinth::util::ratelimit::RateLimit; +use labrinth::util::ratelimit::rate_limit_middleware; use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue}; use std::sync::Arc; use tracing::{error, info}; @@ -33,6 +34,10 @@ struct Args { #[arg(long)] no_background_tasks: bool, + /// Don't automatically run migrations. This means the migrations should be run via --run-background-task. + #[arg(long)] + no_migrations: bool, + /// Run a single background task and then exit. Perfect for cron jobs. #[arg(long, value_enum, id = "task")] run_background_task: Option, @@ -67,9 +72,11 @@ async fn main() -> std::io::Result<()> { dotenvy::var("BIND_ADDR").unwrap() ); - database::check_for_migrations() - .await - .expect("An error occurred while running migrations."); + if !args.no_migrations { + database::check_for_migrations() + .await + .expect("An error occurred while running migrations."); + } } // Database Connector @@ -164,7 +171,7 @@ async fn main() -> std::io::Result<()> { App::new() .wrap(TracingLogger::default()) .wrap(prometheus.clone()) - .wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter))) + .wrap(from_fn(rate_limit_middleware)) .wrap(actix_web::middleware::Compress::default()) .wrap(sentry_actix::Sentry::new()) .configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone())) diff --git a/apps/labrinth/src/routes/index.rs b/apps/labrinth/src/routes/index.rs index 8e332fe33..ba937d7f3 100644 --- a/apps/labrinth/src/routes/index.rs +++ b/apps/labrinth/src/routes/index.rs @@ -7,7 +7,13 @@ pub async fn index_get() -> HttpResponse { "name": "modrinth-labrinth", "version": env!("CARGO_PKG_VERSION"), "documentation": "https://docs.modrinth.com", - "about": "Welcome traveler!" + "about": "Welcome traveler!", + + "build_info": { + "comp_date": env!("COMPILATION_DATE"), + "git_hash": env!("GIT_HASH", "unknown"), + "profile": env!("COMPILATION_PROFILE"), + } }); HttpResponse::Ok().json(data) diff --git a/apps/labrinth/src/routes/internal/statuses.rs b/apps/labrinth/src/routes/internal/statuses.rs index ac5721088..800bf681f 100644 --- a/apps/labrinth/src/routes/internal/statuses.rs +++ b/apps/labrinth/src/routes/internal/statuses.rs @@ -91,30 +91,30 @@ pub async fn ws_init( let friend_statuses = if !friends.is_empty() { let db = db.clone(); let redis = redis.clone(); - tokio_stream::iter(friends.iter()) + + let statuses = tokio_stream::iter(friends.iter()) .map(|x| { let db = db.clone(); let redis = redis.clone(); async move { - async move { - get_user_status( - if x.user_id == user_id.into() { - x.friend_id - } else { - x.user_id - } - .into(), - &db, - &redis, - ) - .await - } + get_user_status( + if x.user_id == user_id.into() { + x.friend_id + } else { + x.user_id + } + .into(), + &db, + &redis, + ) + .await } }) .buffer_unordered(16) - .filter_map(|x| x) .collect::>() - .await + .await; + + statuses.into_iter().flatten().collect() } else { Vec::new() }; diff --git a/apps/labrinth/src/search/mod.rs b/apps/labrinth/src/search/mod.rs index a6e07ef54..88895c508 100644 --- a/apps/labrinth/src/search/mod.rs +++ b/apps/labrinth/src/search/mod.rs @@ -209,8 +209,9 @@ pub async fn search_for_project( let mut filter_string = String::new(); // Convert offset and limit to page and hits_per_page - let hits_per_page = limit; - let page = offset / limit + 1; + let hits_per_page = if limit == 0 { 1 } else { limit }; + + let page = offset / hits_per_page + 1; let results = { let mut query = meilisearch_index.search(); diff --git a/apps/labrinth/src/util/ratelimit.rs b/apps/labrinth/src/util/ratelimit.rs index aa3fd81ff..7694c6faf 100644 --- a/apps/labrinth/src/util/ratelimit.rs +++ b/apps/labrinth/src/util/ratelimit.rs @@ -1,196 +1,238 @@ -use governor::clock::{Clock, DefaultClock}; -use governor::{middleware, state, RateLimiter}; -use std::str::FromStr; -use std::sync::Arc; - +use crate::database::redis::RedisPool; use crate::routes::ApiError; use crate::util::env::parse_var; use actix_web::{ - body::EitherBody, - dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, - Error, ResponseError, + body::{EitherBody, MessageBody}, + dev::{ServiceRequest, ServiceResponse}, + middleware::Next, + web, Error, ResponseError, }; -use futures_util::future::LocalBoxFuture; -use futures_util::future::{ready, Ready}; +use chrono::Utc; +use std::str::FromStr; +use std::sync::Arc; -pub type KeyedRateLimiter< - K = String, - MW = middleware::StateInformationMiddleware, -> = Arc< - RateLimiter, DefaultClock, MW>, ->; +const RATE_LIMIT_NAMESPACE: &str = "rate_limit"; +const RATE_LIMIT_EXPIRY: i64 = 300; // 5 minutes +const MINUTE_IN_NANOS: i64 = 60_000_000_000; -pub struct RateLimit(pub KeyedRateLimiter); +pub struct GCRAParameters { + emission_interval: i64, + burst_size: u32, +} -impl Transform for RateLimit -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - type Response = ServiceResponse>; - type Error = Error; - type Transform = RateLimitService; - type InitError = (); - type Future = Ready>; +impl GCRAParameters { + pub(crate) fn new(requests_per_minute: u32, burst_size: u32) -> Self { + // Calculate emission interval in nanoseconds + let emission_interval = MINUTE_IN_NANOS / requests_per_minute as i64; - fn new_transform(&self, service: S) -> Self::Future { - ready(Ok(RateLimitService { - service, - rate_limiter: Arc::clone(&self.0), - })) + Self { + emission_interval, + burst_size, + } } } -#[doc(hidden)] -pub struct RateLimitService { - service: S, - rate_limiter: KeyedRateLimiter, +pub struct RateLimitDecision { + pub allowed: bool, + pub limit: u32, + pub remaining: u32, + pub reset_after_ms: i64, + pub retry_after_ms: Option, } -impl Service for RateLimitService -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - type Response = ServiceResponse>; - type Error = Error; - type Future = LocalBoxFuture<'static, Result>; +#[derive(Clone)] +pub struct AsyncRateLimiter { + redis_pool: RedisPool, + params: Arc, +} - forward_ready!(service); - - fn call(&self, req: ServiceRequest) -> Self::Future { - if let Some(key) = req.headers().get("x-ratelimit-key") { - if key.to_str().ok() - == dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref() - { - let res = self.service.call(req); - - return Box::pin(async move { - let service_response = res.await?; - Ok(service_response.map_into_left_body()) - }); - } +impl AsyncRateLimiter { + pub fn new(redis_pool: RedisPool, params: GCRAParameters) -> Self { + Self { + redis_pool, + params: Arc::new(params), } + } - let conn_info = req.connection_info().clone(); - let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) { - if let Some(header) = req.headers().get("CF-Connecting-IP") { - header.to_str().ok() - } else { - conn_info.peer_addr() + pub async fn check_rate_limit(&self, key: &str) -> RateLimitDecision { + let mut conn = match self.redis_pool.connect().await { + Ok(conn) => conn, + Err(_) => { + // If Redis is unavailable, allow the request but with reduced limit + return RateLimitDecision { + allowed: true, + limit: self.params.burst_size, + remaining: 1, + reset_after_ms: 60_000, // 1 minute + retry_after_ms: None, + }; } - } else { - conn_info.peer_addr() }; - if let Some(ip) = ip { - let ip = ip.to_string(); + // Get current time in nanoseconds since UNIX epoch + let now = Utc::now().timestamp_nanos_opt().unwrap_or(0); - match self.rate_limiter.check_key(&ip) { - Ok(snapshot) => { - let fut = self.service.call(req); + // Get the current TAT from Redis (if it exists) + let tat_str = conn.get(RATE_LIMIT_NAMESPACE, key).await.ok().flatten(); - Box::pin(async move { - match fut.await { - Ok(mut service_response) => { - // Now you have a mutable reference to the ServiceResponse, so you can modify its headers. - let headers = service_response.headers_mut(); - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-limit", - ) - .unwrap(), - snapshot.quota().burst_size().get().into(), - ); - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-remaining", - ) - .unwrap(), - snapshot.remaining_burst_capacity().into(), - ); + // Parse the TAT or use current time if not found + let current_tat = match tat_str { + Some(tat_str) => tat_str.parse::().unwrap_or(now), + None => now, + }; - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-reset", - ) - .unwrap(), - snapshot - .quota() - .burst_size_replenished_in() - .as_secs() - .into(), - ); + // Calculate the new TAT using GCRA + let increment = self.params.emission_interval; + let max_tat_delta = increment * self.params.burst_size as i64; - // Return the modified response as Ok. - Ok(service_response.map_into_left_body()) - } - Err(e) => { - // Handle error case - Err(e) - } - } - }) - } - Err(negative) => { - let wait_time = - negative.wait_time_from(DefaultClock::default().now()); + // Calculate allowance: how much time has passed since the TAT + let allowance = now - current_tat; - let mut response = ApiError::RateLimitError( - wait_time.as_millis(), - negative.quota().burst_size().get(), - ) - .error_response(); + if allowance < -max_tat_delta { + // Too many requests, rate limit exceeded + // Calculate when the client can retry + let retry_after_ms = (-allowance - max_tat_delta) / 1_000_000; - let headers = response.headers_mut(); + return RateLimitDecision { + allowed: false, + limit: self.params.burst_size, + remaining: 0, + reset_after_ms: -allowance / 1_000_000, + retry_after_ms: Some(retry_after_ms.max(0)), + }; + } - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-limit", - ) - .unwrap(), - negative.quota().burst_size().get().into(), - ); - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-remaining", - ) - .unwrap(), - 0.into(), - ); - headers.insert( - actix_web::http::header::HeaderName::from_str( - "x-ratelimit-reset", - ) - .unwrap(), - wait_time.as_secs().into(), - ); + let new_tat = std::cmp::max(current_tat + increment, now); - // TODO: Sentralize CORS in the CORS util. - headers.insert( - actix_web::http::header::HeaderName::from_str( - "Access-Control-Allow-Origin", - ) - .unwrap(), - "*".parse().unwrap(), - ); + let _ = conn + .set( + RATE_LIMIT_NAMESPACE, + key, + &new_tat.to_string(), + Some(RATE_LIMIT_EXPIRY), + ) + .await; - Box::pin(async { - Ok(req.into_response(response.map_into_right_body())) - }) - } - } + let remaining_capacity = + ((max_tat_delta - (new_tat - now)) / increment).max(0) as u32; + + RateLimitDecision { + allowed: true, + limit: self.params.burst_size, + remaining: remaining_capacity, + reset_after_ms: (new_tat - now) / 1_000_000, + retry_after_ms: None, + } + } +} + +pub async fn rate_limit_middleware( + req: ServiceRequest, + next: Next, +) -> Result>, Error> { + let rate_limiter = req + .app_data::>() + .expect("Rate limiter not configured properly") + .clone(); + + if let Some(key) = req.headers().get("x-ratelimit-key") { + if key.to_str().ok() + == dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref() + { + return Ok(next.call(req).await?.map_into_left_body()); + } + } + + let conn_info = req.connection_info().clone(); + let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) { + if let Some(header) = req.headers().get("CF-Connecting-IP") { + header.to_str().ok() } else { - let response = ApiError::CustomAuthentication( - "Unable to obtain user IP address!".to_string(), + conn_info.peer_addr() + } + } else { + conn_info.peer_addr() + }; + + if let Some(ip) = ip { + let decision = rate_limiter.check_rate_limit(ip).await; + + if decision.allowed { + let mut service_response = next.call(req).await?; + + // Add rate limit headers + let headers = service_response.headers_mut(); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-limit", + ) + .unwrap(), + decision.limit.into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-remaining", + ) + .unwrap(), + decision.remaining.into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-reset", + ) + .unwrap(), + (decision.reset_after_ms / 1000).into(), + ); + + Ok(service_response.map_into_left_body()) + } else { + let mut response = ApiError::RateLimitError( + decision.retry_after_ms.unwrap_or(0) as u128, + decision.limit, ) .error_response(); - Box::pin(async { - Ok(req.into_response(response.map_into_right_body())) - }) + // Add rate limit headers + let headers = response.headers_mut(); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-limit", + ) + .unwrap(), + decision.limit.into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-remaining", + ) + .unwrap(), + 0.into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-reset", + ) + .unwrap(), + (decision.reset_after_ms / 1000).into(), + ); + + // TODO: Centralize CORS in the CORS util. + headers.insert( + actix_web::http::header::HeaderName::from_str( + "Access-Control-Allow-Origin", + ) + .unwrap(), + "*".parse().unwrap(), + ); + + Ok(req.into_response(response.map_into_right_body())) } + } else { + let response = ApiError::CustomAuthentication( + "Unable to obtain user IP address!".to_string(), + ) + .error_response(); + + Ok(req.into_response(response.map_into_right_body())) } }