diff --git a/src/config.py b/src/config.py index 71cef0e..97b4d70 100644 --- a/src/config.py +++ b/src/config.py @@ -43,6 +43,7 @@ class Config: # Database settings database_path: str = "data/krawl.db" database_retention_days: int = 30 + exports_path: str = "data/exports" # Analyzer settings http_risky_methods_threshold: float = None @@ -153,6 +154,7 @@ class Config: canary = data.get("canary", {}) dashboard = data.get("dashboard", {}) api = data.get("api", {}) + exports = data.get("exports", {}) database = data.get("database", {}) behavior = data.get("behavior", {}) analyzer = data.get("analyzer") or {} @@ -191,6 +193,7 @@ class Config: api_server_port=api.get("server_port", 8080), api_server_path=api.get("server_path", "/api/v2/users"), probability_error_codes=behavior.get("probability_error_codes", 0), + exports_path = exports.get("path"), database_path=database.get("path", "data/krawl.db"), database_retention_days=database.get("retention_days", 30), http_risky_methods_threshold=analyzer.get( diff --git a/src/handler.py b/src/handler.py index b3c76e7..ab1f715 100644 --- a/src/handler.py +++ b/src/handler.py @@ -7,8 +7,11 @@ from datetime import datetime from http.server import BaseHTTPRequestHandler from typing import Optional, List from urllib.parse import urlparse, parse_qs +import json +import os -from config import Config +from database import get_database +from config import Config,get_config from tracker import AccessTracker from analyzer import Analyzer from templates import html_templates @@ -26,6 +29,9 @@ from wordlists import get_wordlists from sql_errors import generate_sql_error_response, get_sql_response_with_data from xss_detector import detect_xss_pattern, generate_xss_response from server_errors import generate_server_error +from models import AccessLog +from ip_utils import is_valid_public_ip +from sqlalchemy import distinct class Handler(BaseHTTPRequestHandler): @@ -58,10 +64,6 @@ class Handler(BaseHTTPRequestHandler): # Fallback to direct connection IP return self.client_address[0] - def _get_user_agent(self) -> str: - """Extract user agent from request""" - return self.headers.get("User-Agent", "") - def _get_category_by_ip(self, client_ip: str) -> str: """Get the category of an IP from the database""" return self.tracker.get_category_by_ip(client_ip) @@ -92,10 +94,6 @@ class Handler(BaseHTTPRequestHandler): error_codes = [400, 401, 403, 404, 500, 502, 503] return random.choice(error_codes) - def _parse_query_string(self) -> str: - """Extract query string from the request path""" - parsed = urlparse(self.path) - return parsed.query def _handle_sql_endpoint(self, path: str) -> bool: """ @@ -111,21 +109,20 @@ class Handler(BaseHTTPRequestHandler): try: # Get query parameters - query_string = self._parse_query_string() # Log SQL injection attempt client_ip = self._get_client_ip() - user_agent = self._get_user_agent() + user_agent = self.headers.get("User-Agent", "") # Always check for SQL injection patterns error_msg, content_type, status_code = generate_sql_error_response( - query_string or "" + request_query or "" ) if error_msg: # SQL injection detected - log and return error self.access_logger.warning( - f"[SQL INJECTION DETECTED] {client_ip} - {base_path} - Query: {query_string[:100] if query_string else 'empty'}" + f"[SQL INJECTION DETECTED] {client_ip} - {base_path} - Query: {request_query[:100] if request_query else 'empty'}" ) self.send_response(status_code) self.send_header("Content-type", content_type) @@ -134,13 +131,13 @@ class Handler(BaseHTTPRequestHandler): else: # No injection detected - return fake data self.access_logger.info( - f"[SQL ENDPOINT] {client_ip} - {base_path} - Query: {query_string[:100] if query_string else 'empty'}" + f"[SQL ENDPOINT] {client_ip} - {base_path} - Query: {request_query[:100] if request_query else 'empty'}" ) self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() response_data = get_sql_response_with_data( - base_path, query_string or "" + base_path, request_query or "" ) self.wfile.write(response_data.encode()) @@ -239,10 +236,9 @@ class Handler(BaseHTTPRequestHandler): def do_POST(self): """Handle POST requests (mainly login attempts)""" client_ip = self._get_client_ip() - user_agent = self._get_user_agent() + user_agent = self.headers.get("User-Agent", "") post_data = "" - from urllib.parse import urlparse base_path = urlparse(self.path).path @@ -293,7 +289,6 @@ class Handler(BaseHTTPRequestHandler): for pair in post_data.split("&"): if "=" in pair: key, value = pair.split("=", 1) - from urllib.parse import unquote_plus parsed_data[unquote_plus(key)] = unquote_plus(value) @@ -486,12 +481,25 @@ class Handler(BaseHTTPRequestHandler): def do_GET(self): """Responds to webpage requests""" + client_ip = self._get_client_ip() + + # respond with HTTP error code if client is banned if self.tracker.is_banned_ip(client_ip): self.send_response(500) self.end_headers() return - user_agent = self._get_user_agent() + + # get request data + user_agent = self.headers.get("User-Agent", "") + request_path = urlparse(self.path).path + self.app_logger.info(f"request_query: {request_path}") + query_params = parse_qs(urlparse(self.path).query) + self.app_logger.info(f"query_params: {query_params}") + + # get database reference + db = get_database() + session = db.session if ( self.config.dashboard_secret_path @@ -502,8 +510,7 @@ class Handler(BaseHTTPRequestHandler): self.end_headers() try: stats = self.tracker.get_stats() - dashboard_path = self.config.dashboard_secret_path - self.wfile.write(generate_dashboard(stats, dashboard_path).encode()) + self.wfile.write(generate_dashboard(stats, self.config.dashboard_secret_path).encode()) except BrokenPipeError: pass except Exception as e: @@ -525,10 +532,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - db = get_database() ip_stats_list = db.get_ip_stats(limit=500) self.wfile.write(json.dumps({"ips": ip_stats_list}).encode()) except BrokenPipeError: @@ -552,15 +556,8 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() - # Parse query parameters - parsed_url = urlparse(self.path) - query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) page_size = int(query_params.get("page_size", ["25"])[0]) sort_by = query_params.get("sort_by", ["total_requests"])[0] @@ -598,11 +595,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() # Parse query parameters parsed_url = urlparse(self.path) @@ -648,10 +641,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - db = get_database() ip_stats = db.get_ip_stats_by_ip(ip_address) if ip_stats: self.wfile.write(json.dumps(ip_stats).encode()) @@ -678,11 +668,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -721,11 +707,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -764,11 +746,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -782,7 +760,7 @@ class Handler(BaseHTTPRequestHandler): result = db.get_top_ips_paginated( page=page, page_size=page_size, - sort_by=sort_by, +pathsort_by=sort_by, sort_order=sort_order, ) self.wfile.write(json.dumps(result).encode()) @@ -807,11 +785,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -850,11 +824,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -893,11 +863,7 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Expires", "0") self.end_headers() try: - from database import get_database - import json - from urllib.parse import urlparse, parse_qs - db = get_database() parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) page = int(query_params.get("page", ["1"])[0]) @@ -922,13 +888,35 @@ class Handler(BaseHTTPRequestHandler): self.wfile.write(json.dumps({"error": str(e)}).encode()) return + # API endpoint for downloading malicious IPs blocklist file + if ( + self.config.dashboard_secret_path and + request_path == f"{self.config.dashboard_secret_path}/api/get_banlist" + ): + + + fwtype = query_params.get("fwtype",["iptables"])[0] + # Query distinct suspicious IPs + results = ( + session.query(distinct(AccessLog.ip)) + .filter(AccessLog.is_suspicious == True) + .all() + ) + + # Filter out local/private IPs and the server's own IP + config = get_config() + server_ip = config.get_server_ip() + + public_ips = [ip for (ip,) in results if is_valid_public_ip(ip, server_ip)] + self.wfile.write(f"asdasdd {fwtype} {public_ips}".encode()) + return + # API endpoint for downloading malicious IPs file if ( self.config.dashboard_secret_path and self.path == f"{self.config.dashboard_secret_path}/api/download/malicious_ips.txt" ): - import os file_path = os.path.join( os.path.dirname(__file__), "exports", "malicious_ips.txt"