diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs index 2b82da351..72a65abb3 100644 --- a/src/auth/oauth/errors.rs +++ b/src/auth/oauth/errors.rs @@ -2,7 +2,7 @@ use super::ValidatedRedirectUri; use crate::auth::AuthenticationError; use crate::models::error::ApiError; use crate::models::ids::DecodingError; -use actix_web::http::StatusCode; +use actix_web::http::{header::LOCATION, StatusCode}; use actix_web::HttpResponse; #[derive(thiserror::Error, Debug)] @@ -63,7 +63,7 @@ impl actix_web::ResponseError for OAuthError { | OAuthErrorType::ScopesTooBroad | OAuthErrorType::AccessDenied => { if self.valid_redirect_uri.is_some() { - StatusCode::FOUND + StatusCode::OK } else { StatusCode::INTERNAL_SERVER_ERROR } @@ -94,10 +94,9 @@ impl actix_web::ResponseError for OAuthError { redirect_uri = format!("{}&state={}", redirect_uri, state); } - redirect_uri = urlencoding::encode(&redirect_uri).to_string(); - HttpResponse::Found() - .append_header(("Location".to_string(), redirect_uri)) - .finish() + HttpResponse::Ok() + .append_header((LOCATION, redirect_uri.clone())) + .body(redirect_uri) } else { HttpResponse::build(self.status_code()).json(ApiError { error: &self.error_type.error_name(), diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 0d64b53f5..51a26cc69 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -7,19 +7,21 @@ use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; use crate::database::models::oauth_token_item::OAuthAccessToken; use crate::database::models::{ generate_oauth_access_token_id, generate_oauth_client_authorization_id, - OAuthClientAuthorizationId, OAuthClientId, + OAuthClientAuthorizationId, }; use crate::database::redis::RedisPool; use crate::models; +use crate::models::ids::OAuthClientId; use crate::models::pats::Scopes; use crate::queue::session::AuthQueue; +use actix_web::http::header::LOCATION; use actix_web::web::{scope, Data, Query, ServiceConfig}; use actix_web::{get, post, web, HttpRequest, HttpResponse}; use chrono::Duration; use rand::distributions::Alphanumeric; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; -use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; +use reqwest::header::{CACHE_CONTROL, PRAGMA}; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPool; @@ -75,7 +77,7 @@ pub async fn init_oauth( .await? .1; - let client_id = oauth_info.client_id; + let client_id = oauth_info.client_id.into(); let client = DBOAuthClient::get(client_id, &**pool).await?; if let Some(client) = client { @@ -118,7 +120,7 @@ pub async fn init_oauth( { init_oauth_code_flow( user.id.into(), - client.id, + client.id.into(), existing_authorization.id, requested_scopes, redirect_uris, @@ -141,7 +143,7 @@ pub async fn init_oauth( .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; let access_request = OAuthClientAccessRequest { - client_id: client.id, + client_id: client.id.into(), client_name: client.name, client_icon: client.icon_url, flow_id, @@ -341,7 +343,7 @@ pub async fn accept_or_reject_client_scopes( init_oauth_code_flow( user_id, - client_id, + client_id.into(), auth_id, scopes, redirect_uris, @@ -396,7 +398,7 @@ async fn init_oauth_code_flow( ) -> Result { let code = Flow::OAuthAuthorizationCodeSupplied { user_id, - client_id, + client_id: client_id.into(), authorization_id, scopes, original_redirect_uri: redirect_uris.original.clone(), @@ -413,9 +415,9 @@ async fn init_oauth_code_flow( let redirect_uri = append_params_to_uri(&redirect_uris.validated.0, &redirect_params); // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) - Ok(HttpResponse::Found() - .append_header((LOCATION, redirect_uri)) - .finish()) + Ok(HttpResponse::Ok() + .append_header((LOCATION, redirect_uri.clone())) + .body(redirect_uri)) } fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index 16795aa3e..f13eb97be 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -4,6 +4,7 @@ use super::{ }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use serde_with::serde_as; use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization as DBOAuthClientAuthorization; use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; @@ -64,9 +65,13 @@ pub struct OAuthClientAuthorization { pub created: DateTime, } +#[serde_as] #[derive(Deserialize, Serialize)] pub struct GetOAuthClientsRequest { - pub ids: Vec, + #[serde_as( + as = "serde_with::StringWithSeparator::" + )] + pub ids: Vec, } #[derive(Deserialize, Serialize)] diff --git a/src/models/pats.rs b/src/models/pats.rs index 07a58692b..83f9b1c56 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -132,7 +132,9 @@ impl Scopes { } pub fn parse_from_oauth_scopes(scopes: &str) -> Result { - let scopes = scopes.replace(' ', "|").replace("%20", "|"); + let scopes = scopes + .replace(['+', ' '], "|") + .replace("%20", "|"); bitflags::parser::from_str(&scopes) } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 5f3839b60..b04cb9bb2 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -15,7 +15,8 @@ use validator::Validate; use super::ApiError; use crate::{ - auth::checks::ValidateAllAuthorized, models::oauth_clients::DeleteOAuthClientQueryParam, + auth::checks::ValidateAllAuthorized, + models::{ids::base62_impl::parse_base62, oauth_clients::DeleteOAuthClientQueryParam}, }; use crate::{ auth::{checks::ValidateAuthorized, get_user_from_headers}, @@ -111,13 +112,19 @@ pub async fn get_client( #[get("apps")] pub async fn get_clients( req: HttpRequest, - info: web::Json, + info: web::Query, pool: web::Data, redis: web::Data, session_queue: web::Data, ) -> Result { - let clients = - get_clients_inner(&info.into_inner().ids, req, pool, redis, session_queue).await?; + let ids: Vec<_> = info + .ids + .iter() + .map(|id| parse_base62(id).map(ApiOAuthClientId)) + .collect::>()?; + + let clients = get_clients_inner(&ids, req, pool, redis, session_queue).await?; + Ok(HttpResponse::Ok().json(clients)) } diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 6212dffac..ee78b5d93 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -125,7 +125,7 @@ pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { } pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { - assert_status(response, StatusCode::FOUND); + assert_status(response, StatusCode::OK); let query_params = get_redirect_location_query_params(response); query_params.get("code").unwrap().to_string() } @@ -140,7 +140,13 @@ pub async fn get_access_token(response: ServiceResponse) -> String { pub fn get_redirect_location_query_params( response: &ServiceResponse, ) -> actix_web::web::Query> { - let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + let redirect_location = response + .headers() + .get(LOCATION) + .unwrap() + .to_str() + .unwrap() + .to_string(); actix_web::web::Query::>::from_query( redirect_location.split_once('?').unwrap().1, ) diff --git a/tests/oauth.rs b/tests/oauth.rs index 1ae59d32a..63deb036b 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -42,7 +42,7 @@ async fn oauth_flow_happy_path() { // Accept the authorization request let resp = env.v3.oauth_accept(&flow_id, FRIEND_USER_PAT).await; - assert_status(&resp, StatusCode::FOUND); + assert_status(&resp, StatusCode::OK); let query = get_redirect_location_query_params(&resp); let auth_code = query.get("code").unwrap(); @@ -105,7 +105,7 @@ async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { USER_USER_PAT, ) .await; - assert_status(&resp, StatusCode::FOUND); + assert_status(&resp, StatusCode::OK); }) .await; } @@ -231,10 +231,10 @@ async fn reject_authorize_ends_authorize_flow() { let flow_id = get_authorize_accept_flow_id(resp).await; let resp = env.v3.oauth_reject(&flow_id, USER_USER_PAT).await; - assert_status(&resp, StatusCode::FOUND); + assert_status(&resp, StatusCode::OK); let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; - assert_any_status_except(&resp, StatusCode::FOUND); + assert_any_status_except(&resp, StatusCode::OK); }) .await; } @@ -249,7 +249,7 @@ async fn accept_authorize_after_already_accepting_fails() { .await; let flow_id = get_authorize_accept_flow_id(resp).await; let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; - assert_status(&resp, StatusCode::FOUND); + assert_status(&resp, StatusCode::OK); let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; assert_status(&resp, StatusCode::BAD_REQUEST);