feat: add ban override management to IP statistics model and database manager

This commit is contained in:
Lorenzo Venerandi
2026-03-08 12:26:02 +01:00
parent 53c23d2d18
commit 8a651b00f9
2 changed files with 102 additions and 0 deletions

View File

@@ -2231,6 +2231,103 @@ class DatabaseManager:
finally: finally:
self.close_session() self.close_session()
# ── Ban Override Management ──────────────────────────────────────────
def set_ban_override(self, ip: str, override: Optional[bool]) -> bool:
"""
Set ban override for an IP.
override=True: force into banlist
override=False: force remove from banlist
override=None: reset to automatic (category-based)
Returns True if the IP exists and was updated.
"""
session = self.session
sanitized_ip = sanitize_ip(ip)
ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
if not ip_stats:
return False
ip_stats.ban_override = override
try:
session.commit()
return True
except Exception as e:
session.rollback()
applogger.error(f"Error setting ban override for {sanitized_ip}: {e}")
return False
def force_ban_ip(self, ip: str) -> bool:
"""
Force-ban an IP that may not exist in ip_stats yet.
Creates a minimal entry if needed.
"""
session = self.session
sanitized_ip = sanitize_ip(ip)
ip_stats = session.query(IpStats).filter(IpStats.ip == sanitized_ip).first()
if not ip_stats:
ip_stats = IpStats(
ip=sanitized_ip,
total_requests=0,
first_seen=datetime.utcnow(),
last_seen=datetime.utcnow(),
)
session.add(ip_stats)
ip_stats.ban_override = True
try:
session.commit()
return True
except Exception as e:
session.rollback()
applogger.error(f"Error force-banning {sanitized_ip}: {e}")
return False
def get_ban_overrides_paginated(
self,
page: int = 1,
page_size: int = 25,
) -> Dict[str, Any]:
"""Get all IPs with a non-null ban_override, paginated."""
session = self.session
try:
base_query = session.query(IpStats).filter(IpStats.ban_override.isnot(None))
total = base_query.count()
total_pages = max(1, (total + page_size - 1) // page_size)
results = (
base_query.order_by(IpStats.last_seen.desc())
.offset((page - 1) * page_size)
.limit(page_size)
.all()
)
overrides = []
for r in results:
overrides.append(
{
"ip": r.ip,
"ban_override": r.ban_override,
"category": r.category,
"total_requests": r.total_requests,
"country_code": r.country_code,
"city": r.city,
"last_seen": r.last_seen.isoformat() if r.last_seen else None,
}
)
return {
"overrides": overrides,
"pagination": {
"page": page,
"page_size": page_size,
"total": total,
"total_pages": total_pages,
},
}
finally:
self.close_session()
# Module-level singleton instance # Module-level singleton instance
_db_manager = DatabaseManager() _db_manager = DatabaseManager()

View File

@@ -210,6 +210,11 @@ class IpStats(Base):
total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True) total_violations: Mapped[int] = mapped_column(Integer, default=0, nullable=True)
ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True) ban_multiplier: Mapped[int] = mapped_column(Integer, default=1, nullable=True)
# Admin ban override: True = force ban, False = force unban, None = automatic
ban_override: Mapped[Optional[bool]] = mapped_column(
Boolean, nullable=True, default=None
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>" return f"<IpStats(ip='{self.ip}', total_requests={self.total_requests})>"