update ratelimiter (#897)
* update ratelimiter * Switch to old scheduler
This commit is contained in:
parent
a0aa350a08
commit
0a0837ea02
119
Cargo.lock
generated
119
Cargo.lock
generated
@ -2,31 +2,6 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "actix"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cba56612922b907719d4a01cf11c8d5b458e7d3dba946d0435f20f58d6795ed2"
|
||||
dependencies = [
|
||||
"actix-macros",
|
||||
"actix-rt",
|
||||
"actix_derive",
|
||||
"bitflags 2.4.1",
|
||||
"bytes",
|
||||
"crossbeam-channel",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
"log",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "actix-codec"
|
||||
version = "0.5.1"
|
||||
@ -309,17 +284,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "actix_derive"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c7db3d5a9718568e4cf4a537cfd7070e6e6ff7481510d0237fb529ac850f6d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.21.0"
|
||||
@ -1021,15 +985,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.5"
|
||||
@ -1803,6 +1758,26 @@ version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"no-std-compat",
|
||||
"nonzero_ext",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"rand",
|
||||
"smallvec",
|
||||
"spinning_top",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.23"
|
||||
@ -2304,7 +2279,6 @@ dependencies = [
|
||||
name = "labrinth"
|
||||
version = "2.7.0"
|
||||
dependencies = [
|
||||
"actix",
|
||||
"actix-cors",
|
||||
"actix-files",
|
||||
"actix-http",
|
||||
@ -2330,6 +2304,8 @@ dependencies = [
|
||||
"flate2",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"governor",
|
||||
"hex",
|
||||
"hmac 0.11.0",
|
||||
"hyper",
|
||||
@ -2731,6 +2707,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "no-std-compat"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
@ -2741,6 +2723,12 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint-dig"
|
||||
version = "0.8.4"
|
||||
@ -3132,6 +3120,12 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
@ -3252,6 +3246,21 @@ dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ca0b7bac0b97248c40bb77288fc52029cf1459c0461ea1b05ee32ccf011de2c"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
@ -3339,6 +3348,15 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.8.0"
|
||||
@ -4204,6 +4222,15 @@ dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spinning_top"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.7.3"
|
||||
|
||||
@ -11,7 +11,6 @@ name = "labrinth"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
actix = "0.13.1"
|
||||
actix-web = "4.4.1"
|
||||
actix-rt = "2.9.0"
|
||||
actix-multipart = "0.6.1"
|
||||
@ -19,12 +18,14 @@ actix-cors = "0.7.0"
|
||||
actix-ws = "0.2.5"
|
||||
actix-files = "0.6.5"
|
||||
actix-web-prom = "0.7.0"
|
||||
governor = "0.6.3"
|
||||
|
||||
tokio = { version = "1.35.1", features = ["sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
|
||||
futures = "0.3.30"
|
||||
futures-timer = "3.0.2"
|
||||
futures-util = "0.3.30"
|
||||
async-trait = "0.1.70"
|
||||
dashmap = "5.4.0"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
30
src/lib.rs
30
src/lib.rs
@ -1,4 +1,6 @@
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::web;
|
||||
use database::redis::RedisPool;
|
||||
@ -6,12 +8,13 @@ use log::{info, warn};
|
||||
use queue::{
|
||||
analytics::AnalyticsQueue, payouts::PayoutsQueue, session::AuthQueue, socket::ActiveSockets,
|
||||
};
|
||||
use scheduler::Scheduler;
|
||||
use sqlx::Postgres;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
extern crate clickhouse as clickhouse_crate;
|
||||
use clickhouse_crate::Client;
|
||||
use governor::{Quota, RateLimiter};
|
||||
use governor::middleware::StateInformationMiddleware;
|
||||
use util::cors::default_cors;
|
||||
|
||||
use crate::queue::moderation::AutomatedModerationQueue;
|
||||
@ -20,6 +23,7 @@ use crate::{
|
||||
search::indexing::index_projects,
|
||||
util::env::{parse_strings_from_var, parse_var},
|
||||
};
|
||||
use crate::util::ratelimit::KeyedRateLimiter;
|
||||
|
||||
pub mod auth;
|
||||
pub mod clickhouse;
|
||||
@ -27,7 +31,6 @@ pub mod database;
|
||||
pub mod file_hosting;
|
||||
pub mod models;
|
||||
pub mod queue;
|
||||
pub mod ratelimit;
|
||||
pub mod routes;
|
||||
pub mod scheduler;
|
||||
pub mod search;
|
||||
@ -46,7 +49,7 @@ pub struct LabrinthConfig {
|
||||
pub clickhouse: Client,
|
||||
pub file_host: Arc<dyn file_hosting::FileHost + Send + Sync>,
|
||||
pub maxmind: Arc<queue::maxmind::MaxMindIndexer>,
|
||||
pub scheduler: Arc<Scheduler>,
|
||||
pub scheduler: Arc<scheduler::Scheduler>,
|
||||
pub ip_salt: Pepper,
|
||||
pub search_config: search::SearchConfig,
|
||||
pub session_queue: web::Data<AuthQueue>,
|
||||
@ -54,6 +57,7 @@ pub struct LabrinthConfig {
|
||||
pub analytics_queue: Arc<AnalyticsQueue>,
|
||||
pub active_sockets: web::Data<RwLock<ActiveSockets>>,
|
||||
pub automated_moderation_queue: web::Data<AutomatedModerationQueue>,
|
||||
pub rate_limiter: KeyedRateLimiter,
|
||||
}
|
||||
|
||||
pub fn app_setup(
|
||||
@ -82,6 +86,25 @@ pub fn app_setup(
|
||||
|
||||
let mut scheduler = scheduler::Scheduler::new();
|
||||
|
||||
let limiter: KeyedRateLimiter = Arc::new(
|
||||
RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap()))
|
||||
.with_middleware::<StateInformationMiddleware>(),
|
||||
);
|
||||
let limiter_clone = Arc::clone(&limiter);
|
||||
scheduler.run(Duration::from_secs(60), move || {
|
||||
info!(
|
||||
"Clearing ratelimiter, storage size: {}",
|
||||
limiter_clone.len()
|
||||
);
|
||||
limiter_clone.retain_recent();
|
||||
info!(
|
||||
"Done clearing ratelimiter, storage size: {}",
|
||||
limiter_clone.len()
|
||||
);
|
||||
|
||||
async move {}
|
||||
});
|
||||
|
||||
// The interval in seconds at which the local database is indexed
|
||||
// for searching. Defaults to 1 hour if unset.
|
||||
let local_index_interval =
|
||||
@ -255,6 +278,7 @@ pub fn app_setup(
|
||||
analytics_queue,
|
||||
active_sockets,
|
||||
automated_moderation_queue,
|
||||
rate_limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
38
src/main.rs
38
src/main.rs
@ -1,16 +1,17 @@
|
||||
use actix_web::{App, HttpServer};
|
||||
use actix_web_prom::PrometheusMetricsBuilder;
|
||||
use env_logger::Env;
|
||||
use governor::middleware::StateInformationMiddleware;
|
||||
use governor::{Quota, RateLimiter};
|
||||
use labrinth::database::redis::RedisPool;
|
||||
use labrinth::file_hosting::S3Host;
|
||||
use labrinth::ratelimit::errors::ARError;
|
||||
use labrinth::ratelimit::memory::{MemoryStore, MemoryStoreActor};
|
||||
use labrinth::ratelimit::middleware::RateLimiter;
|
||||
use labrinth::search;
|
||||
use labrinth::util::env::parse_var;
|
||||
use labrinth::util::ratelimit::{KeyedRateLimiter, RateLimit};
|
||||
use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue};
|
||||
use log::{error, info};
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
#[global_allocator]
|
||||
@ -90,17 +91,14 @@ async fn main() -> std::io::Result<()> {
|
||||
|
||||
let maxmind_reader = Arc::new(queue::maxmind::MaxMindIndexer::new().await.unwrap());
|
||||
|
||||
let store = MemoryStore::new();
|
||||
|
||||
let prometheus = PrometheusMetricsBuilder::new("labrinth")
|
||||
.endpoint("/metrics")
|
||||
.build()
|
||||
.expect("Failed to create prometheus metrics middleware");
|
||||
|
||||
let search_config = search::SearchConfig::new(None);
|
||||
info!("Starting Actix HTTP server!");
|
||||
|
||||
let labrinth_config = labrinth::app_setup(
|
||||
let mut labrinth_config = labrinth::app_setup(
|
||||
pool.clone(),
|
||||
redis_pool.clone(),
|
||||
search_config.clone(),
|
||||
@ -109,32 +107,14 @@ async fn main() -> std::io::Result<()> {
|
||||
maxmind_reader.clone(),
|
||||
);
|
||||
|
||||
info!("Starting Actix HTTP server!");
|
||||
|
||||
// Init App
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.wrap(prometheus.clone())
|
||||
.wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter)))
|
||||
.wrap(actix_web::middleware::Compress::default())
|
||||
.wrap(
|
||||
RateLimiter::new(MemoryStoreActor::from(store.clone()).start())
|
||||
.with_identifier(|req| {
|
||||
let connection_info = req.connection_info();
|
||||
let ip =
|
||||
String::from(if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) {
|
||||
if let Some(header) = req.headers().get("CF-Connecting-IP") {
|
||||
header.to_str().map_err(|_| ARError::Identification)?
|
||||
} else {
|
||||
connection_info.peer_addr().ok_or(ARError::Identification)?
|
||||
}
|
||||
} else {
|
||||
connection_info.peer_addr().ok_or(ARError::Identification)?
|
||||
});
|
||||
|
||||
Ok(ip)
|
||||
})
|
||||
.with_interval(std::time::Duration::from_secs(60))
|
||||
.with_max_requests(300)
|
||||
.with_ignore_key(dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok()),
|
||||
)
|
||||
.wrap(sentry_actix::Sentry::new())
|
||||
.configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone()))
|
||||
})
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
//! Errors that can occur during middleware processing stage
|
||||
use crate::models::error::ApiError;
|
||||
use actix_web::ResponseError;
|
||||
use log::*;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Custom error type. Useful for logging and debugging different kinds of errors.
|
||||
/// This type can be converted to Actix Error, which defaults to
|
||||
/// InternalServerError
|
||||
///
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ARError {
|
||||
/// Read/Write error on store
|
||||
#[error("read/write operation failed: {0}")]
|
||||
ReadWrite(String),
|
||||
|
||||
/// Identifier error
|
||||
#[error("client identification failed")]
|
||||
Identification,
|
||||
/// Limited Error
|
||||
#[error("You are being rate-limited. Please wait {reset} seconds. {remaining}/{max_requests} remaining.")]
|
||||
Limited {
|
||||
max_requests: usize,
|
||||
remaining: usize,
|
||||
reset: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl ResponseError for ARError {
|
||||
fn error_response(&self) -> actix_web::HttpResponse {
|
||||
match self {
|
||||
Self::Limited {
|
||||
max_requests,
|
||||
remaining,
|
||||
reset,
|
||||
} => {
|
||||
let mut response = actix_web::HttpResponse::TooManyRequests();
|
||||
response.insert_header(("x-ratelimit-limit", max_requests.to_string()));
|
||||
response.insert_header(("x-ratelimit-remaining", remaining.to_string()));
|
||||
response.insert_header(("x-ratelimit-reset", reset.to_string()));
|
||||
response.json(ApiError {
|
||||
error: "ratelimit_error",
|
||||
description: self.to_string(),
|
||||
})
|
||||
}
|
||||
_ => actix_web::HttpResponse::build(self.status_code()).json(ApiError {
|
||||
error: "ratelimit_error",
|
||||
description: self.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,143 +0,0 @@
|
||||
//! In memory store for rate limiting
|
||||
use actix::prelude::*;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::{self};
|
||||
use log::*;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::ratelimit::errors::ARError;
|
||||
use crate::ratelimit::{ActorMessage, ActorResponse};
|
||||
|
||||
/// Type used to create a concurrent hashmap store
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryStore {
|
||||
inner: Arc<DashMap<String, (usize, Duration)>>,
|
||||
}
|
||||
|
||||
impl Default for MemoryStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryStore {
|
||||
/// Create a new hashmap
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use labrinth::ratelimit::memory::MemoryStore;
|
||||
///
|
||||
/// let store = MemoryStore::new();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
debug!("Creating new MemoryStore");
|
||||
MemoryStore {
|
||||
inner: Arc::new(DashMap::<String, (usize, Duration)>::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Actor for memory store
|
||||
pub struct MemoryStoreActor {
|
||||
inner: Arc<DashMap<String, (usize, Duration)>>,
|
||||
}
|
||||
|
||||
impl From<MemoryStore> for MemoryStoreActor {
|
||||
fn from(store: MemoryStore) -> Self {
|
||||
MemoryStoreActor { inner: store.inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryStoreActor {
|
||||
/// Starts the memory actor and returns it's address
|
||||
pub fn start(self) -> Addr<Self> {
|
||||
debug!("Started memory store");
|
||||
Supervisor::start(|_| self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for MemoryStoreActor {
|
||||
type Context = Context<Self>;
|
||||
}
|
||||
|
||||
impl Supervised for MemoryStoreActor {
|
||||
fn restarting(&mut self, _: &mut Self::Context) {
|
||||
debug!("Restarting memory store");
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<ActorMessage> for MemoryStoreActor {
|
||||
type Result = ActorResponse;
|
||||
fn handle(&mut self, msg: ActorMessage, ctx: &mut Self::Context) -> Self::Result {
|
||||
match msg {
|
||||
ActorMessage::Set { key, value, expiry } => {
|
||||
debug!("Inserting key {} with expiry {}", &key, &expiry.as_secs());
|
||||
let future_key = String::from(&key);
|
||||
let now = SystemTime::now();
|
||||
let now = now.duration_since(UNIX_EPOCH).unwrap();
|
||||
self.inner.insert(key, (value, now + expiry));
|
||||
ctx.notify_later(ActorMessage::Remove(future_key), expiry);
|
||||
ActorResponse::Set(Box::pin(future::ready(Ok(()))))
|
||||
}
|
||||
ActorMessage::Update { key, value } => match self.inner.get_mut(&key) {
|
||||
Some(mut c) => {
|
||||
let val_mut: &mut (usize, Duration) = c.value_mut();
|
||||
if val_mut.0 > value {
|
||||
val_mut.0 -= value;
|
||||
} else {
|
||||
val_mut.0 = 0;
|
||||
}
|
||||
let new_val = val_mut.0;
|
||||
ActorResponse::Update(Box::pin(future::ready(Ok(new_val))))
|
||||
}
|
||||
None => ActorResponse::Update(Box::pin(future::ready(Err(ARError::ReadWrite(
|
||||
"memory store: read failed!".to_string(),
|
||||
))))),
|
||||
},
|
||||
ActorMessage::Get(key) => {
|
||||
if self.inner.contains_key(&key) {
|
||||
let val = match self.inner.get(&key) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return ActorResponse::Get(Box::pin(future::ready(Err(
|
||||
ARError::ReadWrite("memory store: read failed!".to_string()),
|
||||
))))
|
||||
}
|
||||
};
|
||||
let val = val.value().0;
|
||||
ActorResponse::Get(Box::pin(future::ready(Ok(Some(val)))))
|
||||
} else {
|
||||
ActorResponse::Get(Box::pin(future::ready(Ok(None))))
|
||||
}
|
||||
}
|
||||
ActorMessage::Expire(key) => {
|
||||
let c = match self.inner.get(&key) {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
return ActorResponse::Expire(Box::pin(future::ready(Err(
|
||||
ARError::ReadWrite("memory store: read failed!".to_string()),
|
||||
))))
|
||||
}
|
||||
};
|
||||
let dur = c.value().1;
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
let res = dur.checked_sub(now).unwrap_or_else(|| Duration::new(0, 0));
|
||||
ActorResponse::Expire(Box::pin(future::ready(Ok(res))))
|
||||
}
|
||||
ActorMessage::Remove(key) => {
|
||||
debug!("Removing key: {}", &key);
|
||||
let val = match self.inner.remove::<String>(&key) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return ActorResponse::Remove(Box::pin(future::ready(Err(
|
||||
ARError::ReadWrite("memory store: remove failed!".to_string()),
|
||||
))))
|
||||
}
|
||||
};
|
||||
let val = val.1;
|
||||
ActorResponse::Remove(Box::pin(future::ready(Ok(val.0))))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,260 +0,0 @@
|
||||
use crate::ratelimit::errors::ARError;
|
||||
use crate::ratelimit::{ActorMessage, ActorResponse};
|
||||
use actix::dev::*;
|
||||
use actix_web::{
|
||||
dev::{Service, ServiceRequest, ServiceResponse, Transform},
|
||||
error::Error as AWError,
|
||||
http::header::{HeaderName, HeaderValue},
|
||||
};
|
||||
use futures::future::{ok, Ready};
|
||||
use log::*;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
ops::Fn,
|
||||
pin::Pin,
|
||||
rc::Rc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
type RateLimiterIdentifier = Rc<Box<dyn Fn(&ServiceRequest) -> Result<String, ARError> + 'static>>;
|
||||
|
||||
pub struct RateLimiter<T>
|
||||
where
|
||||
T: Handler<ActorMessage> + Send + Sync + 'static,
|
||||
T::Context: ToEnvelope<T, ActorMessage>,
|
||||
{
|
||||
interval: Duration,
|
||||
max_requests: usize,
|
||||
store: Addr<T>,
|
||||
identifier: RateLimiterIdentifier,
|
||||
ignore_key: Option<String>,
|
||||
}
|
||||
|
||||
impl<T> RateLimiter<T>
|
||||
where
|
||||
T: Handler<ActorMessage> + Send + Sync + 'static,
|
||||
<T as Actor>::Context: ToEnvelope<T, ActorMessage>,
|
||||
{
|
||||
/// Creates a new instance of `RateLimiter` with the provided address of `StoreActor`.
|
||||
pub fn new(store: Addr<T>) -> Self {
|
||||
let identifier = |req: &ServiceRequest| {
|
||||
let connection_info = req.connection_info();
|
||||
let ip = connection_info.peer_addr().ok_or(ARError::Identification)?;
|
||||
Ok(String::from(ip))
|
||||
};
|
||||
RateLimiter {
|
||||
interval: Duration::from_secs(0),
|
||||
max_requests: 0,
|
||||
store,
|
||||
identifier: Rc::new(Box::new(identifier)),
|
||||
ignore_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the interval. The counter for a client is reset after this interval
|
||||
pub fn with_interval(mut self, interval: Duration) -> Self {
|
||||
self.interval = interval;
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify the maximum number of requests allowed in the given interval.
|
||||
pub fn with_max_requests(mut self, max_requests: usize) -> Self {
|
||||
self.max_requests = max_requests;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets key which can be used to bypass rate-limiter
|
||||
pub fn with_ignore_key(mut self, ignore_key: Option<String>) -> Self {
|
||||
self.ignore_key = ignore_key;
|
||||
self
|
||||
}
|
||||
|
||||
/// Function to get the identifier for the client request
|
||||
pub fn with_identifier<F: Fn(&ServiceRequest) -> Result<String, ARError> + 'static>(
|
||||
mut self,
|
||||
identifier: F,
|
||||
) -> Self {
|
||||
self.identifier = Rc::new(Box::new(identifier));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B> Transform<S, ServiceRequest> for RateLimiter<T>
|
||||
where
|
||||
T: Handler<ActorMessage> + Send + Sync + 'static,
|
||||
T::Context: ToEnvelope<T, ActorMessage>,
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = AWError> + 'static,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
type Response = ServiceResponse<B>;
|
||||
type Error = S::Error;
|
||||
type Transform = RateLimitMiddleware<S, T>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
ok(RateLimitMiddleware {
|
||||
service: Rc::new(RefCell::new(service)),
|
||||
store: self.store.clone(),
|
||||
max_requests: self.max_requests,
|
||||
interval: self.interval.as_secs(),
|
||||
identifier: self.identifier.clone(),
|
||||
ignore_key: self.ignore_key.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Service factory for RateLimiter
|
||||
pub struct RateLimitMiddleware<S, T>
|
||||
where
|
||||
S: 'static,
|
||||
T: Handler<ActorMessage> + 'static,
|
||||
{
|
||||
service: Rc<RefCell<S>>,
|
||||
store: Addr<T>,
|
||||
// Exists here for the sole purpose of knowing the max_requests and interval from RateLimiter
|
||||
max_requests: usize,
|
||||
interval: u64,
|
||||
identifier: RateLimiterIdentifier,
|
||||
ignore_key: Option<String>,
|
||||
}
|
||||
|
||||
impl<T, S, B> Service<ServiceRequest> for RateLimitMiddleware<S, T>
|
||||
where
|
||||
T: Handler<ActorMessage> + 'static,
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = AWError> + 'static,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
T::Context: ToEnvelope<T, ActorMessage>,
|
||||
{
|
||||
type Response = ServiceResponse<B>;
|
||||
type Error = S::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.service.borrow_mut().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&self, req: ServiceRequest) -> Self::Future {
|
||||
let store = self.store.clone();
|
||||
let srv = self.service.clone();
|
||||
let max_requests = self.max_requests;
|
||||
let interval = Duration::from_secs(self.interval);
|
||||
let identifier = self.identifier.clone();
|
||||
let ignore_key = self.ignore_key.clone();
|
||||
Box::pin(async move {
|
||||
let identifier: String = (identifier)(&req)?;
|
||||
|
||||
if let Some(ignore_key) = ignore_key {
|
||||
if let Some(key) = req.headers().get("x-ratelimit-key") {
|
||||
if key.to_str().ok().unwrap_or_default() == &*ignore_key {
|
||||
let fut = srv.call(req);
|
||||
let res = fut.await?;
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let remaining: ActorResponse = store
|
||||
.send(ActorMessage::Get(String::from(&identifier)))
|
||||
.await
|
||||
.map_err(|_| ARError::Identification)?;
|
||||
match remaining {
|
||||
ActorResponse::Get(opt) => {
|
||||
let opt = opt.await?;
|
||||
if let Some(c) = opt {
|
||||
// Existing entry in store
|
||||
let expiry = store
|
||||
.send(ActorMessage::Expire(String::from(&identifier)))
|
||||
.await
|
||||
.map_err(|_| ARError::ReadWrite("Setting timeout".to_string()))?;
|
||||
let reset: Duration = match expiry {
|
||||
ActorResponse::Expire(dur) => dur.await?,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if c == 0 {
|
||||
info!("Limit exceeded for client: {}", &identifier);
|
||||
Err(ARError::Limited {
|
||||
max_requests,
|
||||
remaining: c,
|
||||
reset: reset.as_secs(),
|
||||
}
|
||||
.into())
|
||||
} else {
|
||||
// Decrement value
|
||||
let res: ActorResponse = store
|
||||
.send(ActorMessage::Update {
|
||||
key: identifier,
|
||||
value: 1,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| {
|
||||
ARError::ReadWrite("Decrementing ratelimit".to_string())
|
||||
})?;
|
||||
let updated_value: usize = match res {
|
||||
ActorResponse::Update(c) => c.await?,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
// Execute the request
|
||||
let fut = srv.call(req);
|
||||
let mut res = fut.await?;
|
||||
let headers = res.headers_mut();
|
||||
// Safe unwraps, since usize is always convertible to string
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-limit"),
|
||||
HeaderValue::from_str(max_requests.to_string().as_str())?,
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-remaining"),
|
||||
HeaderValue::from_str(updated_value.to_string().as_str())?,
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-reset"),
|
||||
HeaderValue::from_str(reset.as_secs().to_string().as_str())?,
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
} else {
|
||||
// New client, create entry in store
|
||||
let current_value = max_requests - 1;
|
||||
let res = store
|
||||
.send(ActorMessage::Set {
|
||||
key: String::from(&identifier),
|
||||
value: current_value,
|
||||
expiry: interval,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| ARError::ReadWrite("Creating store entry".to_string()))?;
|
||||
match res {
|
||||
ActorResponse::Set(c) => c.await?,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let fut = srv.call(req);
|
||||
let mut res = fut.await?;
|
||||
let headers = res.headers_mut();
|
||||
// Safe unwraps, since usize is always convertible to string
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-limit"),
|
||||
HeaderValue::from_str(max_requests.to_string().as_str()).unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-remaining"),
|
||||
HeaderValue::from_str(current_value.to_string().as_str()).unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-ratelimit-reset"),
|
||||
HeaderValue::from_str(interval.as_secs().to_string().as_str()).unwrap(),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,64 +0,0 @@
|
||||
use std::future::Future;
|
||||
use std::marker::Send;
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ratelimit::errors::ARError;
|
||||
use actix::dev::*;
|
||||
|
||||
pub mod errors;
|
||||
pub mod memory;
|
||||
/// The code for this module was directly taken from https://github.com/TerminalWitchcraft/actix-ratelimit
|
||||
/// with some modifications including upgrading it to Actix 4!
|
||||
pub mod middleware;
|
||||
|
||||
/// Represents message that can be handled by a `StoreActor`
|
||||
pub enum ActorMessage {
|
||||
/// Get the remaining count based on the provided identifier
|
||||
Get(String),
|
||||
/// Set the count of the client identified by `key` to `value` valid for `expiry`
|
||||
Set {
|
||||
key: String,
|
||||
value: usize,
|
||||
expiry: Duration,
|
||||
},
|
||||
/// Change the value of count for the client identified by `key` by `value`
|
||||
Update { key: String, value: usize },
|
||||
/// Get the expiration time for the client.
|
||||
Expire(String),
|
||||
/// Remove the client from the store
|
||||
Remove(String),
|
||||
}
|
||||
|
||||
impl Message for ActorMessage {
|
||||
type Result = ActorResponse;
|
||||
}
|
||||
|
||||
/// Wrapper type for `Pin<Box<dyn Future>>` type
|
||||
pub type Output<T> = Pin<Box<dyn Future<Output = Result<T, ARError>> + Send>>;
|
||||
|
||||
/// Represents data returned in response to `Messages` by a `StoreActor`
|
||||
pub enum ActorResponse {
|
||||
/// Returned in response to [Messages::Get](enum.Messages.html)
|
||||
Get(Output<Option<usize>>),
|
||||
/// Returned in response to [Messages::Set](enum.Messages.html)
|
||||
Set(Output<()>),
|
||||
/// Returned in response to [Messages::Update](enum.Messages.html)
|
||||
Update(Output<usize>),
|
||||
/// Returned in response to [Messages::Expire](enum.Messages.html)
|
||||
Expire(Output<Duration>),
|
||||
/// Returned in response to [Messages::Remove](enum.Messages.html)
|
||||
Remove(Output<usize>),
|
||||
}
|
||||
|
||||
impl<A, M> MessageResponse<A, M> for ActorResponse
|
||||
where
|
||||
A: Actor,
|
||||
M: actix::Message<Result = ActorResponse>,
|
||||
{
|
||||
fn handle(self, _: &mut A::Context, tx: Option<OneshotSender<Self>>) {
|
||||
if let Some(tx) = tx {
|
||||
let _ = tx.send(self);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -129,6 +129,8 @@ pub enum ApiError {
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Resource not found")]
|
||||
NotFound,
|
||||
#[error("You are being rate-limited. Please wait {0} milliseconds. 0/{1} remaining.")]
|
||||
RateLimitError(u128, u32),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
@ -160,6 +162,7 @@ impl ApiError {
|
||||
ApiError::NotFound => "not_found",
|
||||
ApiError::Zip(..) => "zip_error",
|
||||
ApiError::Io(..) => "io_error",
|
||||
ApiError::RateLimitError(..) => "ratelimit_error",
|
||||
},
|
||||
description: self.to_string(),
|
||||
}
|
||||
@ -194,6 +197,7 @@ impl actix_web::ResponseError for ApiError {
|
||||
ApiError::NotFound => StatusCode::NOT_FOUND,
|
||||
ApiError::Zip(..) => StatusCode::BAD_REQUEST,
|
||||
ApiError::Io(..) => StatusCode::BAD_REQUEST,
|
||||
ApiError::RateLimitError(..) => StatusCode::TOO_MANY_REQUESTS,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -19,9 +19,9 @@ impl Scheduler {
|
||||
}
|
||||
|
||||
pub fn run<F, R>(&mut self, interval: std::time::Duration, mut task: F)
|
||||
where
|
||||
F: FnMut() -> R + Send + 'static,
|
||||
R: std::future::Future<Output = ()> + Send + 'static,
|
||||
where
|
||||
F: FnMut() -> R + Send + 'static,
|
||||
R: std::future::Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let future = IntervalStream::new(actix_rt::time::interval(interval))
|
||||
.for_each_concurrent(2, move |_| task());
|
||||
@ -207,4 +207,4 @@ async fn update_versions(
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -7,6 +7,7 @@ pub mod env;
|
||||
pub mod ext;
|
||||
pub mod guards;
|
||||
pub mod img;
|
||||
pub mod ratelimit;
|
||||
pub mod redis;
|
||||
pub mod routes;
|
||||
pub mod validate;
|
||||
|
||||
167
src/util/ratelimit.rs
Normal file
167
src/util/ratelimit.rs
Normal file
@ -0,0 +1,167 @@
|
||||
use governor::clock::{Clock, DefaultClock};
|
||||
use governor::{middleware, state, RateLimiter};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::routes::ApiError;
|
||||
use crate::util::env::parse_var;
|
||||
use actix_web::{
|
||||
body::EitherBody,
|
||||
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
||||
Error, ResponseError,
|
||||
};
|
||||
use futures_util::future::LocalBoxFuture;
|
||||
use futures_util::future::{ready, Ready};
|
||||
|
||||
pub type KeyedRateLimiter<K = String, MW = middleware::StateInformationMiddleware> =
|
||||
Arc<RateLimiter<K, state::keyed::DefaultKeyedStateStore<K>, DefaultClock, MW>>;
|
||||
|
||||
pub struct RateLimit(pub KeyedRateLimiter);
|
||||
|
||||
impl<S, B> Transform<S, ServiceRequest> for RateLimit
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
type Response = ServiceResponse<EitherBody<B>>;
|
||||
type Error = Error;
|
||||
type Transform = RateLimitService<S>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
ready(Ok(RateLimitService {
|
||||
service,
|
||||
rate_limiter: Arc::clone(&self.0),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct RateLimitService<S> {
|
||||
service: S,
|
||||
rate_limiter: KeyedRateLimiter,
|
||||
}
|
||||
|
||||
impl<S, B> Service<ServiceRequest> for RateLimitService<S>
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||
S::Future: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
type Response = ServiceResponse<EitherBody<B>>;
|
||||
type Error = Error;
|
||||
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
forward_ready!(service);
|
||||
|
||||
fn call(&self, req: ServiceRequest) -> Self::Future {
|
||||
if let Some(key) = req.headers().get("x-ratelimit-key") {
|
||||
if key.to_str().ok() == dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref() {
|
||||
let res = self.service.call(req);
|
||||
|
||||
return Box::pin(async move {
|
||||
let service_response = res.await?;
|
||||
Ok(service_response.map_into_left_body())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = req.connection_info().clone();
|
||||
let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) {
|
||||
if let Some(header) = req.headers().get("CF-Connecting-IP") {
|
||||
header.to_str().ok()
|
||||
} else {
|
||||
conn_info.peer_addr()
|
||||
}
|
||||
} else {
|
||||
conn_info.peer_addr()
|
||||
};
|
||||
|
||||
if let Some(ip) = ip {
|
||||
let ip = ip.to_string();
|
||||
|
||||
match self.rate_limiter.check_key(&ip) {
|
||||
Ok(snapshot) => {
|
||||
let fut = self.service.call(req);
|
||||
|
||||
Box::pin(async move {
|
||||
match fut.await {
|
||||
Ok(mut service_response) => {
|
||||
// Now you have a mutable reference to the ServiceResponse, so you can modify its headers.
|
||||
let headers = service_response.headers_mut();
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-limit",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot.quota().burst_size().get().into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-remaining",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot.remaining_burst_capacity().into(),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str(
|
||||
"x-ratelimit-reset",
|
||||
)
|
||||
.unwrap(),
|
||||
snapshot
|
||||
.quota()
|
||||
.burst_size_replenished_in()
|
||||
.as_secs()
|
||||
.into(),
|
||||
);
|
||||
|
||||
// Return the modified response as Ok.
|
||||
Ok(service_response.map_into_left_body())
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle error case
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Err(negative) => {
|
||||
let wait_time = negative.wait_time_from(DefaultClock::default().now());
|
||||
|
||||
let mut response = ApiError::RateLimitError(
|
||||
wait_time.as_millis(),
|
||||
negative.quota().burst_size().get(),
|
||||
)
|
||||
.error_response();
|
||||
|
||||
let headers = response.headers_mut();
|
||||
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str("x-ratelimit-limit").unwrap(),
|
||||
negative.quota().burst_size().get().into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str("x-ratelimit-remaining")
|
||||
.unwrap(),
|
||||
0.into(),
|
||||
);
|
||||
headers.insert(
|
||||
actix_web::http::header::HeaderName::from_str("x-ratelimit-reset").unwrap(),
|
||||
wait_time.as_secs().into(),
|
||||
);
|
||||
|
||||
Box::pin(async { Ok(req.into_response(response.map_into_right_body())) })
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let response =
|
||||
ApiError::CustomAuthentication("Unable to obtain user IP address!".to_string())
|
||||
.error_response();
|
||||
|
||||
Box::pin(async { Ok(req.into_response(response.map_into_right_body())) })
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user