refactor: optimize database queries by utilizing IpStats for performance improvements

This commit is contained in:
Lorenzo Venerandi
2026-03-01 15:57:40 +01:00
parent b8f0cc25d0
commit 7401783847

View File

@@ -1384,26 +1384,20 @@ class DatabaseManager:
""" """
session = self.session session = self.session
try: try:
# Get server IP to filter it out
from config import get_config from config import get_config
config = get_config() config = get_config()
server_ip = config.get_server_ip() server_ip = config.get_server_ip()
query = session.query(IpStats.ip, IpStats.total_requests)
query = self._public_ip_filter(query, IpStats.ip, server_ip)
results = ( results = (
session.query(AccessLog.ip, func.count(AccessLog.id).label("count")) query.order_by(IpStats.total_requests.desc())
.group_by(AccessLog.ip) .limit(limit)
.order_by(func.count(AccessLog.id).desc())
.all() .all()
) )
# Filter out local/private IPs and server IP, then limit results return [(row.ip, row.total_requests) for row in results]
filtered = [
(row.ip, row.count)
for row in results
if is_valid_public_ip(row.ip, server_ip)
]
return filtered[:limit]
finally: finally:
self.close_session() self.close_session()
@@ -1470,23 +1464,18 @@ class DatabaseManager:
""" """
session = self.session session = self.session
try: try:
# Get server IP to filter it out
from config import get_config from config import get_config
config = get_config() config = get_config()
server_ip = config.get_server_ip() server_ip = config.get_server_ip()
logs = ( query = (
session.query(AccessLog) session.query(AccessLog)
.filter(AccessLog.is_suspicious == True) .filter(AccessLog.is_suspicious == True)
.order_by(AccessLog.timestamp.desc()) .order_by(AccessLog.timestamp.desc())
.all()
) )
query = self._public_ip_filter(query, AccessLog.ip, server_ip)
# Filter out local/private IPs and server IP logs = query.limit(limit).all()
filtered_logs = [
log for log in logs if is_valid_public_ip(log.ip, server_ip)
]
return [ return [
{ {
@@ -1495,7 +1484,7 @@ class DatabaseManager:
"user_agent": log.user_agent, "user_agent": log.user_agent,
"timestamp": log.timestamp.isoformat(), "timestamp": log.timestamp.isoformat(),
} }
for log in filtered_logs[:limit] for log in logs
] ]
finally: finally:
self.close_session() self.close_session()
@@ -1600,44 +1589,54 @@ class DatabaseManager:
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Get honeypot triggers grouped by IP # Count distinct paths per IP using SQL GROUP BY
results = ( count_col = func.count(distinct(AccessLog.path)).label("path_count")
session.query(AccessLog.ip, AccessLog.path) base_query = (
session.query(AccessLog.ip, count_col)
.filter(AccessLog.is_honeypot_trigger == True) .filter(AccessLog.is_honeypot_trigger == True)
.all()
) )
base_query = self._public_ip_filter(base_query, AccessLog.ip, server_ip)
base_query = base_query.group_by(AccessLog.ip)
# Group paths by IP, filtering out invalid IPs # Get total count of distinct honeypot IPs
ip_paths: Dict[str, List[str]] = {} total_honeypots = base_query.count()
for row in results:
if not is_valid_public_ip(row.ip, server_ip):
continue
if row.ip not in ip_paths:
ip_paths[row.ip] = []
if row.path not in ip_paths[row.ip]:
ip_paths[row.ip].append(row.path)
# Create list and sort
honeypot_list = [
{"ip": ip, "paths": paths, "count": len(paths)}
for ip, paths in ip_paths.items()
]
# Apply sorting
if sort_by == "count": if sort_by == "count":
honeypot_list.sort( order_expr = count_col.desc() if sort_order == "desc" else count_col.asc()
key=lambda x: x["count"], reverse=(sort_order == "desc") else:
) order_expr = AccessLog.ip.desc() if sort_order == "desc" else AccessLog.ip.asc()
else: # sort by ip
honeypot_list.sort(
key=lambda x: x["ip"], reverse=(sort_order == "desc")
)
total_honeypots = len(honeypot_list) ip_rows = base_query.order_by(order_expr).offset(offset).limit(page_size).all()
paginated = honeypot_list[offset : offset + page_size]
total_pages = (total_honeypots + page_size - 1) // page_size # Fetch distinct paths only for the paginated IPs
paginated_ips = [row.ip for row in ip_rows]
honeypot_list = []
if paginated_ips:
path_rows = (
session.query(AccessLog.ip, AccessLog.path)
.filter(
AccessLog.is_honeypot_trigger == True,
AccessLog.ip.in_(paginated_ips),
)
.distinct(AccessLog.ip, AccessLog.path)
.all()
)
ip_paths: Dict[str, List[str]] = {}
for row in path_rows:
ip_paths.setdefault(row.ip, []).append(row.path)
# Preserve the order from the sorted query
for row in ip_rows:
paths = ip_paths.get(row.ip, [])
honeypot_list.append(
{"ip": row.ip, "paths": paths, "count": row.path_count}
)
total_pages = max(1, (total_honeypots + page_size - 1) // page_size)
return { return {
"honeypots": paginated, "honeypots": honeypot_list,
"pagination": { "pagination": {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@@ -1736,6 +1735,9 @@ class DatabaseManager:
""" """
Retrieve paginated list of top IP addresses by access count. Retrieve paginated list of top IP addresses by access count.
Uses the IpStats table (which already stores total_requests per IP)
instead of doing a costly GROUP BY on the large access_logs table.
Args: Args:
page: Page number (1-indexed) page: Page number (1-indexed)
page_size: Number of results per page page_size: Number of results per page
@@ -1754,39 +1756,34 @@ class DatabaseManager:
offset = (page - 1) * page_size offset = (page - 1) * page_size
results = ( base_query = session.query(IpStats)
session.query( base_query = self._public_ip_filter(base_query, IpStats.ip, server_ip)
AccessLog.ip,
func.count(AccessLog.id).label("count"),
IpStats.category,
)
.outerjoin(IpStats, AccessLog.ip == IpStats.ip)
.group_by(AccessLog.ip, IpStats.category)
.all()
)
# Filter out local/private IPs and server IP, then sort total_ips = base_query.count()
filtered = [
{
"ip": row.ip,
"count": row.count,
"category": row.category or "unknown",
}
for row in results
if is_valid_public_ip(row.ip, server_ip)
]
if sort_by == "count": if sort_by == "count":
filtered.sort(key=lambda x: x["count"], reverse=(sort_order == "desc")) order_col = IpStats.total_requests
else: # sort by ip else:
filtered.sort(key=lambda x: x["ip"], reverse=(sort_order == "desc")) order_col = IpStats.ip
total_ips = len(filtered) if sort_order == "desc":
paginated = filtered[offset : offset + page_size] base_query = base_query.order_by(order_col.desc())
total_pages = (total_ips + page_size - 1) // page_size else:
base_query = base_query.order_by(order_col.asc())
results = base_query.offset(offset).limit(page_size).all()
total_pages = max(1, (total_ips + page_size - 1) // page_size)
return { return {
"ips": paginated, "ips": [
{
"ip": row.ip,
"count": row.total_requests,
"category": row.category or "unknown",
}
for row in results
],
"pagination": { "pagination": {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@@ -1820,28 +1817,27 @@ class DatabaseManager:
try: try:
offset = (page - 1) * page_size offset = (page - 1) * page_size
results = ( count_col = func.count(AccessLog.id).label("count")
session.query(AccessLog.path, func.count(AccessLog.id).label("count"))
# Get total number of distinct paths
total_paths = session.query(func.count(distinct(AccessLog.path))).scalar() or 0
# Build query with SQL-level sorting and pagination
query = (
session.query(AccessLog.path, count_col)
.group_by(AccessLog.path) .group_by(AccessLog.path)
.all()
) )
# Create list and sort
paths_list = [{"path": row.path, "count": row.count} for row in results]
if sort_by == "count": if sort_by == "count":
paths_list.sort( order_expr = count_col.desc() if sort_order == "desc" else count_col.asc()
key=lambda x: x["count"], reverse=(sort_order == "desc") else:
) order_expr = AccessLog.path.desc() if sort_order == "desc" else AccessLog.path.asc()
else: # sort by path
paths_list.sort(key=lambda x: x["path"], reverse=(sort_order == "desc"))
total_paths = len(paths_list) results = query.order_by(order_expr).offset(offset).limit(page_size).all()
paginated = paths_list[offset : offset + page_size] total_pages = max(1, (total_paths + page_size - 1) // page_size)
total_pages = (total_paths + page_size - 1) // page_size
return { return {
"paths": paginated, "paths": [{"path": row.path, "count": row.count} for row in results],
"pagination": { "pagination": {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@@ -1875,33 +1871,40 @@ class DatabaseManager:
try: try:
offset = (page - 1) * page_size offset = (page - 1) * page_size
results = ( count_col = func.count(AccessLog.id).label("count")
session.query(
AccessLog.user_agent, func.count(AccessLog.id).label("count") base_filter = [AccessLog.user_agent.isnot(None), AccessLog.user_agent != ""]
)
.filter(AccessLog.user_agent.isnot(None), AccessLog.user_agent != "") # Get total number of distinct user agents
.group_by(AccessLog.user_agent) total_uas = (
.all() session.query(func.count(distinct(AccessLog.user_agent)))
.filter(*base_filter)
.scalar() or 0
) )
# Create list and sort # Build query with SQL-level sorting and pagination
ua_list = [ query = (
{"user_agent": row.user_agent, "count": row.count} for row in results session.query(AccessLog.user_agent, count_col)
] .filter(*base_filter)
.group_by(AccessLog.user_agent)
)
if sort_by == "count": if sort_by == "count":
ua_list.sort(key=lambda x: x["count"], reverse=(sort_order == "desc")) order_expr = count_col.desc() if sort_order == "desc" else count_col.asc()
else: # sort by user_agent else:
ua_list.sort( order_expr = (
key=lambda x: x["user_agent"], reverse=(sort_order == "desc") AccessLog.user_agent.desc() if sort_order == "desc"
else AccessLog.user_agent.asc()
) )
total_uas = len(ua_list) results = query.order_by(order_expr).offset(offset).limit(page_size).all()
paginated = ua_list[offset : offset + page_size] total_pages = max(1, (total_uas + page_size - 1) // page_size)
total_pages = (total_uas + page_size - 1) // page_size
return { return {
"user_agents": paginated, "user_agents": [
{"user_agent": row.user_agent, "count": row.count}
for row in results
],
"pagination": { "pagination": {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,