Finish authentication (#659)

This commit is contained in:
Geometrically 2023-07-18 15:02:54 -07:00 committed by GitHub
parent ec80c2b9db
commit 4bb47d7e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 217 additions and 75 deletions

3
.env
View File

@ -80,3 +80,6 @@ SMTP_HOST=none
SITE_VERIFY_EMAIL_PATH=none SITE_VERIFY_EMAIL_PATH=none
SITE_RESET_PASSWORD_PATH=none SITE_RESET_PASSWORD_PATH=none
BEEHIIV_PUBLICATION_ID=none
BEEHIIV_API_KEY=none

View File

@ -45,7 +45,8 @@ pub fn config(cfg: &mut ServiceConfig) {
.service(change_password) .service(change_password)
.service(resend_verify_email) .service(resend_verify_email)
.service(set_email) .service(set_email)
.service(verify_email), .service(verify_email)
.service(subscribe_newsletter),
); );
} }
@ -1022,11 +1023,17 @@ pub async fn auth_callback(
let session = issue_session(req, user_id, &mut transaction, &redis).await?; let session = issue_session(req, user_id, &mut transaction, &redis).await?;
transaction.commit().await?; transaction.commit().await?;
let redirect_url = if url.contains('?') { let redirect_url = format!(
format!("{}&code={}", url, session.session) "{}{}code={}{}",
} else { url,
format!("{}?code={}", url, session.session) if url.contains('?') { '&' } else { '?' },
}; session.session,
if user_id_opt.is_none() {
"&new_account=true"
} else {
""
}
);
Ok(HttpResponse::TemporaryRedirect() Ok(HttpResponse::TemporaryRedirect()
.append_header(("Location", &*redirect_url)) .append_header(("Location", &*redirect_url))
@ -1091,6 +1098,32 @@ pub async fn delete_auth_provider(
Ok(HttpResponse::NoContent().finish()) Ok(HttpResponse::NoContent().finish())
} }
pub async fn sign_up_beehiiv(email: &str) -> Result<(), AuthenticationError> {
let id = dotenvy::var("BEEHIIV_PUBLICATION_ID")?;
let api_key = dotenvy::var("BEEHIIV_API_KEY")?;
let site_url = dotenvy::var("SITE_URL")?;
let client = reqwest::Client::new();
client
.post(&format!(
"https://api.beehiiv.com/v2/publications/{id}/subscriptions"
))
.header(AUTHORIZATION, format!("Bearer {}", api_key))
.json(&serde_json::json!({
"email": email,
"utm_source": "modrinth",
"utm_medium": "account_creation",
"referring_site": site_url,
}))
.send()
.await?
.error_for_status()?
.text()
.await?;
Ok(())
}
#[derive(Deserialize, Validate)] #[derive(Deserialize, Validate)]
pub struct NewAccount { pub struct NewAccount {
#[validate(length(min = 1, max = 39), regex = "RE_URL_SAFE")] #[validate(length(min = 1, max = 39), regex = "RE_URL_SAFE")]
@ -1100,6 +1133,7 @@ pub struct NewAccount {
#[validate(email)] #[validate(email)]
pub email: String, pub email: String,
pub challenge: String, pub challenge: String,
pub sign_up_newsletter: Option<bool>,
} }
#[post("create")] #[post("create")]
@ -1170,7 +1204,7 @@ pub async fn create_account_with_password(
send_email_verify( send_email_verify(
new_account.email.clone(), new_account.email.clone(),
flow, flow,
&format!("Welcome to Modritnh, {}!", new_account.username), &format!("Welcome to Modrinth, {}!", new_account.username),
)?; )?;
crate::database::models::User { crate::database::models::User {
@ -1185,7 +1219,7 @@ pub async fn create_account_with_password(
totp_secret: None, totp_secret: None,
username: new_account.username.clone(), username: new_account.username.clone(),
name: Some(new_account.username), name: Some(new_account.username),
email: Some(new_account.email), email: Some(new_account.email.clone()),
email_verified: false, email_verified: false,
avatar_url: None, avatar_url: None,
bio: None, bio: None,
@ -1201,7 +1235,12 @@ pub async fn create_account_with_password(
.await?; .await?;
let session = issue_session(req, user_id, &mut transaction, &redis).await?; let session = issue_session(req, user_id, &mut transaction, &redis).await?;
let res = crate::models::sessions::Session::from(session, true); let res = crate::models::sessions::Session::from(session, true, None);
if new_account.sign_up_newsletter.unwrap_or(false) {
sign_up_beehiiv(&new_account.email).await?;
}
transaction.commit().await?; transaction.commit().await?;
Ok(HttpResponse::Ok().json(res)) Ok(HttpResponse::Ok().json(res))
@ -1264,7 +1303,7 @@ pub async fn login_password(
} else { } else {
let mut transaction = pool.begin().await?; let mut transaction = pool.begin().await?;
let session = issue_session(req, user.id, &mut transaction, &redis).await?; let session = issue_session(req, user.id, &mut transaction, &redis).await?;
let res = crate::models::sessions::Session::from(session, true); let res = crate::models::sessions::Session::from(session, true, None);
transaction.commit().await?; transaction.commit().await?;
Ok(HttpResponse::Ok().json(res)) Ok(HttpResponse::Ok().json(res))
@ -1277,7 +1316,15 @@ pub struct Login2FA {
pub flow: String, pub flow: String,
} }
fn get_2fa_code(secret: String) -> Result<String, AuthenticationError> { async fn validate_2fa_code(
input: String,
secret: String,
allow_backup: bool,
user_id: crate::database::models::UserId,
redis: &deadpool_redis::Pool,
pool: &PgPool,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<bool, AuthenticationError> {
let totp = totp_rs::TOTP::new( let totp = totp_rs::TOTP::new(
totp_rs::Algorithm::SHA1, totp_rs::Algorithm::SHA1,
6, 6,
@ -1292,7 +1339,34 @@ fn get_2fa_code(secret: String) -> Result<String, AuthenticationError> {
.generate_current() .generate_current()
.map_err(|_| AuthenticationError::InvalidCredentials)?; .map_err(|_| AuthenticationError::InvalidCredentials)?;
Ok(token) if input == token {
Ok(true)
} else if allow_backup {
let backup_codes = crate::database::models::User::get_backup_codes(user_id, pool).await?;
if !backup_codes.contains(&input) {
Ok(false)
} else {
let code = parse_base62(&input).unwrap_or_default();
sqlx::query!(
"
DELETE FROM user_backup_codes
WHERE user_id = $1 AND code = $2
",
user_id as crate::database::models::ids::UserId,
code as i64,
)
.execute(&mut *transaction)
.await?;
crate::database::models::User::clear_caches(&[(user_id, None)], redis).await?;
Ok(true)
}
} else {
Err(AuthenticationError::InvalidCredentials)
}
} }
#[post("login/2fa")] #[post("login/2fa")]
@ -1311,41 +1385,27 @@ pub async fn login_2fa(
.await? .await?
.ok_or_else(|| AuthenticationError::InvalidCredentials)?; .ok_or_else(|| AuthenticationError::InvalidCredentials)?;
let token = get_2fa_code( let mut transaction = pool.begin().await?;
if !validate_2fa_code(
login.code.clone(),
user.totp_secret user.totp_secret
.ok_or_else(|| AuthenticationError::InvalidCredentials)?, .ok_or_else(|| AuthenticationError::InvalidCredentials)?,
)?; true,
user.id,
let mut transaction = pool.begin().await?; &redis,
if token != login.code { &pool,
let backup_codes = &mut transaction,
crate::database::models::User::get_backup_codes(user_id, &**pool).await?; )
.await?
if !backup_codes.contains(&login.code) { {
return Err(ApiError::Authentication( return Err(ApiError::Authentication(
AuthenticationError::InvalidCredentials, AuthenticationError::InvalidCredentials,
)); ));
} else {
let code = parse_base62(&login.code).unwrap_or_default();
sqlx::query!(
"
DELETE FROM user_backup_codes
WHERE user_id = $1 AND code = $2
",
user_id as crate::database::models::ids::UserId,
code as i64,
)
.execute(&mut *transaction)
.await?;
crate::database::models::User::clear_caches(&[(user_id, None)], &redis).await?;
}
} }
Flow::remove(&login.flow, &redis).await?; Flow::remove(&login.flow, &redis).await?;
let session = issue_session(req, user_id, &mut transaction, &redis).await?; let session = issue_session(req, user_id, &mut transaction, &redis).await?;
let res = crate::models::sessions::Session::from(session, true); let res = crate::models::sessions::Session::from(session, true, None);
transaction.commit().await?; transaction.commit().await?;
Ok(HttpResponse::Ok().json(res)) Ok(HttpResponse::Ok().json(res))
@ -1424,16 +1484,25 @@ pub async fn finish_2fa_flow(
)); ));
} }
let token = get_2fa_code(secret.clone())?; let mut transaction = pool.begin().await?;
if token != login.code { if !validate_2fa_code(
login.code.clone(),
secret.clone(),
false,
user.id.into(),
&redis,
&pool,
&mut transaction,
)
.await?
{
return Err(ApiError::Authentication( return Err(ApiError::Authentication(
AuthenticationError::InvalidCredentials, AuthenticationError::InvalidCredentials,
)); ));
} }
Flow::remove(&login.flow, &redis).await?;
let mut transaction = pool.begin().await?; Flow::remove(&login.flow, &redis).await?;
sqlx::query!( sqlx::query!(
" "
@ -1528,18 +1597,26 @@ pub async fn remove_2fa(
)); ));
} }
let token = get_2fa_code(user.totp_secret.ok_or_else(|| { let mut transaction = pool.begin().await?;
ApiError::InvalidInput("User does not have 2FA enabled on the account!".to_string())
})?)?;
if token != login.code { if !validate_2fa_code(
login.code.clone(),
user.totp_secret.ok_or_else(|| {
ApiError::InvalidInput("User does not have 2FA enabled on the account!".to_string())
})?,
true,
user.id,
&redis,
&pool,
&mut transaction,
)
.await?
{
return Err(ApiError::Authentication( return Err(ApiError::Authentication(
AuthenticationError::InvalidCredentials, AuthenticationError::InvalidCredentials,
)); ));
} }
let mut transaction = pool.begin().await?;
sqlx::query!( sqlx::query!(
" "
UPDATE users UPDATE users
@ -1930,6 +2007,34 @@ pub async fn verify_email(
} }
} }
#[post("email/subscribe")]
pub async fn subscribe_newsletter(
req: HttpRequest,
pool: Data<PgPool>,
redis: Data<deadpool_redis::Pool>,
session_queue: Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::USER_AUTH_WRITE]),
)
.await?
.1;
if let Some(email) = user.email {
sign_up_beehiiv(&email).await?;
Ok(HttpResponse::NoContent().finish())
} else {
Err(ApiError::InvalidInput(
"User does not have an email.".to_string(),
))
}
}
fn send_email_verify( fn send_email_verify(
email: String, email: String,
flow: String, flow: String,

View File

@ -27,7 +27,7 @@ pub enum AuthenticationError {
Database(#[from] crate::database::models::DatabaseError), Database(#[from] crate::database::models::DatabaseError),
#[error("Error while parsing JSON: {0}")] #[error("Error while parsing JSON: {0}")]
SerDe(#[from] serde_json::Error), SerDe(#[from] serde_json::Error),
#[error("Error while communicating to external oauth provider")] #[error("Error while communicating to external provider")]
Reqwest(#[from] reqwest::Error), Reqwest(#[from] reqwest::Error),
#[error("Error uploading user profile picture")] #[error("Error uploading user profile picture")]
FileHosting(#[from] FileHostingError), FileHosting(#[from] FileHostingError),

View File

@ -115,6 +115,16 @@ pub async fn issue_session(
.await? .await?
.ok_or_else(|| AuthenticationError::InvalidCredentials)?; .ok_or_else(|| AuthenticationError::InvalidCredentials)?;
DBSession::clear_cache(
vec![(
Some(session.id),
Some(session.session.clone()),
Some(session.user_id),
)],
redis,
)
.await?;
Ok(session) Ok(session)
} }
@ -135,12 +145,18 @@ pub async fn list(
.await? .await?
.1; .1;
let session = req
.headers()
.get(AUTHORIZATION)
.and_then(|x| x.to_str().ok())
.ok_or_else(|| AuthenticationError::InvalidCredentials)?;
let session_ids = DBSession::get_user_sessions(current_user.id.into(), &**pool, &redis).await?; let session_ids = DBSession::get_user_sessions(current_user.id.into(), &**pool, &redis).await?;
let sessions = DBSession::get_many_ids(&session_ids, &**pool, &redis) let sessions = DBSession::get_many_ids(&session_ids, &**pool, &redis)
.await? .await?
.into_iter() .into_iter()
.filter(|x| x.expires > Utc::now()) .filter(|x| x.expires > Utc::now())
.map(|x| Session::from(x, false)) .map(|x| Session::from(x, false, Some(session)))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(HttpResponse::Ok().json(sessions)) Ok(HttpResponse::Ok().json(sessions))
@ -227,7 +243,7 @@ pub async fn refresh(
transaction.commit().await?; transaction.commit().await?;
Ok(HttpResponse::Ok().json(Session::from(new_session, true))) Ok(HttpResponse::Ok().json(Session::from(new_session, true, None)))
} else { } else {
Err(ApiError::Authentication( Err(ApiError::Authentication(
AuthenticationError::InvalidCredentials, AuthenticationError::InvalidCredentials,

View File

@ -182,7 +182,7 @@ pub struct ReportTypeId(pub i32);
#[sqlx(transparent)] #[sqlx(transparent)]
pub struct FileId(pub i64); pub struct FileId(pub i64);
#[derive(Copy, Clone, Debug, Type, Deserialize, Serialize)] #[derive(Copy, Clone, Debug, Type, Deserialize, Serialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)] #[sqlx(transparent)]
pub struct PatId(pub i64); pub struct PatId(pub i64);
@ -200,7 +200,7 @@ pub struct ThreadId(pub i64);
#[sqlx(transparent)] #[sqlx(transparent)]
pub struct ThreadMessageId(pub i64); pub struct ThreadMessageId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)] #[sqlx(transparent)]
pub struct SessionId(pub i64); pub struct SessionId(pub i64);

View File

@ -146,7 +146,7 @@ impl PersonalAccessToken {
} }
if !remaining_strings.is_empty() { if !remaining_strings.is_empty() {
let pat_ids_parsed: Vec<i64> = pat_strings let pat_ids_parsed: Vec<i64> = remaining_strings
.iter() .iter()
.flat_map(|x| parse_base62(&x.to_string()).ok()) .flat_map(|x| parse_base62(&x.to_string()).ok())
.map(|x| x as i64) .map(|x| x as i64)
@ -159,7 +159,7 @@ impl PersonalAccessToken {
ORDER BY created DESC ORDER BY created DESC
", ",
&pat_ids_parsed, &pat_ids_parsed,
&pat_strings &remaining_strings
.into_iter() .into_iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@ -214,8 +214,9 @@ impl PersonalAccessToken {
let mut redis = redis.get().await?; let mut redis = redis.get().await?;
let res = cmd("GET") let res = cmd("GET")
.arg(format!("{}:{}", PATS_USERS_NAMESPACE, user_id.0)) .arg(format!("{}:{}", PATS_USERS_NAMESPACE, user_id.0))
.query_async::<_, Option<Vec<i64>>>(&mut redis) .query_async::<_, Option<String>>(&mut redis)
.await?; .await?
.and_then(|x| serde_json::from_str::<Vec<i64>>(&x).ok());
if let Some(res) = res { if let Some(res) = res {
return Ok(res.into_iter().map(PatId).collect()); return Ok(res.into_iter().map(PatId).collect());
@ -251,6 +252,10 @@ impl PersonalAccessToken {
clear_pats: Vec<(Option<PatId>, Option<String>, Option<UserId>)>, clear_pats: Vec<(Option<PatId>, Option<String>, Option<UserId>)>,
redis: &deadpool_redis::Pool, redis: &deadpool_redis::Pool,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
if clear_pats.is_empty() {
return Ok(());
}
let mut redis = redis.get().await?; let mut redis = redis.get().await?;
let mut cmd = cmd("DEL"); let mut cmd = cmd("DEL");

View File

@ -187,7 +187,7 @@ impl Session {
} }
if !remaining_strings.is_empty() { if !remaining_strings.is_empty() {
let session_ids_parsed: Vec<i64> = session_strings let session_ids_parsed: Vec<i64> = remaining_strings
.iter() .iter()
.flat_map(|x| parse_base62(&x.to_string()).ok()) .flat_map(|x| parse_base62(&x.to_string()).ok())
.map(|x| x as i64) .map(|x| x as i64)
@ -201,7 +201,7 @@ impl Session {
ORDER BY created DESC ORDER BY created DESC
", ",
&session_ids_parsed, &session_ids_parsed,
&session_strings.into_iter().map(|x| x.to_string()).collect::<Vec<_>>(), &remaining_strings.into_iter().map(|x| x.to_string()).collect::<Vec<_>>(),
) )
.fetch_many(exec) .fetch_many(exec)
.try_filter_map(|e| async { .try_filter_map(|e| async {
@ -258,8 +258,9 @@ impl Session {
let mut redis = redis.get().await?; let mut redis = redis.get().await?;
let res = cmd("GET") let res = cmd("GET")
.arg(format!("{}:{}", SESSIONS_USERS_NAMESPACE, user_id.0)) .arg(format!("{}:{}", SESSIONS_USERS_NAMESPACE, user_id.0))
.query_async::<_, Option<Vec<i64>>>(&mut redis) .query_async::<_, Option<String>>(&mut redis)
.await?; .await?
.and_then(|x| serde_json::from_str::<Vec<i64>>(&x).ok());
if let Some(res) = res { if let Some(res) = res {
return Ok(res.into_iter().map(SessionId).collect()); return Ok(res.into_iter().map(SessionId).collect());
@ -295,6 +296,10 @@ impl Session {
clear_sessions: Vec<(Option<SessionId>, Option<String>, Option<UserId>)>, clear_sessions: Vec<(Option<SessionId>, Option<String>, Option<UserId>)>,
redis: &deadpool_redis::Pool, redis: &deadpool_redis::Pool,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
if clear_sessions.is_empty() {
return Ok(());
}
let mut redis = redis.get().await?; let mut redis = redis.get().await?;
let mut cmd = cmd("DEL"); let mut cmd = cmd("DEL");

View File

@ -459,5 +459,8 @@ fn check_env_vars() -> bool {
failed |= check_var::<String>("SITE_VERIFY_EMAIL_PATH"); failed |= check_var::<String>("SITE_VERIFY_EMAIL_PATH");
failed |= check_var::<String>("SITE_RESET_PASSWORD_PATH"); failed |= check_var::<String>("SITE_RESET_PASSWORD_PATH");
failed |= check_var::<String>("BEEHIIV_PUBLICATION_ID");
failed |= check_var::<String>("BEEHIIV_API_KEY");
failed failed
} }

View File

@ -26,15 +26,19 @@ pub struct Session {
pub city: Option<String>, pub city: Option<String>,
pub country: Option<String>, pub country: Option<String>,
pub ip: String, pub ip: String,
pub current: bool,
} }
impl Session { impl Session {
pub fn from( pub fn from(
data: crate::database::models::session_item::Session, data: crate::database::models::session_item::Session,
include_session: bool, include_session: bool,
current_session: Option<&str>,
) -> Self { ) -> Self {
Session { Session {
id: data.id.into(), id: data.id.into(),
current: Some(&*data.session) == current_session,
session: if include_session { session: if include_session {
Some(data.session) Some(data.session)
} else { } else {

View File

@ -4,41 +4,42 @@ use crate::database::models::session_item::Session;
use crate::database::models::{DatabaseError, PatId, SessionId, UserId}; use crate::database::models::{DatabaseError, PatId, SessionId, UserId};
use chrono::Utc; use chrono::Utc;
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::{HashMap, HashSet};
use tokio::sync::Mutex; use tokio::sync::Mutex;
pub struct AuthQueue { pub struct AuthQueue {
session_queue: Mutex<Vec<(SessionId, SessionMetadata)>>, session_queue: Mutex<HashMap<SessionId, SessionMetadata>>,
pat_queue: Mutex<Vec<PatId>>, pat_queue: Mutex<HashSet<PatId>>,
} }
// Batches session accessing transactions every 30 seconds // Batches session accessing transactions every 30 seconds
impl AuthQueue { impl AuthQueue {
pub fn new() -> Self { pub fn new() -> Self {
AuthQueue { AuthQueue {
session_queue: Mutex::new(Vec::with_capacity(1000)), session_queue: Mutex::new(HashMap::with_capacity(1000)),
pat_queue: Mutex::new(Vec::with_capacity(1000)), pat_queue: Mutex::new(HashSet::with_capacity(1000)),
} }
} }
pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) { pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) {
self.session_queue.lock().await.push((id, metadata)); self.session_queue.lock().await.insert(id, metadata);
} }
pub async fn add_pat(&self, id: PatId) { pub async fn add_pat(&self, id: PatId) {
self.pat_queue.lock().await.push(id); self.pat_queue.lock().await.insert(id);
} }
pub async fn take_sessions(&self) -> Vec<(SessionId, SessionMetadata)> { pub async fn take_sessions(&self) -> HashMap<SessionId, SessionMetadata> {
let mut queue = self.session_queue.lock().await; let mut queue = self.session_queue.lock().await;
let len = queue.len(); let len = queue.len();
std::mem::replace(&mut queue, Vec::with_capacity(len)) std::mem::replace(&mut queue, HashMap::with_capacity(len))
} }
pub async fn take_pats(&self) -> Vec<PatId> { pub async fn take_pats(&self) -> HashSet<PatId> {
let mut queue = self.pat_queue.lock().await; let mut queue = self.pat_queue.lock().await;
let len = queue.len(); let len = queue.len();
std::mem::replace(&mut queue, Vec::with_capacity(len)) std::mem::replace(&mut queue, HashSet::with_capacity(len))
} }
pub async fn index( pub async fn index(