From b94cd38b616c52643468a6c6fc2090d2301698b0 Mon Sep 17 00:00:00 2001 From: Lorenzo Venerandi Date: Tue, 17 Feb 2026 18:04:09 +0100 Subject: [PATCH] refactor: optimize database access and implement retention task --- src/database.py | 71 +++++++++++++++++++++------------ src/tasks/db_retention.py | 83 +++++++++++++++++++++++++++++++++++++++ src/tracker.py | 22 +++++------ 3 files changed, 139 insertions(+), 37 deletions(-) create mode 100644 src/tasks/db_retention.py diff --git a/src/database.py b/src/database.py index 54d8636..30e2fb5 100644 --- a/src/database.py +++ b/src/database.py @@ -1009,6 +1009,27 @@ class DatabaseManager: finally: self.close_session() + def _public_ip_filter(self, query, ip_column, server_ip: Optional[str] = None): + """Apply SQL-level filters to exclude local/private IPs and server IP.""" + query = query.filter( + ~ip_column.like("10.%"), + ~ip_column.like("172.16.%"), + ~ip_column.like("172.17.%"), + ~ip_column.like("172.18.%"), + ~ip_column.like("172.19.%"), + ~ip_column.like("172.2_.%"), + ~ip_column.like("172.30.%"), + ~ip_column.like("172.31.%"), + ~ip_column.like("192.168.%"), + ~ip_column.like("127.%"), + ~ip_column.like("0.%"), + ~ip_column.like("169.254.%"), + ip_column != "::1", + ) + if server_ip: + query = query.filter(ip_column != server_ip) + return query + def get_dashboard_counts(self) -> Dict[str, int]: """ Get aggregate statistics for the dashboard (excludes local/private IPs and server IP). @@ -1019,43 +1040,43 @@ class DatabaseManager: """ session = self.session try: - # Get server IP to filter it out from config import get_config config = get_config() server_ip = config.get_server_ip() - # Get all accesses first, then filter out local IPs and server IP - all_accesses = session.query(AccessLog).all() - - # Filter out local/private IPs and server IP - public_accesses = [ - log for log in all_accesses if is_valid_public_ip(log.ip, server_ip) - ] - - # Calculate counts from filtered data - total_accesses = len(public_accesses) - unique_ips = len(set(log.ip for log in public_accesses)) - unique_paths = len(set(log.path for log in public_accesses)) - suspicious_accesses = sum(1 for log in public_accesses if log.is_suspicious) - honeypot_triggered = sum( - 1 for log in public_accesses if log.is_honeypot_trigger - ) - honeypot_ips = len( - set(log.ip for log in public_accesses if log.is_honeypot_trigger) + # Single aggregation query instead of loading all rows + base = session.query( + func.count(AccessLog.id).label("total_accesses"), + func.count(distinct(AccessLog.ip)).label("unique_ips"), + func.count(distinct(AccessLog.path)).label("unique_paths"), + func.count(case((AccessLog.is_suspicious == True, 1))).label( + "suspicious_accesses" + ), + func.count(case((AccessLog.is_honeypot_trigger == True, 1))).label( + "honeypot_triggered" + ), ) + base = self._public_ip_filter(base, AccessLog.ip, server_ip) + row = base.one() + + # Honeypot unique IPs (separate query for distinct on filtered subset) + hp_query = session.query( + func.count(distinct(AccessLog.ip)) + ).filter(AccessLog.is_honeypot_trigger == True) + hp_query = self._public_ip_filter(hp_query, AccessLog.ip, server_ip) + honeypot_ips = hp_query.scalar() or 0 - # Count unique attackers from IpStats (matching the "Attackers by Total Requests" table) unique_attackers = ( session.query(IpStats).filter(IpStats.category == "attacker").count() ) return { - "total_accesses": total_accesses, - "unique_ips": unique_ips, - "unique_paths": unique_paths, - "suspicious_accesses": suspicious_accesses, - "honeypot_triggered": honeypot_triggered, + "total_accesses": row.total_accesses or 0, + "unique_ips": row.unique_ips or 0, + "unique_paths": row.unique_paths or 0, + "suspicious_accesses": row.suspicious_accesses or 0, + "honeypot_triggered": row.honeypot_triggered or 0, "honeypot_ips": honeypot_ips, "unique_attackers": unique_attackers, } diff --git a/src/tasks/db_retention.py b/src/tasks/db_retention.py new file mode 100644 index 0000000..bcfe1df --- /dev/null +++ b/src/tasks/db_retention.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +""" +Database retention task for Krawl honeypot. +Periodically deletes old records based on configured retention_days. +""" + +from datetime import datetime, timedelta + +from database import get_database +from logger import get_app_logger + +# ---------------------- +# TASK CONFIG +# ---------------------- + +TASK_CONFIG = { + "name": "db-retention", + "cron": "0 3 * * *", # Run daily at 3 AM + "enabled": True, + "run_when_loaded": False, +} + +app_logger = get_app_logger() + + +def main(): + """ + Delete access logs, credential attempts, and attack detections + older than the configured retention period. + """ + try: + from config import get_config + from models import AccessLog, CredentialAttempt, AttackDetection + + config = get_config() + retention_days = config.database_retention_days + + db = get_database() + session = db.session + + cutoff = datetime.now() - timedelta(days=retention_days) + + # Delete attack detections linked to old access logs first (FK constraint) + old_log_ids = session.query(AccessLog.id).filter( + AccessLog.timestamp < cutoff + ) + detections_deleted = ( + session.query(AttackDetection) + .filter(AttackDetection.access_log_id.in_(old_log_ids)) + .delete(synchronize_session=False) + ) + + # Delete old access logs + logs_deleted = ( + session.query(AccessLog) + .filter(AccessLog.timestamp < cutoff) + .delete(synchronize_session=False) + ) + + # Delete old credential attempts + creds_deleted = ( + session.query(CredentialAttempt) + .filter(CredentialAttempt.timestamp < cutoff) + .delete(synchronize_session=False) + ) + + session.commit() + + if logs_deleted or creds_deleted or detections_deleted: + app_logger.info( + f"DB retention: Deleted {logs_deleted} access logs, " + f"{detections_deleted} attack detections, " + f"{creds_deleted} credential attempts older than {retention_days} days" + ) + + except Exception as e: + app_logger.error(f"Error during DB retention cleanup: {e}") + finally: + try: + db.close_session() + except Exception: + pass diff --git a/src/tracker.py b/src/tracker.py index 46965c5..cff9b5a 100644 --- a/src/tracker.py +++ b/src/tracker.py @@ -641,29 +641,27 @@ class AccessTracker: # Clean expired ban entries from ip_page_visits current_time = datetime.now() - ips_to_clean = [] for ip, data in self.ip_page_visits.items(): ban_timestamp = data.get("ban_timestamp") if ban_timestamp is not None: try: ban_time = datetime.fromisoformat(ban_timestamp) time_diff = (current_time - ban_time).total_seconds() - if time_diff > self.ban_duration_seconds: - # Ban expired, reset the entry + effective_duration = self.ban_duration_seconds * data.get("ban_multiplier", 1) + if time_diff > effective_duration: data["count"] = 0 data["ban_timestamp"] = None except (ValueError, TypeError): pass - # Optional: Remove IPs with zero activity (advanced cleanup) - # Comment out to keep indefinite history of zero-activity IPs - # ips_to_remove = [ - # ip - # for ip, data in self.ip_page_visits.items() - # if data.get("count", 0) == 0 and data.get("ban_timestamp") is None - # ] - # for ip in ips_to_remove: - # del self.ip_page_visits[ip] + # Remove IPs with zero activity and no active ban + ips_to_remove = [ + ip + for ip, data in self.ip_page_visits.items() + if data.get("count", 0) == 0 and data.get("ban_timestamp") is None + ] + for ip in ips_to_remove: + del self.ip_page_visits[ip] def get_memory_stats(self) -> Dict[str, int]: """