Distributed rate limit, fix search panic, add migration task (#3419)
* Distributed rate limit, fix search panic, add migration task * Add binary info to root endpoint
This commit is contained in:
parent
5fbf5b22c0
commit
b5a9a93323
74
Cargo.lock
generated
74
Cargo.lock
generated
@ -2469,6 +2469,12 @@ dependencies = [
|
||||
"const-random",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenv-build"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4547f16c17f6051a12cdb8c62b803f94bee6807c74aa7c530b30b737df981fc"
|
||||
|
||||
[[package]]
|
||||
name = "dotenvy"
|
||||
version = "0.15.7"
|
||||
@ -3345,26 +3351,6 @@ dependencies = [
|
||||
"system-deps",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"dashmap 5.5.3",
|
||||
"futures 0.3.30",
|
||||
"futures-timer",
|
||||
"no-std-compat",
|
||||
"nonzero_ext",
|
||||
"parking_lot 0.12.3",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"rand 0.8.5",
|
||||
"smallvec 1.13.2",
|
||||
"spinning_top",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "group"
|
||||
version = "0.13.0"
|
||||
@ -4529,17 +4515,18 @@ dependencies = [
|
||||
"dashmap 5.5.3",
|
||||
"deadpool-redis",
|
||||
"derive-new",
|
||||
"dotenv-build",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"flate2",
|
||||
"futures 0.3.30",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"governor",
|
||||
"hex",
|
||||
"hmac 0.11.0",
|
||||
"hyper 0.14.31",
|
||||
"hyper-tls 0.5.0",
|
||||
"iana-time-zone",
|
||||
"image 0.24.9",
|
||||
"itertools 0.12.1",
|
||||
"jemalloc_pprof",
|
||||
@ -5253,12 +5240,6 @@ dependencies = [
|
||||
"memoffset 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "no-std-compat"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
||||
|
||||
[[package]]
|
||||
name = "nodrop"
|
||||
version = "0.1.14"
|
||||
@ -5275,12 +5256,6 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "normpath"
|
||||
version = "1.3.0"
|
||||
@ -6564,21 +6539,6 @@ dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
|
||||
dependencies = [
|
||||
"crossbeam-utils 0.8.20",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"web-sys",
|
||||
"winapi 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
@ -6792,15 +6752,6 @@ dependencies = [
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-window-handle"
|
||||
version = "0.5.2"
|
||||
@ -8284,15 +8235,6 @@ dependencies = [
|
||||
"lock_api 0.4.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spinning_top"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
||||
dependencies = [
|
||||
"lock_api 0.4.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.7.3"
|
||||
|
||||
@ -1,2 +1,3 @@
|
||||
BASE_URL=https://api.modrinth.com/v2/
|
||||
BROWSER_BASE_URL=https://api.modrinth.com/v2/
|
||||
PYRO_BASE_URL=https://archon.modrinth.com/
|
||||
|
||||
@ -19,14 +19,13 @@ actix-ws = "0.3.0"
|
||||
actix-files = "0.6.5"
|
||||
prometheus = "0.13.4"
|
||||
actix-web-prom = { version = "0.9.0", features = ["process"] }
|
||||
governor = "0.6.3"
|
||||
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.19"
|
||||
tracing-actix-web = "0.7.16"
|
||||
console-subscriber = "0.4.1"
|
||||
|
||||
tokio = { version = "1.35.1", features = ["sync"] }
|
||||
tokio = { version = "1.35.1", features = ["sync", "rt-multi-thread"] }
|
||||
tokio-stream = "0.1.14"
|
||||
|
||||
futures = "0.3.30"
|
||||
@ -132,6 +131,7 @@ json-patch = "*"
|
||||
ariadne = { path = "../../packages/ariadne" }
|
||||
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
iana-time-zone = "0.1.61"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
tikv-jemallocator = { version = "0.6.0", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] }
|
||||
@ -140,3 +140,8 @@ jemalloc_pprof = { version = "0.7.0", features = ["flamegraph"] }
|
||||
|
||||
[dev-dependencies]
|
||||
actix-http = "3.4.0"
|
||||
|
||||
[build-dependencies]
|
||||
dotenv-build = "0.1.1"
|
||||
chrono = "0.4.38"
|
||||
iana-time-zone = "0.1.60"
|
||||
@ -1,3 +1,47 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
use chrono::Local;
|
||||
use dotenv_build::Config;
|
||||
|
||||
fn main() {
|
||||
let output = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("HEAD")
|
||||
.output()
|
||||
.expect("`git` invocation to succeed");
|
||||
|
||||
let git_hash = String::from_utf8(output.stdout)
|
||||
.expect("valid UTF-8 output from `git` invocation");
|
||||
|
||||
println!("cargo::rerun-if-changed=.git/HEAD");
|
||||
println!("cargo::rustc-env=GIT_HASH={}", git_hash.trim());
|
||||
|
||||
let timedate_fmt = Local::now().format("%F @ %I:%M %p");
|
||||
let timezone_fmt = iana_time_zone::get_timezone()
|
||||
.map(|tz| format!(" ({tz})"))
|
||||
.unwrap_or_default();
|
||||
|
||||
let comptime = timedate_fmt.to_string() + timezone_fmt.as_str();
|
||||
|
||||
println!("cargo::rustc-env=COMPILATION_DATE={comptime}");
|
||||
|
||||
// trick to get compilation profile
|
||||
let profile = std::env::var("OUT_DIR")
|
||||
.expect("OUT_DIR to be set")
|
||||
.split(std::path::MAIN_SEPARATOR)
|
||||
.nth_back(3)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
println!("cargo::rustc-env=COMPILATION_PROFILE={profile}");
|
||||
|
||||
dotenv_build::output(Config {
|
||||
filename: Path::new(".build.env"),
|
||||
recursive_search: true,
|
||||
fail_if_missing_dotenv: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=migrations");
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use crate::database::redis::RedisPool;
|
||||
use crate::queue::payouts::process_payout;
|
||||
use crate::search;
|
||||
use crate::search::indexing::index_projects;
|
||||
use crate::{database, search};
|
||||
use clap::ValueEnum;
|
||||
use sqlx::Postgres;
|
||||
use tracing::{info, warn};
|
||||
@ -15,6 +15,7 @@ pub enum BackgroundTask {
|
||||
Payouts,
|
||||
IndexBilling,
|
||||
IndexSubscriptions,
|
||||
Migrations,
|
||||
}
|
||||
|
||||
impl BackgroundTask {
|
||||
@ -28,6 +29,7 @@ impl BackgroundTask {
|
||||
) {
|
||||
use BackgroundTask::*;
|
||||
match self {
|
||||
Migrations => run_migrations().await,
|
||||
IndexSearch => index_search(pool, redis_pool, search_config).await,
|
||||
ReleaseScheduled => release_scheduled(pool).await,
|
||||
UpdateVersions => update_versions(pool, redis_pool).await,
|
||||
@ -50,6 +52,12 @@ impl BackgroundTask {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_migrations() {
|
||||
database::check_for_migrations()
|
||||
.await
|
||||
.expect("An error occurred while running migrations.");
|
||||
}
|
||||
|
||||
pub async fn index_search(
|
||||
pool: sqlx::Pool<Postgres>,
|
||||
redis_pool: RedisPool,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@ -13,14 +12,12 @@ use tracing::{info, warn};
|
||||
|
||||
extern crate clickhouse as clickhouse_crate;
|
||||
use clickhouse_crate::Client;
|
||||
use governor::middleware::StateInformationMiddleware;
|
||||
use governor::{Quota, RateLimiter};
|
||||
use util::cors::default_cors;
|
||||
|
||||
use crate::background_task::update_versions;
|
||||
use crate::queue::moderation::AutomatedModerationQueue;
|
||||
use crate::util::env::{parse_strings_from_var, parse_var};
|
||||
use crate::util::ratelimit::KeyedRateLimiter;
|
||||
use crate::util::ratelimit::{AsyncRateLimiter, GCRAParameters};
|
||||
use sync::friends::handle_pubsub;
|
||||
|
||||
pub mod auth;
|
||||
@ -57,7 +54,7 @@ pub struct LabrinthConfig {
|
||||
pub analytics_queue: Arc<AnalyticsQueue>,
|
||||
pub active_sockets: web::Data<ActiveSockets>,
|
||||
pub automated_moderation_queue: web::Data<AutomatedModerationQueue>,
|
||||
pub rate_limiter: KeyedRateLimiter,
|
||||
pub rate_limiter: web::Data<AsyncRateLimiter>,
|
||||
pub stripe_client: stripe::Client,
|
||||
}
|
||||
|
||||
@ -93,24 +90,10 @@ pub fn app_setup(
|
||||
|
||||
let mut scheduler = scheduler::Scheduler::new();
|
||||
|
||||
let limiter: KeyedRateLimiter = Arc::new(
|
||||
RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap()))
|
||||
.with_middleware::<StateInformationMiddleware>(),
|
||||
);
|
||||
let limiter_clone = Arc::clone(&limiter);
|
||||
scheduler.run(Duration::from_secs(60), move || {
|
||||
info!(
|
||||
"Clearing ratelimiter, storage size: {}",
|
||||
limiter_clone.len()
|
||||
);
|
||||
limiter_clone.retain_recent();
|
||||
info!(
|
||||
"Done clearing ratelimiter, storage size: {}",
|
||||
limiter_clone.len()
|
||||
);
|
||||
|
||||
async move {}
|
||||
});
|
||||
let limiter = web::Data::new(AsyncRateLimiter::new(
|
||||
redis_pool.clone(),
|
||||
GCRAParameters::new(300, 300),
|
||||
));
|
||||
|
||||
if enable_background_tasks {
|
||||
// The interval in seconds at which the local database is indexed
|
||||
@ -329,6 +312,7 @@ pub fn app_config(
|
||||
.app_data(labrinth_config.active_sockets.clone())
|
||||
.app_data(labrinth_config.automated_moderation_queue.clone())
|
||||
.app_data(web::Data::new(labrinth_config.stripe_client.clone()))
|
||||
.app_data(labrinth_config.rate_limiter.clone())
|
||||
.configure(
|
||||
#[allow(unused_variables)]
|
||||
|cfg| {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use actix_web::middleware::from_fn;
|
||||
use actix_web::{App, HttpServer};
|
||||
use actix_web_prom::PrometheusMetricsBuilder;
|
||||
use clap::Parser;
|
||||
@ -5,7 +6,7 @@ use labrinth::background_task::BackgroundTask;
|
||||
use labrinth::database::redis::RedisPool;
|
||||
use labrinth::file_hosting::S3Host;
|
||||
use labrinth::search;
|
||||
use labrinth::util::ratelimit::RateLimit;
|
||||
use labrinth::util::ratelimit::rate_limit_middleware;
|
||||
use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue};
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
@ -33,6 +34,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
no_background_tasks: bool,
|
||||
|
||||
/// Don't automatically run migrations. This means the migrations should be run via --run-background-task.
|
||||
#[arg(long)]
|
||||
no_migrations: bool,
|
||||
|
||||
/// Run a single background task and then exit. Perfect for cron jobs.
|
||||
#[arg(long, value_enum, id = "task")]
|
||||
run_background_task: Option<BackgroundTask>,
|
||||
@ -67,9 +72,11 @@ async fn main() -> std::io::Result<()> {
|
||||
dotenvy::var("BIND_ADDR").unwrap()
|
||||
);
|
||||
|
||||
database::check_for_migrations()
|
||||
.await
|
||||
.expect("An error occurred while running migrations.");
|
||||
if !args.no_migrations {
|
||||
database::check_for_migrations()
|
||||
.await
|
||||
.expect("An error occurred while running migrations.");
|
||||
}
|
||||
}
|
||||
|
||||
// Database Connector
|
||||
@ -164,7 +171,7 @@ async fn main() -> std::io::Result<()> {
|
||||
App::new()
|
||||
.wrap(TracingLogger::default())
|
||||
.wrap(prometheus.clone())
|
||||
.wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter)))
|
||||
.wrap(from_fn(rate_limit_middleware))
|
||||
.wrap(actix_web::middleware::Compress::default())
|
||||
.wrap(sentry_actix::Sentry::new())
|
||||
.configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone()))
|
||||
|
||||
@ -7,7 +7,13 @@ pub async fn index_get() -> HttpResponse {
|
||||
"name": "modrinth-labrinth",
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
"documentation": "https://docs.modrinth.com",
|
||||
"about": "Welcome traveler!"
|
||||
"about": "Welcome traveler!",
|
||||
|
||||
"build_info": {
|
||||
"comp_date": env!("COMPILATION_DATE"),
|
||||
"git_hash": env!("GIT_HASH", "unknown"),
|
||||
"profile": env!("COMPILATION_PROFILE"),
|
||||
}
|
||||
});
|
||||
|
||||
HttpResponse::Ok().json(data)
|
||||
|
||||
@ -91,30 +91,30 @@ pub async fn ws_init(
|
||||
let friend_statuses = if !friends.is_empty() {
|
||||
let db = db.clone();
|
||||
let redis = redis.clone();
|
||||
tokio_stream::iter(friends.iter())
|
||||
|
||||
let statuses = tokio_stream::iter(friends.iter())
|
||||
.map(|x| {
|
||||
let db = db.clone();
|
||||
let redis = redis.clone();
|
||||
async move {
|
||||
async move {
|
||||
get_user_status(
|
||||
if x.user_id == user_id.into() {
|
||||
x.friend_id
|
||||
} else {
|
||||
x.user_id
|
||||
}
|
||||
.into(),
|
||||
&db,
|
||||
&redis,
|
||||
)
|
||||
.await
|
||||
}
|
||||
get_user_status(
|
||||
if x.user_id == user_id.into() {
|
||||
x.friend_id
|
||||
} else {
|
||||
x.user_id
|
||||
}
|
||||
.into(),
|
||||
&db,
|
||||
&redis,
|
||||
)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.buffer_unordered(16)
|
||||
.filter_map(|x| x)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.await;
|
||||
|
||||
statuses.into_iter().flatten().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
@ -209,8 +209,9 @@ pub async fn search_for_project(
|
||||
let mut filter_string = String::new();
|
||||
|
||||
// Convert offset and limit to page and hits_per_page
|
||||
let hits_per_page = limit;
|
||||
let page = offset / limit + 1;
|
||||
let hits_per_page = if limit == 0 { 1 } else { limit };
|
||||
|
||||
let page = offset / hits_per_page + 1;
|
||||
|
||||
let results = {
|
||||
let mut query = meilisearch_index.search();
|
||||
|
||||
@ -1,196 +1,238 @@
|
||||
use governor::clock::{Clock, DefaultClock};
|
||||
use governor::{middleware, state, RateLimiter};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::database::redis::RedisPool;
|
||||
use crate::routes::ApiError;
|
||||
use crate::util::env::parse_var;
|
||||
use actix_web::{
|
||||
body::EitherBody,
|
||||
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
||||
Error, ResponseError,
|
||||
body::{EitherBody, MessageBody},
|
||||
dev::{ServiceRequest, ServiceResponse},
|
||||
middleware::Next,
|
||||
web, Error, ResponseError,
|
||||
};
|
||||
use futures_util::future::LocalBoxFuture;
|
||||
use futures_util::future::{ready, Ready};
|
||||
use chrono::Utc;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub type KeyedRateLimiter<
|
||||
K = String,
|
||||
MW = middleware::StateInformationMiddleware,
|
||||
> = Arc<
|
||||
RateLimiter<K, state::keyed::DefaultKeyedStateStore<K>, DefaultClock, MW>,
|
||||
>;
|
||||
const RATE_LIMIT_NAMESPACE: &str = "rate_limit";
|
||||
const RATE_LIMIT_EXPIRY: i64 = 300; // 5 minutes
|
||||
const MINUTE_IN_NANOS: i64 = 60_000_000_000;
|
||||
|
||||
pub struct RateLimit(pub KeyedRateLimiter);
|
||||
pub struct GCRAParameters {
|
||||
emission_interval: i64,
|
||||
burst_size: u32,
|
||||
}
|
||||
|
||||
impl<S, B> Transform<S, ServiceRequest> for RateLimit
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
type Response = ServiceResponse<EitherBody<B>>;
|
||||
type Error = Error;
|
||||
type Transform = RateLimitService<S>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||
impl GCRAParameters {
|
||||
pub(crate) fn new(requests_per_minute: u32, burst_size: u32) -> Self {
|
||||
// Calculate emission interval in nanoseconds
|
||||
let emission_interval = MINUTE_IN_NANOS / requests_per_minute as i64;
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
ready(Ok(RateLimitService {
|
||||
service,
|
||||
rate_limiter: Arc::clone(&self.0),
|
||||
}))
|
||||
Self {
|
||||
emission_interval,
|
||||
burst_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct RateLimitService<S> {
|
||||
service: S,
|
||||
rate_limiter: KeyedRateLimiter,
|
||||
pub struct RateLimitDecision {
|
||||
pub allowed: bool,
|
||||
pub limit: u32,
|
||||
pub remaining: u32,
|
||||
pub reset_after_ms: i64,
|
||||
pub retry_after_ms: Option<i64>,
|
||||
}
|
||||
|
||||
impl<S, B> Service<ServiceRequest> for RateLimitService<S>
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
type Response = ServiceResponse<EitherBody<B>>;
|
||||
type Error = Error;
|
||||
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
#[derive(Clone)]
|
||||
pub struct AsyncRateLimiter {
|
||||
redis_pool: RedisPool,
|
||||
params: Arc<GCRAParameters>,
|
||||
}
|
||||
|
||||
forward_ready!(service);
|
||||
|
||||
fn call(&self, req: ServiceRequest) -> Self::Future {
|
||||
if let Some(key) = req.headers().get("x-ratelimit-key") {
|
||||
if key.to_str().ok()
|
||||
== dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref()
|
||||
{
|
||||
let res = self.service.call(req);
|
||||
|
||||
return Box::pin(async move {
|
||||
let service_response = res.await?;
|
||||
Ok(service_response.map_into_left_body())
|
||||
});
|
||||
}
|
||||
impl AsyncRateLimiter {
|
||||
pub fn new(redis_pool: RedisPool, params: GCRAParameters) -> Self {
|
||||
Self {
|
||||
redis_pool,
|
||||
params: Arc::new(params),
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = req.connection_info().clone();
|
||||
let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) {
|
||||
if let Some(header) = req.headers().get("CF-Connecting-IP") {
|
||||
header.to_str().ok()
|
||||
} else {
|
||||
conn_info.peer_addr()
|
||||
pub async fn check_rate_limit(&self, key: &str) -> RateLimitDecision {
|
||||
let mut conn = match self.redis_pool.connect().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => {
|
||||
// If Redis is unavailable, allow the request but with reduced limit
|
||||
return RateLimitDecision {
|
||||
allowed: true,
|
||||
limit: self.params.burst_size,
|
||||
remaining: 1,
|
||||
reset_after_ms: 60_000, // 1 minute
|
||||
retry_after_ms: None,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
conn_info.peer_addr()
|
||||
};
|
||||
|
||||
if let Some(ip) = ip {
|
||||
let ip = ip.to_string();
|
||||
// Get current time in nanoseconds since UNIX epoch
|
||||
let now = Utc::now().timestamp_nanos_opt().unwrap_or(0);
|
||||
|
||||
match self.rate_limiter.check_key(&ip) {
|
||||
Ok(snapshot) => {
|
||||
let fut = self.service.call(req);
|
||||
// Get the current TAT from Redis (if it exists)
|
||||
let tat_str = conn.get(RATE_LIMIT_NAMESPACE, key).await.ok().flatten();
|
||||
|
||||
Box::pin(async move {
|
||||
match fut.await {
|
||||
Ok(mut service_response) => {
|
||||
// Now you have a mutable reference to the ServiceResponse, so you can modify its headers.
|
||||
let headers = service_response.headers_mut();
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-limit",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot.quota().burst_size().get().into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-remaining",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot.remaining_burst_capacity().into(),
|
||||
);
|
||||
// Parse the TAT or use current time if not found
|
||||
let current_tat = match tat_str {
|
||||
Some(tat_str) => tat_str.parse::<i64>().unwrap_or(now),
|
||||
None => now,
|
||||
};
|
||||
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-reset",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot
|
||||
.quota()
|
||||
.burst_size_replenished_in()
|
||||
.as_secs()
|
||||
.into(),
|
||||
);
|
||||
// Calculate the new TAT using GCRA
|
||||
let increment = self.params.emission_interval;
|
||||
let max_tat_delta = increment * self.params.burst_size as i64;
|
||||
|
||||
// Return the modified response as Ok.
|
||||
Ok(service_response.map_into_left_body())
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle error case
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Err(negative) => {
|
||||
let wait_time =
|
||||
negative.wait_time_from(DefaultClock::default().now());
|
||||
// Calculate allowance: how much time has passed since the TAT
|
||||
let allowance = now - current_tat;
|
||||
|
||||
let mut response = ApiError::RateLimitError(
|
||||
wait_time.as_millis(),
|
||||
negative.quota().burst_size().get(),
|
||||
)
|
||||
.error_response();
|
||||
if allowance < -max_tat_delta {
|
||||
// Too many requests, rate limit exceeded
|
||||
// Calculate when the client can retry
|
||||
let retry_after_ms = (-allowance - max_tat_delta) / 1_000_000;
|
||||
|
||||
let headers = response.headers_mut();
|
||||
return RateLimitDecision {
|
||||
allowed: false,
|
||||
limit: self.params.burst_size,
|
||||
remaining: 0,
|
||||
reset_after_ms: -allowance / 1_000_000,
|
||||
retry_after_ms: Some(retry_after_ms.max(0)),
|
||||
};
|
||||
}
|
||||
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-limit",
|
||||
)
|
||||
.unwrap(),
|
||||
negative.quota().burst_size().get().into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-remaining",
|
||||
)
|
||||
.unwrap(),
|
||||
0.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-reset",
|
||||
)
|
||||
.unwrap(),
|
||||
wait_time.as_secs().into(),
|
||||
);
|
||||
let new_tat = std::cmp::max(current_tat + increment, now);
|
||||
|
||||
// TODO: Sentralize CORS in the CORS util.
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"Access-Control-Allow-Origin",
|
||||
)
|
||||
.unwrap(),
|
||||
"*".parse().unwrap(),
|
||||
);
|
||||
let _ = conn
|
||||
.set(
|
||||
RATE_LIMIT_NAMESPACE,
|
||||
key,
|
||||
&new_tat.to_string(),
|
||||
Some(RATE_LIMIT_EXPIRY),
|
||||
)
|
||||
.await;
|
||||
|
||||
Box::pin(async {
|
||||
Ok(req.into_response(response.map_into_right_body()))
|
||||
})
|
||||
}
|
||||
}
|
||||
let remaining_capacity =
|
||||
((max_tat_delta - (new_tat - now)) / increment).max(0) as u32;
|
||||
|
||||
RateLimitDecision {
|
||||
allowed: true,
|
||||
limit: self.params.burst_size,
|
||||
remaining: remaining_capacity,
|
||||
reset_after_ms: (new_tat - now) / 1_000_000,
|
||||
retry_after_ms: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
req: ServiceRequest,
|
||||
next: Next<impl MessageBody>,
|
||||
) -> Result<ServiceResponse<EitherBody<impl MessageBody>>, Error> {
|
||||
let rate_limiter = req
|
||||
.app_data::<web::Data<AsyncRateLimiter>>()
|
||||
.expect("Rate limiter not configured properly")
|
||||
.clone();
|
||||
|
||||
if let Some(key) = req.headers().get("x-ratelimit-key") {
|
||||
if key.to_str().ok()
|
||||
== dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref()
|
||||
{
|
||||
return Ok(next.call(req).await?.map_into_left_body());
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = req.connection_info().clone();
|
||||
let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) {
|
||||
if let Some(header) = req.headers().get("CF-Connecting-IP") {
|
||||
header.to_str().ok()
|
||||
} else {
|
||||
let response = ApiError::CustomAuthentication(
|
||||
"Unable to obtain user IP address!".to_string(),
|
||||
conn_info.peer_addr()
|
||||
}
|
||||
} else {
|
||||
conn_info.peer_addr()
|
||||
};
|
||||
|
||||
if let Some(ip) = ip {
|
||||
let decision = rate_limiter.check_rate_limit(ip).await;
|
||||
|
||||
if decision.allowed {
|
||||
let mut service_response = next.call(req).await?;
|
||||
|
||||
// Add rate limit headers
|
||||
let headers = service_response.headers_mut();
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-limit",
|
||||
)
|
||||
.unwrap(),
|
||||
decision.limit.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-remaining",
|
||||
)
|
||||
.unwrap(),
|
||||
decision.remaining.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-reset",
|
||||
)
|
||||
.unwrap(),
|
||||
(decision.reset_after_ms / 1000).into(),
|
||||
);
|
||||
|
||||
Ok(service_response.map_into_left_body())
|
||||
} else {
|
||||
let mut response = ApiError::RateLimitError(
|
||||
decision.retry_after_ms.unwrap_or(0) as u128,
|
||||
decision.limit,
|
||||
)
|
||||
.error_response();
|
||||
|
||||
Box::pin(async {
|
||||
Ok(req.into_response(response.map_into_right_body()))
|
||||
})
|
||||
// Add rate limit headers
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-limit",
|
||||
)
|
||||
.unwrap(),
|
||||
decision.limit.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-remaining",
|
||||
)
|
||||
.unwrap(),
|
||||
0.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-reset",
|
||||
)
|
||||
.unwrap(),
|
||||
(decision.reset_after_ms / 1000).into(),
|
||||
);
|
||||
|
||||
// TODO: Centralize CORS in the CORS util.
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"Access-Control-Allow-Origin",
|
||||
)
|
||||
.unwrap(),
|
||||
"*".parse().unwrap(),
|
||||
);
|
||||
|
||||
Ok(req.into_response(response.map_into_right_body()))
|
||||
}
|
||||
} else {
|
||||
let response = ApiError::CustomAuthentication(
|
||||
"Unable to obtain user IP address!".to_string(),
|
||||
)
|
||||
.error_response();
|
||||
|
||||
Ok(req.into_response(response.map_into_right_body()))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user