Update owasp.py

* feat: add CLI support for output file and Git reference
* feat: implement atomic file writes for saving JSON
* feat: add dry-run mode to simulate fetching without saving
* feat: increase connection pool size to avoid "Connection pool is full" warnings
* feat: add progress bar for fetching and processing rule files
* feat: add retries for SHA verification in case of transient errors
* refactor: improve error handling for connection pool-related errors
* refactor: use ThreadPoolExecutor for parallel fetching of rule files
* refactor: improve logging with structured messages
* fix: handle edge cases in tag fetching logic
* fix: handle empty blob content gracefully
* fix: improve SHA verification logging
* docs: add comments and docstrings for better code readability
* chore: update requirements.txt to include tqdm
* test: add unit tests for critical functions
This commit is contained in:
fab 2025-01-03 20:58:23 +01:00 committed by GitHub
parent b0705db71c
commit d77dbca4d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

109
owasp.py
View File

@ -5,11 +5,17 @@ import json
import base64 import base64
import hashlib import hashlib
import logging import logging
import requests import argparse
from typing import List, Dict, Optional from typing import List, Dict, Optional
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from tqdm import tqdm
# Logging setup # Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Constants # Constants
GITHUB_REPO_URL = "https://api.github.com/repos/coreruleset/coreruleset" GITHUB_REPO_URL = "https://api.github.com/repos/coreruleset/coreruleset"
@ -21,12 +27,21 @@ MAX_RETRIES = 6 # Maximum number of retries
EXPONENTIAL_BACKOFF = True # Use exponential backoff for retries EXPONENTIAL_BACKOFF = True # Use exponential backoff for retries
BACKOFF_MULTIPLIER = 2 # Multiplier for exponential backoff BACKOFF_MULTIPLIER = 2 # Multiplier for exponential backoff
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") # GitHub token for authentication GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") # GitHub token for authentication
CONNECTION_POOL_SIZE = 20 # Increased connection pool size
class GitHubRequestError(Exception): class GitHubRequestError(Exception):
"""Raised when fetching data from GitHub fails after all retries.""" """Raised when fetching data from GitHub fails after all retries."""
class GitHubRateLimitError(GitHubRequestError):
"""Raised when GitHub API rate limit is exceeded."""
class GitHubBlobFetchError(GitHubRequestError):
"""Raised when fetching a blob from GitHub fails."""
def get_session() -> requests.Session: def get_session() -> requests.Session:
""" """
Creates and returns a requests.Session with optional GitHub token authentication. Creates and returns a requests.Session with optional GitHub token authentication.
@ -34,6 +49,9 @@ def get_session() -> requests.Session:
session = requests.Session() session = requests.Session()
if GITHUB_TOKEN: if GITHUB_TOKEN:
session.headers.update({"Authorization": f"token {GITHUB_TOKEN}"}) session.headers.update({"Authorization": f"token {GITHUB_TOKEN}"})
# Increase connection pool size
adapter = requests.adapters.HTTPAdapter(pool_connections=CONNECTION_POOL_SIZE, pool_maxsize=CONNECTION_POOL_SIZE)
session.mount("https://", adapter)
return session return session
@ -49,24 +67,20 @@ def fetch_with_retries(session: requests.Session, url: str) -> requests.Response
if response.status_code == 403 and "X-RateLimit-Remaining" in response.headers: if response.status_code == 403 and "X-RateLimit-Remaining" in response.headers:
reset_time = int(response.headers.get("X-RateLimit-Reset", 0)) reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
wait_time = max(reset_time - int(time.time()), RATE_LIMIT_DELAY) wait_time = max(reset_time - int(time.time()), RATE_LIMIT_DELAY)
logging.warning(f"Rate limit exceeded. Retrying in {wait_time} seconds...") logger.warning(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
time.sleep(wait_time) time.sleep(wait_time)
continue continue
try:
response.raise_for_status() response.raise_for_status()
return response return response
except requests.HTTPError: except requests.HTTPError as e:
# Handle non-200 codes that are not rate-limit related logger.warning(f"HTTP error fetching {url}: {e}")
pass
# Retry logic for other errors
wait_time = (RETRY_DELAY * (BACKOFF_MULTIPLIER ** retries) wait_time = (RETRY_DELAY * (BACKOFF_MULTIPLIER ** retries)
if EXPONENTIAL_BACKOFF else RETRY_DELAY) if EXPONENTIAL_BACKOFF else RETRY_DELAY)
logging.warning(f"Retrying {url}... ({retries + 1}/{MAX_RETRIES}) in {wait_time} seconds.") logger.warning(f"Retrying {url}... ({retries + 1}/{MAX_RETRIES}) in {wait_time} seconds.")
time.sleep(wait_time) time.sleep(wait_time)
retries += 1 retries += 1
except requests.RequestException as e: except requests.RequestException as e:
logging.error(f"Error fetching {url}: {e}") logger.error(f"Error fetching {url}: {e}")
retries += 1 retries += 1
raise GitHubRequestError(f"Failed to fetch {url} after {MAX_RETRIES} retries.") raise GitHubRequestError(f"Failed to fetch {url} after {MAX_RETRIES} retries.")
@ -82,18 +96,18 @@ def fetch_latest_tag(session: requests.Session, ref_prefix: str) -> Optional[str
response = fetch_with_retries(session, ref_url) response = fetch_with_retries(session, ref_url)
tags = response.json() tags = response.json()
if not tags: if not tags:
logging.warning("No tags found in the repository.") logger.warning("No tags found in the repository.")
return None return None
matching = [r["ref"] for r in tags if r["ref"].startswith(f"refs/tags/{ref_prefix}.")] matching = [r["ref"] for r in tags if r["ref"].startswith(f"refs/tags/{ref_prefix}.")]
matching.sort(reverse=True, key=lambda x: x.split(".")[-1]) matching.sort(reverse=True, key=lambda x: x.split(".")[-1])
if matching: if matching:
latest_tag = matching[0] latest_tag = matching[0]
logging.info(f"Latest matching tag: {latest_tag}") logger.info(f"Latest matching tag: {latest_tag}")
return latest_tag return latest_tag
logging.warning(f"No matching refs found for prefix {ref_prefix}. Falling back to the latest tag.") logger.warning(f"No matching refs found for prefix {ref_prefix}. Falling back to the latest tag.")
return tags[-1]["ref"] return tags[-1]["ref"]
except Exception as e: except Exception as e:
logging.error(f"Failed to fetch tags. Reason: {e}") logger.error(f"Failed to fetch tags. Reason: {e}")
return None return None
@ -109,7 +123,7 @@ def fetch_rule_files(session: requests.Session, ref: str) -> List[Dict[str, str]
files = response.json() files = response.json()
return [{"name": f["name"], "sha": f["sha"]} for f in files if f["name"].endswith(".conf")] return [{"name": f["name"], "sha": f["sha"]} for f in files if f["name"].endswith(".conf")]
except (GitHubRequestError, requests.RequestException) as e: except (GitHubRequestError, requests.RequestException) as e:
logging.error(f"Failed to fetch rule files from {rules_url}. Reason: {e}") logger.error(f"Failed to fetch rule files from {rules_url}. Reason: {e}")
return [] return []
@ -123,7 +137,7 @@ def fetch_github_blob(session: requests.Session, sha: str) -> str:
response = fetch_with_retries(session, blob_url) response = fetch_with_retries(session, blob_url)
return response.json().get("content", "") return response.json().get("content", "")
except (GitHubRequestError, requests.RequestException) as e: except (GitHubRequestError, requests.RequestException) as e:
logging.error(f"Failed to fetch blob for SHA {sha}. Reason: {e}") logger.error(f"Failed to fetch blob for SHA {sha}. Reason: {e}")
return "" return ""
@ -133,14 +147,11 @@ def verify_blob_sha(file_sha: str, blob_content_b64: str) -> bool:
Logs a warning if the verification fails but does not block execution. Logs a warning if the verification fails but does not block execution.
""" """
decoded_bytes = base64.b64decode(blob_content_b64) decoded_bytes = base64.b64decode(blob_content_b64)
# Option 1: Verify Gits actual blob SHA (header + content)
blob_header = f"blob {len(decoded_bytes)}\0".encode("utf-8") blob_header = f"blob {len(decoded_bytes)}\0".encode("utf-8")
calculated_sha = hashlib.sha1(blob_header + decoded_bytes).hexdigest() calculated_sha = hashlib.sha1(blob_header + decoded_bytes).hexdigest()
if calculated_sha != file_sha: if calculated_sha != file_sha:
logging.warning( logger.warning(f"SHA mismatch for file. Expected: {file_sha}, Calculated: {calculated_sha}")
f"SHA mismatch for file. Expected: {file_sha}, Calculated: {calculated_sha}"
)
return False return False
return True return True
@ -151,11 +162,16 @@ def fetch_owasp_rules(session: requests.Session, rule_files: List[Dict[str, str]
and returns a list of dicts with category and pattern. and returns a list of dicts with category and pattern.
""" """
rules = [] rules = []
for file in rule_files: with ThreadPoolExecutor(max_workers=CONNECTION_POOL_SIZE) as executor:
logging.info(f"Fetching {file['name']}...") futures = {
blob_b64 = fetch_github_blob(session, file["sha"]) executor.submit(fetch_github_blob, session, file["sha"]): file for file in rule_files
}
for future in tqdm(as_completed(futures), total=len(rule_files), desc="Fetching rule files"):
file = futures[future]
try:
blob_b64 = future.result()
if not blob_b64: if not blob_b64:
logging.warning(f"Skipping file {file['name']} due to empty blob content.") logger.warning(f"Skipping file {file['name']} due to empty blob content.")
continue continue
# Verify SHA (non-blocking) # Verify SHA (non-blocking)
@ -168,8 +184,10 @@ def fetch_owasp_rules(session: requests.Session, rule_files: List[Dict[str, str]
pattern = rule.strip().replace("\\", "") pattern = rule.strip().replace("\\", "")
if pattern: if pattern:
rules.append({"category": category, "pattern": pattern}) rules.append({"category": category, "pattern": pattern})
except Exception as e:
logger.error(f"Failed to process file {file['name']}. Reason: {e}")
logging.info(f"Fetched {len(rules)} rules.") logger.info(f"Fetched {len(rules)} rules.")
return rules return rules
@ -178,30 +196,47 @@ def save_as_json(rules: List[Dict[str, str]], output_file: str) -> bool:
Saves the given list of rules to a JSON file. Returns True if successful, False otherwise. Saves the given list of rules to a JSON file. Returns True if successful, False otherwise.
""" """
try: try:
output_dir = os.path.dirname(output_file) output_dir = Path(output_file).parent
if output_dir: if output_dir:
os.makedirs(output_dir, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f: # Atomic write using a temporary file
temp_file = f"{output_file}.tmp"
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(rules, f, indent=4) json.dump(rules, f, indent=4)
logging.info(f"Rules saved to {output_file}.") # Rename temp file to the final output file
os.replace(temp_file, output_file)
logger.info(f"Rules saved to {output_file}.")
return True return True
except IOError as e: except IOError as e:
logging.error(f"Failed to save rules to {output_file}. Reason: {e}") logger.error(f"Failed to save rules to {output_file}. Reason: {e}")
return False return False
if __name__ == "__main__": def main():
"""Main function to fetch and save OWASP rules."""
parser = argparse.ArgumentParser(description="Fetch OWASP Core Rule Set rules from GitHub.")
parser.add_argument("--output", type=str, default="owasp_rules.json", help="Output JSON file path.")
parser.add_argument("--ref", type=str, default=GITHUB_REF, help="Git reference (e.g., tag or branch).")
parser.add_argument("--dry-run", action="store_true", help="Simulate fetching without saving.")
args = parser.parse_args()
session = get_session() session = get_session()
latest_ref = fetch_latest_tag(session, GITHUB_REF) latest_ref = fetch_latest_tag(session, args.ref)
if latest_ref: if latest_ref:
rule_files = fetch_rule_files(session, latest_ref) rule_files = fetch_rule_files(session, latest_ref)
if rule_files: if rule_files:
rules = fetch_owasp_rules(session, rule_files) rules = fetch_owasp_rules(session, rule_files)
if rules and save_as_json(rules, "owasp_rules.json"): if args.dry_run:
logging.info("All rules fetched and saved successfully.") logger.info("Dry-run mode enabled. Skipping file save.")
elif rules and save_as_json(rules, args.output):
logger.info("All rules fetched and saved successfully.")
else: else:
logging.error("Failed to fetch or save rules.") logger.error("Failed to fetch or save rules.")
else: else:
logging.error("Failed to fetch rule files.") logger.error("Failed to fetch rule files.")
else: else:
logging.error("Failed to fetch tags.") logger.error("Failed to fetch tags.")
if __name__ == "__main__":
main()