feat: implement IP ban and rate-limiting logic in database with migration scripts

This commit is contained in:
Lorenzo Venerandi
2026-02-22 16:23:52 +01:00
parent db848e7ecb
commit 8ce8b6b40a
5 changed files with 191 additions and 209 deletions

View File

@@ -344,6 +344,153 @@ class DatabaseManager:
)
session.add(ip_stats)
def increment_page_visit(self, ip: str, max_pages_limit: int) -> int:
"""
Increment the page visit counter for an IP and apply ban if limit reached.
Args:
ip: Client IP address
max_pages_limit: Page visit threshold before banning
Returns:
The updated page visit count
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats:
now = datetime.now()
ip_stats = IpStats(
ip=sanitized_ip,
total_requests=0,
first_seen=now,
last_seen=now,
page_visit_count=1,
)
session.add(ip_stats)
session.commit()
return 1
ip_stats.page_visit_count = (ip_stats.page_visit_count or 0) + 1
if ip_stats.page_visit_count >= max_pages_limit:
ip_stats.total_violations = (ip_stats.total_violations or 0) + 1
ip_stats.ban_multiplier = 2 ** (ip_stats.total_violations - 1)
ip_stats.ban_timestamp = datetime.now()
session.commit()
return ip_stats.page_visit_count
except Exception as e:
session.rollback()
applogger.error(f"Error incrementing page visit for {ip}: {e}")
return 0
finally:
self.close_session()
def is_banned_ip(self, ip: str, ban_duration_seconds: int) -> bool:
"""
Check if an IP is currently banned.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
True if the IP is currently banned
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats or ip_stats.ban_timestamp is None:
return False
effective_duration = ban_duration_seconds * (ip_stats.ban_multiplier or 1)
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
if elapsed > effective_duration:
# Ban expired — reset count for next cycle
ip_stats.page_visit_count = 0
ip_stats.ban_timestamp = None
session.commit()
return False
return True
except Exception as e:
applogger.error(f"Error checking ban status for {ip}: {e}")
return False
finally:
self.close_session()
def get_ban_info(self, ip: str, ban_duration_seconds: int) -> dict:
"""
Get detailed ban information for an IP.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
Dictionary with ban status details
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = (
session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
)
if not ip_stats:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
violations = ip_stats.total_violations or 0
multiplier = ip_stats.ban_multiplier or 1
if ip_stats.ban_timestamp is None:
return {
"is_banned": False,
"violations": violations,
"ban_multiplier": multiplier,
"remaining_ban_seconds": 0,
}
effective_duration = ban_duration_seconds * multiplier
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
remaining = max(0, effective_duration - elapsed)
return {
"is_banned": remaining > 0,
"violations": violations,
"ban_multiplier": multiplier,
"effective_ban_duration_seconds": effective_duration,
"remaining_ban_seconds": remaining,
}
except Exception as e:
applogger.error(f"Error getting ban info for {ip}: {e}")
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
finally:
self.close_session()
def update_ip_stats_analysis(
self,
ip: str,

View File

@@ -48,6 +48,24 @@ def _migrate_need_reevaluation_column(cursor) -> bool:
return True
def _migrate_ban_state_columns(cursor) -> List[str]:
"""Add ban/rate-limit columns to ip_stats if missing."""
added = []
columns = {
"page_visit_count": "INTEGER DEFAULT 0",
"ban_timestamp": "DATETIME",
"total_violations": "INTEGER DEFAULT 0",
"ban_multiplier": "INTEGER DEFAULT 1",
}
for col_name, col_type in columns.items():
if not _column_exists(cursor, "ip_stats", col_name):
cursor.execute(
f"ALTER TABLE ip_stats ADD COLUMN {col_name} {col_type}"
)
added.append(col_name)
return added
def _migrate_performance_indexes(cursor) -> List[str]:
"""Add performance indexes to attack_detections if missing."""
added = []
@@ -90,6 +108,10 @@ def run_migrations(database_path: str) -> None:
if _migrate_need_reevaluation_column(cursor):
applied.append("add need_reevaluation column to ip_stats")
ban_cols = _migrate_ban_state_columns(cursor)
for col in ban_cols:
applied.append(f"add {col} column to ip_stats")
idx_added = _migrate_performance_indexes(cursor)
for idx in idx_added:
applied.append(f"add index {idx}")

View File

@@ -204,6 +204,12 @@ class IpStats(Base):
Boolean, default=False, nullable=True
)
# Ban/rate-limit state (moved from in-memory tracker to DB)
page_visit_count: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_timestamp: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True)
def __repr__(self) -> str:
return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>"

View File

@@ -2,7 +2,10 @@
"""
Memory cleanup task for Krawl honeypot.
Periodically cleans expired bans and stale entries from ip_page_visits.
NOTE: This task is no longer needed. Ban/rate-limit state has been moved from
in-memory ip_page_visits dict to the ip_stats DB table, eliminating unbounded
memory growth. Kept disabled for reference.
"""
from logger import get_app_logger
@@ -13,8 +16,8 @@ from logger import get_app_logger
TASK_CONFIG = {
"name": "memory-cleanup",
"cron": "*/5 * * * *", # Run every 5 minutes
"enabled": True,
"cron": "*/5 * * * *",
"enabled": False,
"run_when_loaded": False,
}
@@ -22,35 +25,4 @@ app_logger = get_app_logger()
def main():
"""
Clean up in-memory structures in the tracker.
Called periodically to prevent unbounded memory growth.
"""
try:
from tracker import get_tracker
tracker = get_tracker()
if not tracker:
app_logger.warning("Tracker not initialized, skipping memory cleanup")
return
stats_before = tracker.get_memory_stats()
tracker.cleanup_memory()
stats_after = tracker.get_memory_stats()
visits_reduced = stats_before["ip_page_visits"] - stats_after["ip_page_visits"]
if visits_reduced > 0:
app_logger.info(
f"Memory cleanup: Removed {visits_reduced} stale ip_page_visits entries"
)
app_logger.debug(
f"Memory stats after cleanup: "
f"ip_page_visits={stats_after['ip_page_visits']}"
)
except Exception as e:
app_logger.error(f"Error during memory cleanup: {e}")
app_logger.debug("memory-cleanup task is disabled (ban state now in DB)")

View File

@@ -1,9 +1,6 @@
#!/usr/bin/env python3
from typing import Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
from zoneinfo import ZoneInfo
import re
import urllib.parse
@@ -49,9 +46,6 @@ class AccessTracker:
self.max_pages_limit = max_pages_limit
self.ban_duration_seconds = ban_duration_seconds
# Track pages visited by each IP (for good crawler limiting)
self.ip_page_visits: Dict[str, Dict[str, object]] = defaultdict(dict)
# Load suspicious patterns from wordlists
wl = get_wordlists()
self.suspicious_patterns = wl.suspicious_patterns
@@ -372,14 +366,7 @@ class AccessTracker:
def increment_page_visit(self, client_ip: str) -> int:
"""
Increment page visit counter for an IP and return the new count.
Implements incremental bans: each violation increases ban duration exponentially.
Ban duration formula: base_duration * (2 ^ violation_count)
- 1st violation: base_duration (e.g., 60 seconds)
- 2nd violation: base_duration * 2 (120 seconds)
- 3rd violation: base_duration * 4 (240 seconds)
- Nth violation: base_duration * 2^(N-1)
Increment page visit counter for an IP via DB and return the new count.
Args:
client_ip: The client IP address
@@ -387,7 +374,6 @@ class AccessTracker:
Returns:
The updated page visit count for this IP
"""
# Skip if this is the server's own IP
from config import get_config
config = get_config()
@@ -395,85 +381,24 @@ class AccessTracker:
if server_ip and client_ip == server_ip:
return 0
try:
# Initialize if not exists
if client_ip not in self.ip_page_visits:
self.ip_page_visits[client_ip] = {
"count": 0,
"ban_timestamp": None,
"total_violations": 0,
"ban_multiplier": 1,
}
# Increment count
self.ip_page_visits[client_ip]["count"] += 1
# Set ban if reached limit
if self.ip_page_visits[client_ip]["count"] >= self.max_pages_limit:
# Increment violation counter
self.ip_page_visits[client_ip]["total_violations"] += 1
violations = self.ip_page_visits[client_ip]["total_violations"]
# Calculate exponential ban multiplier: 2^(violations - 1)
# Violation 1: 2^0 = 1x
# Violation 2: 2^1 = 2x
# Violation 3: 2^2 = 4x
# Violation 4: 2^3 = 8x, etc.
self.ip_page_visits[client_ip]["ban_multiplier"] = 2 ** (violations - 1)
# Set ban timestamp
self.ip_page_visits[client_ip][
"ban_timestamp"
] = datetime.now().isoformat()
return self.ip_page_visits[client_ip]["count"]
except Exception:
if not self.db:
return 0
return self.db.increment_page_visit(client_ip, self.max_pages_limit)
def is_banned_ip(self, client_ip: str) -> bool:
"""
Check if an IP is currently banned due to exceeding page visit limits.
Uses incremental ban duration based on violation count.
Ban duration = base_duration * (2 ^ (violations - 1))
Each time an IP is banned again, duration doubles.
Check if an IP is currently banned.
Args:
client_ip: The client IP address
Returns:
True if the IP is banned, False otherwise
"""
try:
if client_ip in self.ip_page_visits:
ban_timestamp = self.ip_page_visits[client_ip].get("ban_timestamp")
if ban_timestamp is not None:
# Get the ban multiplier for this violation
ban_multiplier = self.ip_page_visits[client_ip].get(
"ban_multiplier", 1
)
# Calculate effective ban duration based on violations
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
# Check if ban period has expired
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
if time_diff.total_seconds() > effective_ban_duration:
# Ban expired, reset for next cycle
# Keep violation count for next offense
self.ip_page_visits[client_ip]["count"] = 0
self.ip_page_visits[client_ip]["ban_timestamp"] = None
return False
else:
# Still banned
return True
if not self.db:
return False
except Exception:
return False
return self.db.is_banned_ip(client_ip, self.ban_duration_seconds)
def get_ban_info(self, client_ip: str) -> dict:
"""
@@ -482,64 +407,15 @@ class AccessTracker:
Returns:
Dictionary with ban status, violations, and remaining ban time
"""
try:
if client_ip not in self.ip_page_visits:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
ip_data = self.ip_page_visits[client_ip]
ban_timestamp = ip_data.get("ban_timestamp")
if ban_timestamp is None:
return {
"is_banned": False,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ip_data.get("ban_multiplier", 1),
"remaining_ban_seconds": 0,
}
# Ban is active, calculate remaining time
ban_multiplier = ip_data.get("ban_multiplier", 1)
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
remaining_seconds = max(
0, effective_ban_duration - time_diff.total_seconds()
)
return {
"is_banned": remaining_seconds > 0,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ban_multiplier,
"effective_ban_duration_seconds": effective_ban_duration,
"remaining_ban_seconds": remaining_seconds,
}
except Exception:
if not self.db:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
"""
Get the current page visit count for an IP.
Args:
client_ip: The client IP address
Returns:
The page visit count for this IP
"""
try:
return self.ip_page_visits.get(client_ip, 0)
except Exception:
return 0
return self.db.get_ban_info(client_ip, self.ban_duration_seconds)
def get_stats(self) -> Dict:
"""Get statistics summary from database."""
@@ -560,44 +436,3 @@ class AccessTracker:
return stats
def cleanup_memory(self) -> None:
"""
Clean up in-memory structures to prevent unbounded growth.
Should be called periodically (e.g., every 5 minutes).
"""
# Clean expired ban entries from ip_page_visits
current_time = datetime.now()
for ip, data in self.ip_page_visits.items():
ban_timestamp = data.get("ban_timestamp")
if ban_timestamp is not None:
try:
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = (current_time - ban_time).total_seconds()
effective_duration = self.ban_duration_seconds * data.get(
"ban_multiplier", 1
)
if time_diff > effective_duration:
data["count"] = 0
data["ban_timestamp"] = None
except (ValueError, TypeError):
pass
# Remove IPs with zero activity and no active ban
ips_to_remove = [
ip
for ip, data in self.ip_page_visits.items()
if data.get("count", 0) == 0 and data.get("ban_timestamp") is None
]
for ip in ips_to_remove:
del self.ip_page_visits[ip]
def get_memory_stats(self) -> Dict[str, int]:
"""
Get current memory usage statistics for monitoring.
Returns:
Dictionary with counts of in-memory items
"""
return {
"ip_page_visits": len(self.ip_page_visits),
}