Distributed rate limit, fix search panic, add migration task (#3419)

* Distributed rate limit, fix search panic, add migration task

* Add binary info to root endpoint
This commit is contained in:
Jai Agrawal 2025-03-25 01:10:43 -07:00 committed by GitHub
parent 5fbf5b22c0
commit b5a9a93323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 317 additions and 277 deletions

74
Cargo.lock generated
View File

@ -2469,6 +2469,12 @@ dependencies = [
"const-random", "const-random",
] ]
[[package]]
name = "dotenv-build"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4547f16c17f6051a12cdb8c62b803f94bee6807c74aa7c530b30b737df981fc"
[[package]] [[package]]
name = "dotenvy" name = "dotenvy"
version = "0.15.7" version = "0.15.7"
@ -3345,26 +3351,6 @@ dependencies = [
"system-deps", "system-deps",
] ]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if 1.0.0",
"dashmap 5.5.3",
"futures 0.3.30",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot 0.12.3",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec 1.13.2",
"spinning_top",
]
[[package]] [[package]]
name = "group" name = "group"
version = "0.13.0" version = "0.13.0"
@ -4529,17 +4515,18 @@ dependencies = [
"dashmap 5.5.3", "dashmap 5.5.3",
"deadpool-redis", "deadpool-redis",
"derive-new", "derive-new",
"dotenv-build",
"dotenvy", "dotenvy",
"either", "either",
"flate2", "flate2",
"futures 0.3.30", "futures 0.3.30",
"futures-timer", "futures-timer",
"futures-util", "futures-util",
"governor",
"hex", "hex",
"hmac 0.11.0", "hmac 0.11.0",
"hyper 0.14.31", "hyper 0.14.31",
"hyper-tls 0.5.0", "hyper-tls 0.5.0",
"iana-time-zone",
"image 0.24.9", "image 0.24.9",
"itertools 0.12.1", "itertools 0.12.1",
"jemalloc_pprof", "jemalloc_pprof",
@ -5253,12 +5240,6 @@ dependencies = [
"memoffset 0.9.1", "memoffset 0.9.1",
] ]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nodrop" name = "nodrop"
version = "0.1.14" version = "0.1.14"
@ -5275,12 +5256,6 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "normpath" name = "normpath"
version = "1.3.0" version = "1.3.0"
@ -6564,21 +6539,6 @@ dependencies = [
"bytemuck", "bytemuck",
] ]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils 0.8.20",
"libc",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "2.0.1" version = "2.0.1"
@ -6792,15 +6752,6 @@ dependencies = [
"rand_core 0.5.1", "rand_core 0.5.1",
] ]
[[package]]
name = "raw-cpuid"
version = "11.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
dependencies = [
"bitflags 2.6.0",
]
[[package]] [[package]]
name = "raw-window-handle" name = "raw-window-handle"
version = "0.5.2" version = "0.5.2"
@ -8284,15 +8235,6 @@ dependencies = [
"lock_api 0.4.12", "lock_api 0.4.12",
] ]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api 0.4.12",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"

View File

@ -1,2 +1,3 @@
BASE_URL=https://api.modrinth.com/v2/ BASE_URL=https://api.modrinth.com/v2/
BROWSER_BASE_URL=https://api.modrinth.com/v2/ BROWSER_BASE_URL=https://api.modrinth.com/v2/
PYRO_BASE_URL=https://archon.modrinth.com/

View File

@ -19,14 +19,13 @@ actix-ws = "0.3.0"
actix-files = "0.6.5" actix-files = "0.6.5"
prometheus = "0.13.4" prometheus = "0.13.4"
actix-web-prom = { version = "0.9.0", features = ["process"] } actix-web-prom = { version = "0.9.0", features = ["process"] }
governor = "0.6.3"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = "0.3.19" tracing-subscriber = "0.3.19"
tracing-actix-web = "0.7.16" tracing-actix-web = "0.7.16"
console-subscriber = "0.4.1" console-subscriber = "0.4.1"
tokio = { version = "1.35.1", features = ["sync"] } tokio = { version = "1.35.1", features = ["sync", "rt-multi-thread"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
futures = "0.3.30" futures = "0.3.30"
@ -132,6 +131,7 @@ json-patch = "*"
ariadne = { path = "../../packages/ariadne" } ariadne = { path = "../../packages/ariadne" }
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
iana-time-zone = "0.1.61"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = { version = "0.6.0", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] } tikv-jemallocator = { version = "0.6.0", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] }
@ -140,3 +140,8 @@ jemalloc_pprof = { version = "0.7.0", features = ["flamegraph"] }
[dev-dependencies] [dev-dependencies]
actix-http = "3.4.0" actix-http = "3.4.0"
[build-dependencies]
dotenv-build = "0.1.1"
chrono = "0.4.38"
iana-time-zone = "0.1.60"

View File

@ -1,3 +1,47 @@
use std::path::Path;
use std::process::Command;
use chrono::Local;
use dotenv_build::Config;
fn main() { fn main() {
let output = Command::new("git")
.arg("rev-parse")
.arg("HEAD")
.output()
.expect("`git` invocation to succeed");
let git_hash = String::from_utf8(output.stdout)
.expect("valid UTF-8 output from `git` invocation");
println!("cargo::rerun-if-changed=.git/HEAD");
println!("cargo::rustc-env=GIT_HASH={}", git_hash.trim());
let timedate_fmt = Local::now().format("%F @ %I:%M %p");
let timezone_fmt = iana_time_zone::get_timezone()
.map(|tz| format!(" ({tz})"))
.unwrap_or_default();
let comptime = timedate_fmt.to_string() + timezone_fmt.as_str();
println!("cargo::rustc-env=COMPILATION_DATE={comptime}");
// trick to get compilation profile
let profile = std::env::var("OUT_DIR")
.expect("OUT_DIR to be set")
.split(std::path::MAIN_SEPARATOR)
.nth_back(3)
.unwrap_or("unknown")
.to_string();
println!("cargo::rustc-env=COMPILATION_PROFILE={profile}");
dotenv_build::output(Config {
filename: Path::new(".build.env"),
recursive_search: true,
fail_if_missing_dotenv: false,
})
.unwrap();
println!("cargo:rerun-if-changed=migrations"); println!("cargo:rerun-if-changed=migrations");
} }

View File

@ -1,7 +1,7 @@
use crate::database::redis::RedisPool; use crate::database::redis::RedisPool;
use crate::queue::payouts::process_payout; use crate::queue::payouts::process_payout;
use crate::search;
use crate::search::indexing::index_projects; use crate::search::indexing::index_projects;
use crate::{database, search};
use clap::ValueEnum; use clap::ValueEnum;
use sqlx::Postgres; use sqlx::Postgres;
use tracing::{info, warn}; use tracing::{info, warn};
@ -15,6 +15,7 @@ pub enum BackgroundTask {
Payouts, Payouts,
IndexBilling, IndexBilling,
IndexSubscriptions, IndexSubscriptions,
Migrations,
} }
impl BackgroundTask { impl BackgroundTask {
@ -28,6 +29,7 @@ impl BackgroundTask {
) { ) {
use BackgroundTask::*; use BackgroundTask::*;
match self { match self {
Migrations => run_migrations().await,
IndexSearch => index_search(pool, redis_pool, search_config).await, IndexSearch => index_search(pool, redis_pool, search_config).await,
ReleaseScheduled => release_scheduled(pool).await, ReleaseScheduled => release_scheduled(pool).await,
UpdateVersions => update_versions(pool, redis_pool).await, UpdateVersions => update_versions(pool, redis_pool).await,
@ -50,6 +52,12 @@ impl BackgroundTask {
} }
} }
pub async fn run_migrations() {
database::check_for_migrations()
.await
.expect("An error occurred while running migrations.");
}
pub async fn index_search( pub async fn index_search(
pool: sqlx::Pool<Postgres>, pool: sqlx::Pool<Postgres>,
redis_pool: RedisPool, redis_pool: RedisPool,

View File

@ -1,4 +1,3 @@
use std::num::NonZeroU32;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -13,14 +12,12 @@ use tracing::{info, warn};
extern crate clickhouse as clickhouse_crate; extern crate clickhouse as clickhouse_crate;
use clickhouse_crate::Client; use clickhouse_crate::Client;
use governor::middleware::StateInformationMiddleware;
use governor::{Quota, RateLimiter};
use util::cors::default_cors; use util::cors::default_cors;
use crate::background_task::update_versions; use crate::background_task::update_versions;
use crate::queue::moderation::AutomatedModerationQueue; use crate::queue::moderation::AutomatedModerationQueue;
use crate::util::env::{parse_strings_from_var, parse_var}; use crate::util::env::{parse_strings_from_var, parse_var};
use crate::util::ratelimit::KeyedRateLimiter; use crate::util::ratelimit::{AsyncRateLimiter, GCRAParameters};
use sync::friends::handle_pubsub; use sync::friends::handle_pubsub;
pub mod auth; pub mod auth;
@ -57,7 +54,7 @@ pub struct LabrinthConfig {
pub analytics_queue: Arc<AnalyticsQueue>, pub analytics_queue: Arc<AnalyticsQueue>,
pub active_sockets: web::Data<ActiveSockets>, pub active_sockets: web::Data<ActiveSockets>,
pub automated_moderation_queue: web::Data<AutomatedModerationQueue>, pub automated_moderation_queue: web::Data<AutomatedModerationQueue>,
pub rate_limiter: KeyedRateLimiter, pub rate_limiter: web::Data<AsyncRateLimiter>,
pub stripe_client: stripe::Client, pub stripe_client: stripe::Client,
} }
@ -93,24 +90,10 @@ pub fn app_setup(
let mut scheduler = scheduler::Scheduler::new(); let mut scheduler = scheduler::Scheduler::new();
let limiter: KeyedRateLimiter = Arc::new( let limiter = web::Data::new(AsyncRateLimiter::new(
RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap())) redis_pool.clone(),
.with_middleware::<StateInformationMiddleware>(), GCRAParameters::new(300, 300),
); ));
let limiter_clone = Arc::clone(&limiter);
scheduler.run(Duration::from_secs(60), move || {
info!(
"Clearing ratelimiter, storage size: {}",
limiter_clone.len()
);
limiter_clone.retain_recent();
info!(
"Done clearing ratelimiter, storage size: {}",
limiter_clone.len()
);
async move {}
});
if enable_background_tasks { if enable_background_tasks {
// The interval in seconds at which the local database is indexed // The interval in seconds at which the local database is indexed
@ -329,6 +312,7 @@ pub fn app_config(
.app_data(labrinth_config.active_sockets.clone()) .app_data(labrinth_config.active_sockets.clone())
.app_data(labrinth_config.automated_moderation_queue.clone()) .app_data(labrinth_config.automated_moderation_queue.clone())
.app_data(web::Data::new(labrinth_config.stripe_client.clone())) .app_data(web::Data::new(labrinth_config.stripe_client.clone()))
.app_data(labrinth_config.rate_limiter.clone())
.configure( .configure(
#[allow(unused_variables)] #[allow(unused_variables)]
|cfg| { |cfg| {

View File

@ -1,3 +1,4 @@
use actix_web::middleware::from_fn;
use actix_web::{App, HttpServer}; use actix_web::{App, HttpServer};
use actix_web_prom::PrometheusMetricsBuilder; use actix_web_prom::PrometheusMetricsBuilder;
use clap::Parser; use clap::Parser;
@ -5,7 +6,7 @@ use labrinth::background_task::BackgroundTask;
use labrinth::database::redis::RedisPool; use labrinth::database::redis::RedisPool;
use labrinth::file_hosting::S3Host; use labrinth::file_hosting::S3Host;
use labrinth::search; use labrinth::search;
use labrinth::util::ratelimit::RateLimit; use labrinth::util::ratelimit::rate_limit_middleware;
use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue}; use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue};
use std::sync::Arc; use std::sync::Arc;
use tracing::{error, info}; use tracing::{error, info};
@ -33,6 +34,10 @@ struct Args {
#[arg(long)] #[arg(long)]
no_background_tasks: bool, no_background_tasks: bool,
/// Don't automatically run migrations. This means the migrations should be run via --run-background-task.
#[arg(long)]
no_migrations: bool,
/// Run a single background task and then exit. Perfect for cron jobs. /// Run a single background task and then exit. Perfect for cron jobs.
#[arg(long, value_enum, id = "task")] #[arg(long, value_enum, id = "task")]
run_background_task: Option<BackgroundTask>, run_background_task: Option<BackgroundTask>,
@ -67,9 +72,11 @@ async fn main() -> std::io::Result<()> {
dotenvy::var("BIND_ADDR").unwrap() dotenvy::var("BIND_ADDR").unwrap()
); );
database::check_for_migrations() if !args.no_migrations {
.await database::check_for_migrations()
.expect("An error occurred while running migrations."); .await
.expect("An error occurred while running migrations.");
}
} }
// Database Connector // Database Connector
@ -164,7 +171,7 @@ async fn main() -> std::io::Result<()> {
App::new() App::new()
.wrap(TracingLogger::default()) .wrap(TracingLogger::default())
.wrap(prometheus.clone()) .wrap(prometheus.clone())
.wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter))) .wrap(from_fn(rate_limit_middleware))
.wrap(actix_web::middleware::Compress::default()) .wrap(actix_web::middleware::Compress::default())
.wrap(sentry_actix::Sentry::new()) .wrap(sentry_actix::Sentry::new())
.configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone())) .configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone()))

View File

@ -7,7 +7,13 @@ pub async fn index_get() -> HttpResponse {
"name": "modrinth-labrinth", "name": "modrinth-labrinth",
"version": env!("CARGO_PKG_VERSION"), "version": env!("CARGO_PKG_VERSION"),
"documentation": "https://docs.modrinth.com", "documentation": "https://docs.modrinth.com",
"about": "Welcome traveler!" "about": "Welcome traveler!",
"build_info": {
"comp_date": env!("COMPILATION_DATE"),
"git_hash": env!("GIT_HASH", "unknown"),
"profile": env!("COMPILATION_PROFILE"),
}
}); });
HttpResponse::Ok().json(data) HttpResponse::Ok().json(data)

View File

@ -91,30 +91,30 @@ pub async fn ws_init(
let friend_statuses = if !friends.is_empty() { let friend_statuses = if !friends.is_empty() {
let db = db.clone(); let db = db.clone();
let redis = redis.clone(); let redis = redis.clone();
tokio_stream::iter(friends.iter())
let statuses = tokio_stream::iter(friends.iter())
.map(|x| { .map(|x| {
let db = db.clone(); let db = db.clone();
let redis = redis.clone(); let redis = redis.clone();
async move { async move {
async move { get_user_status(
get_user_status( if x.user_id == user_id.into() {
if x.user_id == user_id.into() { x.friend_id
x.friend_id } else {
} else { x.user_id
x.user_id }
} .into(),
.into(), &db,
&db, &redis,
&redis, )
) .await
.await
}
} }
}) })
.buffer_unordered(16) .buffer_unordered(16)
.filter_map(|x| x)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.await .await;
statuses.into_iter().flatten().collect()
} else { } else {
Vec::new() Vec::new()
}; };

View File

@ -209,8 +209,9 @@ pub async fn search_for_project(
let mut filter_string = String::new(); let mut filter_string = String::new();
// Convert offset and limit to page and hits_per_page // Convert offset and limit to page and hits_per_page
let hits_per_page = limit; let hits_per_page = if limit == 0 { 1 } else { limit };
let page = offset / limit + 1;
let page = offset / hits_per_page + 1;
let results = { let results = {
let mut query = meilisearch_index.search(); let mut query = meilisearch_index.search();

View File

@ -1,196 +1,238 @@
use governor::clock::{Clock, DefaultClock}; use crate::database::redis::RedisPool;
use governor::{middleware, state, RateLimiter};
use std::str::FromStr;
use std::sync::Arc;
use crate::routes::ApiError; use crate::routes::ApiError;
use crate::util::env::parse_var; use crate::util::env::parse_var;
use actix_web::{ use actix_web::{
body::EitherBody, body::{EitherBody, MessageBody},
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, dev::{ServiceRequest, ServiceResponse},
Error, ResponseError, middleware::Next,
web, Error, ResponseError,
}; };
use futures_util::future::LocalBoxFuture; use chrono::Utc;
use futures_util::future::{ready, Ready}; use std::str::FromStr;
use std::sync::Arc;
pub type KeyedRateLimiter< const RATE_LIMIT_NAMESPACE: &str = "rate_limit";
K = String, const RATE_LIMIT_EXPIRY: i64 = 300; // 5 minutes
MW = middleware::StateInformationMiddleware, const MINUTE_IN_NANOS: i64 = 60_000_000_000;
> = Arc<
RateLimiter<K, state::keyed::DefaultKeyedStateStore<K>, DefaultClock, MW>,
>;
pub struct RateLimit(pub KeyedRateLimiter); pub struct GCRAParameters {
emission_interval: i64,
burst_size: u32,
}
impl<S, B> Transform<S, ServiceRequest> for RateLimit impl GCRAParameters {
where pub(crate) fn new(requests_per_minute: u32, burst_size: u32) -> Self {
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>, // Calculate emission interval in nanoseconds
S::Future: 'static, let emission_interval = MINUTE_IN_NANOS / requests_per_minute as i64;
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = RateLimitService<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { Self {
ready(Ok(RateLimitService { emission_interval,
service, burst_size,
rate_limiter: Arc::clone(&self.0), }
}))
} }
} }
#[doc(hidden)] pub struct RateLimitDecision {
pub struct RateLimitService<S> { pub allowed: bool,
service: S, pub limit: u32,
rate_limiter: KeyedRateLimiter, pub remaining: u32,
pub reset_after_ms: i64,
pub retry_after_ms: Option<i64>,
} }
impl<S, B> Service<ServiceRequest> for RateLimitService<S> #[derive(Clone)]
where pub struct AsyncRateLimiter {
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>, redis_pool: RedisPool,
S::Future: 'static, params: Arc<GCRAParameters>,
B: 'static, }
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service); impl AsyncRateLimiter {
pub fn new(redis_pool: RedisPool, params: GCRAParameters) -> Self {
fn call(&self, req: ServiceRequest) -> Self::Future { Self {
if let Some(key) = req.headers().get("x-ratelimit-key") { redis_pool,
if key.to_str().ok() params: Arc::new(params),
== dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref()
{
let res = self.service.call(req);
return Box::pin(async move {
let service_response = res.await?;
Ok(service_response.map_into_left_body())
});
}
} }
}
let conn_info = req.connection_info().clone(); pub async fn check_rate_limit(&self, key: &str) -> RateLimitDecision {
let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) { let mut conn = match self.redis_pool.connect().await {
if let Some(header) = req.headers().get("CF-Connecting-IP") { Ok(conn) => conn,
header.to_str().ok() Err(_) => {
} else { // If Redis is unavailable, allow the request but with reduced limit
conn_info.peer_addr() return RateLimitDecision {
allowed: true,
limit: self.params.burst_size,
remaining: 1,
reset_after_ms: 60_000, // 1 minute
retry_after_ms: None,
};
} }
} else {
conn_info.peer_addr()
}; };
if let Some(ip) = ip { // Get current time in nanoseconds since UNIX epoch
let ip = ip.to_string(); let now = Utc::now().timestamp_nanos_opt().unwrap_or(0);
match self.rate_limiter.check_key(&ip) { // Get the current TAT from Redis (if it exists)
Ok(snapshot) => { let tat_str = conn.get(RATE_LIMIT_NAMESPACE, key).await.ok().flatten();
let fut = self.service.call(req);
Box::pin(async move { // Parse the TAT or use current time if not found
match fut.await { let current_tat = match tat_str {
Ok(mut service_response) => { Some(tat_str) => tat_str.parse::<i64>().unwrap_or(now),
// Now you have a mutable reference to the ServiceResponse, so you can modify its headers. None => now,
let headers = service_response.headers_mut(); };
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
snapshot.quota().burst_size().get().into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-remaining",
)
.unwrap(),
snapshot.remaining_burst_capacity().into(),
);
headers.insert( // Calculate the new TAT using GCRA
actix_web::http::header::HeaderName::from_str( let increment = self.params.emission_interval;
"x-ratelimit-reset", let max_tat_delta = increment * self.params.burst_size as i64;
)
.unwrap(),
snapshot
.quota()
.burst_size_replenished_in()
.as_secs()
.into(),
);
// Return the modified response as Ok. // Calculate allowance: how much time has passed since the TAT
Ok(service_response.map_into_left_body()) let allowance = now - current_tat;
}
Err(e) => {
// Handle error case
Err(e)
}
}
})
}
Err(negative) => {
let wait_time =
negative.wait_time_from(DefaultClock::default().now());
let mut response = ApiError::RateLimitError( if allowance < -max_tat_delta {
wait_time.as_millis(), // Too many requests, rate limit exceeded
negative.quota().burst_size().get(), // Calculate when the client can retry
) let retry_after_ms = (-allowance - max_tat_delta) / 1_000_000;
.error_response();
let headers = response.headers_mut(); return RateLimitDecision {
allowed: false,
limit: self.params.burst_size,
remaining: 0,
reset_after_ms: -allowance / 1_000_000,
retry_after_ms: Some(retry_after_ms.max(0)),
};
}
headers.insert( let new_tat = std::cmp::max(current_tat + increment, now);
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
negative.quota().burst_size().get().into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-remaining",
)
.unwrap(),
0.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-reset",
)
.unwrap(),
wait_time.as_secs().into(),
);
// TODO: Sentralize CORS in the CORS util. let _ = conn
headers.insert( .set(
actix_web::http::header::HeaderName::from_str( RATE_LIMIT_NAMESPACE,
"Access-Control-Allow-Origin", key,
) &new_tat.to_string(),
.unwrap(), Some(RATE_LIMIT_EXPIRY),
"*".parse().unwrap(), )
); .await;
Box::pin(async { let remaining_capacity =
Ok(req.into_response(response.map_into_right_body())) ((max_tat_delta - (new_tat - now)) / increment).max(0) as u32;
})
} RateLimitDecision {
} allowed: true,
limit: self.params.burst_size,
remaining: remaining_capacity,
reset_after_ms: (new_tat - now) / 1_000_000,
retry_after_ms: None,
}
}
}
pub async fn rate_limit_middleware(
req: ServiceRequest,
next: Next<impl MessageBody>,
) -> Result<ServiceResponse<EitherBody<impl MessageBody>>, Error> {
let rate_limiter = req
.app_data::<web::Data<AsyncRateLimiter>>()
.expect("Rate limiter not configured properly")
.clone();
if let Some(key) = req.headers().get("x-ratelimit-key") {
if key.to_str().ok()
== dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref()
{
return Ok(next.call(req).await?.map_into_left_body());
}
}
let conn_info = req.connection_info().clone();
let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) {
if let Some(header) = req.headers().get("CF-Connecting-IP") {
header.to_str().ok()
} else { } else {
let response = ApiError::CustomAuthentication( conn_info.peer_addr()
"Unable to obtain user IP address!".to_string(), }
} else {
conn_info.peer_addr()
};
if let Some(ip) = ip {
let decision = rate_limiter.check_rate_limit(ip).await;
if decision.allowed {
let mut service_response = next.call(req).await?;
// Add rate limit headers
let headers = service_response.headers_mut();
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
decision.limit.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-remaining",
)
.unwrap(),
decision.remaining.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-reset",
)
.unwrap(),
(decision.reset_after_ms / 1000).into(),
);
Ok(service_response.map_into_left_body())
} else {
let mut response = ApiError::RateLimitError(
decision.retry_after_ms.unwrap_or(0) as u128,
decision.limit,
) )
.error_response(); .error_response();
Box::pin(async { // Add rate limit headers
Ok(req.into_response(response.map_into_right_body())) let headers = response.headers_mut();
}) headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-limit",
)
.unwrap(),
decision.limit.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-remaining",
)
.unwrap(),
0.into(),
);
headers.insert(
actix_web::http::header::HeaderName::from_str(
"x-ratelimit-reset",
)
.unwrap(),
(decision.reset_after_ms / 1000).into(),
);
// TODO: Centralize CORS in the CORS util.
headers.insert(
actix_web::http::header::HeaderName::from_str(
"Access-Control-Allow-Origin",
)
.unwrap(),
"*".parse().unwrap(),
);
Ok(req.into_response(response.map_into_right_body()))
} }
} else {
let response = ApiError::CustomAuthentication(
"Unable to obtain user IP address!".to_string(),
)
.error_response();
Ok(req.into_response(response.map_into_right_body()))
} }
} }