diff --git a/owasp.py b/owasp.py index 7192186..1780b14 100644 --- a/owasp.py +++ b/owasp.py @@ -5,11 +5,17 @@ import json import base64 import hashlib import logging -import requests +import argparse from typing import List, Dict, Optional +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + +import requests +from tqdm import tqdm # Logging setup logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) # Constants GITHUB_REPO_URL = "https://api.github.com/repos/coreruleset/coreruleset" @@ -21,12 +27,21 @@ MAX_RETRIES = 6 # Maximum number of retries EXPONENTIAL_BACKOFF = True # Use exponential backoff for retries BACKOFF_MULTIPLIER = 2 # Multiplier for exponential backoff GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") # GitHub token for authentication +CONNECTION_POOL_SIZE = 20 # Increased connection pool size class GitHubRequestError(Exception): """Raised when fetching data from GitHub fails after all retries.""" +class GitHubRateLimitError(GitHubRequestError): + """Raised when GitHub API rate limit is exceeded.""" + + +class GitHubBlobFetchError(GitHubRequestError): + """Raised when fetching a blob from GitHub fails.""" + + def get_session() -> requests.Session: """ Creates and returns a requests.Session with optional GitHub token authentication. @@ -34,6 +49,9 @@ def get_session() -> requests.Session: session = requests.Session() if GITHUB_TOKEN: session.headers.update({"Authorization": f"token {GITHUB_TOKEN}"}) + # Increase connection pool size + adapter = requests.adapters.HTTPAdapter(pool_connections=CONNECTION_POOL_SIZE, pool_maxsize=CONNECTION_POOL_SIZE) + session.mount("https://", adapter) return session @@ -49,24 +67,20 @@ def fetch_with_retries(session: requests.Session, url: str) -> requests.Response if response.status_code == 403 and "X-RateLimit-Remaining" in response.headers: reset_time = int(response.headers.get("X-RateLimit-Reset", 0)) wait_time = max(reset_time - int(time.time()), RATE_LIMIT_DELAY) - logging.warning(f"Rate limit exceeded. Retrying in {wait_time} seconds...") + logger.warning(f"Rate limit exceeded. Retrying in {wait_time} seconds...") time.sleep(wait_time) continue - try: - response.raise_for_status() - return response - except requests.HTTPError: - # Handle non-200 codes that are not rate-limit related - pass - - # Retry logic for other errors + response.raise_for_status() + return response + except requests.HTTPError as e: + logger.warning(f"HTTP error fetching {url}: {e}") wait_time = (RETRY_DELAY * (BACKOFF_MULTIPLIER ** retries) if EXPONENTIAL_BACKOFF else RETRY_DELAY) - logging.warning(f"Retrying {url}... ({retries + 1}/{MAX_RETRIES}) in {wait_time} seconds.") + logger.warning(f"Retrying {url}... ({retries + 1}/{MAX_RETRIES}) in {wait_time} seconds.") time.sleep(wait_time) retries += 1 except requests.RequestException as e: - logging.error(f"Error fetching {url}: {e}") + logger.error(f"Error fetching {url}: {e}") retries += 1 raise GitHubRequestError(f"Failed to fetch {url} after {MAX_RETRIES} retries.") @@ -82,18 +96,18 @@ def fetch_latest_tag(session: requests.Session, ref_prefix: str) -> Optional[str response = fetch_with_retries(session, ref_url) tags = response.json() if not tags: - logging.warning("No tags found in the repository.") + logger.warning("No tags found in the repository.") return None matching = [r["ref"] for r in tags if r["ref"].startswith(f"refs/tags/{ref_prefix}.")] matching.sort(reverse=True, key=lambda x: x.split(".")[-1]) if matching: latest_tag = matching[0] - logging.info(f"Latest matching tag: {latest_tag}") + logger.info(f"Latest matching tag: {latest_tag}") return latest_tag - logging.warning(f"No matching refs found for prefix {ref_prefix}. Falling back to the latest tag.") + logger.warning(f"No matching refs found for prefix {ref_prefix}. Falling back to the latest tag.") return tags[-1]["ref"] except Exception as e: - logging.error(f"Failed to fetch tags. Reason: {e}") + logger.error(f"Failed to fetch tags. Reason: {e}") return None @@ -109,7 +123,7 @@ def fetch_rule_files(session: requests.Session, ref: str) -> List[Dict[str, str] files = response.json() return [{"name": f["name"], "sha": f["sha"]} for f in files if f["name"].endswith(".conf")] except (GitHubRequestError, requests.RequestException) as e: - logging.error(f"Failed to fetch rule files from {rules_url}. Reason: {e}") + logger.error(f"Failed to fetch rule files from {rules_url}. Reason: {e}") return [] @@ -123,7 +137,7 @@ def fetch_github_blob(session: requests.Session, sha: str) -> str: response = fetch_with_retries(session, blob_url) return response.json().get("content", "") except (GitHubRequestError, requests.RequestException) as e: - logging.error(f"Failed to fetch blob for SHA {sha}. Reason: {e}") + logger.error(f"Failed to fetch blob for SHA {sha}. Reason: {e}") return "" @@ -133,14 +147,11 @@ def verify_blob_sha(file_sha: str, blob_content_b64: str) -> bool: Logs a warning if the verification fails but does not block execution. """ decoded_bytes = base64.b64decode(blob_content_b64) - # Option 1: Verify Git’s actual blob SHA (header + content) blob_header = f"blob {len(decoded_bytes)}\0".encode("utf-8") calculated_sha = hashlib.sha1(blob_header + decoded_bytes).hexdigest() if calculated_sha != file_sha: - logging.warning( - f"SHA mismatch for file. Expected: {file_sha}, Calculated: {calculated_sha}" - ) + logger.warning(f"SHA mismatch for file. Expected: {file_sha}, Calculated: {calculated_sha}") return False return True @@ -151,25 +162,32 @@ def fetch_owasp_rules(session: requests.Session, rule_files: List[Dict[str, str] and returns a list of dicts with category and pattern. """ rules = [] - for file in rule_files: - logging.info(f"Fetching {file['name']}...") - blob_b64 = fetch_github_blob(session, file["sha"]) - if not blob_b64: - logging.warning(f"Skipping file {file['name']} due to empty blob content.") - continue + with ThreadPoolExecutor(max_workers=CONNECTION_POOL_SIZE) as executor: + futures = { + executor.submit(fetch_github_blob, session, file["sha"]): file for file in rule_files + } + for future in tqdm(as_completed(futures), total=len(rule_files), desc="Fetching rule files"): + file = futures[future] + try: + blob_b64 = future.result() + if not blob_b64: + logger.warning(f"Skipping file {file['name']} due to empty blob content.") + continue - # Verify SHA (non-blocking) - verify_blob_sha(file["sha"], blob_b64) + # Verify SHA (non-blocking) + verify_blob_sha(file["sha"], blob_b64) - raw_text = base64.b64decode(blob_b64).decode("utf-8") - sec_rules = re.findall(r'SecRule\s+.*?"((?:[^"\\]|\\.)+?)"', raw_text, re.DOTALL) - category = file["name"].split("-")[-1].replace(".conf", "") - for rule in sec_rules: - pattern = rule.strip().replace("\\", "") - if pattern: - rules.append({"category": category, "pattern": pattern}) + raw_text = base64.b64decode(blob_b64).decode("utf-8") + sec_rules = re.findall(r'SecRule\s+.*?"((?:[^"\\]|\\.)+?)"', raw_text, re.DOTALL) + category = file["name"].split("-")[-1].replace(".conf", "") + for rule in sec_rules: + pattern = rule.strip().replace("\\", "") + if pattern: + rules.append({"category": category, "pattern": pattern}) + except Exception as e: + logger.error(f"Failed to process file {file['name']}. Reason: {e}") - logging.info(f"Fetched {len(rules)} rules.") + logger.info(f"Fetched {len(rules)} rules.") return rules @@ -178,30 +196,47 @@ def save_as_json(rules: List[Dict[str, str]], output_file: str) -> bool: Saves the given list of rules to a JSON file. Returns True if successful, False otherwise. """ try: - output_dir = os.path.dirname(output_file) + output_dir = Path(output_file).parent if output_dir: - os.makedirs(output_dir, exist_ok=True) - with open(output_file, "w", encoding="utf-8") as f: + output_dir.mkdir(parents=True, exist_ok=True) + # Atomic write using a temporary file + temp_file = f"{output_file}.tmp" + with open(temp_file, "w", encoding="utf-8") as f: json.dump(rules, f, indent=4) - logging.info(f"Rules saved to {output_file}.") + # Rename temp file to the final output file + os.replace(temp_file, output_file) + logger.info(f"Rules saved to {output_file}.") return True except IOError as e: - logging.error(f"Failed to save rules to {output_file}. Reason: {e}") + logger.error(f"Failed to save rules to {output_file}. Reason: {e}") return False -if __name__ == "__main__": +def main(): + """Main function to fetch and save OWASP rules.""" + parser = argparse.ArgumentParser(description="Fetch OWASP Core Rule Set rules from GitHub.") + parser.add_argument("--output", type=str, default="owasp_rules.json", help="Output JSON file path.") + parser.add_argument("--ref", type=str, default=GITHUB_REF, help="Git reference (e.g., tag or branch).") + parser.add_argument("--dry-run", action="store_true", help="Simulate fetching without saving.") + args = parser.parse_args() + session = get_session() - latest_ref = fetch_latest_tag(session, GITHUB_REF) + latest_ref = fetch_latest_tag(session, args.ref) if latest_ref: rule_files = fetch_rule_files(session, latest_ref) if rule_files: rules = fetch_owasp_rules(session, rule_files) - if rules and save_as_json(rules, "owasp_rules.json"): - logging.info("All rules fetched and saved successfully.") + if args.dry_run: + logger.info("Dry-run mode enabled. Skipping file save.") + elif rules and save_as_json(rules, args.output): + logger.info("All rules fetched and saved successfully.") else: - logging.error("Failed to fetch or save rules.") + logger.error("Failed to fetch or save rules.") else: - logging.error("Failed to fetch rule files.") + logger.error("Failed to fetch rule files.") else: - logging.error("Failed to fetch tags.") + logger.error("Failed to fetch tags.") + + +if __name__ == "__main__": + main()