feat: implement IP ban and rate-limiting logic in database with migration scripts

This commit is contained in:
Lorenzo Venerandi
2026-02-22 16:23:52 +01:00
parent db848e7ecb
commit 8ce8b6b40a
5 changed files with 191 additions and 209 deletions

View File

@@ -344,6 +344,153 @@ class DatabaseManager:
) )
session.add(ip_stats) session.add(ip_stats)
def increment_page_visit(self, ip: str, max_pages_limit: int) -> int:
"""
Increment the page visit counter for an IP and apply ban if limit reached.
Args:
ip: Client IP address
max_pages_limit: Page visit threshold before banning
Returns:
The updated page visit count
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats:
now = datetime.now()
ip_stats = IpStats(
ip=sanitized_ip,
total_requests=0,
first_seen=now,
last_seen=now,
page_visit_count=1,
)
session.add(ip_stats)
session.commit()
return 1
ip_stats.page_visit_count = (ip_stats.page_visit_count or 0) + 1
if ip_stats.page_visit_count >= max_pages_limit:
ip_stats.total_violations = (ip_stats.total_violations or 0) + 1
ip_stats.ban_multiplier = 2 ** (ip_stats.total_violations - 1)
ip_stats.ban_timestamp = datetime.now()
session.commit()
return ip_stats.page_visit_count
except Exception as e:
session.rollback()
applogger.error(f"Error incrementing page visit for {ip}: {e}")
return 0
finally:
self.close_session()
def is_banned_ip(self, ip: str, ban_duration_seconds: int) -> bool:
"""
Check if an IP is currently banned.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
True if the IP is currently banned
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats or ip_stats.ban_timestamp is None:
return False
effective_duration = ban_duration_seconds * (ip_stats.ban_multiplier or 1)
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
if elapsed > effective_duration:
# Ban expired — reset count for next cycle
ip_stats.page_visit_count = 0
ip_stats.ban_timestamp = None
session.commit()
return False
return True
except Exception as e:
applogger.error(f"Error checking ban status for {ip}: {e}")
return False
finally:
self.close_session()
def get_ban_info(self, ip: str, ban_duration_seconds: int) -> dict:
"""
Get detailed ban information for an IP.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
Dictionary with ban status details
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
violations = ip_stats.total_violations or 0
multiplier = ip_stats.ban_multiplier or 1
if ip_stats.ban_timestamp is None:
return {
"is_banned": False,
"violations": violations,
"ban_multiplier": multiplier,
"remaining_ban_seconds": 0,
}
effective_duration = ban_duration_seconds * multiplier
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
remaining = max(0, effective_duration - elapsed)
return {
"is_banned": remaining > 0,
"violations": violations,
"ban_multiplier": multiplier,
"effective_ban_duration_seconds": effective_duration,
"remaining_ban_seconds": remaining,
}
except Exception as e:
applogger.error(f"Error getting ban info for {ip}: {e}")
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
finally:
self.close_session()
def update_ip_stats_analysis( def update_ip_stats_analysis(
self, self,
ip: str, ip: str,

View File

@@ -48,6 +48,24 @@ def _migrate_need_reevaluation_column(cursor) -> bool:
return True return True
def _migrate_ban_state_columns(cursor) -> List[str]:
"""Add ban/rate-limit columns to ip_stats if missing."""
added = []
columns = {
"page_visit_count": "INTEGER DEFAULT 0",
"ban_timestamp": "DATETIME",
"total_violations": "INTEGER DEFAULT 0",
"ban_multiplier": "INTEGER DEFAULT 1",
}
for col_name, col_type in columns.items():
if not _column_exists(cursor, "ip_stats", col_name):
cursor.execute(
f"ALTER TABLE ip_stats ADD COLUMN {col_name} {col_type}"
)
added.append(col_name)
return added
def _migrate_performance_indexes(cursor) -> List[str]: def _migrate_performance_indexes(cursor) -> List[str]:
"""Add performance indexes to attack_detections if missing.""" """Add performance indexes to attack_detections if missing."""
added = [] added = []
@@ -90,6 +108,10 @@ def run_migrations(database_path: str) -> None:
if _migrate_need_reevaluation_column(cursor): if _migrate_need_reevaluation_column(cursor):
applied.append("add need_reevaluation column to ip_stats") applied.append("add need_reevaluation column to ip_stats")
ban_cols = _migrate_ban_state_columns(cursor)
for col in ban_cols:
applied.append(f"add {col} column to ip_stats")
idx_added = _migrate_performance_indexes(cursor) idx_added = _migrate_performance_indexes(cursor)
for idx in idx_added: for idx in idx_added:
applied.append(f"add index {idx}") applied.append(f"add index {idx}")

View File

@@ -204,6 +204,12 @@ class IpStats(Base):
Boolean, default=False, nullable=True Boolean, default=False, nullable=True
) )
# Ban/rate-limit state (moved from in-memory tracker to DB)
page_visit_count: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_timestamp: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>" return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>"

View File

@@ -2,7 +2,10 @@
""" """
Memory cleanup task for Krawl honeypot. Memory cleanup task for Krawl honeypot.
Periodically cleans expired bans and stale entries from ip_page_visits.
NOTE: This task is no longer needed. Ban/rate-limit state has been moved from
in-memory ip_page_visits dict to the ip_stats DB table, eliminating unbounded
memory growth. Kept disabled for reference.
""" """
from logger import get_app_logger from logger import get_app_logger
@@ -13,8 +16,8 @@ from logger import get_app_logger
TASK_CONFIG = { TASK_CONFIG = {
"name": "memory-cleanup", "name": "memory-cleanup",
"cron": "*/5 * * * *", # Run every 5 minutes "cron": "*/5 * * * *",
"enabled": True, "enabled": False,
"run_when_loaded": False, "run_when_loaded": False,
} }
@@ -22,35 +25,4 @@ app_logger = get_app_logger()
def main(): def main():
""" app_logger.debug("memory-cleanup task is disabled (ban state now in DB)")
Clean up in-memory structures in the tracker.
Called periodically to prevent unbounded memory growth.
"""
try:
from tracker import get_tracker
tracker = get_tracker()
if not tracker:
app_logger.warning("Tracker not initialized, skipping memory cleanup")
return
stats_before = tracker.get_memory_stats()
tracker.cleanup_memory()
stats_after = tracker.get_memory_stats()
visits_reduced = stats_before["ip_page_visits"] - stats_after["ip_page_visits"]
if visits_reduced > 0:
app_logger.info(
f"Memory cleanup: Removed {visits_reduced} stale ip_page_visits entries"
)
app_logger.debug(
f"Memory stats after cleanup: "
f"ip_page_visits={stats_after['ip_page_visits']}"
)
except Exception as e:
app_logger.error(f"Error during memory cleanup: {e}")

View File

@@ -1,9 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Dict, Tuple, Optional from typing import Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
from zoneinfo import ZoneInfo
import re import re
import urllib.parse import urllib.parse
@@ -49,9 +46,6 @@ class AccessTracker:
self.max_pages_limit = max_pages_limit self.max_pages_limit = max_pages_limit
self.ban_duration_seconds = ban_duration_seconds self.ban_duration_seconds = ban_duration_seconds
# Track pages visited by each IP (for good crawler limiting)
self.ip_page_visits: Dict[str, Dict[str, object]] = defaultdict(dict)
# Load suspicious patterns from wordlists # Load suspicious patterns from wordlists
wl = get_wordlists() wl = get_wordlists()
self.suspicious_patterns = wl.suspicious_patterns self.suspicious_patterns = wl.suspicious_patterns
@@ -372,14 +366,7 @@ class AccessTracker:
def increment_page_visit(self, client_ip: str) -> int: def increment_page_visit(self, client_ip: str) -> int:
""" """
Increment page visit counter for an IP and return the new count. Increment page visit counter for an IP via DB and return the new count.
Implements incremental bans: each violation increases ban duration exponentially.
Ban duration formula: base_duration * (2 ^ violation_count)
- 1st violation: base_duration (e.g., 60 seconds)
- 2nd violation: base_duration * 2 (120 seconds)
- 3rd violation: base_duration * 4 (240 seconds)
- Nth violation: base_duration * 2^(N-1)
Args: Args:
client_ip: The client IP address client_ip: The client IP address
@@ -387,7 +374,6 @@ class AccessTracker:
Returns: Returns:
The updated page visit count for this IP The updated page visit count for this IP
""" """
# Skip if this is the server's own IP
from config import get_config from config import get_config
config = get_config() config = get_config()
@@ -395,85 +381,24 @@ class AccessTracker:
if server_ip and client_ip == server_ip: if server_ip and client_ip == server_ip:
return 0 return 0
try: if not self.db:
# Initialize if not exists
if client_ip not in self.ip_page_visits:
self.ip_page_visits[client_ip] = {
"count": 0,
"ban_timestamp": None,
"total_violations": 0,
"ban_multiplier": 1,
}
# Increment count
self.ip_page_visits[client_ip]["count"] += 1
# Set ban if reached limit
if self.ip_page_visits[client_ip]["count"] >= self.max_pages_limit:
# Increment violation counter
self.ip_page_visits[client_ip]["total_violations"] += 1
violations = self.ip_page_visits[client_ip]["total_violations"]
# Calculate exponential ban multiplier: 2^(violations - 1)
# Violation 1: 2^0 = 1x
# Violation 2: 2^1 = 2x
# Violation 3: 2^2 = 4x
# Violation 4: 2^3 = 8x, etc.
self.ip_page_visits[client_ip]["ban_multiplier"] = 2 ** (violations - 1)
# Set ban timestamp
self.ip_page_visits[client_ip][
"ban_timestamp"
] = datetime.now().isoformat()
return self.ip_page_visits[client_ip]["count"]
except Exception:
return 0 return 0
return self.db.increment_page_visit(client_ip, self.max_pages_limit)
def is_banned_ip(self, client_ip: str) -> bool: def is_banned_ip(self, client_ip: str) -> bool:
""" """
Check if an IP is currently banned due to exceeding page visit limits. Check if an IP is currently banned.
Uses incremental ban duration based on violation count.
Ban duration = base_duration * (2 ^ (violations - 1))
Each time an IP is banned again, duration doubles.
Args: Args:
client_ip: The client IP address client_ip: The client IP address
Returns: Returns:
True if the IP is banned, False otherwise True if the IP is banned, False otherwise
""" """
try: if not self.db:
if client_ip in self.ip_page_visits:
ban_timestamp = self.ip_page_visits[client_ip].get("ban_timestamp")
if ban_timestamp is not None:
# Get the ban multiplier for this violation
ban_multiplier = self.ip_page_visits[client_ip].get(
"ban_multiplier", 1
)
# Calculate effective ban duration based on violations
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
# Check if ban period has expired
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
if time_diff.total_seconds() > effective_ban_duration:
# Ban expired, reset for next cycle
# Keep violation count for next offense
self.ip_page_visits[client_ip]["count"] = 0
self.ip_page_visits[client_ip]["ban_timestamp"] = None
return False
else:
# Still banned
return True
return False return False
except Exception: return self.db.is_banned_ip(client_ip, self.ban_duration_seconds)
return False
def get_ban_info(self, client_ip: str) -> dict: def get_ban_info(self, client_ip: str) -> dict:
""" """
@@ -482,64 +407,15 @@ class AccessTracker:
Returns: Returns:
Dictionary with ban status, violations, and remaining ban time Dictionary with ban status, violations, and remaining ban time
""" """
try: if not self.db:
if client_ip not in self.ip_page_visits:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
ip_data = self.ip_page_visits[client_ip]
ban_timestamp = ip_data.get("ban_timestamp")
if ban_timestamp is None:
return {
"is_banned": False,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ip_data.get("ban_multiplier", 1),
"remaining_ban_seconds": 0,
}
# Ban is active, calculate remaining time
ban_multiplier = ip_data.get("ban_multiplier", 1)
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
remaining_seconds = max(
0, effective_ban_duration - time_diff.total_seconds()
)
return {
"is_banned": remaining_seconds > 0,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ban_multiplier,
"effective_ban_duration_seconds": effective_ban_duration,
"remaining_ban_seconds": remaining_seconds,
}
except Exception:
return { return {
"is_banned": False, "is_banned": False,
"violations": 0, "violations": 0,
"ban_multiplier": 1, "ban_multiplier": 1,
"remaining_ban_seconds": 0, "remaining_ban_seconds": 0,
} }
"""
Get the current page visit count for an IP.
Args: return self.db.get_ban_info(client_ip, self.ban_duration_seconds)
client_ip: The client IP address
Returns:
The page visit count for this IP
"""
try:
return self.ip_page_visits.get(client_ip, 0)
except Exception:
return 0
def get_stats(self) -> Dict: def get_stats(self) -> Dict:
"""Get statistics summary from database.""" """Get statistics summary from database."""
@@ -560,44 +436,3 @@ class AccessTracker:
return stats return stats
def cleanup_memory(self) -> None:
"""
Clean up in-memory structures to prevent unbounded growth.
Should be called periodically (e.g., every 5 minutes).
"""
# Clean expired ban entries from ip_page_visits
current_time = datetime.now()
for ip, data in self.ip_page_visits.items():
ban_timestamp = data.get("ban_timestamp")
if ban_timestamp is not None:
try:
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = (current_time - ban_time).total_seconds()
effective_duration = self.ban_duration_seconds * data.get(
"ban_multiplier", 1
)
if time_diff > effective_duration:
data["count"] = 0
data["ban_timestamp"] = None
except (ValueError, TypeError):
pass
# Remove IPs with zero activity and no active ban
ips_to_remove = [
ip
for ip, data in self.ip_page_visits.items()
if data.get("count", 0) == 0 and data.get("ban_timestamp") is None
]
for ip in ips_to_remove:
del self.ip_page_visits[ip]
def get_memory_stats(self) -> Dict[str, int]:
"""
Get current memory usage statistics for monitoring.
Returns:
Dictionary with counts of in-memory items
"""
return {
"ip_page_visits": len(self.ip_page_visits),
}