diff --git a/src/migrations/runner.py b/src/migrations/runner.py index 7a74267..ebb5e0c 100644 --- a/src/migrations/runner.py +++ b/src/migrations/runner.py @@ -84,6 +84,14 @@ def _migrate_performance_indexes(cursor) -> List[str]: return added +def _migrate_ban_override_column(cursor) -> bool: + """Add ban_override column to ip_stats if missing.""" + if _column_exists(cursor, "ip_stats", "ban_override"): + return False + cursor.execute("ALTER TABLE ip_stats ADD COLUMN ban_override BOOLEAN DEFAULT NULL") + return True + + def run_migrations(database_path: str) -> None: """ Check the database schema and apply any pending migrations. @@ -110,6 +118,9 @@ def run_migrations(database_path: str) -> None: for col in ban_cols: applied.append(f"add {col} column to ip_stats") + if _migrate_ban_override_column(cursor): + applied.append("add ban_override column to ip_stats") + idx_added = _migrate_performance_indexes(cursor) for idx in idx_added: applied.append(f"add index {idx}") diff --git a/src/tasks/top_attacking_ips.py b/src/tasks/top_attacking_ips.py index 69d417b..3e16134 100644 --- a/src/tasks/top_attacking_ips.py +++ b/src/tasks/top_attacking_ips.py @@ -45,9 +45,25 @@ def main(): session = db.session # Query attacker IPs from IpStats (same as dashboard "Attackers by Total Requests") - attackers = ( + # Also include IPs with ban_override=True (force-banned by admin) + # Exclude IPs with ban_override=False (force-unbanned by admin) + from sqlalchemy import or_, and_ + + banned_ips = ( session.query(IpStats) - .filter(IpStats.category == "attacker") + .filter( + or_( + # Automatic: attacker category without explicit unban + and_( + IpStats.category == "attacker", + or_( + IpStats.ban_override.is_(None), IpStats.ban_override == True + ), + ), + # Manual: force-banned by admin regardless of category + IpStats.ban_override == True, + ) + ) .order_by(IpStats.total_requests.desc()) .all() ) @@ -56,9 +72,7 @@ def main(): server_ip = config.get_server_ip() public_ips = [ - attacker.ip - for attacker in attackers - if is_valid_public_ip(attacker.ip, server_ip) + entry.ip for entry in banned_ips if is_valid_public_ip(entry.ip, server_ip) ] # Ensure exports directory exists @@ -81,7 +95,7 @@ def main(): app_logger.info( f"[Background Task] {task_name} exported {len(public_ips)} in {fwname} public IPs" - f"(filtered {len(attackers) - len(public_ips)} local/private IPs) to {output_file}" + f"(filtered {len(banned_ips) - len(public_ips)} local/private IPs) to {output_file}" ) except Exception as e: