Merge pull request #91 from BlessedRebuS/feat-migration-scripts-init

Feat: migration scripts at startup, small variables refactoring
This commit is contained in:
Patrick Di Fazio
2026-02-22 16:29:32 +01:00
committed by GitHub
9 changed files with 339 additions and 221 deletions

View File

@@ -2,8 +2,8 @@ apiVersion: v2
name: krawl-chart name: krawl-chart
description: A Helm chart for Krawl honeypot server description: A Helm chart for Krawl honeypot server
type: application type: application
version: 1.0.5 version: 1.0.6
appVersion: 1.0.5 appVersion: 1.0.6
keywords: keywords:
- honeypot - honeypot
- security - security

View File

@@ -29,10 +29,11 @@ async def lifespan(app: FastAPI):
initialize_logging() initialize_logging()
app_logger = get_app_logger() app_logger = get_app_logger()
# Initialize database # Initialize database and run pending migrations before accepting traffic
try: try:
app_logger.info(f"Initializing database at: {config.database_path}")
initialize_database(config.database_path) initialize_database(config.database_path)
app_logger.info(f"Database initialized at: {config.database_path}") app_logger.info("Database ready")
except Exception as e: except Exception as e:
app_logger.warning( app_logger.warning(
f"Database initialization failed: {e}. Continuing with in-memory only." f"Database initialization failed: {e}. Continuing with in-memory only."

View File

@@ -97,6 +97,11 @@ class DatabaseManager:
# Run automatic migrations for backward compatibility # Run automatic migrations for backward compatibility
self._run_migrations(database_path) self._run_migrations(database_path)
# Run schema migrations (columns & indexes on existing tables)
from migrations.runner import run_migrations
run_migrations(database_path)
# Set restrictive file permissions (owner read/write only) # Set restrictive file permissions (owner read/write only)
if os.path.exists(database_path): if os.path.exists(database_path):
try: try:
@@ -256,7 +261,7 @@ class DatabaseManager:
session.add(detection) session.add(detection)
# Update IP stats # Update IP stats
self._update_ip_stats(session, ip) self._update_ip_stats(session, ip, is_suspicious)
session.commit() session.commit()
return access_log.id return access_log.id
@@ -308,13 +313,16 @@ class DatabaseManager:
finally: finally:
self.close_session() self.close_session()
def _update_ip_stats(self, session: Session, ip: str) -> None: def _update_ip_stats(
self, session: Session, ip: str, is_suspicious: bool = False
) -> None:
""" """
Update IP statistics (upsert pattern). Update IP statistics (upsert pattern).
Args: Args:
session: Active database session session: Active database session
ip: IP address to update ip: IP address to update
is_suspicious: Whether the request was flagged as suspicious
""" """
sanitized_ip = sanitize_ip(ip) sanitized_ip = sanitize_ip(ip)
now = datetime.now() now = datetime.now()
@@ -324,12 +332,159 @@ class DatabaseManager:
if ip_stats: if ip_stats:
ip_stats.total_requests += 1 ip_stats.total_requests += 1
ip_stats.last_seen = now ip_stats.last_seen = now
if is_suspicious:
ip_stats.need_reevaluation = True
else: else:
ip_stats = IpStats( ip_stats = IpStats(
ip=sanitized_ip, total_requests=1, first_seen=now, last_seen=now ip=sanitized_ip,
total_requests=1,
first_seen=now,
last_seen=now,
need_reevaluation=is_suspicious,
) )
session.add(ip_stats) session.add(ip_stats)
def increment_page_visit(self, ip: str, max_pages_limit: int) -> int:
"""
Increment the page visit counter for an IP and apply ban if limit reached.
Args:
ip: Client IP address
max_pages_limit: Page visit threshold before banning
Returns:
The updated page visit count
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
if not ip_stats:
now = datetime.now()
ip_stats = IpStats(
ip=sanitized_ip,
total_requests=0,
first_seen=now,
last_seen=now,
page_visit_count=1,
)
session.add(ip_stats)
session.commit()
return 1
ip_stats.page_visit_count = (ip_stats.page_visit_count or 0) + 1
if ip_stats.page_visit_count >= max_pages_limit:
ip_stats.total_violations = (ip_stats.total_violations or 0) + 1
ip_stats.ban_multiplier = 2 ** (ip_stats.total_violations - 1)
ip_stats.ban_timestamp = datetime.now()
session.commit()
return ip_stats.page_visit_count
except Exception as e:
session.rollback()
applogger.error(f"Error incrementing page visit for {ip}: {e}")
return 0
finally:
self.close_session()
def is_banned_ip(self, ip: str, ban_duration_seconds: int) -> bool:
"""
Check if an IP is currently banned.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
True if the IP is currently banned
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
if not ip_stats or ip_stats.ban_timestamp is None:
return False
effective_duration = ban_duration_seconds * (ip_stats.ban_multiplier or 1)
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
if elapsed > effective_duration:
# Ban expired — reset count for next cycle
ip_stats.page_visit_count = 0
ip_stats.ban_timestamp = None
session.commit()
return False
return True
except Exception as e:
applogger.error(f"Error checking ban status for {ip}: {e}")
return False
finally:
self.close_session()
def get_ban_info(self, ip: str, ban_duration_seconds: int) -> dict:
"""
Get detailed ban information for an IP.
Args:
ip: Client IP address
ban_duration_seconds: Base ban duration in seconds
Returns:
Dictionary with ban status details
"""
session = self.session
try:
sanitized_ip = sanitize_ip(ip)
ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
if not ip_stats:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
violations = ip_stats.total_violations or 0
multiplier = ip_stats.ban_multiplier or 1
if ip_stats.ban_timestamp is None:
return {
"is_banned": False,
"violations": violations,
"ban_multiplier": multiplier,
"remaining_ban_seconds": 0,
}
effective_duration = ban_duration_seconds * multiplier
elapsed = (datetime.now() - ip_stats.ban_timestamp).total_seconds()
remaining = max(0, effective_duration - elapsed)
return {
"is_banned": remaining > 0,
"violations": violations,
"ban_multiplier": multiplier,
"effective_ban_duration_seconds": effective_duration,
"remaining_ban_seconds": remaining,
}
except Exception as e:
applogger.error(f"Error getting ban info for {ip}: {e}")
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
finally:
self.close_session()
def update_ip_stats_analysis( def update_ip_stats_analysis(
self, self,
ip: str, ip: str,
@@ -380,6 +535,7 @@ class DatabaseManager:
ip_stats.category = category ip_stats.category = category
ip_stats.category_scores = category_scores ip_stats.category_scores = category_scores
ip_stats.last_analysis = last_analysis ip_stats.last_analysis = last_analysis
ip_stats.need_reevaluation = False
try: try:
session.commit() session.commit()
@@ -632,6 +788,24 @@ class DatabaseManager:
finally: finally:
self.close_session() self.close_session()
def get_ips_needing_reevaluation(self) -> List[str]:
"""
Get all IP addresses that have been flagged for reevaluation.
Returns:
List of IP addresses where need_reevaluation is True
"""
session = self.session
try:
ips = (
session.query(IpStats.ip)
.filter(IpStats.need_reevaluation == True)
.all()
)
return [ip[0] for ip in ips]
finally:
self.close_session()
def get_access_logs( def get_access_logs(
self, self,
limit: int = 100, limit: int = 100,

View File

127
src/migrations/runner.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Migration runner for Krawl.
Checks the database schema and applies any pending migrations at startup.
All checks are idempotent — safe to run on every boot.
Note: table creation (e.g. category_history) is already handled by
Base.metadata.create_all() in DatabaseManager.initialize() and is NOT
duplicated here. This runner only covers ALTER-level changes that
create_all() cannot apply to existing tables (new columns, new indexes).
"""
import sqlite3
import logging
from typing import List
logger = logging.getLogger("krawl")
def _column_exists(cursor, table_name: str, column_name: str) -> bool:
cursor.execute(f"PRAGMA table_info({table_name})")
columns = [row[1] for row in cursor.fetchall()]
return column_name in columns
def _index_exists(cursor, index_name: str) -> bool:
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND name=?",
(index_name,),
)
return cursor.fetchone() is not None
def _migrate_raw_request_column(cursor) -> bool:
"""Add raw_request column to access_logs if missing."""
if _column_exists(cursor, "access_logs", "raw_request"):
return False
cursor.execute("ALTER TABLE access_logs ADD COLUMN raw_request TEXT")
return True
def _migrate_need_reevaluation_column(cursor) -> bool:
"""Add need_reevaluation column to ip_stats if missing."""
if _column_exists(cursor, "ip_stats", "need_reevaluation"):
return False
cursor.execute(
"ALTER TABLE ip_stats ADD COLUMN need_reevaluation BOOLEAN DEFAULT 0"
)
return True
def _migrate_ban_state_columns(cursor) -> List[str]:
"""Add ban/rate-limit columns to ip_stats if missing."""
added = []
columns = {
"page_visit_count": "INTEGER DEFAULT 0",
"ban_timestamp": "DATETIME",
"total_violations": "INTEGER DEFAULT 0",
"ban_multiplier": "INTEGER DEFAULT 1",
}
for col_name, col_type in columns.items():
if not _column_exists(cursor, "ip_stats", col_name):
cursor.execute(f"ALTER TABLE ip_stats ADD COLUMN {col_name} {col_type}")
added.append(col_name)
return added
def _migrate_performance_indexes(cursor) -> List[str]:
"""Add performance indexes to attack_detections if missing."""
added = []
if not _index_exists(cursor, "ix_attack_detections_attack_type"):
cursor.execute(
"CREATE INDEX ix_attack_detections_attack_type "
"ON attack_detections(attack_type)"
)
added.append("ix_attack_detections_attack_type")
if not _index_exists(cursor, "ix_attack_detections_type_log"):
cursor.execute(
"CREATE INDEX ix_attack_detections_type_log "
"ON attack_detections(attack_type, access_log_id)"
)
added.append("ix_attack_detections_type_log")
return added
def run_migrations(database_path: str) -> None:
"""
Check the database schema and apply any pending migrations.
Only handles ALTER-level changes (columns, indexes) that
Base.metadata.create_all() cannot apply to existing tables.
Args:
database_path: Path to the SQLite database file.
"""
applied: List[str] = []
try:
conn = sqlite3.connect(database_path)
cursor = conn.cursor()
if _migrate_raw_request_column(cursor):
applied.append("add raw_request column to access_logs")
if _migrate_need_reevaluation_column(cursor):
applied.append("add need_reevaluation column to ip_stats")
ban_cols = _migrate_ban_state_columns(cursor)
for col in ban_cols:
applied.append(f"add {col} column to ip_stats")
idx_added = _migrate_performance_indexes(cursor)
for idx in idx_added:
applied.append(f"add index {idx}")
conn.commit()
conn.close()
except sqlite3.Error as e:
logger.error(f"Migration error: {e}")
if applied:
for m in applied:
logger.info(f"Migration applied: {m}")
logger.info(f"All migrations complete ({len(applied)} applied)")
else:
logger.info("Database schema is up to date — no migrations needed")

View File

@@ -200,6 +200,15 @@ class IpStats(Base):
category_scores: Mapped[Dict[str, int]] = mapped_column(JSON, nullable=True) category_scores: Mapped[Dict[str, int]] = mapped_column(JSON, nullable=True)
manual_category: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True) manual_category: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
last_analysis: Mapped[datetime] = mapped_column(DateTime, nullable=True) last_analysis: Mapped[datetime] = mapped_column(DateTime, nullable=True)
need_reevaluation: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=True
)
# Ban/rate-limit state (moved from in-memory tracker to DB)
page_visit_count: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_timestamp: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>" return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>"

View File

@@ -94,12 +94,13 @@ def main():
"attack_url": 0, "attack_url": 0,
}, },
} }
# Get IPs with recent activity (last minute to match cron schedule) # Get IPs flagged for reevaluation (set when a suspicious request arrives)
recent_accesses = db_manager.get_access_logs(limit=999999999, since_minutes=1) ips_to_analyze = set(db_manager.get_ips_needing_reevaluation())
ips_to_analyze = {item["ip"] for item in recent_accesses}
if not ips_to_analyze: if not ips_to_analyze:
app_logger.debug("[Background Task] analyze-ips: No recent activity, skipping") app_logger.debug(
"[Background Task] analyze-ips: No IPs need reevaluation, skipping"
)
return return
for ip in ips_to_analyze: for ip in ips_to_analyze:

View File

@@ -2,7 +2,10 @@
""" """
Memory cleanup task for Krawl honeypot. Memory cleanup task for Krawl honeypot.
Periodically cleans expired bans and stale entries from ip_page_visits.
NOTE: This task is no longer needed. Ban/rate-limit state has been moved from
in-memory ip_page_visits dict to the ip_stats DB table, eliminating unbounded
memory growth. Kept disabled for reference.
""" """
from logger import get_app_logger from logger import get_app_logger
@@ -13,8 +16,8 @@ from logger import get_app_logger
TASK_CONFIG = { TASK_CONFIG = {
"name": "memory-cleanup", "name": "memory-cleanup",
"cron": "*/5 * * * *", # Run every 5 minutes "cron": "*/5 * * * *",
"enabled": True, "enabled": False,
"run_when_loaded": False, "run_when_loaded": False,
} }
@@ -22,35 +25,4 @@ app_logger = get_app_logger()
def main(): def main():
""" app_logger.debug("memory-cleanup task is disabled (ban state now in DB)")
Clean up in-memory structures in the tracker.
Called periodically to prevent unbounded memory growth.
"""
try:
from tracker import get_tracker
tracker = get_tracker()
if not tracker:
app_logger.warning("Tracker not initialized, skipping memory cleanup")
return
stats_before = tracker.get_memory_stats()
tracker.cleanup_memory()
stats_after = tracker.get_memory_stats()
visits_reduced = stats_before["ip_page_visits"] - stats_after["ip_page_visits"]
if visits_reduced > 0:
app_logger.info(
f"Memory cleanup: Removed {visits_reduced} stale ip_page_visits entries"
)
app_logger.debug(
f"Memory stats after cleanup: "
f"ip_page_visits={stats_after['ip_page_visits']}"
)
except Exception as e:
app_logger.error(f"Error during memory cleanup: {e}")

View File

@@ -1,9 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Dict, Tuple, Optional from typing import Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
from zoneinfo import ZoneInfo
import re import re
import urllib.parse import urllib.parse
@@ -49,9 +46,6 @@ class AccessTracker:
self.max_pages_limit = max_pages_limit self.max_pages_limit = max_pages_limit
self.ban_duration_seconds = ban_duration_seconds self.ban_duration_seconds = ban_duration_seconds
# Track pages visited by each IP (for good crawler limiting)
self.ip_page_visits: Dict[str, Dict[str, object]] = defaultdict(dict)
# Load suspicious patterns from wordlists # Load suspicious patterns from wordlists
wl = get_wordlists() wl = get_wordlists()
self.suspicious_patterns = wl.suspicious_patterns self.suspicious_patterns = wl.suspicious_patterns
@@ -372,14 +366,7 @@ class AccessTracker:
def increment_page_visit(self, client_ip: str) -> int: def increment_page_visit(self, client_ip: str) -> int:
""" """
Increment page visit counter for an IP and return the new count. Increment page visit counter for an IP via DB and return the new count.
Implements incremental bans: each violation increases ban duration exponentially.
Ban duration formula: base_duration * (2 ^ violation_count)
- 1st violation: base_duration (e.g., 60 seconds)
- 2nd violation: base_duration * 2 (120 seconds)
- 3rd violation: base_duration * 4 (240 seconds)
- Nth violation: base_duration * 2^(N-1)
Args: Args:
client_ip: The client IP address client_ip: The client IP address
@@ -387,7 +374,6 @@ class AccessTracker:
Returns: Returns:
The updated page visit count for this IP The updated page visit count for this IP
""" """
# Skip if this is the server's own IP
from config import get_config from config import get_config
config = get_config() config = get_config()
@@ -395,85 +381,24 @@ class AccessTracker:
if server_ip and client_ip == server_ip: if server_ip and client_ip == server_ip:
return 0 return 0
try: if not self.db:
# Initialize if not exists
if client_ip not in self.ip_page_visits:
self.ip_page_visits[client_ip] = {
"count": 0,
"ban_timestamp": None,
"total_violations": 0,
"ban_multiplier": 1,
}
# Increment count
self.ip_page_visits[client_ip]["count"] += 1
# Set ban if reached limit
if self.ip_page_visits[client_ip]["count"] >= self.max_pages_limit:
# Increment violation counter
self.ip_page_visits[client_ip]["total_violations"] += 1
violations = self.ip_page_visits[client_ip]["total_violations"]
# Calculate exponential ban multiplier: 2^(violations - 1)
# Violation 1: 2^0 = 1x
# Violation 2: 2^1 = 2x
# Violation 3: 2^2 = 4x
# Violation 4: 2^3 = 8x, etc.
self.ip_page_visits[client_ip]["ban_multiplier"] = 2 ** (violations - 1)
# Set ban timestamp
self.ip_page_visits[client_ip][
"ban_timestamp"
] = datetime.now().isoformat()
return self.ip_page_visits[client_ip]["count"]
except Exception:
return 0 return 0
return self.db.increment_page_visit(client_ip, self.max_pages_limit)
def is_banned_ip(self, client_ip: str) -> bool: def is_banned_ip(self, client_ip: str) -> bool:
""" """
Check if an IP is currently banned due to exceeding page visit limits. Check if an IP is currently banned.
Uses incremental ban duration based on violation count.
Ban duration = base_duration * (2 ^ (violations - 1))
Each time an IP is banned again, duration doubles.
Args: Args:
client_ip: The client IP address client_ip: The client IP address
Returns: Returns:
True if the IP is banned, False otherwise True if the IP is banned, False otherwise
""" """
try: if not self.db:
if client_ip in self.ip_page_visits:
ban_timestamp = self.ip_page_visits[client_ip].get("ban_timestamp")
if ban_timestamp is not None:
# Get the ban multiplier for this violation
ban_multiplier = self.ip_page_visits[client_ip].get(
"ban_multiplier", 1
)
# Calculate effective ban duration based on violations
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
# Check if ban period has expired
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
if time_diff.total_seconds() > effective_ban_duration:
# Ban expired, reset for next cycle
# Keep violation count for next offense
self.ip_page_visits[client_ip]["count"] = 0
self.ip_page_visits[client_ip]["ban_timestamp"] = None
return False
else:
# Still banned
return True
return False return False
except Exception: return self.db.is_banned_ip(client_ip, self.ban_duration_seconds)
return False
def get_ban_info(self, client_ip: str) -> dict: def get_ban_info(self, client_ip: str) -> dict:
""" """
@@ -482,64 +407,15 @@ class AccessTracker:
Returns: Returns:
Dictionary with ban status, violations, and remaining ban time Dictionary with ban status, violations, and remaining ban time
""" """
try: if not self.db:
if client_ip not in self.ip_page_visits:
return {
"is_banned": False,
"violations": 0,
"ban_multiplier": 1,
"remaining_ban_seconds": 0,
}
ip_data = self.ip_page_visits[client_ip]
ban_timestamp = ip_data.get("ban_timestamp")
if ban_timestamp is None:
return {
"is_banned": False,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ip_data.get("ban_multiplier", 1),
"remaining_ban_seconds": 0,
}
# Ban is active, calculate remaining time
ban_multiplier = ip_data.get("ban_multiplier", 1)
effective_ban_duration = self.ban_duration_seconds * ban_multiplier
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = datetime.now() - ban_time
remaining_seconds = max(
0, effective_ban_duration - time_diff.total_seconds()
)
return {
"is_banned": remaining_seconds > 0,
"violations": ip_data.get("total_violations", 0),
"ban_multiplier": ban_multiplier,
"effective_ban_duration_seconds": effective_ban_duration,
"remaining_ban_seconds": remaining_seconds,
}
except Exception:
return { return {
"is_banned": False, "is_banned": False,
"violations": 0, "violations": 0,
"ban_multiplier": 1, "ban_multiplier": 1,
"remaining_ban_seconds": 0, "remaining_ban_seconds": 0,
} }
"""
Get the current page visit count for an IP.
Args: return self.db.get_ban_info(client_ip, self.ban_duration_seconds)
client_ip: The client IP address
Returns:
The page visit count for this IP
"""
try:
return self.ip_page_visits.get(client_ip, 0)
except Exception:
return 0
def get_stats(self) -> Dict: def get_stats(self) -> Dict:
"""Get statistics summary from database.""" """Get statistics summary from database."""
@@ -559,45 +435,3 @@ class AccessTracker:
stats["credential_attempts"] = self.db.get_credential_attempts(limit=50) stats["credential_attempts"] = self.db.get_credential_attempts(limit=50)
return stats return stats
def cleanup_memory(self) -> None:
"""
Clean up in-memory structures to prevent unbounded growth.
Should be called periodically (e.g., every 5 minutes).
"""
# Clean expired ban entries from ip_page_visits
current_time = datetime.now()
for ip, data in self.ip_page_visits.items():
ban_timestamp = data.get("ban_timestamp")
if ban_timestamp is not None:
try:
ban_time = datetime.fromisoformat(ban_timestamp)
time_diff = (current_time - ban_time).total_seconds()
effective_duration = self.ban_duration_seconds * data.get(
"ban_multiplier", 1
)
if time_diff > effective_duration:
data["count"] = 0
data["ban_timestamp"] = None
except (ValueError, TypeError):
pass
# Remove IPs with zero activity and no active ban
ips_to_remove = [
ip
for ip, data in self.ip_page_visits.items()
if data.get("count", 0) == 0 and data.get("ban_timestamp") is None
]
for ip in ips_to_remove:
del self.ip_page_visits[ip]
def get_memory_stats(self) -> Dict[str, int]:
"""
Get current memory usage statistics for monitoring.
Returns:
Dictionary with counts of in-memory items
"""
return {
"ip_page_visits": len(self.ip_page_visits),
}