Limit the number of current request for a certain amount of time for some endpoints.
authorGreg Burri <greg.burri@gmail.com>
Mon, 3 Mar 2025 09:10:55 +0000 (10:10 +0100)
committerGreg Burri <greg.burri@gmail.com>
Mon, 3 Mar 2025 09:10:55 +0000 (10:10 +0100)
Cargo.lock
backend/Cargo.toml
backend/src/consts.rs
backend/src/main.rs

index 56b1295..10e5ceb 100644 (file)
@@ -2872,6 +2872,7 @@ dependencies = [
  "pin-project-lite",
  "sync_wrapper",
  "tokio",
+ "tokio-util",
  "tower-layer",
  "tower-service",
  "tracing",
index 7c9fda6..16390b4 100644 (file)
@@ -10,7 +10,7 @@ common = { path = "../common" }
 axum = { version = "0.8", features = ["macros"] }
 axum-extra = { version = "0.10", features = ["cookie", "query"] }
 tokio = { version = "1", features = ["full"] }
-tower = { version = "0.5", features = ["util"] }
+tower = { version = "0.5", features = ["util", "limit", "buffer"] }
 tower-http = { version = "0.6", features = ["fs", "trace"] }
 
 tracing = "0.1"
index bbe1cdd..fdc306b 100644 (file)
@@ -18,6 +18,9 @@ pub const TOKEN_SIZE: usize = 32;
 
 pub const SEND_EMAIL_TIMEOUT: Duration = Duration::from_secs(60);
 
+pub const NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT: u64 = 5;
+pub const DURATION_FOR_RATE_LIMIT: Duration = Duration::from_secs(5);
+
 // HTTP headers, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers.
 // Common headers can be found in 'axum::http::header' (which is a re-export of the create 'http').
 pub const REVERSE_PROXY_IP_HTTP_FIELD: &str = "x-real-ip"; // Set by the reverse proxy (Nginx).
index 199b933..afc7656 100644 (file)
@@ -1,7 +1,8 @@
 use std::{net::SocketAddr, path::Path};
 
 use axum::{
-    Router,
+    BoxError, Router,
+    error_handling::HandleErrorLayer,
     extract::{ConnectInfo, Extension, FromRef, Request, State},
     http::StatusCode,
     middleware::{self, Next},
@@ -13,6 +14,7 @@ use chrono::prelude::*;
 use clap::Parser;
 use config::Config;
 use itertools::Itertools;
+use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
 use tower_http::{
     services::{ServeDir, ServeFile},
     trace::TraceLayer,
@@ -226,23 +228,38 @@ async fn main() {
         get(services::fragments::recipes_list_fragments),
     );
 
-    let html_routes = Router::new()
-        .route("/", get(services::home_page))
+    let html_routes_with_rate_limit = Router::new()
+        .route("/signin", post(services::user::sign_in_post))
+        .route("/signup", post(services::user::sign_up_post))
         .route(
-            "/signup",
-            get(services::user::sign_up_get).post(services::user::sign_up_post),
+            "/ask_reset_password",
+            post(services::user::ask_reset_password_post),
         )
+        .layer(
+            ServiceBuilder::new()
+                .layer(HandleErrorLayer::new(|err: BoxError| async move {
+                    (
+                        StatusCode::INTERNAL_SERVER_ERROR,
+                        format!("Unhandled error: {}", err),
+                    )
+                }))
+                .layer(BufferLayer::new(1024))
+                .layer(RateLimitLayer::new(
+                    consts::NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT,
+                    consts::DURATION_FOR_RATE_LIMIT,
+                )),
+        );
+
+    let html_routes = Router::new()
+        .route("/", get(services::home_page))
+        .route("/signup", get(services::user::sign_up_get))
         .route("/validation", get(services::user::sign_up_validation))
         .route("/revalidation", get(services::user::email_revalidation))
-        .route(
-            "/signin",
-            get(services::user::sign_in_get).post(services::user::sign_in_post),
-        )
+        .route("/signin", get(services::user::sign_in_get))
         .route("/signout", get(services::user::sign_out))
         .route(
             "/ask_reset_password",
-            get(services::user::ask_reset_password_get)
-                .post(services::user::ask_reset_password_post),
+            get(services::user::ask_reset_password_get),
         )
         .route(
             "/reset_password",
@@ -257,6 +274,7 @@ async fn main() {
             "/user/edit",
             get(services::user::edit_user_get).post(services::user::edit_user_post),
         )
+        .merge(html_routes_with_rate_limit)
         .nest("/fragments", fragments_routes)
         .route_layer(middleware::from_fn(services::ron_error_to_html));