diff --git a/apps/labrinth/src/database/redis.rs b/apps/labrinth/src/database/redis.rs index 9a0409071..ee0268fd2 100644 --- a/apps/labrinth/src/database/redis.rs +++ b/apps/labrinth/src/database/redis.rs @@ -224,8 +224,6 @@ impl RedisPool { + Serialize, S: Display + Clone + DeserializeOwned + Serialize + Debug, { - let connection = self.connect().await?.connection; - let ids = keys .iter() .map(|x| (x.to_string(), x.clone())) @@ -235,49 +233,21 @@ impl RedisPool { return Ok(HashMap::new()); } - let get_cached_values = - |ids: DashMap, - mut connection: deadpool_redis::Connection| async move { - let slug_ids = if let Some(slug_namespace) = slug_namespace { - cmd("MGET") - .arg( - ids.iter() - .map(|x| { - format!( - "{}_{slug_namespace}:{}", - self.meta_namespace, - if case_sensitive { - x.value().to_string() - } else { - x.value().to_string().to_lowercase() - } - ) - }) - .collect::>(), - ) - .query_async::>>(&mut connection) - .await? - .into_iter() - .flatten() - .collect::>() - } else { - Vec::new() - }; - - let cached_values = cmd("MGET") + let get_cached_values = |ids: DashMap| async move { + let slug_ids = if let Some(slug_namespace) = slug_namespace { + let mut connection = self.pool.get().await?; + cmd("MGET") .arg( ids.iter() - .map(|x| x.value().to_string()) - .chain(ids.iter().filter_map(|x| { - parse_base62(&x.value().to_string()) - .ok() - .map(|x| x.to_string()) - })) - .chain(slug_ids) .map(|x| { format!( - "{}_{namespace}:{x}", - self.meta_namespace + "{}_{slug_namespace}:{}", + self.meta_namespace, + if case_sensitive { + x.value().to_string() + } else { + x.value().to_string().to_lowercase() + } ) }) .collect::>(), @@ -285,23 +255,46 @@ impl RedisPool { .query_async::>>(&mut connection) .await? .into_iter() - .filter_map(|x| { - x.and_then(|val| { - serde_json::from_str::>(&val) - .ok() - }) - .map(|val| (val.key.clone(), val)) - }) - .collect::>(); - - Ok::<_, DatabaseError>((cached_values, connection, ids)) + .flatten() + .collect::>() + } else { + Vec::new() }; + let mut connection = self.pool.get().await?; + let cached_values = cmd("MGET") + .arg( + ids.iter() + .map(|x| x.value().to_string()) + .chain(ids.iter().filter_map(|x| { + parse_base62(&x.value().to_string()) + .ok() + .map(|x| x.to_string()) + })) + .chain(slug_ids) + .map(|x| { + format!("{}_{namespace}:{x}", self.meta_namespace) + }) + .collect::>(), + ) + .query_async::>>(&mut connection) + .await? + .into_iter() + .filter_map(|x| { + x.and_then(|val| { + serde_json::from_str::>(&val).ok() + }) + .map(|val| (val.key.clone(), val)) + }) + .collect::>(); + + Ok::<_, DatabaseError>((cached_values, ids)) + }; + let current_time = Utc::now(); let mut expired_values = HashMap::new(); - let (cached_values_raw, mut connection, ids) = - get_cached_values(ids, connection).await?; + let (cached_values_raw, ids) = get_cached_values(ids).await?; let mut cached_values = cached_values_raw .into_iter() .filter_map(|(key, val)| { @@ -352,9 +345,12 @@ impl RedisPool { .with_expiration(SetExpiry::EX(60)), ); }); - let results = pipe - .query_async::>>(&mut connection) - .await?; + let results = { + let mut connection = self.pool.get().await?; + + pipe.query_async::>>(&mut connection) + .await? + }; for (idx, key) in fetch_ids.into_iter().enumerate() { if let Some(locked) = results.get(idx) { @@ -487,6 +483,7 @@ impl RedisPool { )); } + let mut connection = self.pool.get().await?; pipe.query_async::<()>(&mut connection).await?; Ok(return_values) @@ -495,28 +492,29 @@ impl RedisPool { if !subscribe_ids.is_empty() { fetch_tasks.push(Box::pin(async { - let mut connection = self.pool.get().await?; - let mut interval = tokio::time::interval(Duration::from_millis(100)); let start = Utc::now(); loop { - let results = cmd("MGET") - .arg( - subscribe_ids - .iter() - .map(|x| { - format!( - "{}_{namespace}:{}/lock", - self.meta_namespace, - // We lowercase key because locks are stored in lowercase - x.key().to_lowercase() - ) - }) - .collect::>(), - ) - .query_async::>>(&mut connection) - .await?; + let results = { + let mut connection = self.pool.get().await?; + cmd("MGET") + .arg( + subscribe_ids + .iter() + .map(|x| { + format!( + "{}_{namespace}:{}/lock", + self.meta_namespace, + // We lowercase key because locks are stored in lowercase + x.key().to_lowercase() + ) + }) + .collect::>(), + ) + .query_async::>>(&mut connection) + .await? + }; if results.into_iter().all(|x| x.is_none()) { break; @@ -529,8 +527,8 @@ impl RedisPool { interval.tick().await; } - let (return_values, _, _) = - get_cached_values(subscribe_ids, connection).await?; + let (return_values, _) = + get_cached_values(subscribe_ids).await?; Ok(return_values) }));