From 6cfd4637db73cb773ef37da7e74f031a1d254f10 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Mon, 30 Oct 2023 11:14:38 -0500 Subject: [PATCH] OAuth 2.0 Authorization Server [MOD-559] (#733) * WIP end-of-day push * Authorize endpoint, accept endpoints, DB stuff for oauth clients, their redirects, and client authorizations * OAuth Client create route * Get user clients * Client delete * Edit oauth client * Include redirects in edit client route * Database stuff for tokens * Reorg oauth stuff out of auth/flows and into its own module * Impl OAuth get access token endpoint * Accept oauth access tokens as auth and update through AuthQueue * User OAuth authorization management routes * Forgot to actually add the routes lol * Bit o cleanup * Happy path test for OAuth and minor fixes for things it found * Add dummy data oauth client (and detect/handle dummy data version changes) * More tests * Another test * More tests and reject endpoint * Test oauth client and authorization management routes * cargo sqlx prepare * dead code warning * Auto clippy fixes * Uri refactoring * minor name improvement * Don't compile-time check the test sqlx queries * Trying to fix db concurrency problem to get tests to pass * Try fix from test PR * Fixes for updated sqlx * Prevent restricted scopes from being requested or issued * Get OAuth client(s) * Remove joined oauth client info from authorization returns * Add default conversion to OAuthError::error so we can use ? * Rework routes * Consolidate scopes into SESSION_ACCESS * Cargo sqlx prepare * Parse to OAuthClientId automatically through serde and actix * Cargo clippy * Remove validation requiring 1 redirect URI on oauth client creation * Use serde(flatten) on OAuthClientCreationResult --- ...438d6222a26926f4b06865b84979fc92564ba.json | 15 - ...bab9da5fdd1b7ee72d411a9989deb4ee506bb.json | 15 + ...7734a9af0e944f1671df71f9f4e25d835ffd9.json | 22 + ...31044f41d210dd64bbbb5e7c2347acc2304e9.json | 22 + ...103cac7e4b850d446b29d2efd9757b642fc1c.json | 15 + ...cff1889e5490f0d0d62170ed2b9515ffc5104.json | 47 ++ ...3c2efd07808c4a859ab1b1e9e65e16439a8f3.json | 17 + ...e26e68b10522d0f1df3f006d58f6b72be9911.json | 32 ++ ...2876c5ae253df538f4cd4c3701e63137fb01b.json | 70 +++ ...a2593a1556ef214f3bed519de6b6e21c7d477.json | 70 +++ ...2278605311a311def3cbe38846b8ca465737f.json | 19 + ...b81f2f47ccdded1e764c04d8b7651d9796ce0.json | 16 + ...17ef99fae60e5fbff5f0396f70787156de322.json | 46 ++ ...7aaaaf533a4409633cf00c071049bb6816c96.json | 14 + ...f501784de06acdddeed1db8f570cb04755f1a.json | 22 + ...41c55bf62928a96d9582d3e223d6473335428.json | 15 + ...f3412349dd6bc5c836d2634bbcb376a6f7c12.json | 22 + ...b8f3133bea215a89964d78cb652f930465faf.json | 14 + ...23f75ffdaebc76d67a5d35218fa9273d46d53.json | 17 + ...549ab5c23306758168af38f955c06d251b0b7.json | 70 +++ migrations/20231016190056_oauth_provider.sql | 34 ++ src/auth/checks.rs | 36 ++ src/auth/mod.rs | 2 +- src/auth/oauth/errors.rs | 176 +++++++ src/auth/oauth/mod.rs | 430 +++++++++++++++++ src/auth/oauth/uris.rs | 94 ++++ src/auth/validate.rs | 35 +- src/database/models/flow_item.rs | 34 +- src/database/models/ids.rs | 68 +++ src/database/models/mod.rs | 4 + .../models/oauth_client_authorization_item.rs | 126 +++++ src/database/models/oauth_client_item.rs | 245 ++++++++++ src/database/models/oauth_token_item.rs | 95 ++++ src/models/ids.rs | 5 + src/models/mod.rs | 1 + src/models/oauth_clients.rs | 110 +++++ src/models/pats.rs | 78 +++ src/queue/session.rs | 77 ++- src/routes/v3/mod.rs | 8 +- src/routes/v3/oauth_clients.rs | 444 ++++++++++++++++++ src/util/validate.rs | 12 + tests/common/api_v2/project.rs | 18 +- tests/common/api_v2/team.rs | 18 +- tests/common/api_v3/mod.rs | 19 + tests/common/api_v3/oauth.rs | 156 ++++++ tests/common/api_v3/oauth_clients.rs | 107 +++++ tests/common/asserts.rs | 11 +- tests/common/database.rs | 58 ++- tests/common/dummy_data.rs | 172 ++++--- tests/common/environment.rs | 27 +- tests/common/mod.rs | 9 + tests/files/dummy_data.sql | 19 + tests/oauth.rs | 292 ++++++++++++ tests/oauth_clients.rs | 193 ++++++++ 54 files changed, 3658 insertions(+), 135 deletions(-) delete mode 100644 .sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json create mode 100644 .sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json create mode 100644 .sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json create mode 100644 .sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json create mode 100644 .sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json create mode 100644 .sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json create mode 100644 .sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json create mode 100644 .sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json create mode 100644 .sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json create mode 100644 .sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json create mode 100644 .sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json create mode 100644 .sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json create mode 100644 .sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json create mode 100644 .sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json create mode 100644 .sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json create mode 100644 .sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json create mode 100644 .sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json create mode 100644 .sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json create mode 100644 .sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json create mode 100644 .sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json create mode 100644 migrations/20231016190056_oauth_provider.sql create mode 100644 src/auth/oauth/errors.rs create mode 100644 src/auth/oauth/mod.rs create mode 100644 src/auth/oauth/uris.rs create mode 100644 src/database/models/oauth_client_authorization_item.rs create mode 100644 src/database/models/oauth_client_item.rs create mode 100644 src/database/models/oauth_token_item.rs create mode 100644 src/models/oauth_clients.rs create mode 100644 src/routes/v3/oauth_clients.rs create mode 100644 tests/common/api_v3/mod.rs create mode 100644 tests/common/api_v3/oauth.rs create mode 100644 tests/common/api_v3/oauth_clients.rs create mode 100644 tests/oauth.rs create mode 100644 tests/oauth_clients.rs diff --git a/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json b/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json deleted file mode 100644 index f4c25dc7c..000000000 --- a/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pats\n SET last_used = $2\n WHERE (id = $1)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Int8", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba" -} diff --git a/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json b/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json new file mode 100644 index 000000000..42c7c646a --- /dev/null +++ b/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pats\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb" +} diff --git a/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json b/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json new file mode 100644 index 000000000..bd8065151 --- /dev/null +++ b/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9" +} diff --git a/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json b/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json new file mode 100644 index 000000000..7348f7595 --- /dev/null +++ b/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9" +} diff --git a/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json b/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json new file mode 100644 index 000000000..f94f7a466 --- /dev/null +++ b/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth_access_tokens\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c" +} diff --git a/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json b/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json new file mode 100644 index 000000000..de61d14e0 --- /dev/null +++ b/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json @@ -0,0 +1,47 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104" +} diff --git a/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json b/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json new file mode 100644 index 000000000..bbf10def8 --- /dev/null +++ b/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_client_authorizations (\n id, client_id, user_id, scopes\n )\n VALUES (\n $1, $2, $3, $4\n )\n ON CONFLICT (id)\n DO UPDATE SET scopes = EXCLUDED.scopes\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3" +} diff --git a/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json b/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json new file mode 100644 index 000000000..797214d3b --- /dev/null +++ b/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_access_tokens (\n id, authorization_id, token_hash, scopes, last_used\n )\n VALUES (\n $1, $2, $3, $4, $5\n )\n RETURNING created, expires\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 1, + "name": "expires", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Text", + "Int8", + "Timestamptz" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911" +} diff --git a/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json b/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json new file mode 100644 index 000000000..454d523ce --- /dev/null +++ b/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n tokens.id,\n tokens.authorization_id,\n tokens.token_hash,\n tokens.scopes,\n tokens.created,\n tokens.expires,\n tokens.last_used,\n auths.client_id,\n auths.user_id\n FROM oauth_access_tokens tokens\n JOIN oauth_client_authorizations auths\n ON tokens.authorization_id = auths.id\n WHERE tokens.token_hash = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "authorization_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "token_hash", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "expires", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "last_used", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 8, + "name": "user_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false + ] + }, + "hash": "7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b" +} diff --git a/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json b/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json new file mode 100644 index 000000000..9be628e85 --- /dev/null +++ b/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name!", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url?", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes!", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash!", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created!", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by!", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids?", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals?", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ] + }, + "hash": "8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477" +} diff --git a/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json b/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json new file mode 100644 index 000000000..598dfeffc --- /dev/null +++ b/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_clients (\n id, name, icon_url, max_scopes, secret_hash, created_by\n )\n VALUES (\n $1, $2, $3, $4, $5, $6\n )\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Int8", + "Text", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f" +} diff --git a/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json b/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json new file mode 100644 index 000000000..4c3c291ba --- /dev/null +++ b/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_client_redirect_uris (id, client_id, uri)\n SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Int8Array", + "VarcharArray" + ] + }, + "nullable": [] + }, + "hash": "9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0" +} diff --git a/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json b/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json new file mode 100644 index 000000000..722f05896 --- /dev/null +++ b/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE user_id=$1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322" +} diff --git a/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json b/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json new file mode 100644 index 000000000..0833383c0 --- /dev/null +++ b/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_client_redirect_uris\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array" + ] + }, + "nullable": [] + }, + "hash": "cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96" +} diff --git a/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json b/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json new file mode 100644 index 000000000..9a0a54921 --- /dev/null +++ b/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a" +} diff --git a/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json b/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json new file mode 100644 index 000000000..d80d7c906 --- /dev/null +++ b/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428" +} diff --git a/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json b/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json new file mode 100644 index 000000000..5ed8687e5 --- /dev/null +++ b/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12" +} diff --git a/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json b/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json new file mode 100644 index 000000000..29200a656 --- /dev/null +++ b/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_clients\n WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf" +} diff --git a/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json b/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json new file mode 100644 index 000000000..d01d58769 --- /dev/null +++ b/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth_clients\n SET name = $1, icon_url = $2, max_scopes = $3\n WHERE (id = $4)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Text", + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53" +} diff --git a/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json b/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json new file mode 100644 index 000000000..08fa78f0b --- /dev/null +++ b/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = ANY($1::bigint[])", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name!", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url?", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes!", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash!", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created!", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by!", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids?", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals?", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8Array" + ] + }, + "nullable": [ + true, + true, + true, + true, + true, + true, + true, + null, + null + ] + }, + "hash": "fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7" +} diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql new file mode 100644 index 000000000..a6a9406e6 --- /dev/null +++ b/migrations/20231016190056_oauth_provider.sql @@ -0,0 +1,34 @@ +CREATE TABLE oauth_clients ( + id bigint PRIMARY KEY, + name text NOT NULL, + icon_url text NULL, + max_scopes bigint NOT NULL, + secret_hash text NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by bigint NOT NULL REFERENCES users(id) +); +CREATE TABLE oauth_client_redirect_uris ( + id bigint PRIMARY KEY, + client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE, + uri text +); +CREATE TABLE oauth_client_authorizations ( + id bigint PRIMARY KEY, + client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE, + user_id bigint NOT NULL REFERENCES users (id) ON DELETE CASCADE, + scopes bigint NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE (client_id, user_id) +); +CREATE TABLE oauth_access_tokens ( + id bigint PRIMARY KEY, + authorization_id bigint NOT NULL REFERENCES oauth_client_authorizations(id) ON DELETE CASCADE, + token_hash text NOT NULL UNIQUE, + scopes bigint NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP + interval '14 days', + last_used timestamptz NULL +); +CREATE INDEX oauth_client_creator ON oauth_clients(created_by); +CREATE INDEX oauth_redirect_client ON oauth_client_redirect_uris(client_id); +CREATE INDEX oauth_access_token_hash ON oauth_access_tokens(token_hash); \ No newline at end of file diff --git a/src/auth/checks.rs b/src/auth/checks.rs index b358494cf..4d47e72c3 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -8,6 +8,25 @@ use crate::routes::ApiError; use actix_web::web; use sqlx::PgPool; +pub trait ValidateAuthorized { + fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError>; +} + +pub trait ValidateAllAuthorized { + fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError>; +} + +impl<'a, T, A> ValidateAllAuthorized for T +where + T: IntoIterator, + A: ValidateAuthorized + 'a, +{ + fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError> { + self.into_iter() + .try_for_each(|c| c.validate_authorized(user_option)) + } +} + pub async fn is_authorized( project_data: &Project, user_option: &Option, @@ -156,6 +175,23 @@ pub async fn is_authorized_version( Ok(authorized) } +impl ValidateAuthorized for crate::database::models::OAuthClient { + fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError> { + if let Some(user) = user_option { + if user.role.is_mod() || user.id == self.created_by.into() { + return Ok(()); + } else { + return Err(crate::routes::ApiError::CustomAuthentication( + "You don't have sufficient permissions to interact with this OAuth application" + .to_string(), + )); + } + } + + Ok(()) + } +} + pub async fn filter_authorized_versions( versions: Vec, user_option: &Option, diff --git a/src/auth/mod.rs b/src/auth/mod.rs index eec82ec56..9c65c914b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,11 +1,11 @@ pub mod checks; pub mod email; pub mod flows; +pub mod oauth; pub mod pats; pub mod session; mod templates; pub mod validate; - pub use checks::{ filter_authorized_projects, filter_authorized_versions, is_authorized, is_authorized_version, }; diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs new file mode 100644 index 000000000..2b82da351 --- /dev/null +++ b/src/auth/oauth/errors.rs @@ -0,0 +1,176 @@ +use super::ValidatedRedirectUri; +use crate::auth::AuthenticationError; +use crate::models::error::ApiError; +use crate::models::ids::DecodingError; +use actix_web::http::StatusCode; +use actix_web::HttpResponse; + +#[derive(thiserror::Error, Debug)] +#[error("{}", .error_type)] +pub struct OAuthError { + #[source] + pub error_type: OAuthErrorType, + + pub state: Option, + pub valid_redirect_uri: Option, +} + +impl From for OAuthError +where + T: Into, +{ + fn from(value: T) -> Self { + OAuthError::error(value.into()) + } +} + +impl OAuthError { + /// The OAuth request failed either because of an invalid redirection URI + /// or before we could validate the one we were given, so return an error + /// directly to the caller + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn error(error_type: impl Into) -> Self { + Self { + error_type: error_type.into(), + valid_redirect_uri: None, + state: None, + } + } + + /// The OAuth request failed for a reason other than an invalid redirection URI + /// So send the error in url-encoded form to the redirect URI + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn redirect( + err: impl Into, + state: &Option, + valid_redirect_uri: &ValidatedRedirectUri, + ) -> Self { + Self { + error_type: err.into(), + state: state.clone(), + valid_redirect_uri: Some(valid_redirect_uri.clone()), + } + } +} + +impl actix_web::ResponseError for OAuthError { + fn status_code(&self) -> StatusCode { + match self.error_type { + OAuthErrorType::AuthenticationError(_) + | OAuthErrorType::FailedScopeParse(_) + | OAuthErrorType::ScopesTooBroad + | OAuthErrorType::AccessDenied => { + if self.valid_redirect_uri.is_some() { + StatusCode::FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + OAuthErrorType::RedirectUriNotConfigured(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } + | OAuthErrorType::InvalidAcceptFlowId + | OAuthErrorType::MalformedId(_) + | OAuthErrorType::InvalidClientId(_) + | OAuthErrorType::InvalidAuthCode + | OAuthErrorType::OnlySupportsAuthorizationCodeGrant(_) + | OAuthErrorType::RedirectUriChanged(_) + | OAuthErrorType::UnauthorizedClient => StatusCode::BAD_REQUEST, + OAuthErrorType::ClientAuthenticationFailed => StatusCode::UNAUTHORIZED, + } + } + + fn error_response(&self) -> HttpResponse { + if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() { + redirect_uri = format!( + "{}?error={}&error_description={}", + redirect_uri, + self.error_type.error_name(), + self.error_type, + ); + + if let Some(state) = self.state.as_ref() { + redirect_uri = format!("{}&state={}", redirect_uri, state); + } + + redirect_uri = urlencoding::encode(&redirect_uri).to_string(); + HttpResponse::Found() + .append_header(("Location".to_string(), redirect_uri)) + .finish() + } else { + HttpResponse::build(self.status_code()).json(ApiError { + error: &self.error_type.error_name(), + description: &self.error_type.to_string(), + }) + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum OAuthErrorType { + #[error(transparent)] + AuthenticationError(#[from] AuthenticationError), + #[error("Client {} has no redirect URIs specified", .client_id.0)] + ClientMissingRedirectURI { + client_id: crate::database::models::OAuthClientId, + }, + #[error("The provided redirect URI did not match any configured in the client")] + RedirectUriNotConfigured(String), + #[error("The provided scope was malformed or did not correspond to known scopes ({0})")] + FailedScopeParse(bitflags::parser::ParseError), + #[error( + "The provided scope requested scopes broader than the developer app is configured with" + )] + ScopesTooBroad, + #[error("The provided flow id was invalid")] + InvalidAcceptFlowId, + #[error("The provided client id was invalid")] + InvalidClientId(crate::database::models::OAuthClientId), + #[error("The provided ID could not be decoded: {0}")] + MalformedId(#[from] DecodingError), + #[error("Failed to authenticate client")] + ClientAuthenticationFailed, + #[error("The provided authorization grant code was invalid")] + InvalidAuthCode, + #[error("The provided client id did not match the id this authorization code was granted to")] + UnauthorizedClient, + #[error("The provided redirect URI did not exactly match the uri originally provided when this flow began")] + RedirectUriChanged(Option), + #[error("The provided grant type ({0}) must be \"authorization_code\"")] + OnlySupportsAuthorizationCodeGrant(String), + #[error("The resource owner denied the request")] + AccessDenied, +} + +impl From for OAuthErrorType { + fn from(value: crate::database::models::DatabaseError) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + +impl From for OAuthErrorType { + fn from(value: sqlx::Error) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + +impl OAuthErrorType { + pub fn error_name(&self) -> String { + // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) + // And 5.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) + match self { + Self::RedirectUriNotConfigured(_) | Self::ClientMissingRedirectURI { client_id: _ } => { + "invalid_uri" + } + Self::AuthenticationError(_) | Self::InvalidAcceptFlowId => "server_error", + Self::RedirectUriChanged(_) | Self::MalformedId(_) => "invalid_request", + Self::FailedScopeParse(_) | Self::ScopesTooBroad => "invalid_scope", + Self::InvalidClientId(_) | Self::ClientAuthenticationFailed => "invalid_client", + Self::InvalidAuthCode | Self::OnlySupportsAuthorizationCodeGrant(_) => "invalid_grant", + Self::UnauthorizedClient => "unauthorized_client", + Self::AccessDenied => "access_denied", + } + .to_string() + } +} diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs new file mode 100644 index 000000000..0d64b53f5 --- /dev/null +++ b/src/auth/oauth/mod.rs @@ -0,0 +1,430 @@ +use crate::auth::get_user_from_headers; +use crate::auth::oauth::uris::{OAuthRedirectUris, ValidatedRedirectUri}; +use crate::auth::validate::extract_authorization_header; +use crate::database::models::flow_item::Flow; +use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::oauth_token_item::OAuthAccessToken; +use crate::database::models::{ + generate_oauth_access_token_id, generate_oauth_client_authorization_id, + OAuthClientAuthorizationId, OAuthClientId, +}; +use crate::database::redis::RedisPool; +use crate::models; +use crate::models::pats::Scopes; +use crate::queue::session::AuthQueue; +use actix_web::web::{scope, Data, Query, ServiceConfig}; +use actix_web::{get, post, web, HttpRequest, HttpResponse}; +use chrono::Duration; +use rand::distributions::Alphanumeric; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPool; + +use self::errors::{OAuthError, OAuthErrorType}; + +use super::AuthenticationError; + +pub mod errors; +pub mod uris; + +pub fn config(cfg: &mut ServiceConfig) { + cfg.service( + scope("auth/oauth") + .service(init_oauth) + .service(accept_client_scopes) + .service(reject_client_scopes) + .service(request_token), + ); +} + +#[derive(Serialize, Deserialize)] +pub struct OAuthInit { + pub client_id: OAuthClientId, + pub redirect_uri: Option, + pub scope: Option, + pub state: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct OAuthClientAccessRequest { + pub flow_id: String, + pub client_id: OAuthClientId, + pub client_name: String, + pub client_icon: Option, + pub requested_scopes: Scopes, +} + +#[get("authorize")] +pub async fn init_oauth( + req: HttpRequest, + Query(oauth_info): Query, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_AUTH_WRITE]), + ) + .await? + .1; + + let client_id = oauth_info.client_id; + let client = DBOAuthClient::get(client_id, &**pool).await?; + + if let Some(client) = client { + let redirect_uri = ValidatedRedirectUri::validate( + &oauth_info.redirect_uri, + client.redirect_uris.iter().map(|r| r.uri.as_ref()), + client.id, + )?; + + let requested_scopes = oauth_info + .scope + .as_ref() + .map_or(Ok(client.max_scopes), |s| { + Scopes::parse_from_oauth_scopes(s).map_err(|e| { + OAuthError::redirect( + OAuthErrorType::FailedScopeParse(e), + &oauth_info.state, + &redirect_uri, + ) + }) + })?; + + if !client.max_scopes.contains(requested_scopes) { + return Err(OAuthError::redirect( + OAuthErrorType::ScopesTooBroad, + &oauth_info.state, + &redirect_uri, + )); + } + + let existing_authorization = + OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + let redirect_uris = + OAuthRedirectUris::new(oauth_info.redirect_uri.clone(), redirect_uri.clone()); + match existing_authorization { + Some(existing_authorization) + if existing_authorization.scopes.contains(requested_scopes) => + { + init_oauth_code_flow( + user.id.into(), + client.id, + existing_authorization.id, + requested_scopes, + redirect_uris, + oauth_info.state, + &redis, + ) + .await + } + _ => { + let flow_id = Flow::InitOAuthAppApproval { + user_id: user.id.into(), + client_id: client.id, + existing_authorization_id: existing_authorization.map(|a| a.id), + scopes: requested_scopes, + redirect_uris, + state: oauth_info.state.clone(), + } + .insert(Duration::minutes(30), &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + + let access_request = OAuthClientAccessRequest { + client_id: client.id, + client_name: client.name, + client_icon: client.icon_url, + flow_id, + requested_scopes, + }; + Ok(HttpResponse::Ok().json(access_request)) + } + } + } else { + Err(OAuthError::error(OAuthErrorType::InvalidClientId( + client_id, + ))) + } +} + +#[derive(Serialize, Deserialize)] +pub struct RespondToOAuthClientScopes { + pub flow: String, +} + +#[post("accept")] +pub async fn accept_client_scopes( + req: HttpRequest, + accept_body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + accept_or_reject_client_scopes(true, req, accept_body, pool, redis, session_queue).await +} + +#[post("reject")] +pub async fn reject_client_scopes( + req: HttpRequest, + body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + accept_or_reject_client_scopes(false, req, body, pool, redis, session_queue).await +} + +#[derive(Serialize, Deserialize)] +pub struct TokenRequest { + pub grant_type: String, + pub code: String, + pub redirect_uri: Option, + pub client_id: models::ids::OAuthClientId, +} + +#[derive(Serialize, Deserialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: i64, +} + +#[post("token")] +/// Params should be in the urlencoded request body +/// And client secret should be in the HTTP basic authorization header +/// Per IETF RFC6749 Section 4.1.3 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3) +pub async fn request_token( + req: HttpRequest, + req_params: web::Form, + pool: Data, + redis: Data, +) -> Result { + let req_client_id = req_params.client_id; + let client = DBOAuthClient::get(req_client_id.into(), &**pool).await?; + if let Some(client) = client { + authenticate_client_token_request(&req, &client)?; + + // Ensure auth code is single use + // per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5) + let flow = Flow::take_if( + &req_params.code, + |f| matches!(f, Flow::OAuthAuthorizationCodeSupplied { .. }), + &redis, + ) + .await?; + if let Some(Flow::OAuthAuthorizationCodeSupplied { + user_id, + client_id, + authorization_id, + scopes, + original_redirect_uri, + }) = flow + { + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + if req_client_id != client_id.into() { + return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient)); + } + + if original_redirect_uri != req_params.redirect_uri { + return Err(OAuthError::error(OAuthErrorType::RedirectUriChanged( + req_params.redirect_uri.clone(), + ))); + } + + if req_params.grant_type != "authorization_code" { + return Err(OAuthError::error( + OAuthErrorType::OnlySupportsAuthorizationCodeGrant( + req_params.grant_type.clone(), + ), + )); + } + + let scopes = scopes - Scopes::restricted(); + + let mut transaction = pool.begin().await?; + let token_id = generate_oauth_access_token_id(&mut transaction).await?; + let token = generate_access_token(); + let token_hash = OAuthAccessToken::hash_token(&token); + let time_until_expiration = OAuthAccessToken { + id: token_id, + authorization_id, + token_hash, + scopes, + created: Default::default(), + expires: Default::default(), + last_used: None, + client_id, + user_id, + } + .insert(&mut *transaction) + .await?; + + transaction.commit().await?; + + // IETF RFC6749 Section 5.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.1) + Ok(HttpResponse::Ok() + .append_header((CACHE_CONTROL, "no-store")) + .append_header((PRAGMA, "no-cache")) + .json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: time_until_expiration.num_seconds(), + })) + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAuthCode)) + } + } else { + Err(OAuthError::error(OAuthErrorType::InvalidClientId( + req_client_id.into(), + ))) + } +} + +pub async fn accept_or_reject_client_scopes( + accept: bool, + req: HttpRequest, + body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + let flow = Flow::take_if( + &body.flow, + |f| matches!(f, Flow::InitOAuthAppApproval { .. }), + &redis, + ) + .await?; + if let Some(Flow::InitOAuthAppApproval { + user_id, + client_id, + existing_authorization_id, + scopes, + redirect_uris, + state, + }) = flow + { + if current_user.id != user_id.into() { + return Err(OAuthError::error(AuthenticationError::InvalidCredentials)); + } + + if accept { + let mut transaction = pool.begin().await?; + + let auth_id = match existing_authorization_id { + Some(id) => id, + None => generate_oauth_client_authorization_id(&mut transaction).await?, + }; + OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction) + .await?; + + transaction.commit().await?; + + init_oauth_code_flow( + user_id, + client_id, + auth_id, + scopes, + redirect_uris, + state, + &redis, + ) + .await + } else { + Err(OAuthError::redirect( + OAuthErrorType::AccessDenied, + &state, + &redirect_uris.validated, + )) + } + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) + } +} + +fn authenticate_client_token_request( + req: &HttpRequest, + client: &DBOAuthClient, +) -> Result<(), OAuthError> { + let client_secret = extract_authorization_header(req)?; + let hashed_client_secret = DBOAuthClient::hash_secret(client_secret); + if client.secret_hash != hashed_client_secret { + Err(OAuthError::error( + OAuthErrorType::ClientAuthenticationFailed, + )) + } else { + Ok(()) + } +} + +fn generate_access_token() -> String { + let random = ChaCha20Rng::from_entropy() + .sample_iter(&Alphanumeric) + .take(60) + .map(char::from) + .collect::(); + format!("mro_{}", random) +} + +async fn init_oauth_code_flow( + user_id: crate::database::models::UserId, + client_id: OAuthClientId, + authorization_id: OAuthClientAuthorizationId, + scopes: Scopes, + redirect_uris: OAuthRedirectUris, + state: Option, + redis: &RedisPool, +) -> Result { + let code = Flow::OAuthAuthorizationCodeSupplied { + user_id, + client_id, + authorization_id, + scopes, + original_redirect_uri: redirect_uris.original.clone(), + } + .insert(Duration::minutes(10), redis) + .await + .map_err(|e| OAuthError::redirect(e, &state, &redirect_uris.validated.clone()))?; + + let mut redirect_params = vec![format!("code={code}")]; + if let Some(state) = state { + redirect_params.push(format!("state={state}")); + } + + let redirect_uri = append_params_to_uri(&redirect_uris.validated.0, &redirect_params); + + // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) + Ok(HttpResponse::Found() + .append_header((LOCATION, redirect_uri)) + .finish()) +} + +fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { + let mut uri = uri.to_string(); + let mut connector = if uri.contains('?') { "&" } else { "?" }; + for param in params { + uri.push_str(&format!("{}{}", connector, param.as_ref())); + connector = "&"; + } + + uri +} diff --git a/src/auth/oauth/uris.rs b/src/auth/oauth/uris.rs new file mode 100644 index 000000000..708aa8a02 --- /dev/null +++ b/src/auth/oauth/uris.rs @@ -0,0 +1,94 @@ +use super::errors::OAuthError; +use crate::auth::oauth::OAuthErrorType; +use crate::database::models::OAuthClientId; +use serde::{Deserialize, Serialize}; + +#[derive(derive_new::new, Serialize, Deserialize)] +pub struct OAuthRedirectUris { + pub original: Option, + pub validated: ValidatedRedirectUri, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ValidatedRedirectUri(pub String); + +impl ValidatedRedirectUri { + pub fn validate<'a>( + to_validate: &Option, + validate_against: impl IntoIterator + Clone, + client_id: OAuthClientId, + ) -> Result { + if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { + if let Some(to_validate) = to_validate { + if validate_against + .into_iter() + .any(|uri| same_uri_except_query_components(uri, to_validate)) + { + Ok(ValidatedRedirectUri(to_validate.clone())) + } else { + Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( + to_validate.clone(), + ))) + } + } else { + Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())) + } + } else { + Err(OAuthError::error( + OAuthErrorType::ClientMissingRedirectURI { client_id }, + )) + } + } +} + +fn same_uri_except_query_components(a: &str, b: &str) -> bool { + let mut a_components = a.split('?'); + let mut b_components = b.split('?'); + a_components.next() == b_components.next() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_for_none_returns_first_valid_uri() { + let validate_against = vec!["https://modrinth.com/a"]; + + let validated = + ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) + .unwrap(); + + assert_eq!(validate_against[0], validated.0); + } + + #[test] + fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { + let validate_against = vec![ + "https://modrinth.com/a?q3=p3&q4=p4", + "https://modrinth.com/a/b/c?q1=p1&q2=p2", + ]; + let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); + + let validated = ValidatedRedirectUri::validate( + &Some(to_validate.clone()), + validate_against, + OAuthClientId(0), + ) + .unwrap(); + + assert_eq!(to_validate, validated.0); + } + + #[test] + fn validate_for_invalid_uri_returns_err() { + let validate_against = vec!["https://modrinth.com/a"]; + let to_validate = "https://modrinth.com/a/b".to_string(); + + let validated = + ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); + + assert!(validated + .is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_)))); + } +} diff --git a/src/auth/validate.rs b/src/auth/validate.rs index cb9793548..8c2a92e15 100644 --- a/src/auth/validate.rs +++ b/src/auth/validate.rs @@ -91,12 +91,7 @@ where let token = if let Some(token) = token { token } else { - let headers = req.headers(); - let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION); - token_val - .ok_or_else(|| AuthenticationError::InvalidAuthMethod)? - .to_str() - .map_err(|_| AuthenticationError::InvalidCredentials)? + extract_authorization_header(req)? }; let possible_user = match token.split_once('_') { @@ -142,6 +137,25 @@ where user.map(|x| (Scopes::all(), x)) } + Some(("mro", _)) => { + use crate::database::models::oauth_token_item::OAuthAccessToken; + + let hash = OAuthAccessToken::hash_token(token); + let access_token = + crate::database::models::oauth_token_item::OAuthAccessToken::get(hash, executor) + .await? + .ok_or(AuthenticationError::InvalidCredentials)?; + + if access_token.expires < Utc::now() { + return Err(AuthenticationError::InvalidCredentials); + } + + let user = user_item::User::get_id(access_token.user_id, executor, redis).await?; + + session_queue.add_oauth_access_token(access_token.id).await; + + user.map(|u| (access_token.scopes, u)) + } Some(("github", _)) | Some(("gho", _)) | Some(("ghp", _)) => { let user = AuthProvider::GitHub.get_user(token).await?; let id = AuthProvider::GitHub.get_user_id(&user.id, executor).await?; @@ -160,6 +174,15 @@ where Ok(possible_user) } +pub fn extract_authorization_header(req: &HttpRequest) -> Result<&str, AuthenticationError> { + let headers = req.headers(); + let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION); + token_val + .ok_or_else(|| AuthenticationError::InvalidAuthMethod)? + .to_str() + .map_err(|_| AuthenticationError::InvalidCredentials) +} + pub async fn check_is_moderator_from_headers<'a, 'b, E>( req: &HttpRequest, executor: E, diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 5bcd26712..fe81e4a8d 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,7 +1,8 @@ use super::ids::*; -use crate::auth::flows::AuthProvider; +use crate::auth::oauth::uris::OAuthRedirectUris; use crate::database::models::DatabaseError; use crate::database::redis::RedisPool; +use crate::{auth::flows::AuthProvider, models::pats::Scopes}; use chrono::Duration; use rand::distributions::Alphanumeric; use rand::Rng; @@ -34,6 +35,21 @@ pub enum Flow { confirm_email: String, }, MinecraftAuth, + InitOAuthAppApproval { + user_id: UserId, + client_id: OAuthClientId, + existing_authorization_id: Option, + scopes: Scopes, + redirect_uris: OAuthRedirectUris, + state: Option, + }, + OAuthAuthorizationCodeSupplied { + user_id: UserId, + client_id: OAuthClientId, + authorization_id: OAuthClientAuthorizationId, + scopes: Scopes, + original_redirect_uri: Option, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + }, } impl Flow { @@ -58,6 +74,22 @@ impl Flow { redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await } + /// Gets the flow and removes it from the cache, but only removes if the flow was present and the predicate returned true + /// The predicate should validate that the flow being removed is the correct one, as a security measure + pub async fn take_if( + id: &str, + predicate: impl FnOnce(&Flow) -> bool, + redis: &RedisPool, + ) -> Result, DatabaseError> { + let flow = Self::get(id, redis).await?; + if let Some(flow) = flow.as_ref() { + if predicate(flow) { + Self::remove(id, redis).await?; + } + } + Ok(flow) + } + pub async fn remove(id: &str, redis: &RedisPool) -> Result, DatabaseError> { redis.delete(FLOWS_NAMESPACE, id).await?; Ok(Some(())) diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index a2638249f..b8953462f 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -152,6 +152,38 @@ generate_ids!( ImageId ); +generate_ids!( + pub generate_oauth_client_authorization_id, + OAuthClientAuthorizationId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)", + OAuthClientAuthorizationId +); + +generate_ids!( + pub generate_oauth_client_id, + OAuthClientId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)", + OAuthClientId +); + +generate_ids!( + pub generate_oauth_redirect_id, + OAuthRedirectUriId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)", + OAuthRedirectUriId +); + +generate_ids!( + pub generate_oauth_access_token_id, + OAuthAccessTokenId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)", + OAuthAccessTokenId +); + #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] pub struct UserId(pub i64); @@ -238,6 +270,22 @@ pub struct SessionId(pub i64); #[sqlx(transparent)] pub struct ImageId(pub i64); +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthClientId(pub i64); + +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthClientAuthorizationId(pub i64); + +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthRedirectUriId(pub i64); + +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthAccessTokenId(pub i64); + use crate::models::ids; impl From for ProjectId { @@ -360,3 +408,23 @@ impl From for ids::PatId { ids::PatId(id.0 as u64) } } +impl From for ids::OAuthClientId { + fn from(id: OAuthClientId) -> Self { + ids::OAuthClientId(id.0 as u64) + } +} +impl From for OAuthClientId { + fn from(id: ids::OAuthClientId) -> Self { + Self(id.0 as i64) + } +} +impl From for ids::OAuthRedirectUriId { + fn from(id: OAuthRedirectUriId) -> Self { + ids::OAuthRedirectUriId(id.0 as u64) + } +} +impl From for ids::OAuthClientAuthorizationId { + fn from(id: OAuthClientAuthorizationId) -> Self { + ids::OAuthClientAuthorizationId(id.0 as u64) + } +} diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index bfd6e7815..5d5bc34fa 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -6,6 +6,9 @@ pub mod flow_item; pub mod ids; pub mod image_item; pub mod notification_item; +pub mod oauth_client_authorization_item; +pub mod oauth_client_item; +pub mod oauth_token_item; pub mod organization_item; pub mod pat_item; pub mod project_item; @@ -19,6 +22,7 @@ pub mod version_item; pub use collection_item::Collection; pub use ids::*; pub use image_item::Image; +pub use oauth_client_item::OAuthClient; pub use organization_item::Organization; pub use project_item::Project; pub use team_item::Team; diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs new file mode 100644 index 000000000..617e6fcd8 --- /dev/null +++ b/src/database/models/oauth_client_authorization_item.rs @@ -0,0 +1,126 @@ +use chrono::{DateTime, Utc}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::models::pats::Scopes; + +use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId}; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthClientAuthorization { + pub id: OAuthClientAuthorizationId, + pub client_id: OAuthClientId, + pub user_id: UserId, + pub scopes: Scopes, + pub created: DateTime, +} + +struct AuthorizationQueryResult { + id: i64, + client_id: i64, + user_id: i64, + scopes: i64, + created: DateTime, +} + +impl From for OAuthClientAuthorization { + fn from(value: AuthorizationQueryResult) -> Self { + OAuthClientAuthorization { + id: OAuthClientAuthorizationId(value.id), + client_id: OAuthClientId(value.client_id), + user_id: UserId(value.user_id), + scopes: Scopes::from_postgres(value.scopes), + created: value.created, + } + } +} + +impl OAuthClientAuthorization { + pub async fn get( + client_id: OAuthClientId, + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let value = sqlx::query_as!( + AuthorizationQueryResult, + " + SELECT id, client_id, user_id, scopes, created + FROM oauth_client_authorizations + WHERE client_id=$1 AND user_id=$2 + ", + client_id.0, + user_id.0, + ) + .fetch_optional(exec) + .await?; + + Ok(value.map(|r| r.into())) + } + + pub async fn get_all_for_user( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let results = sqlx::query_as!( + AuthorizationQueryResult, + " + SELECT id, client_id, user_id, scopes, created + FROM oauth_client_authorizations + WHERE user_id=$1 + ", + user_id.0 + ) + .fetch_all(exec) + .await?; + + Ok(results.into_iter().map(|r| r.into()).collect_vec()) + } + + pub async fn upsert( + id: OAuthClientAuthorizationId, + client_id: OAuthClientId, + user_id: UserId, + scopes: Scopes, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO oauth_client_authorizations ( + id, client_id, user_id, scopes + ) + VALUES ( + $1, $2, $3, $4 + ) + ON CONFLICT (id) + DO UPDATE SET scopes = EXCLUDED.scopes + ", + id.0, + client_id.0, + user_id.0, + scopes.bits() as i64, + ) + .execute(&mut **transaction) + .await?; + + Ok(()) + } + + pub async fn remove( + client_id: OAuthClientId, + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + DELETE FROM oauth_client_authorizations + WHERE client_id=$1 AND user_id=$2 + ", + client_id.0, + user_id.0 + ) + .execute(exec) + .await?; + + Ok(()) + } +} diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs new file mode 100644 index 000000000..48870c670 --- /dev/null +++ b/src/database/models/oauth_client_item.rs @@ -0,0 +1,245 @@ +use chrono::{DateTime, Utc}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use sha2::Digest; + +use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId}; +use crate::models::pats::Scopes; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthRedirectUri { + pub id: OAuthRedirectUriId, + pub client_id: OAuthClientId, + pub uri: String, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthClient { + pub id: OAuthClientId, + pub name: String, + pub icon_url: Option, + pub max_scopes: Scopes, + pub secret_hash: String, + pub redirect_uris: Vec, + pub created: DateTime, + pub created_by: UserId, +} + +struct ClientQueryResult { + id: i64, + name: String, + icon_url: Option, + max_scopes: i64, + secret_hash: String, + created: DateTime, + created_by: i64, + uri_ids: Option>, + uri_vals: Option>, +} + +macro_rules! select_clients_with_predicate { + ($predicate:tt, $param:ident) => { + // The columns in this query have nullability type hints, because for some reason + // the combination of the JOIN and filter using ANY makes sqlx think all columns are nullable + // https://docs.rs/sqlx/latest/sqlx/macro.query.html#force-nullable + sqlx::query_as!( + ClientQueryResult, + r#" + SELECT + clients.id as "id!", + clients.name as "name!", + clients.icon_url as "icon_url?", + clients.max_scopes as "max_scopes!", + clients.secret_hash as "secret_hash!", + clients.created as "created!", + clients.created_by as "created_by!", + uris.uri_ids as "uri_ids?", + uris.uri_vals as "uri_vals?" + FROM oauth_clients clients + LEFT JOIN ( + SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals + FROM oauth_client_redirect_uris + GROUP BY client_id + ) uris ON clients.id = uris.client_id + "# + + $predicate, + $param + ) + }; +} + +impl OAuthClient { + pub async fn get( + id: OAuthClientId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + Ok(Self::get_many(&[id], exec).await?.into_iter().next()) + } + + pub async fn get_many( + ids: &[OAuthClientId], + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let ids = ids.iter().map(|id| id.0).collect_vec(); + let ids_ref: &[i64] = &ids; + let results = + select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref) + .fetch_all(exec) + .await?; + + Ok(results.into_iter().map(|r| r.into()).collect_vec()) + } + + pub async fn get_all_user_clients( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let user_id_param = user_id.0; + let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param) + .fetch_all(exec) + .await?; + + Ok(clients.into_iter().map(|r| r.into()).collect()) + } + + pub async fn remove( + id: OAuthClientId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + // Cascades to oauth_client_redirect_uris, oauth_client_authorizations + sqlx::query!( + " + DELETE FROM oauth_clients + WHERE id = $1 + ", + id.0 + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub async fn insert( + &self, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO oauth_clients ( + id, name, icon_url, max_scopes, secret_hash, created_by + ) + VALUES ( + $1, $2, $3, $4, $5, $6 + ) + ", + self.id.0, + self.name, + self.icon_url, + self.max_scopes.to_postgres(), + self.secret_hash, + self.created_by.0 + ) + .execute(&mut **transaction) + .await?; + + Self::insert_redirect_uris(&self.redirect_uris, &mut **transaction).await?; + + Ok(()) + } + + pub async fn update_editable_fields( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + UPDATE oauth_clients + SET name = $1, icon_url = $2, max_scopes = $3 + WHERE (id = $4) + ", + self.name, + self.icon_url, + self.max_scopes.to_postgres(), + self.id.0, + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub async fn remove_redirect_uris( + ids: impl IntoIterator, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + let ids = ids.into_iter().map(|id| id.0).collect_vec(); + sqlx::query!( + " + DELETE FROM oauth_client_redirect_uris + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..] + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub async fn insert_redirect_uris( + uris: &[OAuthRedirectUri], + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = uris + .iter() + .map(|r| (r.id.0, r.client_id.0, r.uri.clone())) + .multiunzip(); + sqlx::query!( + " + INSERT INTO oauth_client_redirect_uris (id, client_id, uri) + SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[]) + ", + &ids[..], + &client_ids[..], + &uris[..], + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub fn hash_secret(secret: &str) -> String { + format!("{:x}", sha2::Sha512::digest(secret.as_bytes())) + } +} + +impl From for OAuthClient { + fn from(r: ClientQueryResult) -> Self { + let redirects = if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) { + ids.iter() + .zip(uris.iter()) + .map(|(id, uri)| OAuthRedirectUri { + id: OAuthRedirectUriId(*id), + client_id: OAuthClientId(r.id), + uri: uri.to_string(), + }) + .collect() + } else { + vec![] + }; + + OAuthClient { + id: OAuthClientId(r.id), + name: r.name, + icon_url: r.icon_url, + max_scopes: Scopes::from_postgres(r.max_scopes), + secret_hash: r.secret_hash, + redirect_uris: redirects, + created: r.created, + created_by: UserId(r.created_by), + } + } +} diff --git a/src/database/models/oauth_token_item.rs b/src/database/models/oauth_token_item.rs new file mode 100644 index 000000000..9c12f3836 --- /dev/null +++ b/src/database/models/oauth_token_item.rs @@ -0,0 +1,95 @@ +use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuthClientId, UserId}; +use crate::models::pats::Scopes; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::Digest; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthAccessToken { + pub id: OAuthAccessTokenId, + pub authorization_id: OAuthClientAuthorizationId, + pub token_hash: String, + pub scopes: Scopes, + pub created: DateTime, + pub expires: DateTime, + pub last_used: Option>, + + // Stored separately inside oauth_client_authorizations table + pub client_id: OAuthClientId, + pub user_id: UserId, +} + +impl OAuthAccessToken { + pub async fn get( + token_hash: String, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let value = sqlx::query!( + " + SELECT + tokens.id, + tokens.authorization_id, + tokens.token_hash, + tokens.scopes, + tokens.created, + tokens.expires, + tokens.last_used, + auths.client_id, + auths.user_id + FROM oauth_access_tokens tokens + JOIN oauth_client_authorizations auths + ON tokens.authorization_id = auths.id + WHERE tokens.token_hash = $1 + ", + token_hash + ) + .fetch_optional(exec) + .await?; + + Ok(value.map(|r| OAuthAccessToken { + id: OAuthAccessTokenId(r.id), + authorization_id: OAuthClientAuthorizationId(r.authorization_id), + token_hash: r.token_hash, + scopes: Scopes::from_postgres(r.scopes), + created: r.created, + expires: r.expires, + last_used: r.last_used, + client_id: OAuthClientId(r.client_id), + user_id: UserId(r.user_id), + })) + } + + /// Inserts and returns the time until the token expires + pub async fn insert( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result { + let r = sqlx::query!( + " + INSERT INTO oauth_access_tokens ( + id, authorization_id, token_hash, scopes, last_used + ) + VALUES ( + $1, $2, $3, $4, $5 + ) + RETURNING created, expires + ", + self.id.0, + self.authorization_id.0, + self.token_hash, + self.scopes.to_postgres(), + Option::>::None + ) + .fetch_one(exec) + .await?; + + let (created, expires) = (r.created, r.expires); + let time_until_expiration = expires - created; + + Ok(time_until_expiration) + } + + pub fn hash_token(token: &str) -> String { + format!("{:x}", sha2::Sha512::digest(token.as_bytes())) + } +} diff --git a/src/models/ids.rs b/src/models/ids.rs index 20166b798..8cea089f7 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -3,6 +3,8 @@ use thiserror::Error; pub use super::collections::CollectionId; pub use super::images::ImageId; pub use super::notifications::NotificationId; +pub use super::oauth_clients::OAuthClientAuthorizationId; +pub use super::oauth_clients::{OAuthClientId, OAuthRedirectUriId}; pub use super::organizations::OrganizationId; pub use super::pats::PatId; pub use super::projects::{ProjectId, VersionId}; @@ -122,6 +124,9 @@ base62_id_impl!(ThreadMessageId, ThreadMessageId); base62_id_impl!(SessionId, SessionId); base62_id_impl!(PatId, PatId); base62_id_impl!(ImageId, ImageId); +base62_id_impl!(OAuthClientId, OAuthClientId); +base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); +base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId); pub mod base62_impl { use serde::de::{self, Deserializer, Visitor}; diff --git a/src/models/mod.rs b/src/models/mod.rs index e1d4ace9f..7c97ad31f 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod error; pub mod ids; pub mod images; pub mod notifications; +pub mod oauth_clients; pub mod organizations; pub mod pack; pub mod pats; diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs new file mode 100644 index 000000000..16795aa3e --- /dev/null +++ b/src/models/oauth_clients.rs @@ -0,0 +1,110 @@ +use super::{ + ids::{Base62Id, UserId}, + pats::Scopes, +}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization as DBOAuthClientAuthorization; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::oauth_client_item::OAuthRedirectUri as DBOAuthRedirectUri; + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthClientId(pub u64); + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthClientAuthorizationId(pub u64); + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthRedirectUriId(pub u64); + +#[derive(Deserialize, Serialize)] +pub struct OAuthRedirectUri { + pub id: OAuthRedirectUriId, + pub client_id: OAuthClientId, + pub uri: String, +} + +#[derive(Serialize, Deserialize)] +pub struct OAuthClientCreationResult { + #[serde(flatten)] + pub client: OAuthClient, + + pub client_secret: String, +} + +#[derive(Deserialize, Serialize)] +pub struct OAuthClient { + pub id: OAuthClientId, + pub name: String, + pub icon_url: Option, + + // The maximum scopes the client can request for OAuth + pub max_scopes: Scopes, + + // The valid URIs that can be redirected to during an authorization request + pub redirect_uris: Vec, + + // The user that created (and thus controls) this client + pub created_by: UserId, +} + +#[derive(Deserialize, Serialize)] +pub struct OAuthClientAuthorization { + pub id: OAuthClientAuthorizationId, + pub app_id: OAuthClientId, + pub user_id: UserId, + pub scopes: Scopes, + pub created: DateTime, +} + +#[derive(Deserialize, Serialize)] +pub struct GetOAuthClientsRequest { + pub ids: Vec, +} + +#[derive(Deserialize, Serialize)] +pub struct DeleteOAuthClientQueryParam { + pub client_id: OAuthClientId, +} + +impl From for OAuthClient { + fn from(value: DBOAuthClient) -> Self { + Self { + id: value.id.into(), + name: value.name, + icon_url: value.icon_url, + max_scopes: value.max_scopes, + redirect_uris: value.redirect_uris.into_iter().map(|r| r.into()).collect(), + created_by: value.created_by.into(), + } + } +} + +impl From for OAuthRedirectUri { + fn from(value: DBOAuthRedirectUri) -> Self { + Self { + id: value.id.into(), + client_id: value.client_id.into(), + uri: value.uri, + } + } +} + +impl From for OAuthClientAuthorization { + fn from(value: DBOAuthClientAuthorization) -> Self { + Self { + id: value.id.into(), + app_id: value.client_id.into(), + user_id: value.user_id.into(), + scopes: value.scopes, + created: value.created, + } + } +} diff --git a/src/models/pats.rs b/src/models/pats.rs index 6ac2afc8d..07a58692b 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -103,6 +103,9 @@ bitflags::bitflags! { // delete an organization const ORGANIZATION_DELETE = 1 << 38; + // only accessible by modrinth-issued sessions + const SESSION_ACCESS = 1 << 39; + const NONE = 0b0; } } @@ -118,6 +121,7 @@ impl Scopes { | Scopes::PAT_DELETE | Scopes::SESSION_READ | Scopes::SESSION_DELETE + | Scopes::SESSION_ACCESS | Scopes::USER_AUTH_WRITE | Scopes::USER_DELETE | Scopes::PERFORM_ANALYTICS @@ -126,6 +130,19 @@ impl Scopes { pub fn is_restricted(&self) -> bool { self.intersects(Self::restricted()) } + + pub fn parse_from_oauth_scopes(scopes: &str) -> Result { + let scopes = scopes.replace(' ', "|").replace("%20", "|"); + bitflags::parser::from_str(&scopes) + } + + pub fn to_postgres(&self) -> i64 { + self.bits() as i64 + } + + pub fn from_postgres(value: i64) -> Self { + Self::from_bits(value as u64).unwrap_or(Scopes::NONE) + } } #[derive(Serialize, Deserialize)] @@ -161,3 +178,64 @@ impl PersonalAccessToken { } } } + +#[cfg(test)] +mod test { + use super::*; + use itertools::Itertools; + + #[test] + fn test_parse_from_oauth_scopes_well_formed() { + let raw = "USER_READ_EMAIL SESSION_READ ORGANIZATION_CREATE"; + let expected = Scopes::USER_READ_EMAIL | Scopes::SESSION_READ | Scopes::ORGANIZATION_CREATE; + + let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + #[test] + fn test_parse_from_oauth_scopes_empty() { + let raw = ""; + let expected = Scopes::empty(); + + let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + #[test] + fn test_parse_from_oauth_scopes_invalid_scopes() { + let raw = "notascope"; + + let parsed = Scopes::parse_from_oauth_scopes(raw); + + assert!(parsed.is_err()); + } + + #[test] + fn test_parse_from_oauth_scopes_invalid_separator() { + let raw = "USER_READ_EMAIL & SESSION_READ"; + + let parsed = Scopes::parse_from_oauth_scopes(raw); + + assert!(parsed.is_err()); + } + + #[test] + fn test_parse_from_oauth_scopes_url_encoded() { + let raw = urlencoding::encode("PAT_WRITE COLLECTION_DELETE").to_string(); + let expected = Scopes::PAT_WRITE | Scopes::COLLECTION_DELETE; + + let parsed = Scopes::parse_from_oauth_scopes(&raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + fn assert_same_flags(expected: Scopes, actual: Scopes) { + assert_eq!( + expected.iter_names().map(|(name, _)| name).collect_vec(), + actual.iter_names().map(|(name, _)| name).collect_vec() + ); + } +} diff --git a/src/queue/session.rs b/src/queue/session.rs index 8948810db..ee4568a47 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -1,9 +1,10 @@ use crate::auth::session::SessionMetadata; use crate::database::models::pat_item::PersonalAccessToken; use crate::database::models::session_item::Session; -use crate::database::models::{DatabaseError, PatId, SessionId, UserId}; +use crate::database::models::{DatabaseError, OAuthAccessTokenId, PatId, SessionId, UserId}; use crate::database::redis::RedisPool; use chrono::Utc; +use itertools::Itertools; use sqlx::PgPool; use std::collections::{HashMap, HashSet}; use tokio::sync::Mutex; @@ -11,6 +12,7 @@ use tokio::sync::Mutex; pub struct AuthQueue { session_queue: Mutex>, pat_queue: Mutex>, + oauth_access_token_queue: Mutex>, } impl Default for AuthQueue { @@ -25,6 +27,7 @@ impl AuthQueue { AuthQueue { session_queue: Mutex::new(HashMap::with_capacity(1000)), pat_queue: Mutex::new(HashSet::with_capacity(1000)), + oauth_access_token_queue: Mutex::new(HashSet::with_capacity(1000)), } } pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) { @@ -35,6 +38,10 @@ impl AuthQueue { self.pat_queue.lock().await.insert(id); } + pub async fn add_oauth_access_token(&self, id: crate::database::models::OAuthAccessTokenId) { + self.oauth_access_token_queue.lock().await.insert(id); + } + pub async fn take_sessions(&self) -> HashMap { let mut queue = self.session_queue.lock().await; let len = queue.len(); @@ -42,8 +49,8 @@ impl AuthQueue { std::mem::replace(&mut queue, HashMap::with_capacity(len)) } - pub async fn take_pats(&self) -> HashSet { - let mut queue = self.pat_queue.lock().await; + pub async fn take_hashset(queue: &Mutex>) -> HashSet { + let mut queue = queue.lock().await; let len = queue.len(); std::mem::replace(&mut queue, HashSet::with_capacity(len)) @@ -51,9 +58,13 @@ impl AuthQueue { pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> { let session_queue = self.take_sessions().await; - let pat_queue = self.take_pats().await; + let pat_queue = Self::take_hashset(&self.pat_queue).await; + let oauth_access_token_queue = Self::take_hashset(&self.oauth_access_token_queue).await; - if !session_queue.is_empty() || !pat_queue.is_empty() { + if !session_queue.is_empty() + || !pat_queue.is_empty() + || !oauth_access_token_queue.is_empty() + { let mut transaction = pool.begin().await?; let mut clear_cache_sessions = Vec::new(); @@ -102,29 +113,51 @@ impl AuthQueue { Session::clear_cache(clear_cache_sessions, redis).await?; - let mut clear_cache_pats = Vec::new(); - - for id in pat_queue { - clear_cache_pats.push((Some(id), None, None)); - - sqlx::query!( - " - UPDATE pats - SET last_used = $2 - WHERE (id = $1) - ", - id as PatId, - Utc::now(), - ) - .execute(&mut *transaction) - .await?; - } + let ids = pat_queue.iter().map(|id| id.0).collect_vec(); + let clear_cache_pats = pat_queue + .into_iter() + .map(|id| (Some(id), None, None)) + .collect_vec(); + sqlx::query!( + " + UPDATE pats + SET last_used = $2 + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..], + Utc::now(), + ) + .execute(&mut *transaction) + .await?; PersonalAccessToken::clear_cache(clear_cache_pats, redis).await?; + update_oauth_access_token_last_used(oauth_access_token_queue, &mut transaction).await?; + transaction.commit().await?; } Ok(()) } } + +async fn update_oauth_access_token_last_used( + oauth_access_token_queue: HashSet, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), DatabaseError> { + let ids = oauth_access_token_queue.iter().map(|id| id.0).collect_vec(); + sqlx::query!( + " + UPDATE oauth_access_tokens + SET last_used = $2 + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..], + Utc::now() + ) + .execute(&mut **transaction) + .await?; + Ok(()) +} diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index ddfb05e5f..d90429c28 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -1,13 +1,17 @@ pub use super::ApiError; -use crate::util::cors::default_cors; +use crate::{auth::oauth, util::cors::default_cors}; use actix_web::{web, HttpResponse}; use serde_json::json; +pub mod oauth_clients; + pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("v3") .wrap(default_cors()) - .route("", web::get().to(hello_world)), + .route("", web::get().to(hello_world)) + .configure(oauth::config) + .configure(oauth_clients::config), ); } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs new file mode 100644 index 000000000..5f3839b60 --- /dev/null +++ b/src/routes/v3/oauth_clients.rs @@ -0,0 +1,444 @@ +use std::{collections::HashSet, fmt::Display}; + +use actix_web::{ + delete, get, patch, post, + web::{self, scope}, + HttpRequest, HttpResponse, +}; +use chrono::Utc; +use itertools::Itertools; +use rand::{distributions::Alphanumeric, Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; +use validator::Validate; + +use super::ApiError; +use crate::{ + auth::checks::ValidateAllAuthorized, models::oauth_clients::DeleteOAuthClientQueryParam, +}; +use crate::{ + auth::{checks::ValidateAuthorized, get_user_from_headers}, + database::{ + models::{ + generate_oauth_client_id, generate_oauth_redirect_id, + oauth_client_authorization_item::OAuthClientAuthorization, + oauth_client_item::{OAuthClient, OAuthRedirectUri}, + DatabaseError, OAuthClientId, User, + }, + redis::RedisPool, + }, + models::{ + self, + oauth_clients::{GetOAuthClientsRequest, OAuthClientCreationResult}, + pats::Scopes, + }, + queue::session::AuthQueue, + routes::v2::project_creation::CreateError, + util::validate::validation_errors_to_string, +}; + +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::models::ids::OAuthClientId as ApiOAuthClientId; + +pub fn config(cfg: &mut web::ServiceConfig) { + cfg.service(get_user_clients); + cfg.service( + scope("oauth") + .service(oauth_client_create) + .service(oauth_client_edit) + .service(oauth_client_delete) + .service(get_client) + .service(get_clients) + .service(get_user_oauth_authorizations) + .service(revoke_oauth_authorization), + ); +} + +#[get("user/{user_id}/oauth_apps")] +pub async fn get_user_clients( + req: HttpRequest, + info: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + let target_user = User::get(&info.into_inner(), &**pool, &redis).await?; + + if let Some(target_user) = target_user { + let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; + clients + .iter() + .validate_all_authorized(Some(¤t_user))?; + + let response = clients + .into_iter() + .map(models::oauth_clients::OAuthClient::from) + .collect_vec(); + + Ok(HttpResponse::Ok().json(response)) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[get("app/{id}")] +pub async fn get_client( + req: HttpRequest, + id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let clients = get_clients_inner(&[id.into_inner()], req, pool, redis, session_queue).await?; + if let Some(client) = clients.into_iter().next() { + Ok(HttpResponse::Ok().json(client)) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[get("apps")] +pub async fn get_clients( + req: HttpRequest, + info: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let clients = + get_clients_inner(&info.into_inner().ids, req, pool, redis, session_queue).await?; + Ok(HttpResponse::Ok().json(clients)) +} + +#[derive(Deserialize, Validate)] +pub struct NewOAuthApp { + #[validate( + custom(function = "crate::util::validate::validate_name"), + length(min = 3, max = 255) + )] + pub name: String, + + #[validate( + custom(function = "crate::util::validate::validate_url"), + length(max = 255) + )] + pub icon_url: Option, + + #[validate(custom(function = "crate::util::validate::validate_no_restricted_scopes"))] + pub max_scopes: Scopes, + + pub redirect_uris: Vec, +} + +#[post("app")] +pub async fn oauth_client_create<'a>( + req: HttpRequest, + new_oauth_app: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + new_oauth_app + .validate() + .map_err(|e| CreateError::ValidationError(validation_errors_to_string(e, None)))?; + + let mut transaction = pool.begin().await?; + + let client_id = generate_oauth_client_id(&mut transaction).await?; + + let client_secret = generate_oauth_client_secret(); + let client_secret_hash = DBOAuthClient::hash_secret(&client_secret); + + let redirect_uris = + create_redirect_uris(&new_oauth_app.redirect_uris, client_id, &mut transaction).await?; + + let client = OAuthClient { + id: client_id, + icon_url: new_oauth_app.icon_url.clone(), + max_scopes: new_oauth_app.max_scopes, + name: new_oauth_app.name.clone(), + redirect_uris, + created: Utc::now(), + created_by: current_user.id.into(), + secret_hash: client_secret_hash, + }; + client.clone().insert(&mut transaction).await?; + + transaction.commit().await?; + + let client = models::oauth_clients::OAuthClient::from(client); + + Ok(HttpResponse::Ok().json(OAuthClientCreationResult { + client, + client_secret, + })) +} + +#[delete("app/{id}")] +pub async fn oauth_client_delete<'a>( + req: HttpRequest, + client_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + let client = OAuthClient::get(client_id.into_inner().into(), &**pool).await?; + if let Some(client) = client { + client.validate_authorized(Some(¤t_user))?; + OAuthClient::remove(client.id, &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[derive(Serialize, Deserialize, Validate)] +pub struct OAuthClientEdit { + #[validate( + custom(function = "crate::util::validate::validate_name"), + length(min = 3, max = 255) + )] + pub name: Option, + + #[validate( + custom(function = "crate::util::validate::validate_url"), + length(max = 255) + )] + pub icon_url: Option>, + + pub max_scopes: Option, + + #[validate(length(min = 1))] + pub redirect_uris: Option>, +} + +#[patch("app/{id}")] +pub async fn oauth_client_edit( + req: HttpRequest, + client_id: web::Path, + client_updates: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + client_updates + .validate() + .map_err(|e| ApiError::Validation(validation_errors_to_string(e, None)))?; + + if client_updates.icon_url.is_none() + && client_updates.name.is_none() + && client_updates.max_scopes.is_none() + { + return Err(ApiError::InvalidInput("No changes provided".to_string())); + } + + if let Some(existing_client) = OAuthClient::get(client_id.into_inner().into(), &**pool).await? { + existing_client.validate_authorized(Some(¤t_user))?; + + let mut updated_client = existing_client.clone(); + let OAuthClientEdit { + name, + icon_url, + max_scopes, + redirect_uris, + } = client_updates.into_inner(); + if let Some(name) = name { + updated_client.name = name; + } + + if let Some(icon_url) = icon_url { + updated_client.icon_url = icon_url; + } + + if let Some(max_scopes) = max_scopes { + updated_client.max_scopes = max_scopes; + } + + let mut transaction = pool.begin().await?; + updated_client + .update_editable_fields(&mut *transaction) + .await?; + + if let Some(redirects) = redirect_uris { + edit_redirects(redirects, &existing_client, &mut transaction).await?; + } + + transaction.commit().await?; + + Ok(HttpResponse::Ok().body("")) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[get("authorizations")] +pub async fn get_user_oauth_authorizations( + req: HttpRequest, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + let authorizations = + OAuthClientAuthorization::get_all_for_user(current_user.id.into(), &**pool).await?; + + let mapped: Vec = + authorizations.into_iter().map(|a| a.into()).collect_vec(); + + Ok(HttpResponse::Ok().json(mapped)) +} + +#[delete("authorizations")] +pub async fn revoke_oauth_authorization( + req: HttpRequest, + info: web::Query, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + OAuthClientAuthorization::remove(info.client_id.into(), current_user.id.into(), &**pool) + .await?; + + Ok(HttpResponse::Ok().body("")) +} + +fn generate_oauth_client_secret() -> String { + ChaCha20Rng::from_entropy() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect::() +} + +async fn create_redirect_uris( + uri_strings: impl IntoIterator, + client_id: OAuthClientId, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result, DatabaseError> { + let mut redirect_uris = vec![]; + for uri in uri_strings.into_iter() { + let id = generate_oauth_redirect_id(transaction).await?; + redirect_uris.push(OAuthRedirectUri { + id, + client_id, + uri: uri.to_string(), + }); + } + + Ok(redirect_uris) +} + +async fn edit_redirects( + redirects: Vec, + existing_client: &OAuthClient, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), DatabaseError> { + let updated_redirects: HashSet = redirects.into_iter().collect(); + let original_redirects: HashSet = existing_client + .redirect_uris + .iter() + .map(|r| r.uri.to_string()) + .collect(); + + let redirects_to_add = create_redirect_uris( + updated_redirects.difference(&original_redirects), + existing_client.id, + &mut *transaction, + ) + .await?; + OAuthClient::insert_redirect_uris(&redirects_to_add, &mut **transaction).await?; + + let mut redirects_to_remove = existing_client.redirect_uris.clone(); + redirects_to_remove.retain(|r| !updated_redirects.contains(&r.uri)); + OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut **transaction) + .await?; + + Ok(()) +} + +pub async fn get_clients_inner( + ids: &[ApiOAuthClientId], + req: HttpRequest, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result, ApiError> { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::SESSION_ACCESS]), + ) + .await? + .1; + + let ids: Vec = ids.iter().map(|i| (*i).into()).collect(); + let clients = OAuthClient::get_many(&ids, &**pool).await?; + clients + .iter() + .validate_all_authorized(Some(¤t_user))?; + + Ok(clients.into_iter().map(|c| c.into()).collect_vec()) +} diff --git a/src/util/validate.rs b/src/util/validate.rs index 474bdf7af..7ebac27c9 100644 --- a/src/util/validate.rs +++ b/src/util/validate.rs @@ -3,6 +3,8 @@ use lazy_static::lazy_static; use regex::Regex; use validator::{ValidationErrors, ValidationErrorsKind}; +use crate::models::pats::Scopes; + lazy_static! { pub static ref RE_URL_SAFE: Regex = Regex::new(r#"^[a-zA-Z0-9!@$()`.+,_"-]*$"#).unwrap(); } @@ -91,6 +93,16 @@ pub fn validate_url(value: &str) -> Result<(), validator::ValidationError> { Ok(()) } +pub fn validate_no_restricted_scopes(value: &Scopes) -> Result<(), validator::ValidationError> { + if value.is_restricted() { + return Err(validator::ValidationError::new( + "Restricted scopes not allowed", + )); + } + + Ok(()) +} + pub fn validate_name(value: &str) -> Result<(), validator::ValidationError> { if value.trim().is_empty() { return Err(validator::ValidationError::new( diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index d8f5f8580..7b3af1329 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -29,7 +29,7 @@ impl ApiV2 { .set_multipart(creation_data.segment_data) .to_request(); let resp = self.call(req).await; - assert_status(resp, StatusCode::OK); + assert_status(&resp, StatusCode::OK); // Approve as a moderator. let req = TestRequest::patch() @@ -42,7 +42,7 @@ impl ApiV2 { )) .to_request(); let resp = self.call(req).await; - assert_status(resp, StatusCode::NO_CONTENT); + assert_status(&resp, StatusCode::NO_CONTENT); let project = self .get_project_deserialized(&creation_data.slug, pat) @@ -82,16 +82,20 @@ impl ApiV2 { test::read_body_json(resp).await } + pub async fn get_user_projects(&self, user_id_or_username: &str, pat: &str) -> ServiceResponse { + let req = test::TestRequest::get() + .uri(&format!("/v2/user/{}/projects", user_id_or_username)) + .append_header(("Authorization", pat)) + .to_request(); + self.call(req).await + } + pub async fn get_user_projects_deserialized( &self, user_id_or_username: &str, pat: &str, ) -> Vec { - let req = test::TestRequest::get() - .uri(&format!("/v2/user/{}/projects", user_id_or_username)) - .append_header(("Authorization", pat)) - .to_request(); - let resp = self.call(req).await; + let resp = self.get_user_projects(user_id_or_username, pat).await; assert_eq!(resp.status(), 200); test::read_body_json(resp).await } diff --git a/tests/common/api_v2/team.rs b/tests/common/api_v2/team.rs index f1d6ef735..1a772053f 100644 --- a/tests/common/api_v2/team.rs +++ b/tests/common/api_v2/team.rs @@ -1,3 +1,4 @@ +use actix_http::StatusCode; use actix_web::{dev::ServiceResponse, test}; use labrinth::models::{ notifications::Notification, @@ -5,6 +6,8 @@ use labrinth::models::{ }; use serde_json::json; +use crate::common::asserts::assert_status; + use super::ApiV2; impl ApiV2 { @@ -114,16 +117,21 @@ impl ApiV2 { self.call(req).await } + pub async fn get_user_notifications(&self, user_id: &str, pat: &str) -> ServiceResponse { + let req = test::TestRequest::get() + .uri(&format!("/v2/user/{user_id}/notifications")) + .append_header(("Authorization", pat)) + .to_request(); + self.call(req).await + } + pub async fn get_user_notifications_deserialized( &self, user_id: &str, pat: &str, ) -> Vec { - let req = test::TestRequest::get() - .uri(&format!("/v2/user/{user_id}/notifications")) - .append_header(("Authorization", pat)) - .to_request(); - let resp = self.call(req).await; + let resp = self.get_user_notifications(user_id, pat).await; + assert_status(&resp, StatusCode::OK); test::read_body_json(resp).await } diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs new file mode 100644 index 000000000..2155aa3c5 --- /dev/null +++ b/tests/common/api_v3/mod.rs @@ -0,0 +1,19 @@ +#![allow(dead_code)] + +use super::environment::LocalService; +use actix_web::dev::ServiceResponse; +use std::rc::Rc; + +pub mod oauth; +pub mod oauth_clients; + +#[derive(Clone)] +pub struct ApiV3 { + pub test_app: Rc, +} + +impl ApiV3 { + pub async fn call(&self, req: actix_http::Request) -> ServiceResponse { + self.test_app.call(req).await.unwrap() + } +} diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs new file mode 100644 index 000000000..6212dffac --- /dev/null +++ b/tests/common/api_v3/oauth.rs @@ -0,0 +1,156 @@ +use std::collections::HashMap; + +use actix_http::StatusCode; +use actix_web::{ + dev::ServiceResponse, + test::{self, TestRequest}, +}; +use labrinth::auth::oauth::{ + OAuthClientAccessRequest, RespondToOAuthClientScopes, TokenRequest, TokenResponse, +}; +use reqwest::header::{AUTHORIZATION, LOCATION}; + +use crate::common::asserts::assert_status; + +use super::ApiV3; + +impl ApiV3 { + pub async fn complete_full_authorize_flow( + &self, + client_id: &str, + client_secret: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, + state: Option<&str>, + user_pat: &str, + ) -> String { + let auth_resp = self + .oauth_authorize(client_id, scope, redirect_uri, state, user_pat) + .await; + let flow_id = get_authorize_accept_flow_id(auth_resp).await; + let redirect_resp = self.oauth_accept(&flow_id, user_pat).await; + let auth_code = get_auth_code_from_redirect_params(&redirect_resp).await; + let token_resp = self + .oauth_token(auth_code, None, client_id.to_string(), client_secret) + .await; + get_access_token(token_resp).await + } + + pub async fn oauth_authorize( + &self, + client_id: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, + state: Option<&str>, + pat: &str, + ) -> ServiceResponse { + let uri = generate_authorize_uri(client_id, scope, redirect_uri, state); + let req = TestRequest::get() + .uri(&uri) + .append_header((AUTHORIZATION, pat)) + .to_request(); + self.call(req).await + } + + pub async fn oauth_accept(&self, flow: &str, pat: &str) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/accept") + .append_header((AUTHORIZATION, pat)) + .set_json(RespondToOAuthClientScopes { + flow: flow.to_string(), + }) + .to_request(), + ) + .await + } + + pub async fn oauth_reject(&self, flow: &str, pat: &str) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/reject") + .append_header((AUTHORIZATION, pat)) + .set_json(RespondToOAuthClientScopes { + flow: flow.to_string(), + }) + .to_request(), + ) + .await + } + + pub async fn oauth_token( + &self, + auth_code: String, + original_redirect_uri: Option, + client_id: String, + client_secret: &str, + ) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/token") + .append_header((AUTHORIZATION, client_secret)) + .set_form(TokenRequest { + grant_type: "authorization_code".to_string(), + code: auth_code, + redirect_uri: original_redirect_uri, + client_id: serde_json::from_str(&format!("\"{}\"", client_id)).unwrap(), + }) + .to_request(), + ) + .await + } +} + +pub fn generate_authorize_uri( + client_id: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, + state: Option<&str>, +) -> String { + format!( + "/v3/auth/oauth/authorize?client_id={}{}{}{}", + urlencoding::encode(client_id), + optional_query_param("redirect_uri", redirect_uri), + optional_query_param("scope", scope), + optional_query_param("state", state), + ) + .to_string() +} + +pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { + assert_status(&response, StatusCode::OK); + test::read_body_json::(response) + .await + .flow_id +} + +pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { + assert_status(response, StatusCode::FOUND); + let query_params = get_redirect_location_query_params(response); + query_params.get("code").unwrap().to_string() +} + +pub async fn get_access_token(response: ServiceResponse) -> String { + assert_status(&response, StatusCode::OK); + test::read_body_json::(response) + .await + .access_token +} + +pub fn get_redirect_location_query_params( + response: &ServiceResponse, +) -> actix_web::web::Query> { + let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + actix_web::web::Query::>::from_query( + redirect_location.split_once('?').unwrap().1, + ) + .unwrap() +} + +fn optional_query_param(key: &str, value: Option<&str>) -> String { + if let Some(val) = value { + format!("&{key}={}", urlencoding::encode(val)) + } else { + "".to_string() + } +} diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs new file mode 100644 index 000000000..a08ad1d6f --- /dev/null +++ b/tests/common/api_v3/oauth_clients.rs @@ -0,0 +1,107 @@ +use actix_http::StatusCode; +use actix_web::{ + dev::ServiceResponse, + test::{self, TestRequest}, +}; +use labrinth::{ + models::{ + oauth_clients::{OAuthClient, OAuthClientAuthorization}, + pats::Scopes, + }, + routes::v3::oauth_clients::OAuthClientEdit, +}; +use reqwest::header::AUTHORIZATION; +use serde_json::json; + +use crate::common::asserts::assert_status; + +use super::ApiV3; + +impl ApiV3 { + pub async fn add_oauth_client( + &self, + name: String, + max_scopes: Scopes, + redirect_uris: Vec, + pat: &str, + ) -> ServiceResponse { + let max_scopes = max_scopes.bits(); + let req = TestRequest::post() + .uri("/v3/oauth/app") + .append_header((AUTHORIZATION, pat)) + .set_json(json!({ + "name": name, + "max_scopes": max_scopes, + "redirect_uris": redirect_uris + })) + .to_request(); + + self.call(req).await + } + + pub async fn get_user_oauth_clients(&self, user_id: &str, pat: &str) -> Vec { + let req = TestRequest::get() + .uri(&format!("/v3/user/{}/oauth_apps", user_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } + + pub async fn get_oauth_client(&self, client_id: String, pat: &str) -> ServiceResponse { + let req = TestRequest::get() + .uri(&format!("/v3/oauth/app/{}", client_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + + pub async fn edit_oauth_client( + &self, + client_id: &str, + edit: OAuthClientEdit, + pat: &str, + ) -> ServiceResponse { + let req = TestRequest::patch() + .uri(&format!("/v3/oauth/app/{}", urlencoding::encode(client_id))) + .set_json(edit) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + + pub async fn delete_oauth_client(&self, client_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!("/v3/oauth/app/{}", client_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + + pub async fn revoke_oauth_authorization(&self, client_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!( + "/v3/oauth/authorizations?client_id={}", + urlencoding::encode(client_id) + )) + .append_header((AUTHORIZATION, pat)) + .to_request(); + self.call(req).await + } + + pub async fn get_user_oauth_authorizations(&self, pat: &str) -> Vec { + let req = TestRequest::get() + .uri("/v3/oauth/authorizations") + .append_header((AUTHORIZATION, pat)) + .to_request(); + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } +} diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index c98dbd39d..3c7f585ac 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,3 +1,12 @@ -pub fn assert_status(response: actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { +#![allow(dead_code)] + +pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); } + +pub fn assert_any_status_except( + response: &actix_web::dev::ServiceResponse, + status: actix_http::StatusCode, +) { + assert_ne!(response.status(), status, "{:#?}", response.response()); +} diff --git a/tests/common/database.rs b/tests/common/database.rs index e30bd0773..988cc0284 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -7,6 +7,8 @@ use url::Url; use crate::common::{dummy_data, environment::TestEnvironment}; +use super::dummy_data::DUMMY_DATA_UPDATE; + // The dummy test database adds a fair bit of 'dummy' data to test with. // Some constants are used to refer to that data, and are described here. // The rest can be accessed in the TestEnvironment 'dummy' field. @@ -119,11 +121,7 @@ impl TemporaryDatabase { .await .unwrap(); if db_exists.is_none() { - let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}"); - sqlx::query(&create_db_query) - .execute(&main_pool) - .await - .expect("Database creation failed"); + create_template_database(&main_pool).await; } // Switch to template @@ -135,30 +133,52 @@ impl TemporaryDatabase { .await .expect("Connection to database failed"); - // Run migrations on the template - let migrations = sqlx::migrate!("./migrations"); - migrations.run(&pool).await.expect("Migrations failed"); - // Check if dummy data exists- a fake 'dummy_data' table is created if it does - let dummy_data_exists: bool = + let mut dummy_data_exists: bool = sqlx::query_scalar("SELECT to_regclass('dummy_data') IS NOT NULL") .fetch_one(&pool) .await .unwrap(); + if dummy_data_exists { + // Check if the dummy data needs to be updated + let dummy_data_update = + sqlx::query_scalar::<_, i64>("SELECT update_id FROM dummy_data") + .fetch_optional(&pool) + .await + .unwrap(); + let needs_update = !dummy_data_update.is_some_and(|d| d == DUMMY_DATA_UPDATE); + if needs_update { + println!("Dummy data updated, so template DB tables will be dropped and re-created"); + // Drop all tables in the database so they can be re-created and later filled with updated dummy data + sqlx::query("DROP SCHEMA public CASCADE;") + .execute(&pool) + .await + .unwrap(); + sqlx::query("CREATE SCHEMA public;") + .execute(&pool) + .await + .unwrap(); + dummy_data_exists = false; + } + } + + // Run migrations on the template + let migrations = sqlx::migrate!("./migrations"); + migrations.run(&pool).await.expect("Migrations failed"); + if !dummy_data_exists { // Add dummy data let temporary_test_env = TestEnvironment::build_with_db(TemporaryDatabase { pool: pool.clone(), database_name: TEMPLATE_DATABASE_NAME.to_string(), - redis_pool: RedisPool::new(None), + redis_pool: RedisPool::new(Some(generate_random_name("test_template_"))), }) .await; dummy_data::add_dummy_data(&temporary_test_env).await; + temporary_test_env.db.pool.close().await; } pool.close().await; - - // Switch back to main database (as we cant create from template while connected to it) - let pool = PgPool::connect(url.as_str()).await.unwrap(); + drop(pool); // Create the temporary database from the template let create_db_query = format!( @@ -167,7 +187,7 @@ impl TemporaryDatabase { ); sqlx::query(&create_db_query) - .execute(&pool) + .execute(&main_pool) .await .expect("Database creation failed"); @@ -216,6 +236,14 @@ impl TemporaryDatabase { } } +async fn create_template_database(pool: &sqlx::Pool) { + let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}"); + sqlx::query(&create_db_query) + .execute(pool) + .await + .expect("Database creation failed"); +} + // Appends a random 8-digit number to the end of the str pub fn generate_random_name(str: &str) -> String { let mut str = String::from(str); diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index 7f7e10a43..ed3b7f081 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -1,8 +1,11 @@ #![allow(dead_code)] +use actix_http::StatusCode; use actix_web::test::{self, TestRequest}; use labrinth::{ models::projects::Project, - models::{organizations::Organization, pats::Scopes, projects::Version}, + models::{ + oauth_clients::OAuthClient, organizations::Organization, pats::Scopes, projects::Version, + }, }; use serde_json::json; use sqlx::Executor; @@ -11,11 +14,14 @@ use crate::common::{actix::AppendsMultipart, database::USER_USER_PAT}; use super::{ actix::{MultipartSegment, MultipartSegmentData}, + asserts::assert_status, + database::USER_USER_ID, environment::TestEnvironment, + get_json_val_str, request_data::get_public_project_creation_data, }; -pub const DUMMY_DATA_UPDATE: i64 = 1; +pub const DUMMY_DATA_UPDATE: i64 = 3; #[allow(dead_code)] pub const DUMMY_CATEGORIES: &[&str] = &[ @@ -28,6 +34,8 @@ pub const DUMMY_CATEGORIES: &[&str] = &[ "optimization", ]; +pub const DUMMY_OAUTH_CLIENT_ALPHA_SECRET: &str = "abcdefghijklmnopqrstuvwxyz"; + #[allow(dead_code)] pub enum DummyJarFile { DummyProjectAlpha, @@ -43,16 +51,80 @@ pub enum DummyImage { #[derive(Clone)] pub struct DummyData { + /// Alpha project: + /// This is a dummy project created by USER user. + /// It's approved, listed, and visible to the public. pub project_alpha: DummyProjectAlpha, + + /// Beta project: + /// This is a dummy project created by USER user. + /// It's not approved, unlisted, and not visible to the public. pub project_beta: DummyProjectBeta, + + /// Zeta organization: + /// This is a dummy organization created by USER user. + /// There are no projects in it. pub organization_zeta: DummyOrganizationZeta, + + /// Alpha OAuth Client: + /// This is a dummy OAuth client created by USER user. + /// + /// All scopes are included in its max scopes + /// + /// It has one valid redirect URI + pub oauth_client_alpha: DummyOAuthClientAlpha, +} + +impl DummyData { + pub fn new( + project_alpha: Project, + project_alpha_version: Version, + project_beta: Project, + project_beta_version: Version, + organization_zeta: Organization, + oauth_client_alpha: OAuthClient, + ) -> Self { + DummyData { + project_alpha: DummyProjectAlpha { + team_id: project_alpha.team.to_string(), + project_id: project_alpha.id.to_string(), + project_slug: project_alpha.slug.unwrap(), + version_id: project_alpha_version.id.to_string(), + thread_id: project_alpha.thread_id.to_string(), + file_hash: project_alpha_version.files[0].hashes["sha1"].clone(), + }, + + project_beta: DummyProjectBeta { + team_id: project_beta.team.to_string(), + project_id: project_beta.id.to_string(), + project_slug: project_beta.slug.unwrap(), + version_id: project_beta_version.id.to_string(), + thread_id: project_beta.thread_id.to_string(), + file_hash: project_beta_version.files[0].hashes["sha1"].clone(), + }, + + organization_zeta: DummyOrganizationZeta { + organization_id: organization_zeta.id.to_string(), + team_id: organization_zeta.team_id.to_string(), + organization_title: organization_zeta.title, + }, + + oauth_client_alpha: DummyOAuthClientAlpha { + client_id: get_json_val_str(oauth_client_alpha.id), + client_secret: DUMMY_OAUTH_CLIENT_ALPHA_SECRET.to_string(), + valid_redirect_uri: oauth_client_alpha + .redirect_uris + .first() + .unwrap() + .uri + .clone(), + }, + } + } } #[derive(Clone)] pub struct DummyProjectAlpha { - // Alpha project: - // This is a dummy project created by USER user. - // It's approved, listed, and visible to the public. pub project_id: String, pub project_slug: String, pub version_id: String, @@ -63,9 +135,6 @@ pub struct DummyProjectAlpha { #[derive(Clone)] pub struct DummyProjectBeta { - // Beta project: - // This is a dummy project created by USER user. - // It's not approved, unlisted, and not visible to the public. pub project_id: String, pub project_slug: String, pub version_id: String, @@ -76,14 +145,18 @@ pub struct DummyProjectBeta { #[derive(Clone)] pub struct DummyOrganizationZeta { - // Zeta organization: - // This is a dummy organization created by USER user. - // There are no projects in it. pub organization_id: String, pub organization_title: String, pub team_id: String, } +#[derive(Clone)] +pub struct DummyOAuthClientAlpha { + pub client_id: String, + pub client_secret: String, + pub valid_redirect_uri: String, +} + pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData { // Adds basic dummy data to the database directly with sql (user, pats) let pool = &test_env.db.pool.clone(); @@ -101,37 +174,22 @@ pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData { let zeta_organization = add_organization_zeta(test_env).await; + let oauth_client_alpha = get_oauth_client_alpha(test_env).await; + sqlx::query("INSERT INTO dummy_data (update_id) VALUES ($1)") .bind(DUMMY_DATA_UPDATE) .execute(pool) .await .unwrap(); - DummyData { - project_alpha: DummyProjectAlpha { - team_id: alpha_project.team.to_string(), - project_id: alpha_project.id.to_string(), - project_slug: alpha_project.slug.unwrap(), - version_id: alpha_version.id.to_string(), - thread_id: alpha_project.thread_id.to_string(), - file_hash: alpha_version.files[0].hashes["sha1"].clone(), - }, - - project_beta: DummyProjectBeta { - team_id: beta_project.team.to_string(), - project_id: beta_project.id.to_string(), - project_slug: beta_project.slug.unwrap(), - version_id: beta_version.id.to_string(), - thread_id: beta_project.thread_id.to_string(), - file_hash: beta_version.files[0].hashes["sha1"].clone(), - }, - - organization_zeta: DummyOrganizationZeta { - organization_id: zeta_organization.id.to_string(), - team_id: zeta_organization.team_id.to_string(), - organization_title: zeta_organization.title, - }, - } + DummyData::new( + alpha_project, + alpha_version, + beta_project, + beta_version, + zeta_organization, + oauth_client_alpha, + ) } pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData { @@ -139,31 +197,17 @@ pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData { let (beta_project, beta_version) = get_project_beta(test_env).await; let zeta_organization = get_organization_zeta(test_env).await; - DummyData { - project_alpha: DummyProjectAlpha { - team_id: alpha_project.team.to_string(), - project_id: alpha_project.id.to_string(), - project_slug: alpha_project.slug.unwrap(), - version_id: alpha_version.id.to_string(), - thread_id: alpha_project.thread_id.to_string(), - file_hash: alpha_version.files[0].hashes["sha1"].clone(), - }, - project_beta: DummyProjectBeta { - team_id: beta_project.team.to_string(), - project_id: beta_project.id.to_string(), - project_slug: beta_project.slug.unwrap(), - version_id: beta_version.id.to_string(), - thread_id: beta_project.thread_id.to_string(), - file_hash: beta_version.files[0].hashes["sha1"].clone(), - }, + let oauth_client_alpha = get_oauth_client_alpha(test_env).await; - organization_zeta: DummyOrganizationZeta { - organization_id: zeta_organization.id.to_string(), - team_id: zeta_organization.team_id.to_string(), - organization_title: zeta_organization.title, - }, - } + DummyData::new( + alpha_project, + alpha_version, + beta_project, + beta_version, + zeta_organization, + oauth_client_alpha, + ) } pub async fn add_project_alpha(test_env: &TestEnvironment) -> (Project, Version) { @@ -282,6 +326,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version) .append_header(("Authorization", USER_USER_PAT)) .to_request(); let resp = test_env.call(req).await; + assert_status(&resp, StatusCode::OK); let project: Project = test::read_body_json(resp).await; // Get project's versions @@ -290,6 +335,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version) .append_header(("Authorization", USER_USER_PAT)) .to_request(); let resp = test_env.call(req).await; + assert_status(&resp, StatusCode::OK); let versions: Vec = test::read_body_json(resp).await; let version = versions.into_iter().next().unwrap(); @@ -308,6 +354,14 @@ pub async fn get_organization_zeta(test_env: &TestEnvironment) -> Organization { organization } +pub async fn get_oauth_client_alpha(test_env: &TestEnvironment) -> OAuthClient { + let oauth_clients = test_env + .v3 + .get_user_oauth_clients(USER_USER_ID, USER_USER_PAT) + .await; + oauth_clients.into_iter().next().unwrap() +} + impl DummyJarFile { pub fn filename(&self) -> String { match self { diff --git a/tests/common/environment.rs b/tests/common/environment.rs index abeaf730f..55fd82fad 100644 --- a/tests/common/environment.rs +++ b/tests/common/environment.rs @@ -4,6 +4,7 @@ use std::{rc::Rc, sync::Arc}; use super::{ api_v2::ApiV2, + api_v3::ApiV3, asserts::assert_status, database::{TemporaryDatabase, FRIEND_USER_ID, USER_USER_PAT}, dummy_data, @@ -34,6 +35,7 @@ pub struct TestEnvironment { test_app: Rc, // Rc as it's not Send pub db: TemporaryDatabase, pub v2: ApiV2, + pub v3: ApiV3, pub dummy: Option>, } @@ -56,6 +58,9 @@ impl TestEnvironment { v2: ApiV2 { test_app: test_app.clone(), }, + v3: ApiV3 { + test_app: test_app.clone(), + }, test_app, db, dummy: None, @@ -81,7 +86,27 @@ impl TestEnvironment { USER_USER_PAT, ) .await; - assert_status(resp, StatusCode::NO_CONTENT); + assert_status(&resp, StatusCode::NO_CONTENT); + } + + pub async fn assert_read_notifications_status( + &self, + user_id: &str, + pat: &str, + status_code: StatusCode, + ) { + let resp = self.v2.get_user_notifications(user_id, pat).await; + assert_status(&resp, status_code); + } + + pub async fn assert_read_user_projects_status( + &self, + user_id: &str, + pat: &str, + status_code: StatusCode, + ) { + let resp = self.v2.get_user_projects(user_id, pat).await; + assert_status(&resp, status_code); } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 39b3305ae..b2a317bb0 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -6,6 +6,7 @@ use self::database::TemporaryDatabase; pub mod actix; pub mod api_v2; +pub mod api_v3; pub mod asserts; pub mod database; pub mod dummy_data; @@ -42,3 +43,11 @@ pub async fn setup(db: &TemporaryDatabase) -> LabrinthConfig { maxmind_reader.clone(), ) } + +pub fn get_json_val_str(val: impl serde::Serialize) -> String { + serde_json::to_value(val) + .unwrap() + .as_str() + .unwrap() + .to_string() +} diff --git a/tests/files/dummy_data.sql b/tests/files/dummy_data.sql index 487397a5c..aaa8c1b7f 100644 --- a/tests/files/dummy_data.sql +++ b/tests/files/dummy_data.sql @@ -45,6 +45,25 @@ INSERT INTO categories (id, category, project_type) VALUES (106, 'mobs', 2), (107, 'optimization', 2); +-- Create dummy oauth client, secret_hash is SHA512 hash of full lowercase alphabet +INSERT INTO oauth_clients ( + id, + name, + icon_url, + max_scopes, + secret_hash, + created_by + ) +VALUES ( + 1, + 'oauth_client_alpha', + NULL, + $1, + '4dbff86cc2ca1bae1e16468a05cb9881c97f1753bce3619034898faa1aabe429955a1bf8ec483d7421fe3c1646613a59ed5441fb0f321389f77f48a879c7b1f1', + 3 + ); +INSERT INTO oauth_client_redirect_uris (id, client_id, uri) VALUES (1, 1, 'https://modrinth.com/oauth_callback'); + -- Create dummy data table to mark that this file has been run CREATE TABLE dummy_data ( update_id bigint PRIMARY KEY diff --git a/tests/oauth.rs b/tests/oauth.rs new file mode 100644 index 000000000..1ae59d32a --- /dev/null +++ b/tests/oauth.rs @@ -0,0 +1,292 @@ +use crate::common::{ + api_v3::oauth::get_redirect_location_query_params, database::FRIEND_USER_ID, + dummy_data::DummyOAuthClientAlpha, +}; +use actix_http::StatusCode; +use actix_web::test::{self}; +use common::{ + api_v3::oauth::{get_auth_code_from_redirect_params, get_authorize_accept_flow_id}, + asserts::{assert_any_status_except, assert_status}, + database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, + environment::with_test_environment, +}; +use labrinth::auth::oauth::TokenResponse; +use reqwest::header::{CACHE_CONTROL, PRAGMA}; + +mod common; + +#[actix_rt::test] +async fn oauth_flow_happy_path() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + valid_redirect_uri: base_redirect_uri, + client_id, + client_secret, + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + // Initiate authorization + let redirect_uri = format!("{}?foo=bar", base_redirect_uri); + let original_state = "1234"; + let resp = env + .v3 + .oauth_authorize( + &client_id, + Some("USER_READ NOTIFICATION_READ"), + Some(&redirect_uri), + Some(original_state), + FRIEND_USER_PAT, + ) + .await; + assert_status(&resp, StatusCode::OK); + let flow_id = get_authorize_accept_flow_id(resp).await; + + // Accept the authorization request + let resp = env.v3.oauth_accept(&flow_id, FRIEND_USER_PAT).await; + assert_status(&resp, StatusCode::FOUND); + let query = get_redirect_location_query_params(&resp); + + let auth_code = query.get("code").unwrap(); + let state = query.get("state").unwrap(); + let foo_val = query.get("foo").unwrap(); + assert_eq!(state, original_state); + assert_eq!(foo_val, "bar"); + + // Get the token + let resp = env + .v3 + .oauth_token( + auth_code.to_string(), + Some(redirect_uri.clone()), + client_id.to_string(), + &client_secret, + ) + .await; + assert_status(&resp, StatusCode::OK); + assert_eq!(resp.headers().get(CACHE_CONTROL).unwrap(), "no-store"); + assert_eq!(resp.headers().get(PRAGMA).unwrap(), "no-cache"); + let token_resp: TokenResponse = test::read_body_json(resp).await; + + // Validate the token works + env.assert_read_notifications_status( + FRIEND_USER_ID, + &token_resp.access_token, + StatusCode::OK, + ) + .await; + }) + .await; +} + +#[actix_rt::test] +async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { + with_test_environment(|env| async { + let DummyOAuthClientAlpha { client_id, .. } = env.dummy.unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .oauth_authorize( + &client_id, + Some("USER_READ NOTIFICATION_READ"), + None, + Some("1234"), + USER_USER_PAT, + ) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + + let resp = env + .v3 + .oauth_authorize( + &client_id, + Some("USER_READ"), + None, + Some("5678"), + USER_USER_PAT, + ) + .await; + assert_status(&resp, StatusCode::FOUND); + }) + .await; +} + +#[actix_rt::test] +async fn get_oauth_token_with_already_used_auth_code_fails() { + with_test_environment(|env| async { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + let auth_code = get_auth_code_from_redirect_params(&resp).await; + + let resp = env + .v3 + .oauth_token(auth_code.clone(), None, client_id.clone(), &client_secret) + .await; + assert_status(&resp, StatusCode::OK); + + let resp = env + .v3 + .oauth_token(auth_code, None, client_id, &client_secret) + .await; + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; +} + +#[actix_rt::test] +async fn authorize_with_broader_scopes_can_complete_flow() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let first_access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("PROJECT_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + let second_access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("PROJECT_READ NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + + env.assert_read_notifications_status( + USER_USER_ID, + &first_access_token, + StatusCode::UNAUTHORIZED, + ) + .await; + env.assert_read_user_projects_status(USER_USER_ID, &first_access_token, StatusCode::OK) + .await; + + env.assert_read_notifications_status(USER_USER_ID, &second_access_token, StatusCode::OK) + .await; + env.assert_read_user_projects_status(USER_USER_ID, &second_access_token, StatusCode::OK) + .await; + }) + .await; +} + +#[actix_rt::test] +async fn oauth_authorize_with_broader_scopes_requires_user_accept() { + with_test_environment(|env| async { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, Some("USER_READ"), None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + + let resp = env + .v3 + .oauth_authorize( + &client_id, + Some("USER_READ NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + + assert_status(&resp, StatusCode::OK); + get_authorize_accept_flow_id(resp).await; // ensure we can deser this without error to really confirm + }) + .await; +} + +#[actix_rt::test] +async fn reject_authorize_ends_authorize_flow() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + + let resp = env.v3.oauth_reject(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::FOUND); + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_any_status_except(&resp, StatusCode::FOUND); + }) + .await; +} + +#[actix_rt::test] +async fn accept_authorize_after_already_accepting_fails() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::FOUND); + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; +} + +#[actix_rt::test] +async fn revoke_authorization_after_issuing_token_revokes_token() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + let access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::OK) + .await; + + let resp = env + .v3 + .revoke_oauth_authorization(&client_id, USER_USER_PAT) + .await; + assert_status(&resp, StatusCode::OK); + + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED) + .await; + }) + .await; +} diff --git a/tests/oauth_clients.rs b/tests/oauth_clients.rs new file mode 100644 index 000000000..257bbc90c --- /dev/null +++ b/tests/oauth_clients.rs @@ -0,0 +1,193 @@ +use actix_http::StatusCode; +use actix_web::test; +use common::{ + database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, + dummy_data::DummyOAuthClientAlpha, + environment::with_test_environment, + get_json_val_str, +}; +use labrinth::{ + models::{ + oauth_clients::{OAuthClient, OAuthClientCreationResult}, + pats::Scopes, + }, + routes::v3::oauth_clients::OAuthClientEdit, +}; + +use crate::common::{asserts::assert_status, database::USER_USER_ID_PARSED}; + +mod common; + +#[actix_rt::test] +async fn can_create_edit_get_oauth_client() { + with_test_environment(|env| async move { + let client_name = "test_client".to_string(); + let redirect_uris = vec![ + "https://modrinth.com".to_string(), + "https://modrinth.com/a".to_string(), + ]; + let resp = env + .v3 + .add_oauth_client( + client_name.clone(), + Scopes::all() - Scopes::restricted(), + redirect_uris.clone(), + FRIEND_USER_PAT, + ) + .await; + assert_status(&resp, StatusCode::OK); + let creation_result: OAuthClientCreationResult = test::read_body_json(resp).await; + let client_id = get_json_val_str(creation_result.client.id); + + let icon_url = Some("https://modrinth.com/icon".to_string()); + let edited_redirect_uris = vec![ + redirect_uris[0].clone(), + "https://modrinth.com/b".to_string(), + ]; + let edit = OAuthClientEdit { + name: None, + icon_url: Some(icon_url.clone()), + max_scopes: None, + redirect_uris: Some(edited_redirect_uris.clone()), + }; + let resp = env + .v3 + .edit_oauth_client(&client_id, edit, FRIEND_USER_PAT) + .await; + assert_status(&resp, StatusCode::OK); + + let clients = env + .v3 + .get_user_oauth_clients(FRIEND_USER_ID, FRIEND_USER_PAT) + .await; + assert_eq!(1, clients.len()); + assert_eq!(icon_url, clients[0].icon_url); + assert_eq!(client_name, clients[0].name); + assert_eq!(2, clients[0].redirect_uris.len()); + assert_eq!(edited_redirect_uris[0], clients[0].redirect_uris[0].uri); + assert_eq!(edited_redirect_uris[1], clients[0].redirect_uris[1].uri); + }) + .await; +} + +#[actix_rt::test] +async fn create_oauth_client_with_restricted_scopes_fails() { + with_test_environment(|env| async move { + let resp = env + .v3 + .add_oauth_client( + "test_client".to_string(), + Scopes::restricted(), + vec!["https://modrinth.com".to_string()], + FRIEND_USER_PAT, + ) + .await; + + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; +} + +#[actix_rt::test] +async fn get_oauth_client_for_client_creator_succeeds() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { client_id, .. } = + env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .get_oauth_client(client_id.clone(), USER_USER_PAT) + .await; + + assert_status(&resp, StatusCode::OK); + let client: OAuthClient = test::read_body_json(resp).await; + assert_eq!(get_json_val_str(client.id), client_id); + }) + .await; +} + +#[actix_rt::test] +async fn get_oauth_client_for_unrelated_user_fails() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { client_id, .. } = + env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .get_oauth_client(client_id.clone(), FRIEND_USER_PAT) + .await; + + assert_status(&resp, StatusCode::UNAUTHORIZED); + }) + .await; +} + +#[actix_rt::test] +async fn can_delete_oauth_client() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::NO_CONTENT); + + let clients = env + .v3 + .get_user_oauth_clients(USER_USER_ID, USER_USER_PAT) + .await; + assert_eq!(0, clients.len()); + }) + .await; +} + +#[actix_rt::test] +async fn delete_oauth_client_after_issuing_access_tokens_revokes_tokens() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + let access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + + env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await; + + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED) + .await; + }) + .await; +} + +#[actix_rt::test] +async fn can_list_user_oauth_authorizations() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + env.v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + None, + None, + None, + USER_USER_PAT, + ) + .await; + + let authorizations = env.v3.get_user_oauth_authorizations(USER_USER_PAT).await; + assert_eq!(1, authorizations.len()); + assert_eq!(USER_USER_ID_PARSED, authorizations[0].user_id.0 as i64); + }) + .await; +}