From 8ce8b6b40aa17e5a4856d512dbe7a2d131bf9f74 Mon Sep 17 00:00:00 2001 From: Lorenzo Venerandi Date: Sun, 22 Feb 2026 16:23:52 +0100 Subject: [PATCH] feat: implement IP ban and rate-limiting logic in database with migration scripts --- src/database.py | 147 +++++++++++++++++++++++++++++ src/migrations/runner.py | 22 +++++ src/models.py | 6 ++ src/tasks/memory_cleanup.py | 42 ++------- src/tracker.py | 183 ++---------------------------------- 5 files changed, 191 insertions(+), 209 deletions(-) diff --git a/src/database.py b/src/database.py index c59245d..140660d 100644 --- a/src/database.py +++ b/src/database.py @@ -344,6 +344,153 @@ class DatabaseManager: ) session.add(ip_stats) + def increment_page_visit(self, ip: str, max_pages_limit: int) -> int: + """ + Increment the page visit counter for an IP and apply ban if limit reached. + + Args: + ip: Client IP address + max_pages_limit: Page visit threshold before banning + + Returns: + The updated page visit count + """ + session = self.session + try: + sanitized_ip = sanitize_ip(ip) + ip_stats = ( + session.query(IpStats).filter(IpStats.ip == sanitized_ip).first() + ) + + if not ip_stats: + now = datetime.now() + ip_stats = IpStats( + ip=sanitized_ip, + total_requests=0, + first_seen=now, + last_seen=now, + page_visit_count=1, + ) + session.add(ip_stats) + session.commit() + return 1 + + ip_stats.page_visit_count = (ip_stats.page_visit_count or 0) + 1 + + if ip_stats.page_visit_count >= max_pages_limit: + ip_stats.total_violations = (ip_stats.total_violations or 0) + 1 + ip_stats.ban_multiplier = 2 ** (ip_stats.total_violations - 1) + ip_stats.ban_timestamp = datetime.now() + + session.commit() + return ip_stats.page_visit_count + + except Exception as e: + session.rollback() + applogger.error(f"Error incrementing page visit for {ip}: {e}") + return 0 + finally: + self.close_session() + + def is_banned_ip(self, ip: str, ban_duration_seconds: int) -> bool: + """ + Check if an IP is currently banned. + + Args: + ip: Client IP address + ban_duration_seconds: Base ban duration in seconds + + Returns: + True if the IP is currently banned + """ + session = self.session + try: + sanitized_ip = sanitize_ip(ip) + ip_stats = ( + session.query(IpStats).filter(IpStats.ip == sanitized_ip).first() + ) + + if not ip_stats or ip_stats.ban_timestamp is None: + return False + + effective_duration = ban_duration_seconds * (ip_stats.ban_multiplier or 1) + elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds() + + if elapsed > effective_duration: + # Ban expired — reset count for next cycle + ip_stats.page_visit_count = 0 + ip_stats.ban_timestamp = None + session.commit() + return False + + return True + + except Exception as e: + applogger.error(f"Error checking ban status for {ip}: {e}") + return False + finally: + self.close_session() + + def get_ban_info(self, ip: str, ban_duration_seconds: int) -> dict: + """ + Get detailed ban information for an IP. + + Args: + ip: Client IP address + ban_duration_seconds: Base ban duration in seconds + + Returns: + Dictionary with ban status details + """ + session = self.session + try: + sanitized_ip = sanitize_ip(ip) + ip_stats = ( + session.query(IpStats).filter(IpStats.ip == sanitized_ip).first() + ) + + if not ip_stats: + return { + "is_banned": False, + "violations": 0, + "ban_multiplier": 1, + "remaining_ban_seconds": 0, + } + + violations = ip_stats.total_violations or 0 + multiplier = ip_stats.ban_multiplier or 1 + + if ip_stats.ban_timestamp is None: + return { + "is_banned": False, + "violations": violations, + "ban_multiplier": multiplier, + "remaining_ban_seconds": 0, + } + + effective_duration = ban_duration_seconds * multiplier + elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds() + remaining = max(0, effective_duration - elapsed) + + return { + "is_banned": remaining > 0, + "violations": violations, + "ban_multiplier": multiplier, + "effective_ban_duration_seconds": effective_duration, + "remaining_ban_seconds": remaining, + } + + except Exception as e: + applogger.error(f"Error getting ban info for {ip}: {e}") + return { + "is_banned": False, + "violations": 0, + "ban_multiplier": 1, + "remaining_ban_seconds": 0, + } + finally: + self.close_session() + def update_ip_stats_analysis( self, ip: str, diff --git a/src/migrations/runner.py b/src/migrations/runner.py index 396fbc9..0c5c67d 100644 --- a/src/migrations/runner.py +++ b/src/migrations/runner.py @@ -48,6 +48,24 @@ def _migrate_need_reevaluation_column(cursor) -> bool: return True +def _migrate_ban_state_columns(cursor) -> List[str]: + """Add ban/rate-limit columns to ip_stats if missing.""" + added = [] + columns = { + "page_visit_count": "INTEGER DEFAULT 0", + "ban_timestamp": "DATETIME", + "total_violations": "INTEGER DEFAULT 0", + "ban_multiplier": "INTEGER DEFAULT 1", + } + for col_name, col_type in columns.items(): + if not _column_exists(cursor, "ip_stats", col_name): + cursor.execute( + f"ALTER TABLE ip_stats ADD COLUMN {col_name} {col_type}" + ) + added.append(col_name) + return added + + def _migrate_performance_indexes(cursor) -> List[str]: """Add performance indexes to attack_detections if missing.""" added = [] @@ -90,6 +108,10 @@ def run_migrations(database_path: str) -> None: if _migrate_need_reevaluation_column(cursor): applied.append("add need_reevaluation column to ip_stats") + ban_cols = _migrate_ban_state_columns(cursor) + for col in ban_cols: + applied.append(f"add {col} column to ip_stats") + idx_added = _migrate_performance_indexes(cursor) for idx in idx_added: applied.append(f"add index {idx}") diff --git a/src/models.py b/src/models.py index c9a190f..8fb6e26 100644 --- a/src/models.py +++ b/src/models.py @@ -204,6 +204,12 @@ class IpStats(Base): Boolean, default=False, nullable=True ) + # Ban/rate-limit state (moved from in-memory tracker to DB) + page_visit_count: Mapped[int] = mapped_column(Integer, default=0, nullable=True) + ban_timestamp: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True) + ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True) + def __repr__(self) -> str: return f"" diff --git a/src/tasks/memory_cleanup.py b/src/tasks/memory_cleanup.py index ac9af92..dc230fd 100644 --- a/src/tasks/memory_cleanup.py +++ b/src/tasks/memory_cleanup.py @@ -2,7 +2,10 @@ """ Memory cleanup task for Krawl honeypot. -Periodically cleans expired bans and stale entries from ip_page_visits. + +NOTE: This task is no longer needed. Ban/rate-limit state has been moved from +in-memory ip_page_visits dict to the ip_stats DB table, eliminating unbounded +memory growth. Kept disabled for reference. """ from logger import get_app_logger @@ -13,8 +16,8 @@ from logger import get_app_logger TASK_CONFIG = { "name": "memory-cleanup", - "cron": "*/5 * * * *", # Run every 5 minutes - "enabled": True, + "cron": "*/5 * * * *", + "enabled": False, "run_when_loaded": False, } @@ -22,35 +25,4 @@ app_logger = get_app_logger() def main(): - """ - Clean up in-memory structures in the tracker. - Called periodically to prevent unbounded memory growth. - """ - try: - from tracker import get_tracker - - tracker = get_tracker() - if not tracker: - app_logger.warning("Tracker not initialized, skipping memory cleanup") - return - - stats_before = tracker.get_memory_stats() - - tracker.cleanup_memory() - - stats_after = tracker.get_memory_stats() - - visits_reduced = stats_before["ip_page_visits"] - stats_after["ip_page_visits"] - - if visits_reduced > 0: - app_logger.info( - f"Memory cleanup: Removed {visits_reduced} stale ip_page_visits entries" - ) - - app_logger.debug( - f"Memory stats after cleanup: " - f"ip_page_visits={stats_after['ip_page_visits']}" - ) - - except Exception as e: - app_logger.error(f"Error during memory cleanup: {e}") + app_logger.debug("memory-cleanup task is disabled (ban state now in DB)") diff --git a/src/tracker.py b/src/tracker.py index 292ebba..bd0c0ee 100644 --- a/src/tracker.py +++ b/src/tracker.py @@ -1,9 +1,6 @@ #!/usr/bin/env python3 from typing import Dict, Tuple, Optional -from collections import defaultdict -from datetime import datetime -from zoneinfo import ZoneInfo import re import urllib.parse @@ -49,9 +46,6 @@ class AccessTracker: self.max_pages_limit = max_pages_limit self.ban_duration_seconds = ban_duration_seconds - # Track pages visited by each IP (for good crawler limiting) - self.ip_page_visits: Dict[str, Dict[str, object]] = defaultdict(dict) - # Load suspicious patterns from wordlists wl = get_wordlists() self.suspicious_patterns = wl.suspicious_patterns @@ -372,14 +366,7 @@ class AccessTracker: def increment_page_visit(self, client_ip: str) -> int: """ - Increment page visit counter for an IP and return the new count. - Implements incremental bans: each violation increases ban duration exponentially. - - Ban duration formula: base_duration * (2 ^ violation_count) - - 1st violation: base_duration (e.g., 60 seconds) - - 2nd violation: base_duration * 2 (120 seconds) - - 3rd violation: base_duration * 4 (240 seconds) - - Nth violation: base_duration * 2^(N-1) + Increment page visit counter for an IP via DB and return the new count. Args: client_ip: The client IP address @@ -387,7 +374,6 @@ class AccessTracker: Returns: The updated page visit count for this IP """ - # Skip if this is the server's own IP from config import get_config config = get_config() @@ -395,85 +381,24 @@ class AccessTracker: if server_ip and client_ip == server_ip: return 0 - try: - # Initialize if not exists - if client_ip not in self.ip_page_visits: - self.ip_page_visits[client_ip] = { - "count": 0, - "ban_timestamp": None, - "total_violations": 0, - "ban_multiplier": 1, - } - - # Increment count - self.ip_page_visits[client_ip]["count"] += 1 - - # Set ban if reached limit - if self.ip_page_visits[client_ip]["count"] >= self.max_pages_limit: - # Increment violation counter - self.ip_page_visits[client_ip]["total_violations"] += 1 - violations = self.ip_page_visits[client_ip]["total_violations"] - - # Calculate exponential ban multiplier: 2^(violations - 1) - # Violation 1: 2^0 = 1x - # Violation 2: 2^1 = 2x - # Violation 3: 2^2 = 4x - # Violation 4: 2^3 = 8x, etc. - self.ip_page_visits[client_ip]["ban_multiplier"] = 2 ** (violations - 1) - - # Set ban timestamp - self.ip_page_visits[client_ip][ - "ban_timestamp" - ] = datetime.now().isoformat() - - return self.ip_page_visits[client_ip]["count"] - - except Exception: + if not self.db: return 0 + return self.db.increment_page_visit(client_ip, self.max_pages_limit) + def is_banned_ip(self, client_ip: str) -> bool: """ - Check if an IP is currently banned due to exceeding page visit limits. - Uses incremental ban duration based on violation count. - - Ban duration = base_duration * (2 ^ (violations - 1)) - Each time an IP is banned again, duration doubles. + Check if an IP is currently banned. Args: client_ip: The client IP address Returns: True if the IP is banned, False otherwise """ - try: - if client_ip in self.ip_page_visits: - ban_timestamp = self.ip_page_visits[client_ip].get("ban_timestamp") - if ban_timestamp is not None: - # Get the ban multiplier for this violation - ban_multiplier = self.ip_page_visits[client_ip].get( - "ban_multiplier", 1 - ) - - # Calculate effective ban duration based on violations - effective_ban_duration = self.ban_duration_seconds * ban_multiplier - - # Check if ban period has expired - ban_time = datetime.fromisoformat(ban_timestamp) - time_diff = datetime.now() - ban_time - - if time_diff.total_seconds() > effective_ban_duration: - # Ban expired, reset for next cycle - # Keep violation count for next offense - self.ip_page_visits[client_ip]["count"] = 0 - self.ip_page_visits[client_ip]["ban_timestamp"] = None - return False - else: - # Still banned - return True - + if not self.db: return False - except Exception: - return False + return self.db.is_banned_ip(client_ip, self.ban_duration_seconds) def get_ban_info(self, client_ip: str) -> dict: """ @@ -482,64 +407,15 @@ class AccessTracker: Returns: Dictionary with ban status, violations, and remaining ban time """ - try: - if client_ip not in self.ip_page_visits: - return { - "is_banned": False, - "violations": 0, - "ban_multiplier": 1, - "remaining_ban_seconds": 0, - } - - ip_data = self.ip_page_visits[client_ip] - ban_timestamp = ip_data.get("ban_timestamp") - - if ban_timestamp is None: - return { - "is_banned": False, - "violations": ip_data.get("total_violations", 0), - "ban_multiplier": ip_data.get("ban_multiplier", 1), - "remaining_ban_seconds": 0, - } - - # Ban is active, calculate remaining time - ban_multiplier = ip_data.get("ban_multiplier", 1) - effective_ban_duration = self.ban_duration_seconds * ban_multiplier - - ban_time = datetime.fromisoformat(ban_timestamp) - time_diff = datetime.now() - ban_time - remaining_seconds = max( - 0, effective_ban_duration - time_diff.total_seconds() - ) - - return { - "is_banned": remaining_seconds > 0, - "violations": ip_data.get("total_violations", 0), - "ban_multiplier": ban_multiplier, - "effective_ban_duration_seconds": effective_ban_duration, - "remaining_ban_seconds": remaining_seconds, - } - - except Exception: + if not self.db: return { "is_banned": False, "violations": 0, "ban_multiplier": 1, "remaining_ban_seconds": 0, } - """ - Get the current page visit count for an IP. - Args: - client_ip: The client IP address - - Returns: - The page visit count for this IP - """ - try: - return self.ip_page_visits.get(client_ip, 0) - except Exception: - return 0 + return self.db.get_ban_info(client_ip, self.ban_duration_seconds) def get_stats(self) -> Dict: """Get statistics summary from database.""" @@ -560,44 +436,3 @@ class AccessTracker: return stats - def cleanup_memory(self) -> None: - """ - Clean up in-memory structures to prevent unbounded growth. - Should be called periodically (e.g., every 5 minutes). - """ - # Clean expired ban entries from ip_page_visits - current_time = datetime.now() - for ip, data in self.ip_page_visits.items(): - ban_timestamp = data.get("ban_timestamp") - if ban_timestamp is not None: - try: - ban_time = datetime.fromisoformat(ban_timestamp) - time_diff = (current_time - ban_time).total_seconds() - effective_duration = self.ban_duration_seconds * data.get( - "ban_multiplier", 1 - ) - if time_diff > effective_duration: - data["count"] = 0 - data["ban_timestamp"] = None - except (ValueError, TypeError): - pass - - # Remove IPs with zero activity and no active ban - ips_to_remove = [ - ip - for ip, data in self.ip_page_visits.items() - if data.get("count", 0) == 0 and data.get("ban_timestamp") is None - ] - for ip in ips_to_remove: - del self.ip_page_visits[ip] - - def get_memory_stats(self) -> Dict[str, int]: - """ - Get current memory usage statistics for monitoring. - - Returns: - Dictionary with counts of in-memory items - """ - return { - "ip_page_visits": len(self.ip_page_visits), - }