diff --git a/src/app.py b/src/app.py index 6364705..2884162 100644 --- a/src/app.py +++ b/src/app.py @@ -129,9 +129,17 @@ def create_app() -> FastAPI: # Access log middleware (outermost — logs every request with real client IP) @application.middleware("http") async def access_log_middleware(request: Request, call_next): - response: Response = await call_next(request) from dependencies import get_client_ip + try: + response: Response = await call_next(request) + except ConnectionResetError: + client_ip = get_client_ip(request) + path = request.url.path + method = request.method + get_access_logger().info(f"[BANNED] [{method}] {client_ip} - {path}") + raise + client_ip = get_client_ip(request) path = request.url.path method = request.method diff --git a/src/middleware/ban_check.py b/src/middleware/ban_check.py index a3be689..c4b2e80 100644 --- a/src/middleware/ban_check.py +++ b/src/middleware/ban_check.py @@ -2,6 +2,7 @@ """ Middleware for checking if client IP is banned. +Resets the connection for banned IPs instead of sending a response. """ from starlette.middleware.base import BaseHTTPMiddleware @@ -11,6 +12,13 @@ from starlette.responses import Response from dependencies import get_client_ip +class ConnectionResetResponse(Response): + """Response that abruptly closes the connection without sending data.""" + + async def __call__(self, scope, receive, send): + raise ConnectionResetError() + + class BanCheckMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Skip ban check for dashboard routes @@ -23,7 +31,7 @@ class BanCheckMiddleware(BaseHTTPMiddleware): tracker = request.app.state.tracker if tracker.is_banned_ip(client_ip): - return Response(status_code=500) + return ConnectionResetResponse() response = await call_next(request) return response