From 8a651b00f94b618118f58bd9ae5b707638d724af Mon Sep 17 00:00:00 2001 From: Lorenzo Venerandi Date: Sun, 8 Mar 2026 12:26:02 +0100 Subject: [PATCH] feat: add ban override management to IP statistics model and database manager --- src/database.py | 97 +++++++++++++++++++++++++++++++++++++++++++++++++ src/models.py | 5 +++ 2 files changed, 102 insertions(+) diff --git a/src/database.py b/src/database.py index 803e7e7..f7b0a59 100644 --- a/src/database.py +++ b/src/database.py @@ -2231,6 +2231,103 @@ class DatabaseManager: finally: self.close_session() + # ── Ban Override Management ────────────────────────────────────────── + + def set_ban_override(self, ip: str, override: Optional[bool]) -> bool: + """ + Set ban override for an IP. + override=True: force into banlist + override=False: force remove from banlist + override=None: reset to automatic (category-based) + + Returns True if the IP exists and was updated. + """ + session = self.session + sanitized_ip = sanitize_ip(ip) + ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first() + if not ip_stats: + return False + + ip_stats.ban_override = override + try: + session.commit() + return True + except Exception as e: + session.rollback() + applogger.error(f"Error setting ban override for {sanitized_ip}: {e}") + return False + + def force_ban_ip(self, ip: str) -> bool: + """ + Force-ban an IP that may not exist in ip_stats yet. + Creates a minimal entry if needed. + """ + session = self.session + sanitized_ip = sanitize_ip(ip) + ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first() + if not ip_stats: + ip_stats = IpStats( + ip=sanitized_ip, + total_requests=0, + first_seen=datetime.utcnow(), + last_seen=datetime.utcnow(), + ) + session.add(ip_stats) + + ip_stats.ban_override = True + try: + session.commit() + return True + except Exception as e: + session.rollback() + applogger.error(f"Error force-banning {sanitized_ip}: {e}") + return False + + def get_ban_overrides_paginated( + self, + page: int = 1, + page_size: int = 25, + ) -> Dict[str, Any]: + """Get all IPs with a non-null ban_override, paginated.""" + session = self.session + try: + base_query = session.query(IpStats).filter(IpStats.ban_override.isnot(None)) + total = base_query.count() + total_pages = max(1, (total + page_size - 1) // page_size) + + results = ( + base_query.order_by(IpStats.last_seen.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + + overrides = [] + for r in results: + overrides.append( + { + "ip": r.ip, + "ban_override": r.ban_override, + "category": r.category, + "total_requests": r.total_requests, + "country_code": r.country_code, + "city": r.city, + "last_seen": r.last_seen.isoformat() if r.last_seen else None, + } + ) + + return { + "overrides": overrides, + "pagination": { + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages, + }, + } + finally: + self.close_session() + # Module-level singleton instance _db_manager = DatabaseManager() diff --git a/src/models.py b/src/models.py index 8fb6e26..d759e52 100644 --- a/src/models.py +++ b/src/models.py @@ -210,6 +210,11 @@ class IpStats(Base): total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True) ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True) + # Admin ban override: True = force ban, False = force unban, None = automatic + ban_override: Mapped[Optional[bool]] = mapped_column( + Boolean, nullable=True, default=None + ) + def __repr__(self) -> str: return f""