OAuth 2.0 Authorization Server [MOD-559] (#733)
* WIP end-of-day push * Authorize endpoint, accept endpoints, DB stuff for oauth clients, their redirects, and client authorizations * OAuth Client create route * Get user clients * Client delete * Edit oauth client * Include redirects in edit client route * Database stuff for tokens * Reorg oauth stuff out of auth/flows and into its own module * Impl OAuth get access token endpoint * Accept oauth access tokens as auth and update through AuthQueue * User OAuth authorization management routes * Forgot to actually add the routes lol * Bit o cleanup * Happy path test for OAuth and minor fixes for things it found * Add dummy data oauth client (and detect/handle dummy data version changes) * More tests * Another test * More tests and reject endpoint * Test oauth client and authorization management routes * cargo sqlx prepare * dead code warning * Auto clippy fixes * Uri refactoring * minor name improvement * Don't compile-time check the test sqlx queries * Trying to fix db concurrency problem to get tests to pass * Try fix from test PR * Fixes for updated sqlx * Prevent restricted scopes from being requested or issued * Get OAuth client(s) * Remove joined oauth client info from authorization returns * Add default conversion to OAuthError::error so we can use ? * Rework routes * Consolidate scopes into SESSION_ACCESS * Cargo sqlx prepare * Parse to OAuthClientId automatically through serde and actix * Cargo clippy * Remove validation requiring 1 redirect URI on oauth client creation * Use serde(flatten) on OAuthClientCreationResult
This commit is contained in:
parent
8803e11945
commit
6cfd4637db
@ -1,15 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE pats\n SET last_used = $2\n WHERE (id = $1)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba"
|
||||
}
|
||||
15
.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json
generated
Normal file
15
.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json
generated
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE pats\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8Array",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb"
|
||||
}
|
||||
22
.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json
generated
Normal file
22
.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json
generated
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9"
|
||||
}
|
||||
22
.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json
generated
Normal file
22
.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json
generated
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9"
|
||||
}
|
||||
15
.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json
generated
Normal file
15
.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json
generated
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE oauth_access_tokens\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8Array",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c"
|
||||
}
|
||||
47
.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json
generated
Normal file
47
.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json
generated
Normal file
@ -0,0 +1,47 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "client_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scopes",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104"
|
||||
}
|
||||
17
.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json
generated
Normal file
17
.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json
generated
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth_client_authorizations (\n id, client_id, user_id, scopes\n )\n VALUES (\n $1, $2, $3, $4\n )\n ON CONFLICT (id)\n DO UPDATE SET scopes = EXCLUDED.scopes\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8",
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3"
|
||||
}
|
||||
32
.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json
generated
Normal file
32
.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json
generated
Normal file
@ -0,0 +1,32 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth_access_tokens (\n id, authorization_id, token_hash, scopes, last_used\n )\n VALUES (\n $1, $2, $3, $4, $5\n )\n RETURNING created, expires\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "expires",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int8",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911"
|
||||
}
|
||||
70
.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json
generated
Normal file
70
.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json
generated
Normal file
@ -0,0 +1,70 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n tokens.id,\n tokens.authorization_id,\n tokens.token_hash,\n tokens.scopes,\n tokens.created,\n tokens.expires,\n tokens.last_used,\n auths.client_id,\n auths.user_id\n FROM oauth_access_tokens tokens\n JOIN oauth_client_authorizations auths\n ON tokens.authorization_id = auths.id\n WHERE tokens.token_hash = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "authorization_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "token_hash",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scopes",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "expires",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "last_used",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "client_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b"
|
||||
}
|
||||
70
.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json
generated
Normal file
70
.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json
generated
Normal file
@ -0,0 +1,70 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "name!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "icon_url?",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "max_scopes!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "secret_hash!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "created_by!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "uri_ids?",
|
||||
"type_info": "Int8Array"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "uri_vals?",
|
||||
"type_info": "TextArray"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
null,
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477"
|
||||
}
|
||||
19
.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json
generated
Normal file
19
.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json
generated
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth_clients (\n id, name, icon_url, max_scopes, secret_hash, created_by\n )\n VALUES (\n $1, $2, $3, $4, $5, $6\n )\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f"
|
||||
}
|
||||
16
.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json
generated
Normal file
16
.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json
generated
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth_client_redirect_uris (id, client_id, uri)\n SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8Array",
|
||||
"Int8Array",
|
||||
"VarcharArray"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0"
|
||||
}
|
||||
46
.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json
generated
Normal file
46
.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json
generated
Normal file
@ -0,0 +1,46 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE user_id=$1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "client_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scopes",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322"
|
||||
}
|
||||
14
.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json
generated
Normal file
14
.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json
generated
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_client_redirect_uris\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8Array"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96"
|
||||
}
|
||||
22
.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json
generated
Normal file
22
.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json
generated
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a"
|
||||
}
|
||||
15
.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json
generated
Normal file
15
.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json
generated
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428"
|
||||
}
|
||||
22
.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json
generated
Normal file
22
.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json
generated
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12"
|
||||
}
|
||||
14
.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json
generated
Normal file
14
.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json
generated
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_clients\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf"
|
||||
}
|
||||
17
.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json
generated
Normal file
17
.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json
generated
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE oauth_clients\n SET name = $1, icon_url = $2, max_scopes = $3\n WHERE (id = $4)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text",
|
||||
"Text",
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53"
|
||||
}
|
||||
70
.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json
generated
Normal file
70
.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json
generated
Normal file
@ -0,0 +1,70 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = ANY($1::bigint[])",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "name!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "icon_url?",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "max_scopes!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "secret_hash!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "created_by!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "uri_ids?",
|
||||
"type_info": "Int8Array"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "uri_vals?",
|
||||
"type_info": "TextArray"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8Array"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
null,
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7"
|
||||
}
|
||||
34
migrations/20231016190056_oauth_provider.sql
Normal file
34
migrations/20231016190056_oauth_provider.sql
Normal file
@ -0,0 +1,34 @@
|
||||
CREATE TABLE oauth_clients (
|
||||
id bigint PRIMARY KEY,
|
||||
name text NOT NULL,
|
||||
icon_url text NULL,
|
||||
max_scopes bigint NOT NULL,
|
||||
secret_hash text NOT NULL,
|
||||
created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by bigint NOT NULL REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE oauth_client_redirect_uris (
|
||||
id bigint PRIMARY KEY,
|
||||
client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE,
|
||||
uri text
|
||||
);
|
||||
CREATE TABLE oauth_client_authorizations (
|
||||
id bigint PRIMARY KEY,
|
||||
client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE,
|
||||
user_id bigint NOT NULL REFERENCES users (id) ON DELETE CASCADE,
|
||||
scopes bigint NOT NULL,
|
||||
created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE (client_id, user_id)
|
||||
);
|
||||
CREATE TABLE oauth_access_tokens (
|
||||
id bigint PRIMARY KEY,
|
||||
authorization_id bigint NOT NULL REFERENCES oauth_client_authorizations(id) ON DELETE CASCADE,
|
||||
token_hash text NOT NULL UNIQUE,
|
||||
scopes bigint NOT NULL,
|
||||
created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP + interval '14 days',
|
||||
last_used timestamptz NULL
|
||||
);
|
||||
CREATE INDEX oauth_client_creator ON oauth_clients(created_by);
|
||||
CREATE INDEX oauth_redirect_client ON oauth_client_redirect_uris(client_id);
|
||||
CREATE INDEX oauth_access_token_hash ON oauth_access_tokens(token_hash);
|
||||
@ -8,6 +8,25 @@ use crate::routes::ApiError;
|
||||
use actix_web::web;
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub trait ValidateAuthorized {
|
||||
fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError>;
|
||||
}
|
||||
|
||||
pub trait ValidateAllAuthorized {
|
||||
fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError>;
|
||||
}
|
||||
|
||||
impl<'a, T, A> ValidateAllAuthorized for T
|
||||
where
|
||||
T: IntoIterator<Item = &'a A>,
|
||||
A: ValidateAuthorized + 'a,
|
||||
{
|
||||
fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError> {
|
||||
self.into_iter()
|
||||
.try_for_each(|c| c.validate_authorized(user_option))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_authorized(
|
||||
project_data: &Project,
|
||||
user_option: &Option<User>,
|
||||
@ -156,6 +175,23 @@ pub async fn is_authorized_version(
|
||||
Ok(authorized)
|
||||
}
|
||||
|
||||
impl ValidateAuthorized for crate::database::models::OAuthClient {
|
||||
fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError> {
|
||||
if let Some(user) = user_option {
|
||||
if user.role.is_mod() || user.id == self.created_by.into() {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(crate::routes::ApiError::CustomAuthentication(
|
||||
"You don't have sufficient permissions to interact with this OAuth application"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn filter_authorized_versions(
|
||||
versions: Vec<QueryVersion>,
|
||||
user_option: &Option<User>,
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
pub mod checks;
|
||||
pub mod email;
|
||||
pub mod flows;
|
||||
pub mod oauth;
|
||||
pub mod pats;
|
||||
pub mod session;
|
||||
mod templates;
|
||||
pub mod validate;
|
||||
|
||||
pub use checks::{
|
||||
filter_authorized_projects, filter_authorized_versions, is_authorized, is_authorized_version,
|
||||
};
|
||||
|
||||
176
src/auth/oauth/errors.rs
Normal file
176
src/auth/oauth/errors.rs
Normal file
@ -0,0 +1,176 @@
|
||||
use super::ValidatedRedirectUri;
|
||||
use crate::auth::AuthenticationError;
|
||||
use crate::models::error::ApiError;
|
||||
use crate::models::ids::DecodingError;
|
||||
use actix_web::http::StatusCode;
|
||||
use actix_web::HttpResponse;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[error("{}", .error_type)]
|
||||
pub struct OAuthError {
|
||||
#[source]
|
||||
pub error_type: OAuthErrorType,
|
||||
|
||||
pub state: Option<String>,
|
||||
pub valid_redirect_uri: Option<ValidatedRedirectUri>,
|
||||
}
|
||||
|
||||
impl<T> From<T> for OAuthError
|
||||
where
|
||||
T: Into<OAuthErrorType>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
OAuthError::error(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthError {
|
||||
/// The OAuth request failed either because of an invalid redirection URI
|
||||
/// or before we could validate the one we were given, so return an error
|
||||
/// directly to the caller
|
||||
///
|
||||
/// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1)
|
||||
pub fn error(error_type: impl Into<OAuthErrorType>) -> Self {
|
||||
Self {
|
||||
error_type: error_type.into(),
|
||||
valid_redirect_uri: None,
|
||||
state: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// The OAuth request failed for a reason other than an invalid redirection URI
|
||||
/// So send the error in url-encoded form to the redirect URI
|
||||
///
|
||||
/// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1)
|
||||
pub fn redirect(
|
||||
err: impl Into<OAuthErrorType>,
|
||||
state: &Option<String>,
|
||||
valid_redirect_uri: &ValidatedRedirectUri,
|
||||
) -> Self {
|
||||
Self {
|
||||
error_type: err.into(),
|
||||
state: state.clone(),
|
||||
valid_redirect_uri: Some(valid_redirect_uri.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl actix_web::ResponseError for OAuthError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self.error_type {
|
||||
OAuthErrorType::AuthenticationError(_)
|
||||
| OAuthErrorType::FailedScopeParse(_)
|
||||
| OAuthErrorType::ScopesTooBroad
|
||||
| OAuthErrorType::AccessDenied => {
|
||||
if self.valid_redirect_uri.is_some() {
|
||||
StatusCode::FOUND
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
OAuthErrorType::RedirectUriNotConfigured(_)
|
||||
| OAuthErrorType::ClientMissingRedirectURI { client_id: _ }
|
||||
| OAuthErrorType::InvalidAcceptFlowId
|
||||
| OAuthErrorType::MalformedId(_)
|
||||
| OAuthErrorType::InvalidClientId(_)
|
||||
| OAuthErrorType::InvalidAuthCode
|
||||
| OAuthErrorType::OnlySupportsAuthorizationCodeGrant(_)
|
||||
| OAuthErrorType::RedirectUriChanged(_)
|
||||
| OAuthErrorType::UnauthorizedClient => StatusCode::BAD_REQUEST,
|
||||
OAuthErrorType::ClientAuthenticationFailed => StatusCode::UNAUTHORIZED,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() {
|
||||
redirect_uri = format!(
|
||||
"{}?error={}&error_description={}",
|
||||
redirect_uri,
|
||||
self.error_type.error_name(),
|
||||
self.error_type,
|
||||
);
|
||||
|
||||
if let Some(state) = self.state.as_ref() {
|
||||
redirect_uri = format!("{}&state={}", redirect_uri, state);
|
||||
}
|
||||
|
||||
redirect_uri = urlencoding::encode(&redirect_uri).to_string();
|
||||
HttpResponse::Found()
|
||||
.append_header(("Location".to_string(), redirect_uri))
|
||||
.finish()
|
||||
} else {
|
||||
HttpResponse::build(self.status_code()).json(ApiError {
|
||||
error: &self.error_type.error_name(),
|
||||
description: &self.error_type.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum OAuthErrorType {
|
||||
#[error(transparent)]
|
||||
AuthenticationError(#[from] AuthenticationError),
|
||||
#[error("Client {} has no redirect URIs specified", .client_id.0)]
|
||||
ClientMissingRedirectURI {
|
||||
client_id: crate::database::models::OAuthClientId,
|
||||
},
|
||||
#[error("The provided redirect URI did not match any configured in the client")]
|
||||
RedirectUriNotConfigured(String),
|
||||
#[error("The provided scope was malformed or did not correspond to known scopes ({0})")]
|
||||
FailedScopeParse(bitflags::parser::ParseError),
|
||||
#[error(
|
||||
"The provided scope requested scopes broader than the developer app is configured with"
|
||||
)]
|
||||
ScopesTooBroad,
|
||||
#[error("The provided flow id was invalid")]
|
||||
InvalidAcceptFlowId,
|
||||
#[error("The provided client id was invalid")]
|
||||
InvalidClientId(crate::database::models::OAuthClientId),
|
||||
#[error("The provided ID could not be decoded: {0}")]
|
||||
MalformedId(#[from] DecodingError),
|
||||
#[error("Failed to authenticate client")]
|
||||
ClientAuthenticationFailed,
|
||||
#[error("The provided authorization grant code was invalid")]
|
||||
InvalidAuthCode,
|
||||
#[error("The provided client id did not match the id this authorization code was granted to")]
|
||||
UnauthorizedClient,
|
||||
#[error("The provided redirect URI did not exactly match the uri originally provided when this flow began")]
|
||||
RedirectUriChanged(Option<String>),
|
||||
#[error("The provided grant type ({0}) must be \"authorization_code\"")]
|
||||
OnlySupportsAuthorizationCodeGrant(String),
|
||||
#[error("The resource owner denied the request")]
|
||||
AccessDenied,
|
||||
}
|
||||
|
||||
impl From<crate::database::models::DatabaseError> for OAuthErrorType {
|
||||
fn from(value: crate::database::models::DatabaseError) -> Self {
|
||||
OAuthErrorType::AuthenticationError(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for OAuthErrorType {
|
||||
fn from(value: sqlx::Error) -> Self {
|
||||
OAuthErrorType::AuthenticationError(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthErrorType {
|
||||
pub fn error_name(&self) -> String {
|
||||
// IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38)
|
||||
// And 5.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
|
||||
match self {
|
||||
Self::RedirectUriNotConfigured(_) | Self::ClientMissingRedirectURI { client_id: _ } => {
|
||||
"invalid_uri"
|
||||
}
|
||||
Self::AuthenticationError(_) | Self::InvalidAcceptFlowId => "server_error",
|
||||
Self::RedirectUriChanged(_) | Self::MalformedId(_) => "invalid_request",
|
||||
Self::FailedScopeParse(_) | Self::ScopesTooBroad => "invalid_scope",
|
||||
Self::InvalidClientId(_) | Self::ClientAuthenticationFailed => "invalid_client",
|
||||
Self::InvalidAuthCode | Self::OnlySupportsAuthorizationCodeGrant(_) => "invalid_grant",
|
||||
Self::UnauthorizedClient => "unauthorized_client",
|
||||
Self::AccessDenied => "access_denied",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
430
src/auth/oauth/mod.rs
Normal file
430
src/auth/oauth/mod.rs
Normal file
@ -0,0 +1,430 @@
|
||||
use crate::auth::get_user_from_headers;
|
||||
use crate::auth::oauth::uris::{OAuthRedirectUris, ValidatedRedirectUri};
|
||||
use crate::auth::validate::extract_authorization_header;
|
||||
use crate::database::models::flow_item::Flow;
|
||||
use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization;
|
||||
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
|
||||
use crate::database::models::oauth_token_item::OAuthAccessToken;
|
||||
use crate::database::models::{
|
||||
generate_oauth_access_token_id, generate_oauth_client_authorization_id,
|
||||
OAuthClientAuthorizationId, OAuthClientId,
|
||||
};
|
||||
use crate::database::redis::RedisPool;
|
||||
use crate::models;
|
||||
use crate::models::pats::Scopes;
|
||||
use crate::queue::session::AuthQueue;
|
||||
use actix_web::web::{scope, Data, Query, ServiceConfig};
|
||||
use actix_web::{get, post, web, HttpRequest, HttpResponse};
|
||||
use chrono::Duration;
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::postgres::PgPool;
|
||||
|
||||
use self::errors::{OAuthError, OAuthErrorType};
|
||||
|
||||
use super::AuthenticationError;
|
||||
|
||||
pub mod errors;
|
||||
pub mod uris;
|
||||
|
||||
pub fn config(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
scope("auth/oauth")
|
||||
.service(init_oauth)
|
||||
.service(accept_client_scopes)
|
||||
.service(reject_client_scopes)
|
||||
.service(request_token),
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct OAuthInit {
|
||||
pub client_id: OAuthClientId,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub scope: Option<String>,
|
||||
pub state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct OAuthClientAccessRequest {
|
||||
pub flow_id: String,
|
||||
pub client_id: OAuthClientId,
|
||||
pub client_name: String,
|
||||
pub client_icon: Option<String>,
|
||||
pub requested_scopes: Scopes,
|
||||
}
|
||||
|
||||
#[get("authorize")]
|
||||
pub async fn init_oauth(
|
||||
req: HttpRequest,
|
||||
Query(oauth_info): Query<OAuthInit>,
|
||||
pool: Data<PgPool>,
|
||||
redis: Data<RedisPool>,
|
||||
session_queue: Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
let user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::USER_AUTH_WRITE]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let client_id = oauth_info.client_id;
|
||||
let client = DBOAuthClient::get(client_id, &**pool).await?;
|
||||
|
||||
if let Some(client) = client {
|
||||
let redirect_uri = ValidatedRedirectUri::validate(
|
||||
&oauth_info.redirect_uri,
|
||||
client.redirect_uris.iter().map(|r| r.uri.as_ref()),
|
||||
client.id,
|
||||
)?;
|
||||
|
||||
let requested_scopes = oauth_info
|
||||
.scope
|
||||
.as_ref()
|
||||
.map_or(Ok(client.max_scopes), |s| {
|
||||
Scopes::parse_from_oauth_scopes(s).map_err(|e| {
|
||||
OAuthError::redirect(
|
||||
OAuthErrorType::FailedScopeParse(e),
|
||||
&oauth_info.state,
|
||||
&redirect_uri,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
if !client.max_scopes.contains(requested_scopes) {
|
||||
return Err(OAuthError::redirect(
|
||||
OAuthErrorType::ScopesTooBroad,
|
||||
&oauth_info.state,
|
||||
&redirect_uri,
|
||||
));
|
||||
}
|
||||
|
||||
let existing_authorization =
|
||||
OAuthClientAuthorization::get(client.id, user.id.into(), &**pool)
|
||||
.await
|
||||
.map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?;
|
||||
let redirect_uris =
|
||||
OAuthRedirectUris::new(oauth_info.redirect_uri.clone(), redirect_uri.clone());
|
||||
match existing_authorization {
|
||||
Some(existing_authorization)
|
||||
if existing_authorization.scopes.contains(requested_scopes) =>
|
||||
{
|
||||
init_oauth_code_flow(
|
||||
user.id.into(),
|
||||
client.id,
|
||||
existing_authorization.id,
|
||||
requested_scopes,
|
||||
redirect_uris,
|
||||
oauth_info.state,
|
||||
&redis,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
let flow_id = Flow::InitOAuthAppApproval {
|
||||
user_id: user.id.into(),
|
||||
client_id: client.id,
|
||||
existing_authorization_id: existing_authorization.map(|a| a.id),
|
||||
scopes: requested_scopes,
|
||||
redirect_uris,
|
||||
state: oauth_info.state.clone(),
|
||||
}
|
||||
.insert(Duration::minutes(30), &redis)
|
||||
.await
|
||||
.map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?;
|
||||
|
||||
let access_request = OAuthClientAccessRequest {
|
||||
client_id: client.id,
|
||||
client_name: client.name,
|
||||
client_icon: client.icon_url,
|
||||
flow_id,
|
||||
requested_scopes,
|
||||
};
|
||||
Ok(HttpResponse::Ok().json(access_request))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(OAuthError::error(OAuthErrorType::InvalidClientId(
|
||||
client_id,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct RespondToOAuthClientScopes {
|
||||
pub flow: String,
|
||||
}
|
||||
|
||||
#[post("accept")]
|
||||
pub async fn accept_client_scopes(
|
||||
req: HttpRequest,
|
||||
accept_body: web::Json<RespondToOAuthClientScopes>,
|
||||
pool: Data<PgPool>,
|
||||
redis: Data<RedisPool>,
|
||||
session_queue: Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
accept_or_reject_client_scopes(true, req, accept_body, pool, redis, session_queue).await
|
||||
}
|
||||
|
||||
#[post("reject")]
|
||||
pub async fn reject_client_scopes(
|
||||
req: HttpRequest,
|
||||
body: web::Json<RespondToOAuthClientScopes>,
|
||||
pool: Data<PgPool>,
|
||||
redis: Data<RedisPool>,
|
||||
session_queue: Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
accept_or_reject_client_scopes(false, req, body, pool, redis, session_queue).await
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TokenRequest {
|
||||
pub grant_type: String,
|
||||
pub code: String,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub client_id: models::ids::OAuthClientId,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: i64,
|
||||
}
|
||||
|
||||
#[post("token")]
|
||||
/// Params should be in the urlencoded request body
|
||||
/// And client secret should be in the HTTP basic authorization header
|
||||
/// Per IETF RFC6749 Section 4.1.3 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3)
|
||||
pub async fn request_token(
|
||||
req: HttpRequest,
|
||||
req_params: web::Form<TokenRequest>,
|
||||
pool: Data<PgPool>,
|
||||
redis: Data<RedisPool>,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
let req_client_id = req_params.client_id;
|
||||
let client = DBOAuthClient::get(req_client_id.into(), &**pool).await?;
|
||||
if let Some(client) = client {
|
||||
authenticate_client_token_request(&req, &client)?;
|
||||
|
||||
// Ensure auth code is single use
|
||||
// per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5)
|
||||
let flow = Flow::take_if(
|
||||
&req_params.code,
|
||||
|f| matches!(f, Flow::OAuthAuthorizationCodeSupplied { .. }),
|
||||
&redis,
|
||||
)
|
||||
.await?;
|
||||
if let Some(Flow::OAuthAuthorizationCodeSupplied {
|
||||
user_id,
|
||||
client_id,
|
||||
authorization_id,
|
||||
scopes,
|
||||
original_redirect_uri,
|
||||
}) = flow
|
||||
{
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
if req_client_id != client_id.into() {
|
||||
return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient));
|
||||
}
|
||||
|
||||
if original_redirect_uri != req_params.redirect_uri {
|
||||
return Err(OAuthError::error(OAuthErrorType::RedirectUriChanged(
|
||||
req_params.redirect_uri.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if req_params.grant_type != "authorization_code" {
|
||||
return Err(OAuthError::error(
|
||||
OAuthErrorType::OnlySupportsAuthorizationCodeGrant(
|
||||
req_params.grant_type.clone(),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let scopes = scopes - Scopes::restricted();
|
||||
|
||||
let mut transaction = pool.begin().await?;
|
||||
let token_id = generate_oauth_access_token_id(&mut transaction).await?;
|
||||
let token = generate_access_token();
|
||||
let token_hash = OAuthAccessToken::hash_token(&token);
|
||||
let time_until_expiration = OAuthAccessToken {
|
||||
id: token_id,
|
||||
authorization_id,
|
||||
token_hash,
|
||||
scopes,
|
||||
created: Default::default(),
|
||||
expires: Default::default(),
|
||||
last_used: None,
|
||||
client_id,
|
||||
user_id,
|
||||
}
|
||||
.insert(&mut *transaction)
|
||||
.await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
// IETF RFC6749 Section 5.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.1)
|
||||
Ok(HttpResponse::Ok()
|
||||
.append_header((CACHE_CONTROL, "no-store"))
|
||||
.append_header((PRAGMA, "no-cache"))
|
||||
.json(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: time_until_expiration.num_seconds(),
|
||||
}))
|
||||
} else {
|
||||
Err(OAuthError::error(OAuthErrorType::InvalidAuthCode))
|
||||
}
|
||||
} else {
|
||||
Err(OAuthError::error(OAuthErrorType::InvalidClientId(
|
||||
req_client_id.into(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn accept_or_reject_client_scopes(
|
||||
accept: bool,
|
||||
req: HttpRequest,
|
||||
body: web::Json<RespondToOAuthClientScopes>,
|
||||
pool: Data<PgPool>,
|
||||
redis: Data<RedisPool>,
|
||||
session_queue: Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let flow = Flow::take_if(
|
||||
&body.flow,
|
||||
|f| matches!(f, Flow::InitOAuthAppApproval { .. }),
|
||||
&redis,
|
||||
)
|
||||
.await?;
|
||||
if let Some(Flow::InitOAuthAppApproval {
|
||||
user_id,
|
||||
client_id,
|
||||
existing_authorization_id,
|
||||
scopes,
|
||||
redirect_uris,
|
||||
state,
|
||||
}) = flow
|
||||
{
|
||||
if current_user.id != user_id.into() {
|
||||
return Err(OAuthError::error(AuthenticationError::InvalidCredentials));
|
||||
}
|
||||
|
||||
if accept {
|
||||
let mut transaction = pool.begin().await?;
|
||||
|
||||
let auth_id = match existing_authorization_id {
|
||||
Some(id) => id,
|
||||
None => generate_oauth_client_authorization_id(&mut transaction).await?,
|
||||
};
|
||||
OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction)
|
||||
.await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
init_oauth_code_flow(
|
||||
user_id,
|
||||
client_id,
|
||||
auth_id,
|
||||
scopes,
|
||||
redirect_uris,
|
||||
state,
|
||||
&redis,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
Err(OAuthError::redirect(
|
||||
OAuthErrorType::AccessDenied,
|
||||
&state,
|
||||
&redirect_uris.validated,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId))
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate_client_token_request(
|
||||
req: &HttpRequest,
|
||||
client: &DBOAuthClient,
|
||||
) -> Result<(), OAuthError> {
|
||||
let client_secret = extract_authorization_header(req)?;
|
||||
let hashed_client_secret = DBOAuthClient::hash_secret(client_secret);
|
||||
if client.secret_hash != hashed_client_secret {
|
||||
Err(OAuthError::error(
|
||||
OAuthErrorType::ClientAuthenticationFailed,
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_access_token() -> String {
|
||||
let random = ChaCha20Rng::from_entropy()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(60)
|
||||
.map(char::from)
|
||||
.collect::<String>();
|
||||
format!("mro_{}", random)
|
||||
}
|
||||
|
||||
async fn init_oauth_code_flow(
|
||||
user_id: crate::database::models::UserId,
|
||||
client_id: OAuthClientId,
|
||||
authorization_id: OAuthClientAuthorizationId,
|
||||
scopes: Scopes,
|
||||
redirect_uris: OAuthRedirectUris,
|
||||
state: Option<String>,
|
||||
redis: &RedisPool,
|
||||
) -> Result<HttpResponse, OAuthError> {
|
||||
let code = Flow::OAuthAuthorizationCodeSupplied {
|
||||
user_id,
|
||||
client_id,
|
||||
authorization_id,
|
||||
scopes,
|
||||
original_redirect_uri: redirect_uris.original.clone(),
|
||||
}
|
||||
.insert(Duration::minutes(10), redis)
|
||||
.await
|
||||
.map_err(|e| OAuthError::redirect(e, &state, &redirect_uris.validated.clone()))?;
|
||||
|
||||
let mut redirect_params = vec![format!("code={code}")];
|
||||
if let Some(state) = state {
|
||||
redirect_params.push(format!("state={state}"));
|
||||
}
|
||||
|
||||
let redirect_uri = append_params_to_uri(&redirect_uris.validated.0, &redirect_params);
|
||||
|
||||
// IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2)
|
||||
Ok(HttpResponse::Found()
|
||||
.append_header((LOCATION, redirect_uri))
|
||||
.finish())
|
||||
}
|
||||
|
||||
fn append_params_to_uri(uri: &str, params: &[impl AsRef<str>]) -> String {
|
||||
let mut uri = uri.to_string();
|
||||
let mut connector = if uri.contains('?') { "&" } else { "?" };
|
||||
for param in params {
|
||||
uri.push_str(&format!("{}{}", connector, param.as_ref()));
|
||||
connector = "&";
|
||||
}
|
||||
|
||||
uri
|
||||
}
|
||||
94
src/auth/oauth/uris.rs
Normal file
94
src/auth/oauth/uris.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use super::errors::OAuthError;
|
||||
use crate::auth::oauth::OAuthErrorType;
|
||||
use crate::database::models::OAuthClientId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(derive_new::new, Serialize, Deserialize)]
|
||||
pub struct OAuthRedirectUris {
|
||||
pub original: Option<String>,
|
||||
pub validated: ValidatedRedirectUri,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ValidatedRedirectUri(pub String);
|
||||
|
||||
impl ValidatedRedirectUri {
|
||||
pub fn validate<'a>(
|
||||
to_validate: &Option<String>,
|
||||
validate_against: impl IntoIterator<Item = &'a str> + Clone,
|
||||
client_id: OAuthClientId,
|
||||
) -> Result<Self, OAuthError> {
|
||||
if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() {
|
||||
if let Some(to_validate) = to_validate {
|
||||
if validate_against
|
||||
.into_iter()
|
||||
.any(|uri| same_uri_except_query_components(uri, to_validate))
|
||||
{
|
||||
Ok(ValidatedRedirectUri(to_validate.clone()))
|
||||
} else {
|
||||
Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured(
|
||||
to_validate.clone(),
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string()))
|
||||
}
|
||||
} else {
|
||||
Err(OAuthError::error(
|
||||
OAuthErrorType::ClientMissingRedirectURI { client_id },
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn same_uri_except_query_components(a: &str, b: &str) -> bool {
|
||||
let mut a_components = a.split('?');
|
||||
let mut b_components = b.split('?');
|
||||
a_components.next() == b_components.next()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn validate_for_none_returns_first_valid_uri() {
|
||||
let validate_against = vec!["https://modrinth.com/a"];
|
||||
|
||||
let validated =
|
||||
ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(validate_against[0], validated.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() {
|
||||
let validate_against = vec![
|
||||
"https://modrinth.com/a?q3=p3&q4=p4",
|
||||
"https://modrinth.com/a/b/c?q1=p1&q2=p2",
|
||||
];
|
||||
let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string();
|
||||
|
||||
let validated = ValidatedRedirectUri::validate(
|
||||
&Some(to_validate.clone()),
|
||||
validate_against,
|
||||
OAuthClientId(0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(to_validate, validated.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_for_invalid_uri_returns_err() {
|
||||
let validate_against = vec!["https://modrinth.com/a"];
|
||||
let to_validate = "https://modrinth.com/a/b".to_string();
|
||||
|
||||
let validated =
|
||||
ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0));
|
||||
|
||||
assert!(validated
|
||||
.is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_))));
|
||||
}
|
||||
}
|
||||
@ -91,12 +91,7 @@ where
|
||||
let token = if let Some(token) = token {
|
||||
token
|
||||
} else {
|
||||
let headers = req.headers();
|
||||
let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION);
|
||||
token_val
|
||||
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
|
||||
.to_str()
|
||||
.map_err(|_| AuthenticationError::InvalidCredentials)?
|
||||
extract_authorization_header(req)?
|
||||
};
|
||||
|
||||
let possible_user = match token.split_once('_') {
|
||||
@ -142,6 +137,25 @@ where
|
||||
|
||||
user.map(|x| (Scopes::all(), x))
|
||||
}
|
||||
Some(("mro", _)) => {
|
||||
use crate::database::models::oauth_token_item::OAuthAccessToken;
|
||||
|
||||
let hash = OAuthAccessToken::hash_token(token);
|
||||
let access_token =
|
||||
crate::database::models::oauth_token_item::OAuthAccessToken::get(hash, executor)
|
||||
.await?
|
||||
.ok_or(AuthenticationError::InvalidCredentials)?;
|
||||
|
||||
if access_token.expires < Utc::now() {
|
||||
return Err(AuthenticationError::InvalidCredentials);
|
||||
}
|
||||
|
||||
let user = user_item::User::get_id(access_token.user_id, executor, redis).await?;
|
||||
|
||||
session_queue.add_oauth_access_token(access_token.id).await;
|
||||
|
||||
user.map(|u| (access_token.scopes, u))
|
||||
}
|
||||
Some(("github", _)) | Some(("gho", _)) | Some(("ghp", _)) => {
|
||||
let user = AuthProvider::GitHub.get_user(token).await?;
|
||||
let id = AuthProvider::GitHub.get_user_id(&user.id, executor).await?;
|
||||
@ -160,6 +174,15 @@ where
|
||||
Ok(possible_user)
|
||||
}
|
||||
|
||||
pub fn extract_authorization_header(req: &HttpRequest) -> Result<&str, AuthenticationError> {
|
||||
let headers = req.headers();
|
||||
let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION);
|
||||
token_val
|
||||
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
|
||||
.to_str()
|
||||
.map_err(|_| AuthenticationError::InvalidCredentials)
|
||||
}
|
||||
|
||||
pub async fn check_is_moderator_from_headers<'a, 'b, E>(
|
||||
req: &HttpRequest,
|
||||
executor: E,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
use super::ids::*;
|
||||
use crate::auth::flows::AuthProvider;
|
||||
use crate::auth::oauth::uris::OAuthRedirectUris;
|
||||
use crate::database::models::DatabaseError;
|
||||
use crate::database::redis::RedisPool;
|
||||
use crate::{auth::flows::AuthProvider, models::pats::Scopes};
|
||||
use chrono::Duration;
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::Rng;
|
||||
@ -34,6 +35,21 @@ pub enum Flow {
|
||||
confirm_email: String,
|
||||
},
|
||||
MinecraftAuth,
|
||||
InitOAuthAppApproval {
|
||||
user_id: UserId,
|
||||
client_id: OAuthClientId,
|
||||
existing_authorization_id: Option<OAuthClientAuthorizationId>,
|
||||
scopes: Scopes,
|
||||
redirect_uris: OAuthRedirectUris,
|
||||
state: Option<String>,
|
||||
},
|
||||
OAuthAuthorizationCodeSupplied {
|
||||
user_id: UserId,
|
||||
client_id: OAuthClientId,
|
||||
authorization_id: OAuthClientAuthorizationId,
|
||||
scopes: Scopes,
|
||||
original_redirect_uri: Option<String>, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
},
|
||||
}
|
||||
|
||||
impl Flow {
|
||||
@ -58,6 +74,22 @@ impl Flow {
|
||||
redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await
|
||||
}
|
||||
|
||||
/// Gets the flow and removes it from the cache, but only removes if the flow was present and the predicate returned true
|
||||
/// The predicate should validate that the flow being removed is the correct one, as a security measure
|
||||
pub async fn take_if(
|
||||
id: &str,
|
||||
predicate: impl FnOnce(&Flow) -> bool,
|
||||
redis: &RedisPool,
|
||||
) -> Result<Option<Flow>, DatabaseError> {
|
||||
let flow = Self::get(id, redis).await?;
|
||||
if let Some(flow) = flow.as_ref() {
|
||||
if predicate(flow) {
|
||||
Self::remove(id, redis).await?;
|
||||
}
|
||||
}
|
||||
Ok(flow)
|
||||
}
|
||||
|
||||
pub async fn remove(id: &str, redis: &RedisPool) -> Result<Option<()>, DatabaseError> {
|
||||
redis.delete(FLOWS_NAMESPACE, id).await?;
|
||||
Ok(Some(()))
|
||||
|
||||
@ -152,6 +152,38 @@ generate_ids!(
|
||||
ImageId
|
||||
);
|
||||
|
||||
generate_ids!(
|
||||
pub generate_oauth_client_authorization_id,
|
||||
OAuthClientAuthorizationId,
|
||||
8,
|
||||
"SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)",
|
||||
OAuthClientAuthorizationId
|
||||
);
|
||||
|
||||
generate_ids!(
|
||||
pub generate_oauth_client_id,
|
||||
OAuthClientId,
|
||||
8,
|
||||
"SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)",
|
||||
OAuthClientId
|
||||
);
|
||||
|
||||
generate_ids!(
|
||||
pub generate_oauth_redirect_id,
|
||||
OAuthRedirectUriId,
|
||||
8,
|
||||
"SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)",
|
||||
OAuthRedirectUriId
|
||||
);
|
||||
|
||||
generate_ids!(
|
||||
pub generate_oauth_access_token_id,
|
||||
OAuthAccessTokenId,
|
||||
8,
|
||||
"SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)",
|
||||
OAuthAccessTokenId
|
||||
);
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct UserId(pub i64);
|
||||
@ -238,6 +270,22 @@ pub struct SessionId(pub i64);
|
||||
#[sqlx(transparent)]
|
||||
pub struct ImageId(pub i64);
|
||||
|
||||
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct OAuthClientId(pub i64);
|
||||
|
||||
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct OAuthClientAuthorizationId(pub i64);
|
||||
|
||||
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct OAuthRedirectUriId(pub i64);
|
||||
|
||||
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct OAuthAccessTokenId(pub i64);
|
||||
|
||||
use crate::models::ids;
|
||||
|
||||
impl From<ids::ProjectId> for ProjectId {
|
||||
@ -360,3 +408,23 @@ impl From<PatId> for ids::PatId {
|
||||
ids::PatId(id.0 as u64)
|
||||
}
|
||||
}
|
||||
impl From<OAuthClientId> for ids::OAuthClientId {
|
||||
fn from(id: OAuthClientId) -> Self {
|
||||
ids::OAuthClientId(id.0 as u64)
|
||||
}
|
||||
}
|
||||
impl From<ids::OAuthClientId> for OAuthClientId {
|
||||
fn from(id: ids::OAuthClientId) -> Self {
|
||||
Self(id.0 as i64)
|
||||
}
|
||||
}
|
||||
impl From<OAuthRedirectUriId> for ids::OAuthRedirectUriId {
|
||||
fn from(id: OAuthRedirectUriId) -> Self {
|
||||
ids::OAuthRedirectUriId(id.0 as u64)
|
||||
}
|
||||
}
|
||||
impl From<OAuthClientAuthorizationId> for ids::OAuthClientAuthorizationId {
|
||||
fn from(id: OAuthClientAuthorizationId) -> Self {
|
||||
ids::OAuthClientAuthorizationId(id.0 as u64)
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,6 +6,9 @@ pub mod flow_item;
|
||||
pub mod ids;
|
||||
pub mod image_item;
|
||||
pub mod notification_item;
|
||||
pub mod oauth_client_authorization_item;
|
||||
pub mod oauth_client_item;
|
||||
pub mod oauth_token_item;
|
||||
pub mod organization_item;
|
||||
pub mod pat_item;
|
||||
pub mod project_item;
|
||||
@ -19,6 +22,7 @@ pub mod version_item;
|
||||
pub use collection_item::Collection;
|
||||
pub use ids::*;
|
||||
pub use image_item::Image;
|
||||
pub use oauth_client_item::OAuthClient;
|
||||
pub use organization_item::Organization;
|
||||
pub use project_item::Project;
|
||||
pub use team_item::Team;
|
||||
|
||||
126
src/database/models/oauth_client_authorization_item.rs
Normal file
126
src/database/models/oauth_client_authorization_item.rs
Normal file
@ -0,0 +1,126 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::models::pats::Scopes;
|
||||
|
||||
use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId};
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
||||
pub struct OAuthClientAuthorization {
|
||||
pub id: OAuthClientAuthorizationId,
|
||||
pub client_id: OAuthClientId,
|
||||
pub user_id: UserId,
|
||||
pub scopes: Scopes,
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
struct AuthorizationQueryResult {
|
||||
id: i64,
|
||||
client_id: i64,
|
||||
user_id: i64,
|
||||
scopes: i64,
|
||||
created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<AuthorizationQueryResult> for OAuthClientAuthorization {
|
||||
fn from(value: AuthorizationQueryResult) -> Self {
|
||||
OAuthClientAuthorization {
|
||||
id: OAuthClientAuthorizationId(value.id),
|
||||
client_id: OAuthClientId(value.client_id),
|
||||
user_id: UserId(value.user_id),
|
||||
scopes: Scopes::from_postgres(value.scopes),
|
||||
created: value.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthClientAuthorization {
|
||||
pub async fn get(
|
||||
client_id: OAuthClientId,
|
||||
user_id: UserId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Option<OAuthClientAuthorization>, DatabaseError> {
|
||||
let value = sqlx::query_as!(
|
||||
AuthorizationQueryResult,
|
||||
"
|
||||
SELECT id, client_id, user_id, scopes, created
|
||||
FROM oauth_client_authorizations
|
||||
WHERE client_id=$1 AND user_id=$2
|
||||
",
|
||||
client_id.0,
|
||||
user_id.0,
|
||||
)
|
||||
.fetch_optional(exec)
|
||||
.await?;
|
||||
|
||||
Ok(value.map(|r| r.into()))
|
||||
}
|
||||
|
||||
pub async fn get_all_for_user(
|
||||
user_id: UserId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Vec<OAuthClientAuthorization>, DatabaseError> {
|
||||
let results = sqlx::query_as!(
|
||||
AuthorizationQueryResult,
|
||||
"
|
||||
SELECT id, client_id, user_id, scopes, created
|
||||
FROM oauth_client_authorizations
|
||||
WHERE user_id=$1
|
||||
",
|
||||
user_id.0
|
||||
)
|
||||
.fetch_all(exec)
|
||||
.await?;
|
||||
|
||||
Ok(results.into_iter().map(|r| r.into()).collect_vec())
|
||||
}
|
||||
|
||||
pub async fn upsert(
|
||||
id: OAuthClientAuthorizationId,
|
||||
client_id: OAuthClientId,
|
||||
user_id: UserId,
|
||||
scopes: Scopes,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT INTO oauth_client_authorizations (
|
||||
id, client_id, user_id, scopes
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3, $4
|
||||
)
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET scopes = EXCLUDED.scopes
|
||||
",
|
||||
id.0,
|
||||
client_id.0,
|
||||
user_id.0,
|
||||
scopes.bits() as i64,
|
||||
)
|
||||
.execute(&mut **transaction)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove(
|
||||
client_id: OAuthClientId,
|
||||
user_id: UserId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
sqlx::query!(
|
||||
"
|
||||
DELETE FROM oauth_client_authorizations
|
||||
WHERE client_id=$1 AND user_id=$2
|
||||
",
|
||||
client_id.0,
|
||||
user_id.0
|
||||
)
|
||||
.execute(exec)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
245
src/database/models/oauth_client_item.rs
Normal file
245
src/database/models/oauth_client_item.rs
Normal file
@ -0,0 +1,245 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Digest;
|
||||
|
||||
use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId};
|
||||
use crate::models::pats::Scopes;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
||||
pub struct OAuthRedirectUri {
|
||||
pub id: OAuthRedirectUriId,
|
||||
pub client_id: OAuthClientId,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
||||
pub struct OAuthClient {
|
||||
pub id: OAuthClientId,
|
||||
pub name: String,
|
||||
pub icon_url: Option<String>,
|
||||
pub max_scopes: Scopes,
|
||||
pub secret_hash: String,
|
||||
pub redirect_uris: Vec<OAuthRedirectUri>,
|
||||
pub created: DateTime<Utc>,
|
||||
pub created_by: UserId,
|
||||
}
|
||||
|
||||
struct ClientQueryResult {
|
||||
id: i64,
|
||||
name: String,
|
||||
icon_url: Option<String>,
|
||||
max_scopes: i64,
|
||||
secret_hash: String,
|
||||
created: DateTime<Utc>,
|
||||
created_by: i64,
|
||||
uri_ids: Option<Vec<i64>>,
|
||||
uri_vals: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
macro_rules! select_clients_with_predicate {
|
||||
($predicate:tt, $param:ident) => {
|
||||
// The columns in this query have nullability type hints, because for some reason
|
||||
// the combination of the JOIN and filter using ANY makes sqlx think all columns are nullable
|
||||
// https://docs.rs/sqlx/latest/sqlx/macro.query.html#force-nullable
|
||||
sqlx::query_as!(
|
||||
ClientQueryResult,
|
||||
r#"
|
||||
SELECT
|
||||
clients.id as "id!",
|
||||
clients.name as "name!",
|
||||
clients.icon_url as "icon_url?",
|
||||
clients.max_scopes as "max_scopes!",
|
||||
clients.secret_hash as "secret_hash!",
|
||||
clients.created as "created!",
|
||||
clients.created_by as "created_by!",
|
||||
uris.uri_ids as "uri_ids?",
|
||||
uris.uri_vals as "uri_vals?"
|
||||
FROM oauth_clients clients
|
||||
LEFT JOIN (
|
||||
SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals
|
||||
FROM oauth_client_redirect_uris
|
||||
GROUP BY client_id
|
||||
) uris ON clients.id = uris.client_id
|
||||
"#
|
||||
+ $predicate,
|
||||
$param
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub async fn get(
|
||||
id: OAuthClientId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Option<OAuthClient>, DatabaseError> {
|
||||
Ok(Self::get_many(&[id], exec).await?.into_iter().next())
|
||||
}
|
||||
|
||||
pub async fn get_many(
|
||||
ids: &[OAuthClientId],
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Vec<OAuthClient>, DatabaseError> {
|
||||
let ids = ids.iter().map(|id| id.0).collect_vec();
|
||||
let ids_ref: &[i64] = &ids;
|
||||
let results =
|
||||
select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref)
|
||||
.fetch_all(exec)
|
||||
.await?;
|
||||
|
||||
Ok(results.into_iter().map(|r| r.into()).collect_vec())
|
||||
}
|
||||
|
||||
pub async fn get_all_user_clients(
|
||||
user_id: UserId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Vec<OAuthClient>, DatabaseError> {
|
||||
let user_id_param = user_id.0;
|
||||
let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param)
|
||||
.fetch_all(exec)
|
||||
.await?;
|
||||
|
||||
Ok(clients.into_iter().map(|r| r.into()).collect())
|
||||
}
|
||||
|
||||
pub async fn remove(
|
||||
id: OAuthClientId,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
// Cascades to oauth_client_redirect_uris, oauth_client_authorizations
|
||||
sqlx::query!(
|
||||
"
|
||||
DELETE FROM oauth_clients
|
||||
WHERE id = $1
|
||||
",
|
||||
id.0
|
||||
)
|
||||
.execute(exec)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert(
|
||||
&self,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT INTO oauth_clients (
|
||||
id, name, icon_url, max_scopes, secret_hash, created_by
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
)
|
||||
",
|
||||
self.id.0,
|
||||
self.name,
|
||||
self.icon_url,
|
||||
self.max_scopes.to_postgres(),
|
||||
self.secret_hash,
|
||||
self.created_by.0
|
||||
)
|
||||
.execute(&mut **transaction)
|
||||
.await?;
|
||||
|
||||
Self::insert_redirect_uris(&self.redirect_uris, &mut **transaction).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_editable_fields(
|
||||
&self,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
sqlx::query!(
|
||||
"
|
||||
UPDATE oauth_clients
|
||||
SET name = $1, icon_url = $2, max_scopes = $3
|
||||
WHERE (id = $4)
|
||||
",
|
||||
self.name,
|
||||
self.icon_url,
|
||||
self.max_scopes.to_postgres(),
|
||||
self.id.0,
|
||||
)
|
||||
.execute(exec)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_redirect_uris(
|
||||
ids: impl IntoIterator<Item = OAuthRedirectUriId>,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let ids = ids.into_iter().map(|id| id.0).collect_vec();
|
||||
sqlx::query!(
|
||||
"
|
||||
DELETE FROM oauth_client_redirect_uris
|
||||
WHERE id IN
|
||||
(SELECT * FROM UNNEST($1::bigint[]))
|
||||
",
|
||||
&ids[..]
|
||||
)
|
||||
.execute(exec)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_redirect_uris(
|
||||
uris: &[OAuthRedirectUri],
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = uris
|
||||
.iter()
|
||||
.map(|r| (r.id.0, r.client_id.0, r.uri.clone()))
|
||||
.multiunzip();
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT INTO oauth_client_redirect_uris (id, client_id, uri)
|
||||
SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])
|
||||
",
|
||||
&ids[..],
|
||||
&client_ids[..],
|
||||
&uris[..],
|
||||
)
|
||||
.execute(exec)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn hash_secret(secret: &str) -> String {
|
||||
format!("{:x}", sha2::Sha512::digest(secret.as_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientQueryResult> for OAuthClient {
|
||||
fn from(r: ClientQueryResult) -> Self {
|
||||
let redirects = if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) {
|
||||
ids.iter()
|
||||
.zip(uris.iter())
|
||||
.map(|(id, uri)| OAuthRedirectUri {
|
||||
id: OAuthRedirectUriId(*id),
|
||||
client_id: OAuthClientId(r.id),
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
OAuthClient {
|
||||
id: OAuthClientId(r.id),
|
||||
name: r.name,
|
||||
icon_url: r.icon_url,
|
||||
max_scopes: Scopes::from_postgres(r.max_scopes),
|
||||
secret_hash: r.secret_hash,
|
||||
redirect_uris: redirects,
|
||||
created: r.created,
|
||||
created_by: UserId(r.created_by),
|
||||
}
|
||||
}
|
||||
}
|
||||
95
src/database/models/oauth_token_item.rs
Normal file
95
src/database/models/oauth_token_item.rs
Normal file
@ -0,0 +1,95 @@
|
||||
use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuthClientId, UserId};
|
||||
use crate::models::pats::Scopes;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Digest;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
||||
pub struct OAuthAccessToken {
|
||||
pub id: OAuthAccessTokenId,
|
||||
pub authorization_id: OAuthClientAuthorizationId,
|
||||
pub token_hash: String,
|
||||
pub scopes: Scopes,
|
||||
pub created: DateTime<Utc>,
|
||||
pub expires: DateTime<Utc>,
|
||||
pub last_used: Option<DateTime<Utc>>,
|
||||
|
||||
// Stored separately inside oauth_client_authorizations table
|
||||
pub client_id: OAuthClientId,
|
||||
pub user_id: UserId,
|
||||
}
|
||||
|
||||
impl OAuthAccessToken {
|
||||
pub async fn get(
|
||||
token_hash: String,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<Option<OAuthAccessToken>, DatabaseError> {
|
||||
let value = sqlx::query!(
|
||||
"
|
||||
SELECT
|
||||
tokens.id,
|
||||
tokens.authorization_id,
|
||||
tokens.token_hash,
|
||||
tokens.scopes,
|
||||
tokens.created,
|
||||
tokens.expires,
|
||||
tokens.last_used,
|
||||
auths.client_id,
|
||||
auths.user_id
|
||||
FROM oauth_access_tokens tokens
|
||||
JOIN oauth_client_authorizations auths
|
||||
ON tokens.authorization_id = auths.id
|
||||
WHERE tokens.token_hash = $1
|
||||
",
|
||||
token_hash
|
||||
)
|
||||
.fetch_optional(exec)
|
||||
.await?;
|
||||
|
||||
Ok(value.map(|r| OAuthAccessToken {
|
||||
id: OAuthAccessTokenId(r.id),
|
||||
authorization_id: OAuthClientAuthorizationId(r.authorization_id),
|
||||
token_hash: r.token_hash,
|
||||
scopes: Scopes::from_postgres(r.scopes),
|
||||
created: r.created,
|
||||
expires: r.expires,
|
||||
last_used: r.last_used,
|
||||
client_id: OAuthClientId(r.client_id),
|
||||
user_id: UserId(r.user_id),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Inserts and returns the time until the token expires
|
||||
pub async fn insert(
|
||||
&self,
|
||||
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
) -> Result<chrono::Duration, DatabaseError> {
|
||||
let r = sqlx::query!(
|
||||
"
|
||||
INSERT INTO oauth_access_tokens (
|
||||
id, authorization_id, token_hash, scopes, last_used
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3, $4, $5
|
||||
)
|
||||
RETURNING created, expires
|
||||
",
|
||||
self.id.0,
|
||||
self.authorization_id.0,
|
||||
self.token_hash,
|
||||
self.scopes.to_postgres(),
|
||||
Option::<DateTime<Utc>>::None
|
||||
)
|
||||
.fetch_one(exec)
|
||||
.await?;
|
||||
|
||||
let (created, expires) = (r.created, r.expires);
|
||||
let time_until_expiration = expires - created;
|
||||
|
||||
Ok(time_until_expiration)
|
||||
}
|
||||
|
||||
pub fn hash_token(token: &str) -> String {
|
||||
format!("{:x}", sha2::Sha512::digest(token.as_bytes()))
|
||||
}
|
||||
}
|
||||
@ -3,6 +3,8 @@ use thiserror::Error;
|
||||
pub use super::collections::CollectionId;
|
||||
pub use super::images::ImageId;
|
||||
pub use super::notifications::NotificationId;
|
||||
pub use super::oauth_clients::OAuthClientAuthorizationId;
|
||||
pub use super::oauth_clients::{OAuthClientId, OAuthRedirectUriId};
|
||||
pub use super::organizations::OrganizationId;
|
||||
pub use super::pats::PatId;
|
||||
pub use super::projects::{ProjectId, VersionId};
|
||||
@ -122,6 +124,9 @@ base62_id_impl!(ThreadMessageId, ThreadMessageId);
|
||||
base62_id_impl!(SessionId, SessionId);
|
||||
base62_id_impl!(PatId, PatId);
|
||||
base62_id_impl!(ImageId, ImageId);
|
||||
base62_id_impl!(OAuthClientId, OAuthClientId);
|
||||
base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId);
|
||||
base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId);
|
||||
|
||||
pub mod base62_impl {
|
||||
use serde::de::{self, Deserializer, Visitor};
|
||||
|
||||
@ -4,6 +4,7 @@ pub mod error;
|
||||
pub mod ids;
|
||||
pub mod images;
|
||||
pub mod notifications;
|
||||
pub mod oauth_clients;
|
||||
pub mod organizations;
|
||||
pub mod pack;
|
||||
pub mod pats;
|
||||
|
||||
110
src/models/oauth_clients.rs
Normal file
110
src/models/oauth_clients.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use super::{
|
||||
ids::{Base62Id, UserId},
|
||||
pats::Scopes,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization as DBOAuthClientAuthorization;
|
||||
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
|
||||
use crate::database::models::oauth_client_item::OAuthRedirectUri as DBOAuthRedirectUri;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(from = "Base62Id")]
|
||||
#[serde(into = "Base62Id")]
|
||||
pub struct OAuthClientId(pub u64);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(from = "Base62Id")]
|
||||
#[serde(into = "Base62Id")]
|
||||
pub struct OAuthClientAuthorizationId(pub u64);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(from = "Base62Id")]
|
||||
#[serde(into = "Base62Id")]
|
||||
pub struct OAuthRedirectUriId(pub u64);
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct OAuthRedirectUri {
|
||||
pub id: OAuthRedirectUriId,
|
||||
pub client_id: OAuthClientId,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct OAuthClientCreationResult {
|
||||
#[serde(flatten)]
|
||||
pub client: OAuthClient,
|
||||
|
||||
pub client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct OAuthClient {
|
||||
pub id: OAuthClientId,
|
||||
pub name: String,
|
||||
pub icon_url: Option<String>,
|
||||
|
||||
// The maximum scopes the client can request for OAuth
|
||||
pub max_scopes: Scopes,
|
||||
|
||||
// The valid URIs that can be redirected to during an authorization request
|
||||
pub redirect_uris: Vec<OAuthRedirectUri>,
|
||||
|
||||
// The user that created (and thus controls) this client
|
||||
pub created_by: UserId,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct OAuthClientAuthorization {
|
||||
pub id: OAuthClientAuthorizationId,
|
||||
pub app_id: OAuthClientId,
|
||||
pub user_id: UserId,
|
||||
pub scopes: Scopes,
|
||||
pub created: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct GetOAuthClientsRequest {
|
||||
pub ids: Vec<OAuthClientId>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct DeleteOAuthClientQueryParam {
|
||||
pub client_id: OAuthClientId,
|
||||
}
|
||||
|
||||
impl From<DBOAuthClient> for OAuthClient {
|
||||
fn from(value: DBOAuthClient) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
name: value.name,
|
||||
icon_url: value.icon_url,
|
||||
max_scopes: value.max_scopes,
|
||||
redirect_uris: value.redirect_uris.into_iter().map(|r| r.into()).collect(),
|
||||
created_by: value.created_by.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DBOAuthRedirectUri> for OAuthRedirectUri {
|
||||
fn from(value: DBOAuthRedirectUri) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
client_id: value.client_id.into(),
|
||||
uri: value.uri,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DBOAuthClientAuthorization> for OAuthClientAuthorization {
|
||||
fn from(value: DBOAuthClientAuthorization) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
app_id: value.client_id.into(),
|
||||
user_id: value.user_id.into(),
|
||||
scopes: value.scopes,
|
||||
created: value.created,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -103,6 +103,9 @@ bitflags::bitflags! {
|
||||
// delete an organization
|
||||
const ORGANIZATION_DELETE = 1 << 38;
|
||||
|
||||
// only accessible by modrinth-issued sessions
|
||||
const SESSION_ACCESS = 1 << 39;
|
||||
|
||||
const NONE = 0b0;
|
||||
}
|
||||
}
|
||||
@ -118,6 +121,7 @@ impl Scopes {
|
||||
| Scopes::PAT_DELETE
|
||||
| Scopes::SESSION_READ
|
||||
| Scopes::SESSION_DELETE
|
||||
| Scopes::SESSION_ACCESS
|
||||
| Scopes::USER_AUTH_WRITE
|
||||
| Scopes::USER_DELETE
|
||||
| Scopes::PERFORM_ANALYTICS
|
||||
@ -126,6 +130,19 @@ impl Scopes {
|
||||
pub fn is_restricted(&self) -> bool {
|
||||
self.intersects(Self::restricted())
|
||||
}
|
||||
|
||||
pub fn parse_from_oauth_scopes(scopes: &str) -> Result<Scopes, bitflags::parser::ParseError> {
|
||||
let scopes = scopes.replace(' ', "|").replace("%20", "|");
|
||||
bitflags::parser::from_str(&scopes)
|
||||
}
|
||||
|
||||
pub fn to_postgres(&self) -> i64 {
|
||||
self.bits() as i64
|
||||
}
|
||||
|
||||
pub fn from_postgres(value: i64) -> Self {
|
||||
Self::from_bits(value as u64).unwrap_or(Scopes::NONE)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -161,3 +178,64 @@ impl PersonalAccessToken {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use itertools::Itertools;
|
||||
|
||||
#[test]
|
||||
fn test_parse_from_oauth_scopes_well_formed() {
|
||||
let raw = "USER_READ_EMAIL SESSION_READ ORGANIZATION_CREATE";
|
||||
let expected = Scopes::USER_READ_EMAIL | Scopes::SESSION_READ | Scopes::ORGANIZATION_CREATE;
|
||||
|
||||
let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap();
|
||||
|
||||
assert_same_flags(expected, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_from_oauth_scopes_empty() {
|
||||
let raw = "";
|
||||
let expected = Scopes::empty();
|
||||
|
||||
let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap();
|
||||
|
||||
assert_same_flags(expected, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_from_oauth_scopes_invalid_scopes() {
|
||||
let raw = "notascope";
|
||||
|
||||
let parsed = Scopes::parse_from_oauth_scopes(raw);
|
||||
|
||||
assert!(parsed.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_from_oauth_scopes_invalid_separator() {
|
||||
let raw = "USER_READ_EMAIL & SESSION_READ";
|
||||
|
||||
let parsed = Scopes::parse_from_oauth_scopes(raw);
|
||||
|
||||
assert!(parsed.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_from_oauth_scopes_url_encoded() {
|
||||
let raw = urlencoding::encode("PAT_WRITE COLLECTION_DELETE").to_string();
|
||||
let expected = Scopes::PAT_WRITE | Scopes::COLLECTION_DELETE;
|
||||
|
||||
let parsed = Scopes::parse_from_oauth_scopes(&raw).unwrap();
|
||||
|
||||
assert_same_flags(expected, parsed);
|
||||
}
|
||||
|
||||
fn assert_same_flags(expected: Scopes, actual: Scopes) {
|
||||
assert_eq!(
|
||||
expected.iter_names().map(|(name, _)| name).collect_vec(),
|
||||
actual.iter_names().map(|(name, _)| name).collect_vec()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
use crate::auth::session::SessionMetadata;
|
||||
use crate::database::models::pat_item::PersonalAccessToken;
|
||||
use crate::database::models::session_item::Session;
|
||||
use crate::database::models::{DatabaseError, PatId, SessionId, UserId};
|
||||
use crate::database::models::{DatabaseError, OAuthAccessTokenId, PatId, SessionId, UserId};
|
||||
use crate::database::redis::RedisPool;
|
||||
use chrono::Utc;
|
||||
use itertools::Itertools;
|
||||
use sqlx::PgPool;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use tokio::sync::Mutex;
|
||||
@ -11,6 +12,7 @@ use tokio::sync::Mutex;
|
||||
pub struct AuthQueue {
|
||||
session_queue: Mutex<HashMap<SessionId, SessionMetadata>>,
|
||||
pat_queue: Mutex<HashSet<PatId>>,
|
||||
oauth_access_token_queue: Mutex<HashSet<OAuthAccessTokenId>>,
|
||||
}
|
||||
|
||||
impl Default for AuthQueue {
|
||||
@ -25,6 +27,7 @@ impl AuthQueue {
|
||||
AuthQueue {
|
||||
session_queue: Mutex::new(HashMap::with_capacity(1000)),
|
||||
pat_queue: Mutex::new(HashSet::with_capacity(1000)),
|
||||
oauth_access_token_queue: Mutex::new(HashSet::with_capacity(1000)),
|
||||
}
|
||||
}
|
||||
pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) {
|
||||
@ -35,6 +38,10 @@ impl AuthQueue {
|
||||
self.pat_queue.lock().await.insert(id);
|
||||
}
|
||||
|
||||
pub async fn add_oauth_access_token(&self, id: crate::database::models::OAuthAccessTokenId) {
|
||||
self.oauth_access_token_queue.lock().await.insert(id);
|
||||
}
|
||||
|
||||
pub async fn take_sessions(&self) -> HashMap<SessionId, SessionMetadata> {
|
||||
let mut queue = self.session_queue.lock().await;
|
||||
let len = queue.len();
|
||||
@ -42,8 +49,8 @@ impl AuthQueue {
|
||||
std::mem::replace(&mut queue, HashMap::with_capacity(len))
|
||||
}
|
||||
|
||||
pub async fn take_pats(&self) -> HashSet<PatId> {
|
||||
let mut queue = self.pat_queue.lock().await;
|
||||
pub async fn take_hashset<T>(queue: &Mutex<HashSet<T>>) -> HashSet<T> {
|
||||
let mut queue = queue.lock().await;
|
||||
let len = queue.len();
|
||||
|
||||
std::mem::replace(&mut queue, HashSet::with_capacity(len))
|
||||
@ -51,9 +58,13 @@ impl AuthQueue {
|
||||
|
||||
pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> {
|
||||
let session_queue = self.take_sessions().await;
|
||||
let pat_queue = self.take_pats().await;
|
||||
let pat_queue = Self::take_hashset(&self.pat_queue).await;
|
||||
let oauth_access_token_queue = Self::take_hashset(&self.oauth_access_token_queue).await;
|
||||
|
||||
if !session_queue.is_empty() || !pat_queue.is_empty() {
|
||||
if !session_queue.is_empty()
|
||||
|| !pat_queue.is_empty()
|
||||
|| !oauth_access_token_queue.is_empty()
|
||||
{
|
||||
let mut transaction = pool.begin().await?;
|
||||
let mut clear_cache_sessions = Vec::new();
|
||||
|
||||
@ -102,29 +113,51 @@ impl AuthQueue {
|
||||
|
||||
Session::clear_cache(clear_cache_sessions, redis).await?;
|
||||
|
||||
let mut clear_cache_pats = Vec::new();
|
||||
|
||||
for id in pat_queue {
|
||||
clear_cache_pats.push((Some(id), None, None));
|
||||
|
||||
sqlx::query!(
|
||||
"
|
||||
UPDATE pats
|
||||
SET last_used = $2
|
||||
WHERE (id = $1)
|
||||
",
|
||||
id as PatId,
|
||||
Utc::now(),
|
||||
)
|
||||
.execute(&mut *transaction)
|
||||
.await?;
|
||||
}
|
||||
let ids = pat_queue.iter().map(|id| id.0).collect_vec();
|
||||
let clear_cache_pats = pat_queue
|
||||
.into_iter()
|
||||
.map(|id| (Some(id), None, None))
|
||||
.collect_vec();
|
||||
sqlx::query!(
|
||||
"
|
||||
UPDATE pats
|
||||
SET last_used = $2
|
||||
WHERE id IN
|
||||
(SELECT * FROM UNNEST($1::bigint[]))
|
||||
",
|
||||
&ids[..],
|
||||
Utc::now(),
|
||||
)
|
||||
.execute(&mut *transaction)
|
||||
.await?;
|
||||
|
||||
PersonalAccessToken::clear_cache(clear_cache_pats, redis).await?;
|
||||
|
||||
update_oauth_access_token_last_used(oauth_access_token_queue, &mut transaction).await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_oauth_access_token_last_used(
|
||||
oauth_access_token_queue: HashSet<OAuthAccessTokenId>,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let ids = oauth_access_token_queue.iter().map(|id| id.0).collect_vec();
|
||||
sqlx::query!(
|
||||
"
|
||||
UPDATE oauth_access_tokens
|
||||
SET last_used = $2
|
||||
WHERE id IN
|
||||
(SELECT * FROM UNNEST($1::bigint[]))
|
||||
",
|
||||
&ids[..],
|
||||
Utc::now()
|
||||
)
|
||||
.execute(&mut **transaction)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
pub use super::ApiError;
|
||||
use crate::util::cors::default_cors;
|
||||
use crate::{auth::oauth, util::cors::default_cors};
|
||||
use actix_web::{web, HttpResponse};
|
||||
use serde_json::json;
|
||||
|
||||
pub mod oauth_clients;
|
||||
|
||||
pub fn config(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(
|
||||
web::scope("v3")
|
||||
.wrap(default_cors())
|
||||
.route("", web::get().to(hello_world)),
|
||||
.route("", web::get().to(hello_world))
|
||||
.configure(oauth::config)
|
||||
.configure(oauth_clients::config),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
444
src/routes/v3/oauth_clients.rs
Normal file
444
src/routes/v3/oauth_clients.rs
Normal file
@ -0,0 +1,444 @@
|
||||
use std::{collections::HashSet, fmt::Display};
|
||||
|
||||
use actix_web::{
|
||||
delete, get, patch, post,
|
||||
web::{self, scope},
|
||||
HttpRequest, HttpResponse,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use itertools::Itertools;
|
||||
use rand::{distributions::Alphanumeric, Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use validator::Validate;
|
||||
|
||||
use super::ApiError;
|
||||
use crate::{
|
||||
auth::checks::ValidateAllAuthorized, models::oauth_clients::DeleteOAuthClientQueryParam,
|
||||
};
|
||||
use crate::{
|
||||
auth::{checks::ValidateAuthorized, get_user_from_headers},
|
||||
database::{
|
||||
models::{
|
||||
generate_oauth_client_id, generate_oauth_redirect_id,
|
||||
oauth_client_authorization_item::OAuthClientAuthorization,
|
||||
oauth_client_item::{OAuthClient, OAuthRedirectUri},
|
||||
DatabaseError, OAuthClientId, User,
|
||||
},
|
||||
redis::RedisPool,
|
||||
},
|
||||
models::{
|
||||
self,
|
||||
oauth_clients::{GetOAuthClientsRequest, OAuthClientCreationResult},
|
||||
pats::Scopes,
|
||||
},
|
||||
queue::session::AuthQueue,
|
||||
routes::v2::project_creation::CreateError,
|
||||
util::validate::validation_errors_to_string,
|
||||
};
|
||||
|
||||
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
|
||||
use crate::models::ids::OAuthClientId as ApiOAuthClientId;
|
||||
|
||||
pub fn config(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(get_user_clients);
|
||||
cfg.service(
|
||||
scope("oauth")
|
||||
.service(oauth_client_create)
|
||||
.service(oauth_client_edit)
|
||||
.service(oauth_client_delete)
|
||||
.service(get_client)
|
||||
.service(get_clients)
|
||||
.service(get_user_oauth_authorizations)
|
||||
.service(revoke_oauth_authorization),
|
||||
);
|
||||
}
|
||||
|
||||
#[get("user/{user_id}/oauth_apps")]
|
||||
pub async fn get_user_clients(
|
||||
req: HttpRequest,
|
||||
info: web::Path<String>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let target_user = User::get(&info.into_inner(), &**pool, &redis).await?;
|
||||
|
||||
if let Some(target_user) = target_user {
|
||||
let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?;
|
||||
clients
|
||||
.iter()
|
||||
.validate_all_authorized(Some(¤t_user))?;
|
||||
|
||||
let response = clients
|
||||
.into_iter()
|
||||
.map(models::oauth_clients::OAuthClient::from)
|
||||
.collect_vec();
|
||||
|
||||
Ok(HttpResponse::Ok().json(response))
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().body(""))
|
||||
}
|
||||
}
|
||||
|
||||
#[get("app/{id}")]
|
||||
pub async fn get_client(
|
||||
req: HttpRequest,
|
||||
id: web::Path<ApiOAuthClientId>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let clients = get_clients_inner(&[id.into_inner()], req, pool, redis, session_queue).await?;
|
||||
if let Some(client) = clients.into_iter().next() {
|
||||
Ok(HttpResponse::Ok().json(client))
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().body(""))
|
||||
}
|
||||
}
|
||||
|
||||
#[get("apps")]
|
||||
pub async fn get_clients(
|
||||
req: HttpRequest,
|
||||
info: web::Json<GetOAuthClientsRequest>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let clients =
|
||||
get_clients_inner(&info.into_inner().ids, req, pool, redis, session_queue).await?;
|
||||
Ok(HttpResponse::Ok().json(clients))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Validate)]
|
||||
pub struct NewOAuthApp {
|
||||
#[validate(
|
||||
custom(function = "crate::util::validate::validate_name"),
|
||||
length(min = 3, max = 255)
|
||||
)]
|
||||
pub name: String,
|
||||
|
||||
#[validate(
|
||||
custom(function = "crate::util::validate::validate_url"),
|
||||
length(max = 255)
|
||||
)]
|
||||
pub icon_url: Option<String>,
|
||||
|
||||
#[validate(custom(function = "crate::util::validate::validate_no_restricted_scopes"))]
|
||||
pub max_scopes: Scopes,
|
||||
|
||||
pub redirect_uris: Vec<String>,
|
||||
}
|
||||
|
||||
#[post("app")]
|
||||
pub async fn oauth_client_create<'a>(
|
||||
req: HttpRequest,
|
||||
new_oauth_app: web::Json<NewOAuthApp>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, CreateError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
new_oauth_app
|
||||
.validate()
|
||||
.map_err(|e| CreateError::ValidationError(validation_errors_to_string(e, None)))?;
|
||||
|
||||
let mut transaction = pool.begin().await?;
|
||||
|
||||
let client_id = generate_oauth_client_id(&mut transaction).await?;
|
||||
|
||||
let client_secret = generate_oauth_client_secret();
|
||||
let client_secret_hash = DBOAuthClient::hash_secret(&client_secret);
|
||||
|
||||
let redirect_uris =
|
||||
create_redirect_uris(&new_oauth_app.redirect_uris, client_id, &mut transaction).await?;
|
||||
|
||||
let client = OAuthClient {
|
||||
id: client_id,
|
||||
icon_url: new_oauth_app.icon_url.clone(),
|
||||
max_scopes: new_oauth_app.max_scopes,
|
||||
name: new_oauth_app.name.clone(),
|
||||
redirect_uris,
|
||||
created: Utc::now(),
|
||||
created_by: current_user.id.into(),
|
||||
secret_hash: client_secret_hash,
|
||||
};
|
||||
client.clone().insert(&mut transaction).await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
let client = models::oauth_clients::OAuthClient::from(client);
|
||||
|
||||
Ok(HttpResponse::Ok().json(OAuthClientCreationResult {
|
||||
client,
|
||||
client_secret,
|
||||
}))
|
||||
}
|
||||
|
||||
#[delete("app/{id}")]
|
||||
pub async fn oauth_client_delete<'a>(
|
||||
req: HttpRequest,
|
||||
client_id: web::Path<ApiOAuthClientId>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let client = OAuthClient::get(client_id.into_inner().into(), &**pool).await?;
|
||||
if let Some(client) = client {
|
||||
client.validate_authorized(Some(¤t_user))?;
|
||||
OAuthClient::remove(client.id, &**pool).await?;
|
||||
|
||||
Ok(HttpResponse::NoContent().body(""))
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().body(""))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Validate)]
|
||||
pub struct OAuthClientEdit {
|
||||
#[validate(
|
||||
custom(function = "crate::util::validate::validate_name"),
|
||||
length(min = 3, max = 255)
|
||||
)]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[validate(
|
||||
custom(function = "crate::util::validate::validate_url"),
|
||||
length(max = 255)
|
||||
)]
|
||||
pub icon_url: Option<Option<String>>,
|
||||
|
||||
pub max_scopes: Option<Scopes>,
|
||||
|
||||
#[validate(length(min = 1))]
|
||||
pub redirect_uris: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[patch("app/{id}")]
|
||||
pub async fn oauth_client_edit(
|
||||
req: HttpRequest,
|
||||
client_id: web::Path<ApiOAuthClientId>,
|
||||
client_updates: web::Json<OAuthClientEdit>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
client_updates
|
||||
.validate()
|
||||
.map_err(|e| ApiError::Validation(validation_errors_to_string(e, None)))?;
|
||||
|
||||
if client_updates.icon_url.is_none()
|
||||
&& client_updates.name.is_none()
|
||||
&& client_updates.max_scopes.is_none()
|
||||
{
|
||||
return Err(ApiError::InvalidInput("No changes provided".to_string()));
|
||||
}
|
||||
|
||||
if let Some(existing_client) = OAuthClient::get(client_id.into_inner().into(), &**pool).await? {
|
||||
existing_client.validate_authorized(Some(¤t_user))?;
|
||||
|
||||
let mut updated_client = existing_client.clone();
|
||||
let OAuthClientEdit {
|
||||
name,
|
||||
icon_url,
|
||||
max_scopes,
|
||||
redirect_uris,
|
||||
} = client_updates.into_inner();
|
||||
if let Some(name) = name {
|
||||
updated_client.name = name;
|
||||
}
|
||||
|
||||
if let Some(icon_url) = icon_url {
|
||||
updated_client.icon_url = icon_url;
|
||||
}
|
||||
|
||||
if let Some(max_scopes) = max_scopes {
|
||||
updated_client.max_scopes = max_scopes;
|
||||
}
|
||||
|
||||
let mut transaction = pool.begin().await?;
|
||||
updated_client
|
||||
.update_editable_fields(&mut *transaction)
|
||||
.await?;
|
||||
|
||||
if let Some(redirects) = redirect_uris {
|
||||
edit_redirects(redirects, &existing_client, &mut transaction).await?;
|
||||
}
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
Ok(HttpResponse::Ok().body(""))
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().body(""))
|
||||
}
|
||||
}
|
||||
|
||||
#[get("authorizations")]
|
||||
pub async fn get_user_oauth_authorizations(
|
||||
req: HttpRequest,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let authorizations =
|
||||
OAuthClientAuthorization::get_all_for_user(current_user.id.into(), &**pool).await?;
|
||||
|
||||
let mapped: Vec<models::oauth_clients::OAuthClientAuthorization> =
|
||||
authorizations.into_iter().map(|a| a.into()).collect_vec();
|
||||
|
||||
Ok(HttpResponse::Ok().json(mapped))
|
||||
}
|
||||
|
||||
#[delete("authorizations")]
|
||||
pub async fn revoke_oauth_authorization(
|
||||
req: HttpRequest,
|
||||
info: web::Query<DeleteOAuthClientQueryParam>,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
OAuthClientAuthorization::remove(info.client_id.into(), current_user.id.into(), &**pool)
|
||||
.await?;
|
||||
|
||||
Ok(HttpResponse::Ok().body(""))
|
||||
}
|
||||
|
||||
fn generate_oauth_client_secret() -> String {
|
||||
ChaCha20Rng::from_entropy()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect::<String>()
|
||||
}
|
||||
|
||||
async fn create_redirect_uris(
|
||||
uri_strings: impl IntoIterator<Item = impl Display>,
|
||||
client_id: OAuthClientId,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||
) -> Result<Vec<OAuthRedirectUri>, DatabaseError> {
|
||||
let mut redirect_uris = vec![];
|
||||
for uri in uri_strings.into_iter() {
|
||||
let id = generate_oauth_redirect_id(transaction).await?;
|
||||
redirect_uris.push(OAuthRedirectUri {
|
||||
id,
|
||||
client_id,
|
||||
uri: uri.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(redirect_uris)
|
||||
}
|
||||
|
||||
async fn edit_redirects(
|
||||
redirects: Vec<String>,
|
||||
existing_client: &OAuthClient,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let updated_redirects: HashSet<String> = redirects.into_iter().collect();
|
||||
let original_redirects: HashSet<String> = existing_client
|
||||
.redirect_uris
|
||||
.iter()
|
||||
.map(|r| r.uri.to_string())
|
||||
.collect();
|
||||
|
||||
let redirects_to_add = create_redirect_uris(
|
||||
updated_redirects.difference(&original_redirects),
|
||||
existing_client.id,
|
||||
&mut *transaction,
|
||||
)
|
||||
.await?;
|
||||
OAuthClient::insert_redirect_uris(&redirects_to_add, &mut **transaction).await?;
|
||||
|
||||
let mut redirects_to_remove = existing_client.redirect_uris.clone();
|
||||
redirects_to_remove.retain(|r| !updated_redirects.contains(&r.uri));
|
||||
OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut **transaction)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_clients_inner(
|
||||
ids: &[ApiOAuthClientId],
|
||||
req: HttpRequest,
|
||||
pool: web::Data<PgPool>,
|
||||
redis: web::Data<RedisPool>,
|
||||
session_queue: web::Data<AuthQueue>,
|
||||
) -> Result<Vec<models::oauth_clients::OAuthClient>, ApiError> {
|
||||
let current_user = get_user_from_headers(
|
||||
&req,
|
||||
&**pool,
|
||||
&redis,
|
||||
&session_queue,
|
||||
Some(&[Scopes::SESSION_ACCESS]),
|
||||
)
|
||||
.await?
|
||||
.1;
|
||||
|
||||
let ids: Vec<OAuthClientId> = ids.iter().map(|i| (*i).into()).collect();
|
||||
let clients = OAuthClient::get_many(&ids, &**pool).await?;
|
||||
clients
|
||||
.iter()
|
||||
.validate_all_authorized(Some(¤t_user))?;
|
||||
|
||||
Ok(clients.into_iter().map(|c| c.into()).collect_vec())
|
||||
}
|
||||
@ -3,6 +3,8 @@ use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use validator::{ValidationErrors, ValidationErrorsKind};
|
||||
|
||||
use crate::models::pats::Scopes;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref RE_URL_SAFE: Regex = Regex::new(r#"^[a-zA-Z0-9!@$()`.+,_"-]*$"#).unwrap();
|
||||
}
|
||||
@ -91,6 +93,16 @@ pub fn validate_url(value: &str) -> Result<(), validator::ValidationError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_no_restricted_scopes(value: &Scopes) -> Result<(), validator::ValidationError> {
|
||||
if value.is_restricted() {
|
||||
return Err(validator::ValidationError::new(
|
||||
"Restricted scopes not allowed",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_name(value: &str) -> Result<(), validator::ValidationError> {
|
||||
if value.trim().is_empty() {
|
||||
return Err(validator::ValidationError::new(
|
||||
|
||||
@ -29,7 +29,7 @@ impl ApiV2 {
|
||||
.set_multipart(creation_data.segment_data)
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
assert_status(resp, StatusCode::OK);
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
// Approve as a moderator.
|
||||
let req = TestRequest::patch()
|
||||
@ -42,7 +42,7 @@ impl ApiV2 {
|
||||
))
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
assert_status(resp, StatusCode::NO_CONTENT);
|
||||
assert_status(&resp, StatusCode::NO_CONTENT);
|
||||
|
||||
let project = self
|
||||
.get_project_deserialized(&creation_data.slug, pat)
|
||||
@ -82,16 +82,20 @@ impl ApiV2 {
|
||||
test::read_body_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn get_user_projects(&self, user_id_or_username: &str, pat: &str) -> ServiceResponse {
|
||||
let req = test::TestRequest::get()
|
||||
.uri(&format!("/v2/user/{}/projects", user_id_or_username))
|
||||
.append_header(("Authorization", pat))
|
||||
.to_request();
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn get_user_projects_deserialized(
|
||||
&self,
|
||||
user_id_or_username: &str,
|
||||
pat: &str,
|
||||
) -> Vec<Project> {
|
||||
let req = test::TestRequest::get()
|
||||
.uri(&format!("/v2/user/{}/projects", user_id_or_username))
|
||||
.append_header(("Authorization", pat))
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
let resp = self.get_user_projects(user_id_or_username, pat).await;
|
||||
assert_eq!(resp.status(), 200);
|
||||
test::read_body_json(resp).await
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::{dev::ServiceResponse, test};
|
||||
use labrinth::models::{
|
||||
notifications::Notification,
|
||||
@ -5,6 +6,8 @@ use labrinth::models::{
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::common::asserts::assert_status;
|
||||
|
||||
use super::ApiV2;
|
||||
|
||||
impl ApiV2 {
|
||||
@ -114,16 +117,21 @@ impl ApiV2 {
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn get_user_notifications(&self, user_id: &str, pat: &str) -> ServiceResponse {
|
||||
let req = test::TestRequest::get()
|
||||
.uri(&format!("/v2/user/{user_id}/notifications"))
|
||||
.append_header(("Authorization", pat))
|
||||
.to_request();
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn get_user_notifications_deserialized(
|
||||
&self,
|
||||
user_id: &str,
|
||||
pat: &str,
|
||||
) -> Vec<Notification> {
|
||||
let req = test::TestRequest::get()
|
||||
.uri(&format!("/v2/user/{user_id}/notifications"))
|
||||
.append_header(("Authorization", pat))
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
let resp = self.get_user_notifications(user_id, pat).await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
test::read_body_json(resp).await
|
||||
}
|
||||
|
||||
|
||||
19
tests/common/api_v3/mod.rs
Normal file
19
tests/common/api_v3/mod.rs
Normal file
@ -0,0 +1,19 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::environment::LocalService;
|
||||
use actix_web::dev::ServiceResponse;
|
||||
use std::rc::Rc;
|
||||
|
||||
pub mod oauth;
|
||||
pub mod oauth_clients;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiV3 {
|
||||
pub test_app: Rc<dyn LocalService>,
|
||||
}
|
||||
|
||||
impl ApiV3 {
|
||||
pub async fn call(&self, req: actix_http::Request) -> ServiceResponse {
|
||||
self.test_app.call(req).await.unwrap()
|
||||
}
|
||||
}
|
||||
156
tests/common/api_v3/oauth.rs
Normal file
156
tests/common/api_v3/oauth.rs
Normal file
@ -0,0 +1,156 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::{
|
||||
dev::ServiceResponse,
|
||||
test::{self, TestRequest},
|
||||
};
|
||||
use labrinth::auth::oauth::{
|
||||
OAuthClientAccessRequest, RespondToOAuthClientScopes, TokenRequest, TokenResponse,
|
||||
};
|
||||
use reqwest::header::{AUTHORIZATION, LOCATION};
|
||||
|
||||
use crate::common::asserts::assert_status;
|
||||
|
||||
use super::ApiV3;
|
||||
|
||||
impl ApiV3 {
|
||||
pub async fn complete_full_authorize_flow(
|
||||
&self,
|
||||
client_id: &str,
|
||||
client_secret: &str,
|
||||
scope: Option<&str>,
|
||||
redirect_uri: Option<&str>,
|
||||
state: Option<&str>,
|
||||
user_pat: &str,
|
||||
) -> String {
|
||||
let auth_resp = self
|
||||
.oauth_authorize(client_id, scope, redirect_uri, state, user_pat)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(auth_resp).await;
|
||||
let redirect_resp = self.oauth_accept(&flow_id, user_pat).await;
|
||||
let auth_code = get_auth_code_from_redirect_params(&redirect_resp).await;
|
||||
let token_resp = self
|
||||
.oauth_token(auth_code, None, client_id.to_string(), client_secret)
|
||||
.await;
|
||||
get_access_token(token_resp).await
|
||||
}
|
||||
|
||||
pub async fn oauth_authorize(
|
||||
&self,
|
||||
client_id: &str,
|
||||
scope: Option<&str>,
|
||||
redirect_uri: Option<&str>,
|
||||
state: Option<&str>,
|
||||
pat: &str,
|
||||
) -> ServiceResponse {
|
||||
let uri = generate_authorize_uri(client_id, scope, redirect_uri, state);
|
||||
let req = TestRequest::get()
|
||||
.uri(&uri)
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn oauth_accept(&self, flow: &str, pat: &str) -> ServiceResponse {
|
||||
self.call(
|
||||
TestRequest::post()
|
||||
.uri("/v3/auth/oauth/accept")
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.set_json(RespondToOAuthClientScopes {
|
||||
flow: flow.to_string(),
|
||||
})
|
||||
.to_request(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn oauth_reject(&self, flow: &str, pat: &str) -> ServiceResponse {
|
||||
self.call(
|
||||
TestRequest::post()
|
||||
.uri("/v3/auth/oauth/reject")
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.set_json(RespondToOAuthClientScopes {
|
||||
flow: flow.to_string(),
|
||||
})
|
||||
.to_request(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn oauth_token(
|
||||
&self,
|
||||
auth_code: String,
|
||||
original_redirect_uri: Option<String>,
|
||||
client_id: String,
|
||||
client_secret: &str,
|
||||
) -> ServiceResponse {
|
||||
self.call(
|
||||
TestRequest::post()
|
||||
.uri("/v3/auth/oauth/token")
|
||||
.append_header((AUTHORIZATION, client_secret))
|
||||
.set_form(TokenRequest {
|
||||
grant_type: "authorization_code".to_string(),
|
||||
code: auth_code,
|
||||
redirect_uri: original_redirect_uri,
|
||||
client_id: serde_json::from_str(&format!("\"{}\"", client_id)).unwrap(),
|
||||
})
|
||||
.to_request(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_authorize_uri(
|
||||
client_id: &str,
|
||||
scope: Option<&str>,
|
||||
redirect_uri: Option<&str>,
|
||||
state: Option<&str>,
|
||||
) -> String {
|
||||
format!(
|
||||
"/v3/auth/oauth/authorize?client_id={}{}{}{}",
|
||||
urlencoding::encode(client_id),
|
||||
optional_query_param("redirect_uri", redirect_uri),
|
||||
optional_query_param("scope", scope),
|
||||
optional_query_param("state", state),
|
||||
)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String {
|
||||
assert_status(&response, StatusCode::OK);
|
||||
test::read_body_json::<OAuthClientAccessRequest, _>(response)
|
||||
.await
|
||||
.flow_id
|
||||
}
|
||||
|
||||
pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String {
|
||||
assert_status(response, StatusCode::FOUND);
|
||||
let query_params = get_redirect_location_query_params(response);
|
||||
query_params.get("code").unwrap().to_string()
|
||||
}
|
||||
|
||||
pub async fn get_access_token(response: ServiceResponse) -> String {
|
||||
assert_status(&response, StatusCode::OK);
|
||||
test::read_body_json::<TokenResponse, _>(response)
|
||||
.await
|
||||
.access_token
|
||||
}
|
||||
|
||||
pub fn get_redirect_location_query_params(
|
||||
response: &ServiceResponse,
|
||||
) -> actix_web::web::Query<HashMap<String, String>> {
|
||||
let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap();
|
||||
actix_web::web::Query::<HashMap<String, String>>::from_query(
|
||||
redirect_location.split_once('?').unwrap().1,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn optional_query_param(key: &str, value: Option<&str>) -> String {
|
||||
if let Some(val) = value {
|
||||
format!("&{key}={}", urlencoding::encode(val))
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
}
|
||||
107
tests/common/api_v3/oauth_clients.rs
Normal file
107
tests/common/api_v3/oauth_clients.rs
Normal file
@ -0,0 +1,107 @@
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::{
|
||||
dev::ServiceResponse,
|
||||
test::{self, TestRequest},
|
||||
};
|
||||
use labrinth::{
|
||||
models::{
|
||||
oauth_clients::{OAuthClient, OAuthClientAuthorization},
|
||||
pats::Scopes,
|
||||
},
|
||||
routes::v3::oauth_clients::OAuthClientEdit,
|
||||
};
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::common::asserts::assert_status;
|
||||
|
||||
use super::ApiV3;
|
||||
|
||||
impl ApiV3 {
|
||||
pub async fn add_oauth_client(
|
||||
&self,
|
||||
name: String,
|
||||
max_scopes: Scopes,
|
||||
redirect_uris: Vec<String>,
|
||||
pat: &str,
|
||||
) -> ServiceResponse {
|
||||
let max_scopes = max_scopes.bits();
|
||||
let req = TestRequest::post()
|
||||
.uri("/v3/oauth/app")
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.set_json(json!({
|
||||
"name": name,
|
||||
"max_scopes": max_scopes,
|
||||
"redirect_uris": redirect_uris
|
||||
}))
|
||||
.to_request();
|
||||
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn get_user_oauth_clients(&self, user_id: &str, pat: &str) -> Vec<OAuthClient> {
|
||||
let req = TestRequest::get()
|
||||
.uri(&format!("/v3/user/{}/oauth_apps", user_id))
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
test::read_body_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn get_oauth_client(&self, client_id: String, pat: &str) -> ServiceResponse {
|
||||
let req = TestRequest::get()
|
||||
.uri(&format!("/v3/oauth/app/{}", client_id))
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn edit_oauth_client(
|
||||
&self,
|
||||
client_id: &str,
|
||||
edit: OAuthClientEdit,
|
||||
pat: &str,
|
||||
) -> ServiceResponse {
|
||||
let req = TestRequest::patch()
|
||||
.uri(&format!("/v3/oauth/app/{}", urlencoding::encode(client_id)))
|
||||
.set_json(edit)
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn delete_oauth_client(&self, client_id: &str, pat: &str) -> ServiceResponse {
|
||||
let req = TestRequest::delete()
|
||||
.uri(&format!("/v3/oauth/app/{}", client_id))
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn revoke_oauth_authorization(&self, client_id: &str, pat: &str) -> ServiceResponse {
|
||||
let req = TestRequest::delete()
|
||||
.uri(&format!(
|
||||
"/v3/oauth/authorizations?client_id={}",
|
||||
urlencoding::encode(client_id)
|
||||
))
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
self.call(req).await
|
||||
}
|
||||
|
||||
pub async fn get_user_oauth_authorizations(&self, pat: &str) -> Vec<OAuthClientAuthorization> {
|
||||
let req = TestRequest::get()
|
||||
.uri("/v3/oauth/authorizations")
|
||||
.append_header((AUTHORIZATION, pat))
|
||||
.to_request();
|
||||
let resp = self.call(req).await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
test::read_body_json(resp).await
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,12 @@
|
||||
pub fn assert_status(response: actix_web::dev::ServiceResponse, status: actix_http::StatusCode) {
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) {
|
||||
assert_eq!(response.status(), status, "{:#?}", response.response());
|
||||
}
|
||||
|
||||
pub fn assert_any_status_except(
|
||||
response: &actix_web::dev::ServiceResponse,
|
||||
status: actix_http::StatusCode,
|
||||
) {
|
||||
assert_ne!(response.status(), status, "{:#?}", response.response());
|
||||
}
|
||||
|
||||
@ -7,6 +7,8 @@ use url::Url;
|
||||
|
||||
use crate::common::{dummy_data, environment::TestEnvironment};
|
||||
|
||||
use super::dummy_data::DUMMY_DATA_UPDATE;
|
||||
|
||||
// The dummy test database adds a fair bit of 'dummy' data to test with.
|
||||
// Some constants are used to refer to that data, and are described here.
|
||||
// The rest can be accessed in the TestEnvironment 'dummy' field.
|
||||
@ -119,11 +121,7 @@ impl TemporaryDatabase {
|
||||
.await
|
||||
.unwrap();
|
||||
if db_exists.is_none() {
|
||||
let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}");
|
||||
sqlx::query(&create_db_query)
|
||||
.execute(&main_pool)
|
||||
.await
|
||||
.expect("Database creation failed");
|
||||
create_template_database(&main_pool).await;
|
||||
}
|
||||
|
||||
// Switch to template
|
||||
@ -135,30 +133,52 @@ impl TemporaryDatabase {
|
||||
.await
|
||||
.expect("Connection to database failed");
|
||||
|
||||
// Run migrations on the template
|
||||
let migrations = sqlx::migrate!("./migrations");
|
||||
migrations.run(&pool).await.expect("Migrations failed");
|
||||
|
||||
// Check if dummy data exists- a fake 'dummy_data' table is created if it does
|
||||
let dummy_data_exists: bool =
|
||||
let mut dummy_data_exists: bool =
|
||||
sqlx::query_scalar("SELECT to_regclass('dummy_data') IS NOT NULL")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
if dummy_data_exists {
|
||||
// Check if the dummy data needs to be updated
|
||||
let dummy_data_update =
|
||||
sqlx::query_scalar::<_, i64>("SELECT update_id FROM dummy_data")
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
let needs_update = !dummy_data_update.is_some_and(|d| d == DUMMY_DATA_UPDATE);
|
||||
if needs_update {
|
||||
println!("Dummy data updated, so template DB tables will be dropped and re-created");
|
||||
// Drop all tables in the database so they can be re-created and later filled with updated dummy data
|
||||
sqlx::query("DROP SCHEMA public CASCADE;")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
sqlx::query("CREATE SCHEMA public;")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
dummy_data_exists = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Run migrations on the template
|
||||
let migrations = sqlx::migrate!("./migrations");
|
||||
migrations.run(&pool).await.expect("Migrations failed");
|
||||
|
||||
if !dummy_data_exists {
|
||||
// Add dummy data
|
||||
let temporary_test_env = TestEnvironment::build_with_db(TemporaryDatabase {
|
||||
pool: pool.clone(),
|
||||
database_name: TEMPLATE_DATABASE_NAME.to_string(),
|
||||
redis_pool: RedisPool::new(None),
|
||||
redis_pool: RedisPool::new(Some(generate_random_name("test_template_"))),
|
||||
})
|
||||
.await;
|
||||
dummy_data::add_dummy_data(&temporary_test_env).await;
|
||||
temporary_test_env.db.pool.close().await;
|
||||
}
|
||||
pool.close().await;
|
||||
|
||||
// Switch back to main database (as we cant create from template while connected to it)
|
||||
let pool = PgPool::connect(url.as_str()).await.unwrap();
|
||||
drop(pool);
|
||||
|
||||
// Create the temporary database from the template
|
||||
let create_db_query = format!(
|
||||
@ -167,7 +187,7 @@ impl TemporaryDatabase {
|
||||
);
|
||||
|
||||
sqlx::query(&create_db_query)
|
||||
.execute(&pool)
|
||||
.execute(&main_pool)
|
||||
.await
|
||||
.expect("Database creation failed");
|
||||
|
||||
@ -216,6 +236,14 @@ impl TemporaryDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_template_database(pool: &sqlx::Pool<sqlx::Postgres>) {
|
||||
let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}");
|
||||
sqlx::query(&create_db_query)
|
||||
.execute(pool)
|
||||
.await
|
||||
.expect("Database creation failed");
|
||||
}
|
||||
|
||||
// Appends a random 8-digit number to the end of the str
|
||||
pub fn generate_random_name(str: &str) -> String {
|
||||
let mut str = String::from(str);
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
#![allow(dead_code)]
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::test::{self, TestRequest};
|
||||
use labrinth::{
|
||||
models::projects::Project,
|
||||
models::{organizations::Organization, pats::Scopes, projects::Version},
|
||||
models::{
|
||||
oauth_clients::OAuthClient, organizations::Organization, pats::Scopes, projects::Version,
|
||||
},
|
||||
};
|
||||
use serde_json::json;
|
||||
use sqlx::Executor;
|
||||
@ -11,11 +14,14 @@ use crate::common::{actix::AppendsMultipart, database::USER_USER_PAT};
|
||||
|
||||
use super::{
|
||||
actix::{MultipartSegment, MultipartSegmentData},
|
||||
asserts::assert_status,
|
||||
database::USER_USER_ID,
|
||||
environment::TestEnvironment,
|
||||
get_json_val_str,
|
||||
request_data::get_public_project_creation_data,
|
||||
};
|
||||
|
||||
pub const DUMMY_DATA_UPDATE: i64 = 1;
|
||||
pub const DUMMY_DATA_UPDATE: i64 = 3;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub const DUMMY_CATEGORIES: &[&str] = &[
|
||||
@ -28,6 +34,8 @@ pub const DUMMY_CATEGORIES: &[&str] = &[
|
||||
"optimization",
|
||||
];
|
||||
|
||||
pub const DUMMY_OAUTH_CLIENT_ALPHA_SECRET: &str = "abcdefghijklmnopqrstuvwxyz";
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub enum DummyJarFile {
|
||||
DummyProjectAlpha,
|
||||
@ -43,16 +51,80 @@ pub enum DummyImage {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyData {
|
||||
/// Alpha project:
|
||||
/// This is a dummy project created by USER user.
|
||||
/// It's approved, listed, and visible to the public.
|
||||
pub project_alpha: DummyProjectAlpha,
|
||||
|
||||
/// Beta project:
|
||||
/// This is a dummy project created by USER user.
|
||||
/// It's not approved, unlisted, and not visible to the public.
|
||||
pub project_beta: DummyProjectBeta,
|
||||
|
||||
/// Zeta organization:
|
||||
/// This is a dummy organization created by USER user.
|
||||
/// There are no projects in it.
|
||||
pub organization_zeta: DummyOrganizationZeta,
|
||||
|
||||
/// Alpha OAuth Client:
|
||||
/// This is a dummy OAuth client created by USER user.
|
||||
///
|
||||
/// All scopes are included in its max scopes
|
||||
///
|
||||
/// It has one valid redirect URI
|
||||
pub oauth_client_alpha: DummyOAuthClientAlpha,
|
||||
}
|
||||
|
||||
impl DummyData {
|
||||
pub fn new(
|
||||
project_alpha: Project,
|
||||
project_alpha_version: Version,
|
||||
project_beta: Project,
|
||||
project_beta_version: Version,
|
||||
organization_zeta: Organization,
|
||||
oauth_client_alpha: OAuthClient,
|
||||
) -> Self {
|
||||
DummyData {
|
||||
project_alpha: DummyProjectAlpha {
|
||||
team_id: project_alpha.team.to_string(),
|
||||
project_id: project_alpha.id.to_string(),
|
||||
project_slug: project_alpha.slug.unwrap(),
|
||||
version_id: project_alpha_version.id.to_string(),
|
||||
thread_id: project_alpha.thread_id.to_string(),
|
||||
file_hash: project_alpha_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
|
||||
project_beta: DummyProjectBeta {
|
||||
team_id: project_beta.team.to_string(),
|
||||
project_id: project_beta.id.to_string(),
|
||||
project_slug: project_beta.slug.unwrap(),
|
||||
version_id: project_beta_version.id.to_string(),
|
||||
thread_id: project_beta.thread_id.to_string(),
|
||||
file_hash: project_beta_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
|
||||
organization_zeta: DummyOrganizationZeta {
|
||||
organization_id: organization_zeta.id.to_string(),
|
||||
team_id: organization_zeta.team_id.to_string(),
|
||||
organization_title: organization_zeta.title,
|
||||
},
|
||||
|
||||
oauth_client_alpha: DummyOAuthClientAlpha {
|
||||
client_id: get_json_val_str(oauth_client_alpha.id),
|
||||
client_secret: DUMMY_OAUTH_CLIENT_ALPHA_SECRET.to_string(),
|
||||
valid_redirect_uri: oauth_client_alpha
|
||||
.redirect_uris
|
||||
.first()
|
||||
.unwrap()
|
||||
.uri
|
||||
.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyProjectAlpha {
|
||||
// Alpha project:
|
||||
// This is a dummy project created by USER user.
|
||||
// It's approved, listed, and visible to the public.
|
||||
pub project_id: String,
|
||||
pub project_slug: String,
|
||||
pub version_id: String,
|
||||
@ -63,9 +135,6 @@ pub struct DummyProjectAlpha {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyProjectBeta {
|
||||
// Beta project:
|
||||
// This is a dummy project created by USER user.
|
||||
// It's not approved, unlisted, and not visible to the public.
|
||||
pub project_id: String,
|
||||
pub project_slug: String,
|
||||
pub version_id: String,
|
||||
@ -76,14 +145,18 @@ pub struct DummyProjectBeta {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyOrganizationZeta {
|
||||
// Zeta organization:
|
||||
// This is a dummy organization created by USER user.
|
||||
// There are no projects in it.
|
||||
pub organization_id: String,
|
||||
pub organization_title: String,
|
||||
pub team_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyOAuthClientAlpha {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub valid_redirect_uri: String,
|
||||
}
|
||||
|
||||
pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData {
|
||||
// Adds basic dummy data to the database directly with sql (user, pats)
|
||||
let pool = &test_env.db.pool.clone();
|
||||
@ -101,37 +174,22 @@ pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData {
|
||||
|
||||
let zeta_organization = add_organization_zeta(test_env).await;
|
||||
|
||||
let oauth_client_alpha = get_oauth_client_alpha(test_env).await;
|
||||
|
||||
sqlx::query("INSERT INTO dummy_data (update_id) VALUES ($1)")
|
||||
.bind(DUMMY_DATA_UPDATE)
|
||||
.execute(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
DummyData {
|
||||
project_alpha: DummyProjectAlpha {
|
||||
team_id: alpha_project.team.to_string(),
|
||||
project_id: alpha_project.id.to_string(),
|
||||
project_slug: alpha_project.slug.unwrap(),
|
||||
version_id: alpha_version.id.to_string(),
|
||||
thread_id: alpha_project.thread_id.to_string(),
|
||||
file_hash: alpha_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
|
||||
project_beta: DummyProjectBeta {
|
||||
team_id: beta_project.team.to_string(),
|
||||
project_id: beta_project.id.to_string(),
|
||||
project_slug: beta_project.slug.unwrap(),
|
||||
version_id: beta_version.id.to_string(),
|
||||
thread_id: beta_project.thread_id.to_string(),
|
||||
file_hash: beta_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
|
||||
organization_zeta: DummyOrganizationZeta {
|
||||
organization_id: zeta_organization.id.to_string(),
|
||||
team_id: zeta_organization.team_id.to_string(),
|
||||
organization_title: zeta_organization.title,
|
||||
},
|
||||
}
|
||||
DummyData::new(
|
||||
alpha_project,
|
||||
alpha_version,
|
||||
beta_project,
|
||||
beta_version,
|
||||
zeta_organization,
|
||||
oauth_client_alpha,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData {
|
||||
@ -139,31 +197,17 @@ pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData {
|
||||
let (beta_project, beta_version) = get_project_beta(test_env).await;
|
||||
|
||||
let zeta_organization = get_organization_zeta(test_env).await;
|
||||
DummyData {
|
||||
project_alpha: DummyProjectAlpha {
|
||||
team_id: alpha_project.team.to_string(),
|
||||
project_id: alpha_project.id.to_string(),
|
||||
project_slug: alpha_project.slug.unwrap(),
|
||||
version_id: alpha_version.id.to_string(),
|
||||
thread_id: alpha_project.thread_id.to_string(),
|
||||
file_hash: alpha_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
|
||||
project_beta: DummyProjectBeta {
|
||||
team_id: beta_project.team.to_string(),
|
||||
project_id: beta_project.id.to_string(),
|
||||
project_slug: beta_project.slug.unwrap(),
|
||||
version_id: beta_version.id.to_string(),
|
||||
thread_id: beta_project.thread_id.to_string(),
|
||||
file_hash: beta_version.files[0].hashes["sha1"].clone(),
|
||||
},
|
||||
let oauth_client_alpha = get_oauth_client_alpha(test_env).await;
|
||||
|
||||
organization_zeta: DummyOrganizationZeta {
|
||||
organization_id: zeta_organization.id.to_string(),
|
||||
team_id: zeta_organization.team_id.to_string(),
|
||||
organization_title: zeta_organization.title,
|
||||
},
|
||||
}
|
||||
DummyData::new(
|
||||
alpha_project,
|
||||
alpha_version,
|
||||
beta_project,
|
||||
beta_version,
|
||||
zeta_organization,
|
||||
oauth_client_alpha,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn add_project_alpha(test_env: &TestEnvironment) -> (Project, Version) {
|
||||
@ -282,6 +326,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version)
|
||||
.append_header(("Authorization", USER_USER_PAT))
|
||||
.to_request();
|
||||
let resp = test_env.call(req).await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
let project: Project = test::read_body_json(resp).await;
|
||||
|
||||
// Get project's versions
|
||||
@ -290,6 +335,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version)
|
||||
.append_header(("Authorization", USER_USER_PAT))
|
||||
.to_request();
|
||||
let resp = test_env.call(req).await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
let versions: Vec<Version> = test::read_body_json(resp).await;
|
||||
let version = versions.into_iter().next().unwrap();
|
||||
|
||||
@ -308,6 +354,14 @@ pub async fn get_organization_zeta(test_env: &TestEnvironment) -> Organization {
|
||||
organization
|
||||
}
|
||||
|
||||
pub async fn get_oauth_client_alpha(test_env: &TestEnvironment) -> OAuthClient {
|
||||
let oauth_clients = test_env
|
||||
.v3
|
||||
.get_user_oauth_clients(USER_USER_ID, USER_USER_PAT)
|
||||
.await;
|
||||
oauth_clients.into_iter().next().unwrap()
|
||||
}
|
||||
|
||||
impl DummyJarFile {
|
||||
pub fn filename(&self) -> String {
|
||||
match self {
|
||||
|
||||
@ -4,6 +4,7 @@ use std::{rc::Rc, sync::Arc};
|
||||
|
||||
use super::{
|
||||
api_v2::ApiV2,
|
||||
api_v3::ApiV3,
|
||||
asserts::assert_status,
|
||||
database::{TemporaryDatabase, FRIEND_USER_ID, USER_USER_PAT},
|
||||
dummy_data,
|
||||
@ -34,6 +35,7 @@ pub struct TestEnvironment {
|
||||
test_app: Rc<dyn LocalService>, // Rc as it's not Send
|
||||
pub db: TemporaryDatabase,
|
||||
pub v2: ApiV2,
|
||||
pub v3: ApiV3,
|
||||
|
||||
pub dummy: Option<Arc<dummy_data::DummyData>>,
|
||||
}
|
||||
@ -56,6 +58,9 @@ impl TestEnvironment {
|
||||
v2: ApiV2 {
|
||||
test_app: test_app.clone(),
|
||||
},
|
||||
v3: ApiV3 {
|
||||
test_app: test_app.clone(),
|
||||
},
|
||||
test_app,
|
||||
db,
|
||||
dummy: None,
|
||||
@ -81,7 +86,27 @@ impl TestEnvironment {
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
assert_status(resp, StatusCode::NO_CONTENT);
|
||||
assert_status(&resp, StatusCode::NO_CONTENT);
|
||||
}
|
||||
|
||||
pub async fn assert_read_notifications_status(
|
||||
&self,
|
||||
user_id: &str,
|
||||
pat: &str,
|
||||
status_code: StatusCode,
|
||||
) {
|
||||
let resp = self.v2.get_user_notifications(user_id, pat).await;
|
||||
assert_status(&resp, status_code);
|
||||
}
|
||||
|
||||
pub async fn assert_read_user_projects_status(
|
||||
&self,
|
||||
user_id: &str,
|
||||
pat: &str,
|
||||
status_code: StatusCode,
|
||||
) {
|
||||
let resp = self.v2.get_user_projects(user_id, pat).await;
|
||||
assert_status(&resp, status_code);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ use self::database::TemporaryDatabase;
|
||||
|
||||
pub mod actix;
|
||||
pub mod api_v2;
|
||||
pub mod api_v3;
|
||||
pub mod asserts;
|
||||
pub mod database;
|
||||
pub mod dummy_data;
|
||||
@ -42,3 +43,11 @@ pub async fn setup(db: &TemporaryDatabase) -> LabrinthConfig {
|
||||
maxmind_reader.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_json_val_str(val: impl serde::Serialize) -> String {
|
||||
serde_json::to_value(val)
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
@ -45,6 +45,25 @@ INSERT INTO categories (id, category, project_type) VALUES
|
||||
(106, 'mobs', 2),
|
||||
(107, 'optimization', 2);
|
||||
|
||||
-- Create dummy oauth client, secret_hash is SHA512 hash of full lowercase alphabet
|
||||
INSERT INTO oauth_clients (
|
||||
id,
|
||||
name,
|
||||
icon_url,
|
||||
max_scopes,
|
||||
secret_hash,
|
||||
created_by
|
||||
)
|
||||
VALUES (
|
||||
1,
|
||||
'oauth_client_alpha',
|
||||
NULL,
|
||||
$1,
|
||||
'4dbff86cc2ca1bae1e16468a05cb9881c97f1753bce3619034898faa1aabe429955a1bf8ec483d7421fe3c1646613a59ed5441fb0f321389f77f48a879c7b1f1',
|
||||
3
|
||||
);
|
||||
INSERT INTO oauth_client_redirect_uris (id, client_id, uri) VALUES (1, 1, 'https://modrinth.com/oauth_callback');
|
||||
|
||||
-- Create dummy data table to mark that this file has been run
|
||||
CREATE TABLE dummy_data (
|
||||
update_id bigint PRIMARY KEY
|
||||
|
||||
292
tests/oauth.rs
Normal file
292
tests/oauth.rs
Normal file
@ -0,0 +1,292 @@
|
||||
use crate::common::{
|
||||
api_v3::oauth::get_redirect_location_query_params, database::FRIEND_USER_ID,
|
||||
dummy_data::DummyOAuthClientAlpha,
|
||||
};
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::test::{self};
|
||||
use common::{
|
||||
api_v3::oauth::{get_auth_code_from_redirect_params, get_authorize_accept_flow_id},
|
||||
asserts::{assert_any_status_except, assert_status},
|
||||
database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT},
|
||||
environment::with_test_environment,
|
||||
};
|
||||
use labrinth::auth::oauth::TokenResponse;
|
||||
use reqwest::header::{CACHE_CONTROL, PRAGMA};
|
||||
|
||||
mod common;
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn oauth_flow_happy_path() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha {
|
||||
valid_redirect_uri: base_redirect_uri,
|
||||
client_id,
|
||||
client_secret,
|
||||
} = env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
|
||||
// Initiate authorization
|
||||
let redirect_uri = format!("{}?foo=bar", base_redirect_uri);
|
||||
let original_state = "1234";
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(
|
||||
&client_id,
|
||||
Some("USER_READ NOTIFICATION_READ"),
|
||||
Some(&redirect_uri),
|
||||
Some(original_state),
|
||||
FRIEND_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
|
||||
// Accept the authorization request
|
||||
let resp = env.v3.oauth_accept(&flow_id, FRIEND_USER_PAT).await;
|
||||
assert_status(&resp, StatusCode::FOUND);
|
||||
let query = get_redirect_location_query_params(&resp);
|
||||
|
||||
let auth_code = query.get("code").unwrap();
|
||||
let state = query.get("state").unwrap();
|
||||
let foo_val = query.get("foo").unwrap();
|
||||
assert_eq!(state, original_state);
|
||||
assert_eq!(foo_val, "bar");
|
||||
|
||||
// Get the token
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_token(
|
||||
auth_code.to_string(),
|
||||
Some(redirect_uri.clone()),
|
||||
client_id.to_string(),
|
||||
&client_secret,
|
||||
)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
assert_eq!(resp.headers().get(CACHE_CONTROL).unwrap(), "no-store");
|
||||
assert_eq!(resp.headers().get(PRAGMA).unwrap(), "no-cache");
|
||||
let token_resp: TokenResponse = test::read_body_json(resp).await;
|
||||
|
||||
// Validate the token works
|
||||
env.assert_read_notifications_status(
|
||||
FRIEND_USER_ID,
|
||||
&token_resp.access_token,
|
||||
StatusCode::OK,
|
||||
)
|
||||
.await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() {
|
||||
with_test_environment(|env| async {
|
||||
let DummyOAuthClientAlpha { client_id, .. } = env.dummy.unwrap().oauth_client_alpha.clone();
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(
|
||||
&client_id,
|
||||
Some("USER_READ NOTIFICATION_READ"),
|
||||
None,
|
||||
Some("1234"),
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(
|
||||
&client_id,
|
||||
Some("USER_READ"),
|
||||
None,
|
||||
Some("5678"),
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::FOUND);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn get_oauth_token_with_already_used_auth_code_fails() {
|
||||
with_test_environment(|env| async {
|
||||
let DummyOAuthClientAlpha {
|
||||
client_id,
|
||||
client_secret,
|
||||
..
|
||||
} = env.dummy.unwrap().oauth_client_alpha.clone();
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(&client_id, None, None, None, USER_USER_PAT)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
|
||||
let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
let auth_code = get_auth_code_from_redirect_params(&resp).await;
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_token(auth_code.clone(), None, client_id.clone(), &client_secret)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_token(auth_code, None, client_id, &client_secret)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::BAD_REQUEST);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn authorize_with_broader_scopes_can_complete_flow() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha {
|
||||
client_id,
|
||||
client_secret,
|
||||
..
|
||||
} = env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
|
||||
let first_access_token = env
|
||||
.v3
|
||||
.complete_full_authorize_flow(
|
||||
&client_id,
|
||||
&client_secret,
|
||||
Some("PROJECT_READ"),
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
let second_access_token = env
|
||||
.v3
|
||||
.complete_full_authorize_flow(
|
||||
&client_id,
|
||||
&client_secret,
|
||||
Some("PROJECT_READ NOTIFICATION_READ"),
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
|
||||
env.assert_read_notifications_status(
|
||||
USER_USER_ID,
|
||||
&first_access_token,
|
||||
StatusCode::UNAUTHORIZED,
|
||||
)
|
||||
.await;
|
||||
env.assert_read_user_projects_status(USER_USER_ID, &first_access_token, StatusCode::OK)
|
||||
.await;
|
||||
|
||||
env.assert_read_notifications_status(USER_USER_ID, &second_access_token, StatusCode::OK)
|
||||
.await;
|
||||
env.assert_read_user_projects_status(USER_USER_ID, &second_access_token, StatusCode::OK)
|
||||
.await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn oauth_authorize_with_broader_scopes_requires_user_accept() {
|
||||
with_test_environment(|env| async {
|
||||
let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone();
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(&client_id, Some("USER_READ"), None, None, USER_USER_PAT)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(
|
||||
&client_id,
|
||||
Some("USER_READ NOTIFICATION_READ"),
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
get_authorize_accept_flow_id(resp).await; // ensure we can deser this without error to really confirm
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn reject_authorize_ends_authorize_flow() {
|
||||
with_test_environment(|env| async move {
|
||||
let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone();
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(&client_id, None, None, None, USER_USER_PAT)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
|
||||
let resp = env.v3.oauth_reject(&flow_id, USER_USER_PAT).await;
|
||||
assert_status(&resp, StatusCode::FOUND);
|
||||
|
||||
let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
assert_any_status_except(&resp, StatusCode::FOUND);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn accept_authorize_after_already_accepting_fails() {
|
||||
with_test_environment(|env| async move {
|
||||
let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone();
|
||||
let resp = env
|
||||
.v3
|
||||
.oauth_authorize(&client_id, None, None, None, USER_USER_PAT)
|
||||
.await;
|
||||
let flow_id = get_authorize_accept_flow_id(resp).await;
|
||||
let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
assert_status(&resp, StatusCode::FOUND);
|
||||
|
||||
let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await;
|
||||
assert_status(&resp, StatusCode::BAD_REQUEST);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn revoke_authorization_after_issuing_token_revokes_token() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha {
|
||||
client_id,
|
||||
client_secret,
|
||||
..
|
||||
} = env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
let access_token = env
|
||||
.v3
|
||||
.complete_full_authorize_flow(
|
||||
&client_id,
|
||||
&client_secret,
|
||||
Some("NOTIFICATION_READ"),
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::OK)
|
||||
.await;
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.revoke_oauth_authorization(&client_id, USER_USER_PAT)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED)
|
||||
.await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
193
tests/oauth_clients.rs
Normal file
193
tests/oauth_clients.rs
Normal file
@ -0,0 +1,193 @@
|
||||
use actix_http::StatusCode;
|
||||
use actix_web::test;
|
||||
use common::{
|
||||
database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT},
|
||||
dummy_data::DummyOAuthClientAlpha,
|
||||
environment::with_test_environment,
|
||||
get_json_val_str,
|
||||
};
|
||||
use labrinth::{
|
||||
models::{
|
||||
oauth_clients::{OAuthClient, OAuthClientCreationResult},
|
||||
pats::Scopes,
|
||||
},
|
||||
routes::v3::oauth_clients::OAuthClientEdit,
|
||||
};
|
||||
|
||||
use crate::common::{asserts::assert_status, database::USER_USER_ID_PARSED};
|
||||
|
||||
mod common;
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn can_create_edit_get_oauth_client() {
|
||||
with_test_environment(|env| async move {
|
||||
let client_name = "test_client".to_string();
|
||||
let redirect_uris = vec![
|
||||
"https://modrinth.com".to_string(),
|
||||
"https://modrinth.com/a".to_string(),
|
||||
];
|
||||
let resp = env
|
||||
.v3
|
||||
.add_oauth_client(
|
||||
client_name.clone(),
|
||||
Scopes::all() - Scopes::restricted(),
|
||||
redirect_uris.clone(),
|
||||
FRIEND_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
let creation_result: OAuthClientCreationResult = test::read_body_json(resp).await;
|
||||
let client_id = get_json_val_str(creation_result.client.id);
|
||||
|
||||
let icon_url = Some("https://modrinth.com/icon".to_string());
|
||||
let edited_redirect_uris = vec![
|
||||
redirect_uris[0].clone(),
|
||||
"https://modrinth.com/b".to_string(),
|
||||
];
|
||||
let edit = OAuthClientEdit {
|
||||
name: None,
|
||||
icon_url: Some(icon_url.clone()),
|
||||
max_scopes: None,
|
||||
redirect_uris: Some(edited_redirect_uris.clone()),
|
||||
};
|
||||
let resp = env
|
||||
.v3
|
||||
.edit_oauth_client(&client_id, edit, FRIEND_USER_PAT)
|
||||
.await;
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
|
||||
let clients = env
|
||||
.v3
|
||||
.get_user_oauth_clients(FRIEND_USER_ID, FRIEND_USER_PAT)
|
||||
.await;
|
||||
assert_eq!(1, clients.len());
|
||||
assert_eq!(icon_url, clients[0].icon_url);
|
||||
assert_eq!(client_name, clients[0].name);
|
||||
assert_eq!(2, clients[0].redirect_uris.len());
|
||||
assert_eq!(edited_redirect_uris[0], clients[0].redirect_uris[0].uri);
|
||||
assert_eq!(edited_redirect_uris[1], clients[0].redirect_uris[1].uri);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn create_oauth_client_with_restricted_scopes_fails() {
|
||||
with_test_environment(|env| async move {
|
||||
let resp = env
|
||||
.v3
|
||||
.add_oauth_client(
|
||||
"test_client".to_string(),
|
||||
Scopes::restricted(),
|
||||
vec!["https://modrinth.com".to_string()],
|
||||
FRIEND_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_status(&resp, StatusCode::BAD_REQUEST);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn get_oauth_client_for_client_creator_succeeds() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha { client_id, .. } =
|
||||
env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.get_oauth_client(client_id.clone(), USER_USER_PAT)
|
||||
.await;
|
||||
|
||||
assert_status(&resp, StatusCode::OK);
|
||||
let client: OAuthClient = test::read_body_json(resp).await;
|
||||
assert_eq!(get_json_val_str(client.id), client_id);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn get_oauth_client_for_unrelated_user_fails() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha { client_id, .. } =
|
||||
env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
|
||||
let resp = env
|
||||
.v3
|
||||
.get_oauth_client(client_id.clone(), FRIEND_USER_PAT)
|
||||
.await;
|
||||
|
||||
assert_status(&resp, StatusCode::UNAUTHORIZED);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn can_delete_oauth_client() {
|
||||
with_test_environment(|env| async move {
|
||||
let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone();
|
||||
let resp = env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await;
|
||||
assert_status(&resp, StatusCode::NO_CONTENT);
|
||||
|
||||
let clients = env
|
||||
.v3
|
||||
.get_user_oauth_clients(USER_USER_ID, USER_USER_PAT)
|
||||
.await;
|
||||
assert_eq!(0, clients.len());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn delete_oauth_client_after_issuing_access_tokens_revokes_tokens() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha {
|
||||
client_id,
|
||||
client_secret,
|
||||
..
|
||||
} = env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
let access_token = env
|
||||
.v3
|
||||
.complete_full_authorize_flow(
|
||||
&client_id,
|
||||
&client_secret,
|
||||
Some("NOTIFICATION_READ"),
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
|
||||
env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await;
|
||||
|
||||
env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED)
|
||||
.await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn can_list_user_oauth_authorizations() {
|
||||
with_test_environment(|env| async move {
|
||||
let DummyOAuthClientAlpha {
|
||||
client_id,
|
||||
client_secret,
|
||||
..
|
||||
} = env.dummy.as_ref().unwrap().oauth_client_alpha.clone();
|
||||
env.v3
|
||||
.complete_full_authorize_flow(
|
||||
&client_id,
|
||||
&client_secret,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
USER_USER_PAT,
|
||||
)
|
||||
.await;
|
||||
|
||||
let authorizations = env.v3.get_user_oauth_authorizations(USER_USER_PAT).await;
|
||||
assert_eq!(1, authorizations.len());
|
||||
assert_eq!(USER_USER_ID_PARSED, authorizations[0].user_id.0 as i64);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user