mirror of
https://github.com/Rarebuffalo/securelens-backend.git
synced 2026-06-19 07:00:30 +00:00
updated the architecture
This commit is contained in:
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
.env
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
venv
|
||||||
|
.venv
|
||||||
|
*.egg-info
|
||||||
|
.pytest_cache
|
||||||
27
.env.example
Normal file
27
.env.example
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# SecureLens AI Configuration
|
||||||
|
|
||||||
|
# Application
|
||||||
|
APP_NAME=SecureLens AI
|
||||||
|
APP_VERSION=1.0.0
|
||||||
|
DEBUG=true
|
||||||
|
|
||||||
|
# Server
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# CORS - comma-separated list of allowed origins
|
||||||
|
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
|
||||||
|
|
||||||
|
# Rate Limiting
|
||||||
|
RATE_LIMIT=30/minute
|
||||||
|
|
||||||
|
# Scanner
|
||||||
|
SCAN_TIMEOUT=5
|
||||||
|
PATH_CHECK_TIMEOUT=3
|
||||||
|
|
||||||
|
# Database configuration
|
||||||
|
DATABASE_URL=postgresql+asyncpg://securelens:securelens@localhost:5433/securelens
|
||||||
|
|
||||||
|
# AI Integration
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.eggs/
|
||||||
|
venv/
|
||||||
|
.venv/
|
||||||
|
.env
|
||||||
|
*.log
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
*.db
|
||||||
15
Dockerfile
Normal file
15
Dockerfile
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
FROM python:3.12-slim AS base
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
149
alembic.ini
Normal file
149
alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/migrations
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||||
|
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the tzdata library which can be installed by adding
|
||||||
|
# `alembic[tz]` to the pip requirements.
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = module
|
||||||
|
# ruff.module = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
34
app/config.py
Normal file
34
app/config.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
app_name: str = "SecureLens AI"
|
||||||
|
app_version: str = "1.0.0"
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8000
|
||||||
|
|
||||||
|
cors_origins: str = "http://localhost:3000,http://localhost:5173"
|
||||||
|
|
||||||
|
rate_limit: str = "30/minute"
|
||||||
|
|
||||||
|
scan_timeout: int = 5
|
||||||
|
path_check_timeout: int = 3
|
||||||
|
|
||||||
|
database_url: str = "postgresql+asyncpg://securelens:securelens@localhost:5433/securelens"
|
||||||
|
|
||||||
|
jwt_secret: str = "change-me-in-production-use-a-long-random-string"
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
jwt_expiry_minutes: int = 1440
|
||||||
|
|
||||||
|
openai_api_key: str | None = None
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origin_list(self) -> list[str]:
|
||||||
|
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
35
app/database.py
Normal file
35
app/database.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(settings.database_url, echo=settings.debug)
|
||||||
|
|
||||||
|
AsyncSessionLocal = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db():
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_db():
|
||||||
|
await engine.dispose()
|
||||||
65
app/main.py
Normal file
65
app/main.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import close_db, init_db
|
||||||
|
from app.middleware.rate_limiter import limiter
|
||||||
|
from app.routers import auth, health, history, scan, apikey, report
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if settings.debug else logging.INFO,
|
||||||
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
import app.models # noqa: F401 — register models with Base.metadata
|
||||||
|
await init_db()
|
||||||
|
logger.info("Database initialized")
|
||||||
|
yield
|
||||||
|
await close_db()
|
||||||
|
logger.info("Database connection closed")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
application = FastAPI(
|
||||||
|
title=settings.app_name,
|
||||||
|
version=settings.app_version,
|
||||||
|
docs_url="/docs" if settings.debug else None,
|
||||||
|
redoc_url="/redoc" if settings.debug else None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
application.state.limiter = limiter
|
||||||
|
application.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
application.add_middleware(SlowAPIMiddleware)
|
||||||
|
|
||||||
|
application.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origin_list,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
application.include_router(health.router)
|
||||||
|
application.include_router(auth.router)
|
||||||
|
application.include_router(scan.router)
|
||||||
|
application.include_router(history.router)
|
||||||
|
application.include_router(apikey.router)
|
||||||
|
application.include_router(report.router)
|
||||||
|
|
||||||
|
logger.info(f"{settings.app_name} v{settings.app_version} initialized")
|
||||||
|
|
||||||
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
0
app/middleware/__init__.py
Normal file
0
app/middleware/__init__.py
Normal file
71
app/middleware/auth.py
Normal file
71
app/middleware/auth.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.apikey import ApiKey
|
||||||
|
from app.utils.auth import decode_access_token
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login", auto_error=False)
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str | None = Depends(oauth2_scheme),
|
||||||
|
api_key: str | None = Depends(api_key_header),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> User:
|
||||||
|
if token:
|
||||||
|
user_id = decode_access_token(token)
|
||||||
|
if user_id:
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
hashed_key = hashlib.sha256(api_key.encode()).hexdigest()
|
||||||
|
result = await db.execute(
|
||||||
|
select(User)
|
||||||
|
.join(ApiKey, User.id == ApiKey.user_id)
|
||||||
|
.where(ApiKey.hashed_key == hashed_key)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user(
|
||||||
|
token: str | None = Depends(oauth2_scheme),
|
||||||
|
api_key: str | None = Depends(api_key_header),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> User | None:
|
||||||
|
if token:
|
||||||
|
user_id = decode_access_token(token)
|
||||||
|
if user_id:
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
hashed_key = hashlib.sha256(api_key.encode()).hexdigest()
|
||||||
|
result = await db.execute(
|
||||||
|
select(User)
|
||||||
|
.join(ApiKey, User.id == ApiKey.user_id)
|
||||||
|
.where(ApiKey.hashed_key == hashed_key)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
return None
|
||||||
6
app/middleware/rate_limiter.py
Normal file
6
app/middleware/rate_limiter.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_remote_address, default_limits=[settings.rate_limit])
|
||||||
4
app/models/__init__.py
Normal file
4
app/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.models.user import User
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
|
||||||
|
__all__ = ["User", "ScanResult"]
|
||||||
26
app/models/apikey.py
Normal file
26
app/models/apikey.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import String, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKey(Base):
|
||||||
|
__tablename__ = "api_keys"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("users.id"), index=True
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
key_prefix: Mapped[str] = mapped_column(String(10))
|
||||||
|
hashed_key: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
user = relationship("User", back_populates="api_keys")
|
||||||
27
app/models/scan.py
Normal file
27
app/models/scan.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Integer, JSON, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ScanResult(Base):
|
||||||
|
__tablename__ = "scan_results"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("users.id"), index=True
|
||||||
|
)
|
||||||
|
url: Mapped[str] = mapped_column(String(2048))
|
||||||
|
security_score: Mapped[int] = mapped_column(Integer)
|
||||||
|
layers: Mapped[dict] = mapped_column(JSON)
|
||||||
|
issues: Mapped[list] = mapped_column(JSON)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
user = relationship("User", back_populates="scans")
|
||||||
25
app/models/user.py
Normal file
25
app/models/user.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import String, DateTime
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||||
|
username: Mapped[str] = mapped_column(String(100), unique=True, index=True)
|
||||||
|
hashed_password: Mapped[str] = mapped_column(String(255))
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
scans = relationship("ScanResult", back_populates="user", lazy="selectin")
|
||||||
|
api_keys = relationship("ApiKey", back_populates="user", lazy="selectin", cascade="all, delete")
|
||||||
|
webhooks = relationship("Webhook", back_populates="user", lazy="selectin", cascade="all, delete")
|
||||||
21
app/models/webhook.py
Normal file
21
app/models/webhook.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Webhook(Base):
|
||||||
|
__tablename__ = "webhooks"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
user_id = Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
target_url = Column(String, nullable=False)
|
||||||
|
secret_key = Column(String, nullable=True) # Used for HMAC signing
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
user = relationship("User", back_populates="webhooks")
|
||||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
76
app/routers/apikey.py
Normal file
76
app/routers/apikey.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_current_user
|
||||||
|
from app.models.apikey import ApiKey
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.apikey import ApiKeyCreate, ApiKeyCreateResponse, ApiKeyResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api-keys", tags=["apikeys"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ApiKeyCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_api_key(
|
||||||
|
data: ApiKeyCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
raw_key = f"sl_{secrets.token_urlsafe(32)}"
|
||||||
|
key_prefix = raw_key[:10]
|
||||||
|
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
||||||
|
|
||||||
|
api_key = ApiKey(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=data.name,
|
||||||
|
key_prefix=key_prefix,
|
||||||
|
hashed_key=hashed_key,
|
||||||
|
)
|
||||||
|
db.add(api_key)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(api_key)
|
||||||
|
|
||||||
|
return ApiKeyCreateResponse(
|
||||||
|
id=api_key.id,
|
||||||
|
name=api_key.name,
|
||||||
|
key_prefix=api_key.key_prefix,
|
||||||
|
created_at=api_key.created_at,
|
||||||
|
key=raw_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[ApiKeyResponse])
|
||||||
|
async def list_api_keys(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ApiKey)
|
||||||
|
.where(ApiKey.user_id == current_user.id)
|
||||||
|
.order_by(ApiKey.created_at.desc())
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_api_key(
|
||||||
|
key_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
api_key = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="API Key not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(api_key)
|
||||||
|
await db.commit()
|
||||||
61
app/routers/auth.py
Normal file
61
app/routers/auth.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.auth import (
|
||||||
|
LoginRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
TokenResponse,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
from app.utils.auth import create_access_token, hash_password, verify_password
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(data: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(
|
||||||
|
select(User).where((User.email == data.email) | (User.username == data.username))
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing:
|
||||||
|
field = "email" if existing.email == data.email else "username"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"A user with this {field} already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
email=data.email,
|
||||||
|
username=data.username,
|
||||||
|
hashed_password=hash_password(data.password),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
token = create_access_token(user.id)
|
||||||
|
return TokenResponse(access_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(data: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(User).where(User.email == data.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None or not verify_password(data.password, user.hashed_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_access_token(user.id)
|
||||||
|
return TokenResponse(access_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_me(current_user: User = Depends(get_current_user)):
|
||||||
|
return current_user
|
||||||
19
app/routers/health.py
Normal file
19
app/routers/health.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["health"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": f"{settings.app_name} backend running 🚀"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"app": settings.app_name,
|
||||||
|
"version": settings.app_version,
|
||||||
|
}
|
||||||
226
app/routers/history.py
Normal file
226
app/routers/history.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_current_user
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.scan import (
|
||||||
|
Issue,
|
||||||
|
LayerStatus,
|
||||||
|
ScanHistoryItem,
|
||||||
|
ScanHistoryResponse,
|
||||||
|
ScanResponse,
|
||||||
|
DashboardTrendsResponse,
|
||||||
|
ChatRequest,
|
||||||
|
ChatResponse,
|
||||||
|
ThreatNarrativeResponse,
|
||||||
|
ScanDiffResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.services.ai import chat_with_scan_context, generate_threat_narrative
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/scans", tags=["history"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ScanHistoryResponse)
|
||||||
|
async def list_scans(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
per_page: int = Query(20, ge=1, le=100),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
|
||||||
|
count_result = await db.execute(
|
||||||
|
select(func.count()).select_from(ScanResult).where(ScanResult.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult)
|
||||||
|
.where(ScanResult.user_id == current_user.id)
|
||||||
|
.order_by(ScanResult.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(per_page)
|
||||||
|
)
|
||||||
|
scans = result.scalars().all()
|
||||||
|
|
||||||
|
return ScanHistoryResponse(
|
||||||
|
scans=[ScanHistoryItem.model_validate(s) for s in scans],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/trends", response_model=DashboardTrendsResponse)
|
||||||
|
async def get_trends(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
count_result = await db.execute(
|
||||||
|
select(func.count()).select_from(ScanResult).where(ScanResult.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
total_scans = count_result.scalar_one()
|
||||||
|
|
||||||
|
avg_result = await db.execute(
|
||||||
|
select(func.avg(ScanResult.security_score)).where(ScanResult.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
avg_score = avg_result.scalar_one() or 0.0
|
||||||
|
|
||||||
|
recent_result = await db.execute(
|
||||||
|
select(ScanResult)
|
||||||
|
.where(ScanResult.user_id == current_user.id)
|
||||||
|
.order_by(ScanResult.created_at.desc())
|
||||||
|
.limit(5)
|
||||||
|
)
|
||||||
|
recent_scans = recent_result.scalars().all()
|
||||||
|
|
||||||
|
return DashboardTrendsResponse(
|
||||||
|
total_scans=total_scans,
|
||||||
|
average_score=float(avg_score),
|
||||||
|
recent_scans=[ScanHistoryItem.model_validate(s) for s in recent_scans]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scan_id}", response_model=ScanResponse)
|
||||||
|
async def get_scan(
|
||||||
|
scan_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(
|
||||||
|
ScanResult.id == scan_id,
|
||||||
|
ScanResult.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if scan is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
return ScanResponse(
|
||||||
|
id=scan.id,
|
||||||
|
url=scan.url,
|
||||||
|
security_score=scan.security_score,
|
||||||
|
layers={k: LayerStatus(**v) for k, v in scan.layers.items()},
|
||||||
|
issues=[Issue(**i) for i in scan.issues],
|
||||||
|
created_at=scan.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{scan_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_scan(
|
||||||
|
scan_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(
|
||||||
|
ScanResult.id == scan_id,
|
||||||
|
ScanResult.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if scan is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
await db.delete(scan)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{scan_id}/chat", response_model=ChatResponse)
|
||||||
|
async def chat_about_scan(
|
||||||
|
scan_id: str,
|
||||||
|
data: ChatRequest,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(
|
||||||
|
ScanResult.id == scan_id,
|
||||||
|
ScanResult.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not scan:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
"url": scan.url,
|
||||||
|
"score": scan.security_score,
|
||||||
|
"layers": scan.layers,
|
||||||
|
"issues": scan.issues,
|
||||||
|
}
|
||||||
|
|
||||||
|
reply = await chat_with_scan_context(scan_id, context_data, data.message)
|
||||||
|
return ChatResponse(reply=reply)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scan_id}/threat-narrative", response_model=ThreatNarrativeResponse)
|
||||||
|
async def get_threat_narrative(
|
||||||
|
scan_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(
|
||||||
|
ScanResult.id == scan_id,
|
||||||
|
ScanResult.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not scan:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
"url": scan.url,
|
||||||
|
"score": scan.security_score,
|
||||||
|
"layers": scan.layers,
|
||||||
|
"issues": scan.issues,
|
||||||
|
}
|
||||||
|
|
||||||
|
narrative = await generate_threat_narrative(context_data)
|
||||||
|
return ThreatNarrativeResponse(narrative=narrative)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{old_id}/diff/{new_id}", response_model=ScanDiffResponse)
|
||||||
|
async def diff_scans(
|
||||||
|
old_id: str,
|
||||||
|
new_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(
|
||||||
|
ScanResult.id.in_([old_id, new_id]),
|
||||||
|
ScanResult.user_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scans = result.scalars().all()
|
||||||
|
|
||||||
|
if len(scans) != 2:
|
||||||
|
raise HTTPException(status_code=404, detail="One or both scans not found, or access denied.")
|
||||||
|
|
||||||
|
s_old = scans[0] if scans[0].id == old_id else scans[1]
|
||||||
|
s_new = scans[1] if scans[1].id == new_id else scans[0]
|
||||||
|
|
||||||
|
# Convert to set-like structures using issue names
|
||||||
|
old_map = {i.get("issue"): i for i in s_old.issues}
|
||||||
|
new_map = {i.get("issue"): i for i in s_new.issues}
|
||||||
|
|
||||||
|
resolved = [v for k, v in old_map.items() if k not in new_map]
|
||||||
|
new_issues = [v for k, v in new_map.items() if k not in old_map]
|
||||||
|
persisting = [v for k, v in new_map.items() if k in old_map]
|
||||||
|
|
||||||
|
return ScanDiffResponse(
|
||||||
|
resolved_issues=resolved,
|
||||||
|
new_issues=new_issues,
|
||||||
|
persisting_issues=persisting,
|
||||||
|
score_change=s_new.security_score - s_old.security_score
|
||||||
|
)
|
||||||
116
app/routers/report.py
Normal file
116
app/routers/report.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import csv
|
||||||
|
import io
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from fpdf import FPDF
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_current_user
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/scans", tags=["report"])
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_csv(scan: ScanResult) -> io.StringIO:
|
||||||
|
output = io.StringIO()
|
||||||
|
writer = csv.writer(output)
|
||||||
|
|
||||||
|
writer.writerow(["SecureLens AI Scan Report"])
|
||||||
|
writer.writerow(["URL", scan.url])
|
||||||
|
writer.writerow(["Date", scan.created_at.strftime("%Y-%m-%d %H:%M:%S")])
|
||||||
|
writer.writerow(["Security Score", scan.security_score])
|
||||||
|
writer.writerow([])
|
||||||
|
|
||||||
|
writer.writerow(["Issue", "Severity", "Layer", "Fix", "Contextual Severity", "Explanation"])
|
||||||
|
for i in scan.issues:
|
||||||
|
writer.writerow([
|
||||||
|
i.get("issue"),
|
||||||
|
i.get("severity"),
|
||||||
|
i.get("layer"),
|
||||||
|
i.get("fix"),
|
||||||
|
i.get("contextual_severity", ""),
|
||||||
|
i.get("explanation", ""),
|
||||||
|
])
|
||||||
|
|
||||||
|
output.seek(0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_pdf(scan: ScanResult) -> io.BytesIO:
|
||||||
|
pdf = FPDF()
|
||||||
|
pdf.add_page()
|
||||||
|
|
||||||
|
pdf.set_font("helvetica", "B", 16)
|
||||||
|
pdf.cell(0, 10, "SecureLens AI Scan Report", new_x="LMARGIN", new_y="NEXT", align="C")
|
||||||
|
|
||||||
|
pdf.set_font("helvetica", "", 12)
|
||||||
|
pdf.cell(0, 10, f"URL: {scan.url}", new_x="LMARGIN", new_y="NEXT")
|
||||||
|
pdf.cell(0, 10, f"Date: {scan.created_at.strftime('%Y-%m-%d %H:%M:%S')}", new_x="LMARGIN", new_y="NEXT")
|
||||||
|
pdf.cell(0, 10, f"Security Score: {scan.security_score}/100", new_x="LMARGIN", new_y="NEXT")
|
||||||
|
|
||||||
|
pdf.ln(5)
|
||||||
|
pdf.set_font("helvetica", "B", 14)
|
||||||
|
pdf.cell(0, 10, "Discovered Issues", new_x="LMARGIN", new_y="NEXT")
|
||||||
|
|
||||||
|
for i in scan.issues:
|
||||||
|
pdf.set_font("helvetica", "B", 12)
|
||||||
|
pdf.cell(0, 8, f"Issue: {i.get('issue')} [{i.get('severity')}]", new_x="LMARGIN", new_y="NEXT")
|
||||||
|
|
||||||
|
pdf.set_font("helvetica", "", 10)
|
||||||
|
pdf.multi_cell(0, 6, f"Layer: {i.get('layer')}")
|
||||||
|
pdf.multi_cell(0, 6, f"Fix: {i.get('fix')}")
|
||||||
|
|
||||||
|
if i.get("explanation"):
|
||||||
|
pdf.multi_cell(0, 6, f"AI Context: {i.get('explanation')}")
|
||||||
|
pdf.ln(4)
|
||||||
|
|
||||||
|
pdf_bytes = pdf.output()
|
||||||
|
return io.BytesIO(pdf_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scan_id}/export/csv")
|
||||||
|
async def export_csv(
|
||||||
|
scan_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(ScanResult.id == scan_id, ScanResult.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not scan:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
csv_data = _generate_csv(scan)
|
||||||
|
response = StreamingResponse(iter([csv_data.getvalue()]), media_type="text/csv")
|
||||||
|
response.headers["Content-Disposition"] = f"attachment; filename=scan_{scan_id}.csv"
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scan_id}/export/pdf")
|
||||||
|
async def export_pdf(
|
||||||
|
scan_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
result = await db.execute(
|
||||||
|
select(ScanResult).where(ScanResult.id == scan_id, ScanResult.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
scan = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not scan:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
pdf_data = _generate_pdf(scan)
|
||||||
|
response = StreamingResponse(pdf_data, media_type="application/pdf")
|
||||||
|
response.headers["Content-Disposition"] = f"attachment; filename=scan_{scan_id}.pdf"
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"PDF Generation failed: {str(e)}")
|
||||||
154
app/routers/scan.py
Normal file
154
app/routers/scan.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, Depends, Request, BackgroundTasks
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_optional_user
|
||||||
|
from app.middleware.rate_limiter import limiter
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.webhook import Webhook
|
||||||
|
from app.schemas.scan import ScanRequest, ScanResponse
|
||||||
|
from app.services.scanner.cookies import CookieScanner
|
||||||
|
from app.services.scanner.exposure import ExposureScanner
|
||||||
|
from app.services.scanner.headers import HeaderScanner
|
||||||
|
from app.services.scanner.ssl_checker import SSLScanner
|
||||||
|
from app.services.scanner.transport import TransportScanner
|
||||||
|
from app.services.scanner.dns import DNSScanner
|
||||||
|
from app.services.scanner.ports import PortScanner
|
||||||
|
from app.services.scoring import calculate_layer_statuses, calculate_score
|
||||||
|
from app.services.ai import enhance_security_issues
|
||||||
|
from app.utils.validators import validate_url
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["scan"])
|
||||||
|
|
||||||
|
transport_scanner = TransportScanner()
|
||||||
|
ssl_scanner = SSLScanner()
|
||||||
|
header_scanner = HeaderScanner()
|
||||||
|
cookie_scanner = CookieScanner()
|
||||||
|
exposure_scanner = ExposureScanner()
|
||||||
|
dns_scanner = DNSScanner()
|
||||||
|
port_scanner = PortScanner()
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_webhooks(user_id: str, scan_data: dict, db_session):
|
||||||
|
import hmac, hashlib, json
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Webhook).where(Webhook.user_id == user_id, Webhook.is_active == True)
|
||||||
|
)
|
||||||
|
hooks = result.scalars().all()
|
||||||
|
if not hooks:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
payload = json.dumps(scan_data).encode("utf-8")
|
||||||
|
for hook in hooks:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if hook.secret_key:
|
||||||
|
sig = hmac.new(hook.secret_key.encode(), payload, hashlib.sha256).hexdigest()
|
||||||
|
headers["X-SecureLens-Signature"] = sig
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.post(hook.target_url, content=payload, headers=headers, timeout=5.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Webhook {hook.target_url} failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/scan", response_model=ScanResponse)
|
||||||
|
@limiter.limit(settings.rate_limit)
|
||||||
|
async def scan_website(
|
||||||
|
data: ScanRequest,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User | None = Depends(get_optional_user),
|
||||||
|
):
|
||||||
|
url = validate_url(data.url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
dns_task = asyncio.create_task(dns_scanner.scan(url))
|
||||||
|
port_task = asyncio.create_task(port_scanner.scan(url))
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(settings.scan_timeout),
|
||||||
|
follow_redirects=True,
|
||||||
|
) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
|
||||||
|
all_issues = []
|
||||||
|
all_issues.extend(await transport_scanner.scan(url, response))
|
||||||
|
all_issues.extend(await ssl_scanner.scan(url, response))
|
||||||
|
all_issues.extend(await header_scanner.scan(url, response))
|
||||||
|
all_issues.extend(await cookie_scanner.scan(url, response))
|
||||||
|
all_issues.extend(await exposure_scanner.scan(url, response))
|
||||||
|
|
||||||
|
# Await infrastructure scans
|
||||||
|
all_issues.extend(await dns_task)
|
||||||
|
all_issues.extend(await port_task)
|
||||||
|
|
||||||
|
score = calculate_score(all_issues)
|
||||||
|
layers = calculate_layer_statuses(all_issues)
|
||||||
|
|
||||||
|
if settings.openai_api_key and all_issues:
|
||||||
|
issues_dict_list = [i.model_dump() for i in all_issues]
|
||||||
|
ai_data = await enhance_security_issues(issues_dict_list)
|
||||||
|
enhanced_list = ai_data.get("enhanced_issues", [])
|
||||||
|
enhancement_map = {e.get("issue"): e for e in enhanced_list}
|
||||||
|
for original in all_issues:
|
||||||
|
enh = enhancement_map.get(original.issue)
|
||||||
|
if enh:
|
||||||
|
original.contextual_severity = enh.get("contextual_severity")
|
||||||
|
original.explanation = enh.get("explanation")
|
||||||
|
original.remediation_snippet = enh.get("remediation_snippet")
|
||||||
|
|
||||||
|
scan_id = None
|
||||||
|
created_at = None
|
||||||
|
|
||||||
|
if current_user is not None:
|
||||||
|
layers_dict = {k: v.model_dump() for k, v in layers.items()}
|
||||||
|
issues_list = [i.model_dump() for i in all_issues]
|
||||||
|
|
||||||
|
scan_record = ScanResult(
|
||||||
|
user_id=current_user.id,
|
||||||
|
url=url,
|
||||||
|
security_score=score,
|
||||||
|
layers=layers_dict,
|
||||||
|
issues=issues_list,
|
||||||
|
)
|
||||||
|
db.add(scan_record)
|
||||||
|
await db.flush()
|
||||||
|
scan_id = scan_record.id
|
||||||
|
created_at = scan_record.created_at
|
||||||
|
|
||||||
|
scan_summary = {
|
||||||
|
"scan_id": scan_id,
|
||||||
|
"url": url,
|
||||||
|
"score": score
|
||||||
|
}
|
||||||
|
background_tasks.add_task(dispatch_webhooks, current_user.id, scan_summary, db)
|
||||||
|
|
||||||
|
return ScanResponse(
|
||||||
|
id=scan_id,
|
||||||
|
url=url,
|
||||||
|
security_score=score,
|
||||||
|
layers=layers,
|
||||||
|
issues=all_issues,
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Scan failed for {url}: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=502,
|
||||||
|
content={"error": f"Could not reach {url}: {str(e)}"},
|
||||||
|
)
|
||||||
56
app/routers/webhook.py
Normal file
56
app/routers/webhook.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import secrets
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.middleware.auth import get_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.webhook import Webhook
|
||||||
|
from app.schemas.webhook import WebhookCreate, WebhookResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=WebhookResponse)
|
||||||
|
async def create_webhook(
|
||||||
|
hook_in: WebhookCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
secret = hook_in.secret_key or secrets.token_hex(16)
|
||||||
|
db_hook = Webhook(
|
||||||
|
user_id=current_user.id,
|
||||||
|
target_url=str(hook_in.target_url),
|
||||||
|
secret_key=secret
|
||||||
|
)
|
||||||
|
db.add(db_hook)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_hook)
|
||||||
|
return db_hook
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[WebhookResponse])
|
||||||
|
async def list_webhooks(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Webhook).where(Webhook.user_id == current_user.id))
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{hook_id}")
|
||||||
|
async def delete_webhook(
|
||||||
|
hook_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Webhook).where(Webhook.id == hook_id, Webhook.user_id == current_user.id))
|
||||||
|
hook = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not hook:
|
||||||
|
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||||
|
|
||||||
|
await db.delete(hook)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": "success", "message": "Webhook deleted"}
|
||||||
0
app/schemas/__init__.py
Normal file
0
app/schemas/__init__.py
Normal file
17
app/schemas/apikey.py
Normal file
17
app/schemas/apikey.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
key_prefix: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||||
|
key: str # The raw API key returned only once upon creation
|
||||||
28
app/schemas/auth.py
Normal file
28
app/schemas/auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
username: str = Field(..., min_length=3, max_length=100)
|
||||||
|
password: str = Field(..., min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
username: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
72
app/schemas/scan.py
Normal file
72
app/schemas/scan.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ScanRequest(BaseModel):
|
||||||
|
url: str = Field(..., description="The URL of the website to scan")
|
||||||
|
|
||||||
|
|
||||||
|
class Issue(BaseModel):
|
||||||
|
issue: str
|
||||||
|
severity: str
|
||||||
|
layer: str
|
||||||
|
fix: str
|
||||||
|
contextual_severity: str | None = None
|
||||||
|
explanation: str | None = None
|
||||||
|
remediation_snippet: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LayerStatus(BaseModel):
|
||||||
|
issues: int = 0
|
||||||
|
status: str = "green"
|
||||||
|
|
||||||
|
|
||||||
|
class ScanResponse(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
url: str
|
||||||
|
security_score: int
|
||||||
|
layers: dict[str, LayerStatus]
|
||||||
|
issues: list[Issue]
|
||||||
|
created_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScanHistoryItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
url: str
|
||||||
|
security_score: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ScanHistoryResponse(BaseModel):
|
||||||
|
scans: list[ScanHistoryItem]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
per_page: int
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardTrendsResponse(BaseModel):
|
||||||
|
total_scans: int
|
||||||
|
average_score: float
|
||||||
|
recent_scans: list[ScanHistoryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
reply: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatNarrativeResponse(BaseModel):
|
||||||
|
narrative: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScanDiffResponse(BaseModel):
|
||||||
|
resolved_issues: list[Issue]
|
||||||
|
new_issues: list[Issue]
|
||||||
|
persisting_issues: list[Issue]
|
||||||
|
score_change: int
|
||||||
17
app/schemas/webhook.py
Normal file
17
app/schemas/webhook.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookCreate(BaseModel):
|
||||||
|
target_url: HttpUrl
|
||||||
|
secret_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
target_url: str
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
105
app/services/ai.py
Normal file
105
app/services/ai.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
api_key = settings.openai_api_key or "mock-key-for-testing"
|
||||||
|
client = AsyncOpenAI(api_key=api_key)
|
||||||
|
|
||||||
|
async def enhance_security_issues(issues: list[dict]) -> dict:
|
||||||
|
"""
|
||||||
|
Takes a list of basic security issues and uses an LLM to provide:
|
||||||
|
- Contextual severity
|
||||||
|
- Natural language explanations
|
||||||
|
- Auto-generated remediation code snippets
|
||||||
|
"""
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
logger.warning("OPENAI_API_KEY is not set. AI enhancements are skipped.")
|
||||||
|
return {"enhanced_issues": issues}
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Analyze the following security vulnerabilities:\n"
|
||||||
|
f"{json.dumps(issues, indent=2)}\n\n"
|
||||||
|
"Return a JSON object with a single key 'enhanced_issues' containing a list of objects. "
|
||||||
|
"Each object MUST correspond to one of the original issues and have the following keys: "
|
||||||
|
"'issue' (exact string of the original issue), "
|
||||||
|
"'contextual_severity' (Low, Medium, High, Critical), "
|
||||||
|
"'explanation' (a 1-2 sentence non-technical explanation), "
|
||||||
|
"'remediation_snippet' (Actionable code snippet, e.g. Nginx config, or 'N/A')."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a senior cybersecurity automation agent. Always respond with valid JSON."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0.2,
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
if not content:
|
||||||
|
return {"enhanced_issues": issues, "ai_error": "Empty response"}
|
||||||
|
|
||||||
|
return json.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Generation Error: {str(e)}")
|
||||||
|
return {"enhanced_issues": issues, "ai_error": str(e)}
|
||||||
|
|
||||||
|
async def chat_with_scan_context(scan_id: str, context_data: dict, user_message: str) -> str:
|
||||||
|
"""
|
||||||
|
Allows a user to ask a question about a specific scan's results.
|
||||||
|
"""
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
return "AI Chat is disabled because OPENAI_API_KEY is not configured."
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are SecureLens AI, an expert cybersecurity assistant. "
|
||||||
|
"You are helping a developer understand a security scan report for their website. "
|
||||||
|
f"Here is the context of the scan: {json.dumps(context_data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message}
|
||||||
|
],
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or "No response from AI."
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Chat Error: {str(e)}")
|
||||||
|
return "I encountered an error trying to process your request."
|
||||||
|
|
||||||
|
async def generate_threat_narrative(context_data: dict) -> str:
|
||||||
|
"""
|
||||||
|
Weaves multiple scan issues into a cohesive attack sequence.
|
||||||
|
"""
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
return "AI Threat Narrative is disabled because OPENAI_API_KEY is not configured."
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are a senior cybersecurity red-teamer. Analyze the following security scan results "
|
||||||
|
"and weave them into a single, cohesive 'Threat Narrative'. Explain how an attacker might "
|
||||||
|
"chain these specific vulnerabilities together to compromise the system. "
|
||||||
|
"Keep it professional, concise (2-3 paragraphs), and actionable."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": json.dumps(context_data)}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or "Could not generate threat narrative."
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Narrative Error: {str(e)}")
|
||||||
|
return "I encountered an error trying to generate the threat narrative."
|
||||||
0
app/services/scanner/__init__.py
Normal file
0
app/services/scanner/__init__.py
Normal file
11
app/services/scanner/base.py
Normal file
11
app/services/scanner/base.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScanner(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
pass
|
||||||
70
app/services/scanner/cookies.py
Normal file
70
app/services/scanner/cookies.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import logging
|
||||||
|
from http.cookies import SimpleCookie
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scanner.base import BaseScanner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CookieScanner(BaseScanner):
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
is_https = url.startswith("https")
|
||||||
|
|
||||||
|
raw_cookies = response.headers.multi_items()
|
||||||
|
set_cookie_headers = [
|
||||||
|
value for key, value in raw_cookies if key.lower() == "set-cookie"
|
||||||
|
]
|
||||||
|
|
||||||
|
if not set_cookie_headers:
|
||||||
|
return issues
|
||||||
|
|
||||||
|
for cookie_str in set_cookie_headers:
|
||||||
|
cookie_lower = cookie_str.lower()
|
||||||
|
|
||||||
|
cookie = SimpleCookie()
|
||||||
|
try:
|
||||||
|
cookie.load(cookie_str)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Could not parse cookie: {cookie_str[:80]}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for name, morsel in cookie.items():
|
||||||
|
if "httponly" not in cookie_lower:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Cookie '{name}' missing HttpOnly flag",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Cookie Security",
|
||||||
|
fix=f"Set the HttpOnly flag on cookie '{name}' to prevent JavaScript access",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "; secure" not in cookie_lower:
|
||||||
|
if is_https:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Cookie '{name}' missing Secure flag",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Cookie Security",
|
||||||
|
fix=f"Set the Secure flag on cookie '{name}' to ensure it is only sent over HTTPS",
|
||||||
|
))
|
||||||
|
|
||||||
|
samesite_value = morsel.get("samesite", "").lower()
|
||||||
|
if not samesite_value:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Cookie '{name}' missing SameSite attribute",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Cookie Security",
|
||||||
|
fix=f"Set SameSite=Lax or SameSite=Strict on cookie '{name}' to prevent CSRF attacks",
|
||||||
|
))
|
||||||
|
elif samesite_value == "none":
|
||||||
|
if "; secure" not in cookie_lower:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Cookie '{name}' has SameSite=None without Secure flag",
|
||||||
|
severity="Critical",
|
||||||
|
layer="Cookie Security",
|
||||||
|
fix=f"Cookies with SameSite=None must also have the Secure flag set",
|
||||||
|
))
|
||||||
|
|
||||||
|
return issues
|
||||||
148
app/services/scanner/dns.py
Normal file
148
app/services/scanner/dns.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiodns
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DNSScanner:
|
||||||
|
def __init__(self):
|
||||||
|
self.resolver = aiodns.DNSResolver(timeout=3.0)
|
||||||
|
|
||||||
|
async def scan(self, url: str) -> list[Issue]:
|
||||||
|
issues = []
|
||||||
|
domain = self._extract_domain(url)
|
||||||
|
if not domain:
|
||||||
|
return issues
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
self._check_spf(domain),
|
||||||
|
self._check_dmarc(domain),
|
||||||
|
self._enumerate_subdomains(domain),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, list):
|
||||||
|
issues.extend(result)
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
logger.error(f"DNS scan error: {result}")
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
def _extract_domain(self, url: str) -> str | None:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
domain = parsed.netloc.split(":")[0]
|
||||||
|
if domain.startswith("www."):
|
||||||
|
domain = domain[4:]
|
||||||
|
return domain
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _check_spf(self, domain: str) -> list[Issue]:
|
||||||
|
issues = []
|
||||||
|
try:
|
||||||
|
records = await self.resolver.query(domain, "TXT")
|
||||||
|
has_spf = any("v=spf1" in r.text for r in records if r.text)
|
||||||
|
if not has_spf:
|
||||||
|
issues.append(
|
||||||
|
Issue(
|
||||||
|
issue="Missing SPF Record",
|
||||||
|
severity="Medium",
|
||||||
|
layer="DNS",
|
||||||
|
fix="Add a TXT record with SPF rules (e.g., v=spf1 mx -all) to prevent email spoofing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except aiodns.error.DNSError as e:
|
||||||
|
# Code 4 usually means no record of that type, or Code 1 is domain not found
|
||||||
|
if e.args[0] in [1, 4]:
|
||||||
|
issues.append(
|
||||||
|
Issue(
|
||||||
|
issue="Missing SPF Record",
|
||||||
|
severity="Medium",
|
||||||
|
layer="DNS",
|
||||||
|
fix="Add a TXT record with SPF rules (e.g., v=spf1 mx -all) to prevent email spoofing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"SPF DNS error for {domain}: {e}")
|
||||||
|
return issues
|
||||||
|
|
||||||
|
async def _check_dmarc(self, domain: str) -> list[Issue]:
|
||||||
|
issues = []
|
||||||
|
dmarc_domain = f"_dmarc.{domain}"
|
||||||
|
try:
|
||||||
|
records = await self.resolver.query(dmarc_domain, "TXT")
|
||||||
|
has_dmarc = any("v=DMARC1" in r.text for r in records if r.text)
|
||||||
|
if not has_dmarc:
|
||||||
|
issues.append(
|
||||||
|
Issue(
|
||||||
|
issue="Missing DMARC Record",
|
||||||
|
severity="Low",
|
||||||
|
layer="DNS",
|
||||||
|
fix="Add a DMARC TXT record at _dmarc to policy control email spoofing failures.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except aiodns.error.DNSError as e:
|
||||||
|
if e.args[0] in [1, 4]:
|
||||||
|
issues.append(
|
||||||
|
Issue(
|
||||||
|
issue="Missing DMARC Record",
|
||||||
|
severity="Low",
|
||||||
|
layer="DNS",
|
||||||
|
fix="Add a DMARC TXT record at _dmarc to policy control email spoofing failures.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"DMARC DNS error for {domain}: {e}")
|
||||||
|
return issues
|
||||||
|
|
||||||
|
async def _enumerate_subdomains(self, domain: str) -> list[Issue]:
|
||||||
|
issues = []
|
||||||
|
# Query Certificate Transparency logs via crt.sh
|
||||||
|
url = f"https://crt.sh/?q=%.{domain}&output=json"
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
unique_subs = set()
|
||||||
|
|
||||||
|
# Extract subdomains
|
||||||
|
for entry in data:
|
||||||
|
name = entry.get("name_value", "")
|
||||||
|
# Handle multiple names separated by newlines
|
||||||
|
for sub in name.split("\n"):
|
||||||
|
sub = sub.strip()
|
||||||
|
if "*" not in sub and sub != domain and sub != f"www.{domain}":
|
||||||
|
unique_subs.add(sub)
|
||||||
|
|
||||||
|
# Look for risky subdomains
|
||||||
|
keywords = ["dev", "test", "staging", "qa", "admin", "internal", "api", "dashboard"]
|
||||||
|
dev_envs = [sub for sub in unique_subs if any(kw in sub.lower() for kw in keywords)]
|
||||||
|
|
||||||
|
if dev_envs:
|
||||||
|
env_str = ", ".join(list(dev_envs)[:3])
|
||||||
|
more = len(dev_envs) - 3
|
||||||
|
if more > 0:
|
||||||
|
env_str += f", and {more} more"
|
||||||
|
|
||||||
|
issues.append(
|
||||||
|
Issue(
|
||||||
|
issue="Exposed Subdomains Detected",
|
||||||
|
severity="Info",
|
||||||
|
layer="DNS",
|
||||||
|
fix=f"Subdomains such as {env_str} are exposed in CT logs. Ensure they are protected and not publicly accessible if they afford sensitive access.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Subdomain enumeration failed for {domain}: {str(e)}")
|
||||||
|
|
||||||
|
return issues
|
||||||
135
app/services/scanner/exposure.py
Normal file
135
app/services/scanner/exposure.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scanner.base import BaseScanner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SENSITIVE_PATHS = [
|
||||||
|
"/admin",
|
||||||
|
"/.env",
|
||||||
|
"/.git",
|
||||||
|
"/.git/config",
|
||||||
|
"/.git/HEAD",
|
||||||
|
"/backup",
|
||||||
|
"/debug",
|
||||||
|
"/wp-admin",
|
||||||
|
"/wp-login.php",
|
||||||
|
"/phpmyadmin",
|
||||||
|
"/.DS_Store",
|
||||||
|
"/server-status",
|
||||||
|
"/server-info",
|
||||||
|
"/swagger.json",
|
||||||
|
"/openapi.json",
|
||||||
|
"/api/docs",
|
||||||
|
"/.htaccess",
|
||||||
|
"/.htpasswd",
|
||||||
|
"/web.config",
|
||||||
|
"/elmah.axd",
|
||||||
|
"/trace.axd",
|
||||||
|
"/phpinfo.php",
|
||||||
|
"/config.php",
|
||||||
|
"/wp-config.php.bak",
|
||||||
|
"/.well-known/security.txt",
|
||||||
|
]
|
||||||
|
|
||||||
|
ROBOTS_SENSITIVE_KEYWORDS = [
|
||||||
|
"/admin",
|
||||||
|
"/login",
|
||||||
|
"/dashboard",
|
||||||
|
"/secret",
|
||||||
|
"/private",
|
||||||
|
"/backup",
|
||||||
|
"/config",
|
||||||
|
"/database",
|
||||||
|
"/staging",
|
||||||
|
"/internal",
|
||||||
|
"/api/v",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ExposureScanner(BaseScanner):
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
base_url = url.rstrip("/")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(settings.path_check_timeout),
|
||||||
|
follow_redirects=True,
|
||||||
|
) as client:
|
||||||
|
for path in SENSITIVE_PATHS:
|
||||||
|
try:
|
||||||
|
r = await client.get(base_url + path)
|
||||||
|
if r.status_code == 200:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Sensitive path exposed: {path}",
|
||||||
|
severity="Critical",
|
||||||
|
layer="Exposure Layer",
|
||||||
|
fix=f"Restrict access to {path} using authentication or firewall rules",
|
||||||
|
))
|
||||||
|
except httpx.HTTPError:
|
||||||
|
logger.debug(f"Could not reach {base_url}{path}")
|
||||||
|
|
||||||
|
issues.extend(await self._check_robots_txt(client, base_url))
|
||||||
|
issues.extend(await self._check_directory_listing(client, base_url))
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
async def _check_robots_txt(
|
||||||
|
self, client: httpx.AsyncClient, base_url: str
|
||||||
|
) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
try:
|
||||||
|
r = await client.get(base_url + "/robots.txt")
|
||||||
|
if r.status_code == 200:
|
||||||
|
content = r.text.lower()
|
||||||
|
exposed_paths = []
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("disallow:"):
|
||||||
|
path = line.split(":", 1)[1].strip()
|
||||||
|
if path:
|
||||||
|
for keyword in ROBOTS_SENSITIVE_KEYWORDS:
|
||||||
|
if keyword in path.lower():
|
||||||
|
exposed_paths.append(path)
|
||||||
|
break
|
||||||
|
|
||||||
|
if exposed_paths:
|
||||||
|
paths_str = ", ".join(exposed_paths[:5])
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"robots.txt reveals sensitive paths: {paths_str}",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Exposure Layer",
|
||||||
|
fix="Avoid listing sensitive paths in robots.txt; use authentication instead",
|
||||||
|
))
|
||||||
|
except httpx.HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
async def _check_directory_listing(
|
||||||
|
self, client: httpx.AsyncClient, base_url: str
|
||||||
|
) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
test_paths = ["/images/", "/assets/", "/static/", "/uploads/"]
|
||||||
|
|
||||||
|
for path in test_paths:
|
||||||
|
try:
|
||||||
|
r = await client.get(base_url + path)
|
||||||
|
if r.status_code == 200:
|
||||||
|
body = r.text.lower()
|
||||||
|
if "index of" in body or "directory listing" in body or ("<pre>" in body and "parent directory" in body):
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Directory listing enabled at {path}",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Exposure Layer",
|
||||||
|
fix=f"Disable directory listing for {path} in your web server configuration",
|
||||||
|
))
|
||||||
|
break
|
||||||
|
except httpx.HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return issues
|
||||||
140
app/services/scanner/headers.py
Normal file
140
app/services/scanner/headers.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scanner.base import BaseScanner
|
||||||
|
|
||||||
|
|
||||||
|
class HeaderScanner(BaseScanner):
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
headers = response.headers
|
||||||
|
|
||||||
|
if "Content-Security-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Content-Security-Policy header",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Content-Security-Policy: default-src 'self';",
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
csp = headers["Content-Security-Policy"]
|
||||||
|
if "'unsafe-inline'" in csp:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Content-Security-Policy allows 'unsafe-inline'",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Remove 'unsafe-inline' from CSP and use nonces or hashes for inline scripts/styles",
|
||||||
|
))
|
||||||
|
if "'unsafe-eval'" in csp:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Content-Security-Policy allows 'unsafe-eval'",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Remove 'unsafe-eval' from CSP to prevent dynamic code execution via eval()",
|
||||||
|
))
|
||||||
|
if "*" in csp.split():
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Content-Security-Policy uses wildcard (*) source",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Replace wildcard (*) in CSP with specific trusted domains",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "X-Frame-Options" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing X-Frame-Options header",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: X-Frame-Options: SAMEORIGIN",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "X-Content-Type-Options" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing X-Content-Type-Options header",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: X-Content-Type-Options: nosniff",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "Referrer-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Referrer-Policy header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Referrer-Policy: strict-origin-when-cross-origin",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "Permissions-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Permissions-Policy header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Permissions-Policy: geolocation=(), camera=(), microphone=()",
|
||||||
|
))
|
||||||
|
|
||||||
|
if headers.get("Access-Control-Allow-Origin") == "*":
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="CORS allows all origins (*)",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Restrict Access-Control-Allow-Origin to trusted domains",
|
||||||
|
))
|
||||||
|
|
||||||
|
server = headers.get("Server", "")
|
||||||
|
if server:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Server header discloses technology: {server}",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Remove or obfuscate the Server header to prevent information disclosure",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "X-Powered-By" in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"X-Powered-By header discloses technology: {headers['X-Powered-By']}",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Remove the X-Powered-By header to prevent information disclosure",
|
||||||
|
))
|
||||||
|
|
||||||
|
cache_control = headers.get("Cache-Control", "")
|
||||||
|
if not cache_control:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Cache-Control header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add Cache-Control header with appropriate directives (e.g., no-store for sensitive pages)",
|
||||||
|
))
|
||||||
|
elif "no-store" not in cache_control.lower() and "private" not in cache_control.lower():
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Cache-Control does not prevent caching of potentially sensitive content",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add 'no-store' or 'private' to Cache-Control for pages with sensitive data",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "Cross-Origin-Opener-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Cross-Origin-Opener-Policy (COOP) header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Cross-Origin-Opener-Policy: same-origin",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "Cross-Origin-Resource-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Cross-Origin-Resource-Policy (CORP) header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Cross-Origin-Resource-Policy: same-origin",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "Cross-Origin-Embedder-Policy" not in headers:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing Cross-Origin-Embedder-Policy (COEP) header",
|
||||||
|
severity="Info",
|
||||||
|
layer="Server Config Layer",
|
||||||
|
fix="Add header: Cross-Origin-Embedder-Policy: require-corp",
|
||||||
|
))
|
||||||
|
|
||||||
|
return issues
|
||||||
76
app/services/scanner/ports.py
Normal file
76
app/services/scanner/ports.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# High-risk ports that generally shouldn't be publicly exposed
|
||||||
|
HIGH_RISK_PORTS = {
|
||||||
|
22: "SSH",
|
||||||
|
1433: "MSSQL",
|
||||||
|
3306: "MySQL",
|
||||||
|
5432: "PostgreSQL",
|
||||||
|
27017: "MongoDB",
|
||||||
|
6379: "Redis",
|
||||||
|
11211: "Memcached",
|
||||||
|
9200: "Elasticsearch",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PortScanner:
|
||||||
|
def __init__(self, timeout: float = 2.0):
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
async def scan(self, url: str) -> list[Issue]:
|
||||||
|
issues = []
|
||||||
|
domain = self._extract_domain(url)
|
||||||
|
if not domain:
|
||||||
|
return issues
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
self._check_port(domain, port, service)
|
||||||
|
for port, service in HIGH_RISK_PORTS.items()
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Issue):
|
||||||
|
issues.append(result)
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
logger.debug(f"Port scanning exception: {result}")
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
def _extract_domain(self, url: str) -> str | None:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
domain = parsed.netloc.split(":")[0]
|
||||||
|
if domain.startswith("www."):
|
||||||
|
domain = domain[4:]
|
||||||
|
return domain
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _check_port(self, domain: str, port: int, service: str) -> Issue | None:
|
||||||
|
try:
|
||||||
|
# Short timeout ensuring minimal scanning latency overhead
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(domain, port), timeout=self.timeout
|
||||||
|
)
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
|
||||||
|
return Issue(
|
||||||
|
issue=f"Exposed Database/Service Port: {port} ({service})",
|
||||||
|
severity="Critical",
|
||||||
|
layer="Network",
|
||||||
|
fix=f"Close port {port} to the public internet. Use a VPN, VPC peering, or strict IP whitelisting to access {service}.",
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, ConnectionRefusedError, OSError):
|
||||||
|
# Normal: port is either closed, filtered, or timing out.
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Unexpected error validating port {port} on {domain}: {e}")
|
||||||
|
return None
|
||||||
136
app/services/scanner/ssl_checker.py
Normal file
136
app/services/scanner/ssl_checker.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scanner.base import BaseScanner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
WEAK_TLS_VERSIONS = {"TLSv1", "TLSv1.1"}
|
||||||
|
|
||||||
|
|
||||||
|
def _check_ssl(hostname: str, port: int) -> dict:
|
||||||
|
result: dict = {
|
||||||
|
"error": None,
|
||||||
|
"cert": None,
|
||||||
|
"tls_version": None,
|
||||||
|
"self_signed": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
context = ssl.create_default_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with socket.create_connection((hostname, port), timeout=5) as sock:
|
||||||
|
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||||
|
result["cert"] = ssock.getpeercert()
|
||||||
|
result["tls_version"] = ssock.version()
|
||||||
|
except ssl.SSLCertVerificationError as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
result["error"] = error_msg
|
||||||
|
|
||||||
|
if "self-signed" in error_msg.lower() or "self signed" in error_msg.lower():
|
||||||
|
result["self_signed"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
ctx_no_verify = ssl.create_default_context()
|
||||||
|
ctx_no_verify.check_hostname = False
|
||||||
|
ctx_no_verify.verify_mode = ssl.CERT_NONE
|
||||||
|
with socket.create_connection((hostname, port), timeout=5) as sock:
|
||||||
|
with ctx_no_verify.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||||
|
result["tls_version"] = ssock.version()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (socket.timeout, socket.gaierror, OSError) as e:
|
||||||
|
result["error"] = str(e)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SSLScanner(BaseScanner):
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.scheme != "https":
|
||||||
|
return issues
|
||||||
|
|
||||||
|
hostname = parsed.hostname
|
||||||
|
port = parsed.port or 443
|
||||||
|
|
||||||
|
if not hostname:
|
||||||
|
return issues
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.to_thread(_check_ssl, hostname, port)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"SSL check failed for {hostname}: {e}")
|
||||||
|
return issues
|
||||||
|
|
||||||
|
if result["self_signed"]:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="SSL certificate is self-signed",
|
||||||
|
severity="Critical",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Obtain a valid SSL certificate from a trusted Certificate Authority (e.g., Let's Encrypt)",
|
||||||
|
))
|
||||||
|
|
||||||
|
if result["error"] and not result["self_signed"]:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"SSL certificate verification failed: {result['error'][:120]}",
|
||||||
|
severity="Critical",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Ensure the SSL certificate is valid, not expired, and issued by a trusted CA",
|
||||||
|
))
|
||||||
|
|
||||||
|
cert = result.get("cert")
|
||||||
|
if cert:
|
||||||
|
not_after = cert.get("notAfter")
|
||||||
|
if not_after:
|
||||||
|
try:
|
||||||
|
expiry = datetime.datetime.strptime(not_after, "%b %d %H:%M:%S %Y %Z")
|
||||||
|
now = datetime.datetime.utcnow()
|
||||||
|
|
||||||
|
if expiry < now:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="SSL certificate has expired",
|
||||||
|
severity="Critical",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Renew the SSL certificate immediately",
|
||||||
|
))
|
||||||
|
elif (expiry - now).days < 30:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"SSL certificate expires in {(expiry - now).days} days",
|
||||||
|
severity="Warning",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Renew the SSL certificate before it expires",
|
||||||
|
))
|
||||||
|
except ValueError:
|
||||||
|
logger.debug(f"Could not parse cert expiry: {not_after}")
|
||||||
|
|
||||||
|
subject = cert.get("subject", ())
|
||||||
|
issuer = cert.get("issuer", ())
|
||||||
|
if subject and issuer and subject == issuer:
|
||||||
|
if not result["self_signed"]:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="SSL certificate is self-signed",
|
||||||
|
severity="Critical",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Obtain a valid SSL certificate from a trusted Certificate Authority",
|
||||||
|
))
|
||||||
|
|
||||||
|
tls_version = result.get("tls_version")
|
||||||
|
if tls_version and tls_version in WEAK_TLS_VERSIONS:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"Server supports weak TLS version: {tls_version}",
|
||||||
|
severity="Critical",
|
||||||
|
layer="SSL/TLS Layer",
|
||||||
|
fix="Disable TLS 1.0 and TLS 1.1; enforce TLS 1.2 or higher",
|
||||||
|
))
|
||||||
|
|
||||||
|
return issues
|
||||||
76
app/services/scanner/transport.py
Normal file
76
app/services/scanner/transport.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scanner.base import BaseScanner
|
||||||
|
|
||||||
|
MIN_HSTS_MAX_AGE = 15768000 # 6 months in seconds
|
||||||
|
|
||||||
|
|
||||||
|
class TransportScanner(BaseScanner):
|
||||||
|
async def scan(self, url: str, response: httpx.Response) -> list[Issue]:
|
||||||
|
issues: list[Issue] = []
|
||||||
|
headers = response.headers
|
||||||
|
|
||||||
|
if not url.startswith("https"):
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Website is not using HTTPS",
|
||||||
|
severity="Critical",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Install SSL certificate and redirect HTTP to HTTPS",
|
||||||
|
))
|
||||||
|
return issues
|
||||||
|
|
||||||
|
hsts = headers.get("Strict-Transport-Security", "")
|
||||||
|
if not hsts:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="Missing HSTS header",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Add header: Strict-Transport-Security: max-age=31536000; includeSubDomains; preload",
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
hsts_lower = hsts.lower()
|
||||||
|
|
||||||
|
max_age = 0
|
||||||
|
for directive in hsts_lower.split(";"):
|
||||||
|
directive = directive.strip()
|
||||||
|
if directive.startswith("max-age="):
|
||||||
|
try:
|
||||||
|
max_age = int(directive.split("=", 1)[1])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if max_age < MIN_HSTS_MAX_AGE:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue=f"HSTS max-age is too short ({max_age}s, minimum recommended: {MIN_HSTS_MAX_AGE}s)",
|
||||||
|
severity="Warning",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Set HSTS max-age to at least 15768000 (6 months), ideally 31536000 (1 year)",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "includesubdomains" not in hsts_lower:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="HSTS header missing includeSubDomains directive",
|
||||||
|
severity="Info",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Add includeSubDomains to HSTS header to protect all subdomains",
|
||||||
|
))
|
||||||
|
|
||||||
|
if "preload" not in hsts_lower:
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="HSTS header missing preload directive",
|
||||||
|
severity="Info",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Add preload to HSTS header and submit to hstspreload.org for browser preload list",
|
||||||
|
))
|
||||||
|
|
||||||
|
csp = headers.get("Content-Security-Policy", "")
|
||||||
|
if url.startswith("https") and "upgrade-insecure-requests" not in csp.lower():
|
||||||
|
issues.append(Issue(
|
||||||
|
issue="CSP does not include upgrade-insecure-requests directive",
|
||||||
|
severity="Info",
|
||||||
|
layer="Transport Layer",
|
||||||
|
fix="Add 'upgrade-insecure-requests' to Content-Security-Policy to auto-upgrade HTTP resources",
|
||||||
|
))
|
||||||
|
|
||||||
|
return issues
|
||||||
42
app/services/scoring.py
Normal file
42
app/services/scoring.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from app.schemas.scan import Issue, LayerStatus
|
||||||
|
|
||||||
|
SEVERITY_WEIGHTS: dict[str, int] = {
|
||||||
|
"Critical": 15,
|
||||||
|
"Warning": 5,
|
||||||
|
"Info": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
LAYER_NAMES = [
|
||||||
|
"Transport Layer",
|
||||||
|
"SSL/TLS Layer",
|
||||||
|
"Server Config Layer",
|
||||||
|
"Cookie Security",
|
||||||
|
"Exposure Layer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_score(issues: list[Issue]) -> int:
|
||||||
|
score = 100
|
||||||
|
for issue in issues:
|
||||||
|
score -= SEVERITY_WEIGHTS.get(issue.severity, 0)
|
||||||
|
return max(score, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_layer_statuses(issues: list[Issue]) -> dict[str, LayerStatus]:
|
||||||
|
layers: dict[str, LayerStatus] = {
|
||||||
|
name: LayerStatus(issues=0, status="green") for name in LAYER_NAMES
|
||||||
|
}
|
||||||
|
|
||||||
|
for issue in issues:
|
||||||
|
if issue.layer in layers:
|
||||||
|
layers[issue.layer].issues += 1
|
||||||
|
|
||||||
|
for layer in layers.values():
|
||||||
|
if layer.issues == 0:
|
||||||
|
layer.status = "green"
|
||||||
|
elif layer.issues < 3:
|
||||||
|
layer.status = "yellow"
|
||||||
|
else:
|
||||||
|
layer.status = "red"
|
||||||
|
|
||||||
|
return layers
|
||||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
30
app/utils/auth.py
Normal file
30
app/utils/auth.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(user_id: str) -> str:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expiry_minutes)
|
||||||
|
payload = {"sub": user_id, "exp": expire}
|
||||||
|
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_access_token(token: str) -> str | None:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
|
||||||
|
return payload.get("sub")
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
46
app/utils/validators.py
Normal file
46
app/utils/validators.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
PRIVATE_NETWORKS = [
|
||||||
|
ipaddress.ip_network("10.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("172.16.0.0/12"),
|
||||||
|
ipaddress.ip_network("192.168.0.0/16"),
|
||||||
|
ipaddress.ip_network("127.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("169.254.0.0/16"),
|
||||||
|
ipaddress.ip_network("0.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("::1/128"),
|
||||||
|
ipaddress.ip_network("fc00::/7"),
|
||||||
|
ipaddress.ip_network("fe80::/10"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> str:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
raise HTTPException(status_code=400, detail="URL must use http or https scheme")
|
||||||
|
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid URL: no hostname found")
|
||||||
|
|
||||||
|
blocked_hostnames = {"localhost", "0.0.0.0"}
|
||||||
|
if hostname in blocked_hostnames:
|
||||||
|
raise HTTPException(status_code=400, detail="Scanning internal addresses is not allowed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolved_ip = socket.gethostbyname(hostname)
|
||||||
|
ip = ipaddress.ip_address(resolved_ip)
|
||||||
|
for network in PRIVATE_NETWORKS:
|
||||||
|
if ip in network:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Scanning internal/private IP addresses is not allowed",
|
||||||
|
)
|
||||||
|
except socket.gaierror:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Could not resolve hostname: {hostname}")
|
||||||
|
|
||||||
|
return url
|
||||||
46
ci/securelens-scan.yml
Normal file
46
ci/securelens-scan.yml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
name: SecureLens CI/CD Scan
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
security-scan:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
# Example: Wait for deployment/staging URL to be available
|
||||||
|
- name: Run SecureLens Scan
|
||||||
|
env:
|
||||||
|
SECURELENS_API_URL: "https://your-securelens-instance.com"
|
||||||
|
SECURELENS_API_KEY: ${{ secrets.SECURELENS_API_KEY }}
|
||||||
|
TARGET_URL: "https://staging.your-app.com"
|
||||||
|
MINIMUM_SCORE: 80
|
||||||
|
run: |
|
||||||
|
echo "Initiating SecureLens Scan against $TARGET_URL"
|
||||||
|
|
||||||
|
# Trigger Scan
|
||||||
|
RESPONSE=$(curl -s -X POST "$SECURELENS_API_URL/scans/scan" \
|
||||||
|
-H "X-API-Key: $SECURELENS_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "{\"url\": \"$TARGET_URL\"}")
|
||||||
|
|
||||||
|
# Extract score using jq
|
||||||
|
SCORE=$(echo $RESPONSE | jq -r '.security_score')
|
||||||
|
SCAN_ID=$(echo $RESPONSE | jq -r '.id')
|
||||||
|
|
||||||
|
echo "Scan completed (ID: $SCAN_ID)"
|
||||||
|
echo "Security Score: $SCORE"
|
||||||
|
|
||||||
|
# Check Threshold
|
||||||
|
if (( $(echo "$SCORE < $MINIMUM_SCORE" | bc -l) )); then
|
||||||
|
echo "::error::Security score ($SCORE) is below the minimum threshold ($MINIMUM_SCORE)"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Security check passed!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
34
docker-compose.yml
Normal file
34
docker-compose.yml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: securelens
|
||||||
|
POSTGRES_PASSWORD: securelens
|
||||||
|
POSTGRES_DB: securelens
|
||||||
|
ports:
|
||||||
|
- "5433:5432"
|
||||||
|
volumes:
|
||||||
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U securelens"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
backend:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://securelens:securelens@db:5432/securelens
|
||||||
|
volumes:
|
||||||
|
- .:/app
|
||||||
|
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pgdata:
|
||||||
131
main.py
131
main.py
@@ -1,130 +1,3 @@
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from app.main import app
|
||||||
from fastapi import FastAPI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import requests
|
|
||||||
|
|
||||||
app = FastAPI()
|
__all__ = ["app"]
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
class ScanRequest(BaseModel):
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
def read_root():
|
|
||||||
return {"message": "SecureLens AI backend running 🚀"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/scan")
|
|
||||||
def scan_website(data: ScanRequest):
|
|
||||||
url = data.url
|
|
||||||
issues = []
|
|
||||||
score = 100
|
|
||||||
|
|
||||||
layers = {
|
|
||||||
"Transport Layer": {"issues": 0, "status": "green"},
|
|
||||||
"Server Config Layer": {"issues": 0, "status": "green"},
|
|
||||||
"Exposure Layer": {"issues": 0, "status": "green"}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(url, timeout=5)
|
|
||||||
headers = response.headers
|
|
||||||
|
|
||||||
# Transport Layer
|
|
||||||
if not url.startswith("https"):
|
|
||||||
issues.append({
|
|
||||||
"issue": "Website is not using HTTPS",
|
|
||||||
"severity": "Critical",
|
|
||||||
"layer": "Transport Layer",
|
|
||||||
"fix": "Install SSL certificate and redirect HTTP to HTTPS"
|
|
||||||
})
|
|
||||||
score -= 15
|
|
||||||
layers["Transport Layer"]["issues"] += 1
|
|
||||||
|
|
||||||
# Server Config
|
|
||||||
if "Content-Security-Policy" not in headers:
|
|
||||||
issues.append({
|
|
||||||
"issue": "Missing Content-Security-Policy header",
|
|
||||||
"severity": "Warning",
|
|
||||||
"layer": "Server Config Layer",
|
|
||||||
"fix": "Add header: Content-Security-Policy: default-src 'self';"
|
|
||||||
})
|
|
||||||
score -= 5
|
|
||||||
layers["Server Config Layer"]["issues"] += 1
|
|
||||||
|
|
||||||
if "X-Frame-Options" not in headers:
|
|
||||||
issues.append({
|
|
||||||
"issue": "Missing X-Frame-Options header",
|
|
||||||
"severity": "Warning",
|
|
||||||
"layer": "Server Config Layer",
|
|
||||||
"fix": "Add header: X-Frame-Options: SAMEORIGIN"
|
|
||||||
})
|
|
||||||
score -= 5
|
|
||||||
layers["Server Config Layer"]["issues"] += 1
|
|
||||||
|
|
||||||
if "Strict-Transport-Security" not in headers:
|
|
||||||
issues.append({
|
|
||||||
"issue": "Missing HSTS header",
|
|
||||||
"severity": "Warning",
|
|
||||||
"layer": "Server Config Layer",
|
|
||||||
"fix": "Add header: Strict-Transport-Security: max-age=31536000; includeSubDomains"
|
|
||||||
})
|
|
||||||
score -= 5
|
|
||||||
layers["Server Config Layer"]["issues"] += 1
|
|
||||||
|
|
||||||
if headers.get("Access-Control-Allow-Origin") == "*":
|
|
||||||
issues.append({
|
|
||||||
"issue": "CORS allows all origins (*)",
|
|
||||||
"severity": "Warning",
|
|
||||||
"layer": "Server Config Layer",
|
|
||||||
"fix": "Restrict Access-Control-Allow-Origin to trusted domains"
|
|
||||||
})
|
|
||||||
score -= 5
|
|
||||||
layers["Server Config Layer"]["issues"] += 1
|
|
||||||
|
|
||||||
# Exposure
|
|
||||||
sensitive_paths = ["/admin", "/.env", "/backup", "/debug"]
|
|
||||||
|
|
||||||
for path in sensitive_paths:
|
|
||||||
try:
|
|
||||||
test_url = url.rstrip("/") + path
|
|
||||||
r = requests.get(test_url, timeout=3)
|
|
||||||
if r.status_code == 200:
|
|
||||||
issues.append({
|
|
||||||
"issue": f"Sensitive path exposed: {path}",
|
|
||||||
"severity": "Critical",
|
|
||||||
"layer": "Exposure Layer",
|
|
||||||
"fix": f"Restrict access to {path} using authentication or firewall rules"
|
|
||||||
})
|
|
||||||
score -= 15
|
|
||||||
layers["Exposure Layer"]["issues"] += 1
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
# Set layer status
|
|
||||||
for layer in layers:
|
|
||||||
count = layers[layer]["issues"]
|
|
||||||
if count == 0:
|
|
||||||
layers[layer]["status"] = "green"
|
|
||||||
elif count < 3:
|
|
||||||
layers[layer]["status"] = "yellow"
|
|
||||||
else:
|
|
||||||
layers[layer]["status"] = "red"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"url": url,
|
|
||||||
"security_score": max(score, 0),
|
|
||||||
"layers": layers,
|
|
||||||
"issues": issues
|
|
||||||
}
|
|
||||||
|
|||||||
1
migrations/README
Normal file
1
migrations/README
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration with an async dbapi.
|
||||||
96
migrations/env.py
Normal file
96
migrations/env.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import asyncio
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import Base
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
from app.models.apikey import ApiKey
|
||||||
|
from app.models.webhook import Webhook
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode."""
|
||||||
|
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
migrations/script.py.mako
Normal file
28
migrations/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
57
migrations/versions/a2ca840d767c_initial_migration.py
Normal file
57
migrations/versions/a2ca840d767c_initial_migration.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Initial migration
|
||||||
|
|
||||||
|
Revision ID: a2ca840d767c
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-24 18:29:43.533353
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a2ca840d767c'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('users',
|
||||||
|
sa.Column('id', sa.String(length=36), nullable=False),
|
||||||
|
sa.Column('email', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('username', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||||
|
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||||
|
op.create_table('scan_results',
|
||||||
|
sa.Column('id', sa.String(length=36), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||||
|
sa.Column('url', sa.String(length=2048), nullable=False),
|
||||||
|
sa.Column('security_score', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('layers', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('issues', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_scan_results_user_id'), 'scan_results', ['user_id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_scan_results_user_id'), table_name='scan_results')
|
||||||
|
op.drop_table('scan_results')
|
||||||
|
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||||
|
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||||
|
op.drop_table('users')
|
||||||
|
# ### end Alembic commands ###
|
||||||
45
migrations/versions/a8253e561192_add_api_key_model.py
Normal file
45
migrations/versions/a8253e561192_add_api_key_model.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Add API Key model
|
||||||
|
|
||||||
|
Revision ID: a8253e561192
|
||||||
|
Revises: a2ca840d767c
|
||||||
|
Create Date: 2026-03-24 18:31:38.229135
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a8253e561192'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'a2ca840d767c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('api_keys',
|
||||||
|
sa.Column('id', sa.String(length=36), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('key_prefix', sa.String(length=10), nullable=False),
|
||||||
|
sa.Column('hashed_key', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_api_keys_hashed_key'), 'api_keys', ['hashed_key'], unique=True)
|
||||||
|
op.create_index(op.f('ix_api_keys_user_id'), 'api_keys', ['user_id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_api_keys_user_id'), table_name='api_keys')
|
||||||
|
op.drop_index(op.f('ix_api_keys_hashed_key'), table_name='api_keys')
|
||||||
|
op.drop_table('api_keys')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1,4 +1,19 @@
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
requests
|
httpx
|
||||||
pydantic
|
pydantic
|
||||||
|
pydantic-settings
|
||||||
|
python-dotenv
|
||||||
|
slowapi
|
||||||
|
sqlalchemy[asyncio]
|
||||||
|
aiosqlite
|
||||||
|
asyncpg
|
||||||
|
python-jose[cryptography]
|
||||||
|
passlib[bcrypt]
|
||||||
|
pydantic[email]
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
|
alembic
|
||||||
|
openai
|
||||||
|
aiodns
|
||||||
|
fpdf2
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
70
tests/conftest.py
Normal file
70
tests/conftest.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.database import Base, get_db
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.utils.auth import create_access_token, hash_password
|
||||||
|
|
||||||
|
TEST_DB_URL = "sqlite+aiosqlite://"
|
||||||
|
|
||||||
|
test_engine = create_async_engine(TEST_DB_URL, echo=False)
|
||||||
|
TestSessionLocal = async_sessionmaker(
|
||||||
|
bind=test_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with TestSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def setup_db():
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_client():
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user():
|
||||||
|
async with TestSessionLocal() as session:
|
||||||
|
user = User(
|
||||||
|
email="test@example.com",
|
||||||
|
username="testuser",
|
||||||
|
hashed_password=hash_password("testpassword123"),
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def auth_headers(test_user):
|
||||||
|
token = create_access_token(test_user.id)
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
106
tests/test_auth.py
Normal file
106
tests/test_auth.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register(async_client):
|
||||||
|
response = await async_client.post("/auth/register", json={
|
||||||
|
"email": "new@example.com",
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "securepass123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_email(async_client, test_user):
|
||||||
|
response = await async_client.post("/auth/register", json={
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "different",
|
||||||
|
"password": "securepass123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_username(async_client, test_user):
|
||||||
|
response = await async_client.post("/auth/register", json={
|
||||||
|
"email": "different@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "securepass123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_short_password(async_client):
|
||||||
|
response = await async_client.post("/auth/register", json={
|
||||||
|
"email": "new@example.com",
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "short",
|
||||||
|
})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_invalid_email(async_client):
|
||||||
|
response = await async_client.post("/auth/register", json={
|
||||||
|
"email": "not-an-email",
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "securepass123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login(async_client, test_user):
|
||||||
|
response = await async_client.post("/auth/login", json={
|
||||||
|
"email": "test@example.com",
|
||||||
|
"password": "testpassword123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_wrong_password(async_client, test_user):
|
||||||
|
response = await async_client.post("/auth/login", json={
|
||||||
|
"email": "test@example.com",
|
||||||
|
"password": "wrongpassword",
|
||||||
|
})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_nonexistent_email(async_client):
|
||||||
|
response = await async_client.post("/auth/login", json={
|
||||||
|
"email": "nobody@example.com",
|
||||||
|
"password": "testpassword123",
|
||||||
|
})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me(async_client, test_user, auth_headers):
|
||||||
|
response = await async_client.get("/auth/me", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == "test@example.com"
|
||||||
|
assert data["username"] == "testuser"
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_unauthorized(async_client):
|
||||||
|
response = await async_client.get("/auth/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_invalid_token(async_client):
|
||||||
|
response = await async_client.get("/auth/me", headers={"Authorization": "Bearer invalid"})
|
||||||
|
assert response.status_code == 401
|
||||||
65
tests/test_cookies.py
Normal file
65
tests/test_cookies.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.scanner.cookies import CookieScanner
|
||||||
|
|
||||||
|
scanner = CookieScanner()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(set_cookie_headers: list[str]) -> MagicMock:
|
||||||
|
items = [("content-type", "text/html")]
|
||||||
|
for cookie in set_cookie_headers:
|
||||||
|
items.append(("set-cookie", cookie))
|
||||||
|
response = MagicMock()
|
||||||
|
response.headers.multi_items.return_value = items
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_cookies_returns_empty():
|
||||||
|
response = _make_response([])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert issues == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_httponly():
|
||||||
|
response = _make_response(["session=abc123; Path=/; Secure; SameSite=Lax"])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("HttpOnly" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_secure():
|
||||||
|
response = _make_response(["session=abc123; Path=/; HttpOnly; SameSite=Lax"])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("Secure" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_samesite():
|
||||||
|
response = _make_response(["session=abc123; Path=/; HttpOnly; Secure"])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("SameSite" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_samesite_none_without_secure():
|
||||||
|
response = _make_response(["session=abc123; Path=/; HttpOnly; SameSite=None"])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("SameSite=None" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_secure_cookie_passes():
|
||||||
|
response = _make_response(["session=abc123; Path=/; HttpOnly; Secure; SameSite=Lax"])
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert len(issues) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_secure_check_for_http():
|
||||||
|
response = _make_response(["session=abc123; Path=/; HttpOnly; SameSite=Lax"])
|
||||||
|
issues = await scanner.scan("http://example.com", response)
|
||||||
|
assert not any("Secure flag" in i.issue for i in issues)
|
||||||
94
tests/test_headers.py
Normal file
94
tests/test_headers.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.scanner.headers import HeaderScanner
|
||||||
|
|
||||||
|
scanner = HeaderScanner()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(headers: dict) -> MagicMock:
|
||||||
|
response = MagicMock()
|
||||||
|
response.headers = headers
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_all_missing_headers():
|
||||||
|
response = _make_response({})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
issue_texts = [i.issue for i in issues]
|
||||||
|
assert any("Content-Security-Policy" in t for t in issue_texts)
|
||||||
|
assert any("X-Frame-Options" in t for t in issue_texts)
|
||||||
|
assert any("X-Content-Type-Options" in t for t in issue_texts)
|
||||||
|
assert any("Referrer-Policy" in t for t in issue_texts)
|
||||||
|
assert any("Permissions-Policy" in t for t in issue_texts)
|
||||||
|
assert any("Cache-Control" in t for t in issue_texts)
|
||||||
|
assert any("COOP" in t for t in issue_texts)
|
||||||
|
assert any("CORP" in t for t in issue_texts)
|
||||||
|
assert any("COEP" in t for t in issue_texts)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_unsafe_inline_csp():
|
||||||
|
response = _make_response({
|
||||||
|
"Content-Security-Policy": "default-src 'self' 'unsafe-inline'",
|
||||||
|
"X-Frame-Options": "DENY",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Referrer-Policy": "strict-origin",
|
||||||
|
"Permissions-Policy": "camera=()",
|
||||||
|
"Cache-Control": "no-store",
|
||||||
|
"Cross-Origin-Opener-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Resource-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Embedder-Policy": "require-corp",
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("unsafe-inline" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_unsafe_eval_csp():
|
||||||
|
response = _make_response({
|
||||||
|
"Content-Security-Policy": "default-src 'self' 'unsafe-eval'",
|
||||||
|
"X-Frame-Options": "DENY",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Referrer-Policy": "strict-origin",
|
||||||
|
"Permissions-Policy": "camera=()",
|
||||||
|
"Cache-Control": "no-store",
|
||||||
|
"Cross-Origin-Opener-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Resource-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Embedder-Policy": "require-corp",
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("unsafe-eval" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_server_disclosure():
|
||||||
|
response = _make_response({"Server": "Apache/2.4.41"})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("Server header" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_x_powered_by():
|
||||||
|
response = _make_response({"X-Powered-By": "Express"})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("X-Powered-By" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_issues_with_all_headers():
|
||||||
|
response = _make_response({
|
||||||
|
"Content-Security-Policy": "default-src 'self'",
|
||||||
|
"X-Frame-Options": "SAMEORIGIN",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||||
|
"Permissions-Policy": "geolocation=()",
|
||||||
|
"Cache-Control": "no-store",
|
||||||
|
"Cross-Origin-Opener-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Resource-Policy": "same-origin",
|
||||||
|
"Cross-Origin-Embedder-Policy": "require-corp",
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert len(issues) == 0
|
||||||
19
tests/test_health.py
Normal file
19
tests/test_health.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_root(async_client):
|
||||||
|
response = await async_client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "running" in data["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health(async_client):
|
||||||
|
response = await async_client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "app" in data
|
||||||
|
assert "version" in data
|
||||||
90
tests/test_history.py
Normal file
90
tests/test_history.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.scan import ScanResult
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_scan(user_id: str) -> None:
|
||||||
|
from tests.conftest import TestSessionLocal
|
||||||
|
async with TestSessionLocal() as session:
|
||||||
|
scan = ScanResult(
|
||||||
|
user_id=user_id,
|
||||||
|
url="https://example.com",
|
||||||
|
security_score=85,
|
||||||
|
layers={"Transport Layer": {"issues": 1, "status": "yellow"}},
|
||||||
|
issues=[{"issue": "Missing HSTS", "severity": "Warning", "layer": "Transport Layer", "fix": "Add HSTS"}],
|
||||||
|
)
|
||||||
|
session.add(scan)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(scan)
|
||||||
|
return scan
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_scans_empty(async_client, test_user, auth_headers):
|
||||||
|
response = await async_client.get("/scans", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["scans"] == []
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_scans_with_results(async_client, test_user, auth_headers):
|
||||||
|
scan = await _create_scan(test_user.id)
|
||||||
|
|
||||||
|
response = await async_client.get("/scans", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 1
|
||||||
|
assert len(data["scans"]) == 1
|
||||||
|
assert data["scans"][0]["url"] == "https://example.com"
|
||||||
|
assert data["scans"][0]["security_score"] == 85
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_scans_pagination(async_client, test_user, auth_headers):
|
||||||
|
for _ in range(5):
|
||||||
|
await _create_scan(test_user.id)
|
||||||
|
|
||||||
|
response = await async_client.get("/scans?page=1&per_page=2", headers=auth_headers)
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 5
|
||||||
|
assert len(data["scans"]) == 2
|
||||||
|
assert data["page"] == 1
|
||||||
|
assert data["per_page"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_scan_by_id(async_client, test_user, auth_headers):
|
||||||
|
scan = await _create_scan(test_user.id)
|
||||||
|
|
||||||
|
response = await async_client.get(f"/scans/{scan.id}", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["url"] == "https://example.com"
|
||||||
|
assert data["security_score"] == 85
|
||||||
|
assert len(data["issues"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_scan_not_found(async_client, test_user, auth_headers):
|
||||||
|
response = await async_client.get("/scans/nonexistent", headers=auth_headers)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_scan(async_client, test_user, auth_headers):
|
||||||
|
scan = await _create_scan(test_user.id)
|
||||||
|
|
||||||
|
response = await async_client.delete(f"/scans/{scan.id}", headers=auth_headers)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
response = await async_client.get(f"/scans/{scan.id}", headers=auth_headers)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_scans_unauthorized(async_client):
|
||||||
|
response = await async_client.get("/scans")
|
||||||
|
assert response.status_code == 401
|
||||||
54
tests/test_scan.py
Normal file
54
tests/test_scan.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_rejects_invalid_url(async_client):
|
||||||
|
response = await async_client.post("/scan", json={"url": "not-a-url"})
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_rejects_localhost(async_client):
|
||||||
|
response = await async_client.post("/scan", json={"url": "http://localhost:8000"})
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_rejects_private_ip(async_client):
|
||||||
|
response = await async_client.post("/scan", json={"url": "http://192.168.1.1"})
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_valid_url(async_client):
|
||||||
|
response = await async_client.post("/scan", json={"url": "https://example.com"})
|
||||||
|
assert response.status_code in (200, 502)
|
||||||
|
data = response.json()
|
||||||
|
assert "security_score" in data or "error" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_missing_url(async_client):
|
||||||
|
response = await async_client.post("/scan", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_saves_when_authenticated(async_client, test_user, auth_headers):
|
||||||
|
response = await async_client.post(
|
||||||
|
"/scan",
|
||||||
|
json={"url": "https://example.com"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] is not None
|
||||||
|
assert data["created_at"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_no_save_when_anonymous(async_client):
|
||||||
|
response = await async_client.post("/scan", json={"url": "https://example.com"})
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] is None
|
||||||
63
tests/test_scoring.py
Normal file
63
tests/test_scoring.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from app.schemas.scan import Issue
|
||||||
|
from app.services.scoring import calculate_layer_statuses, calculate_score
|
||||||
|
|
||||||
|
|
||||||
|
def test_perfect_score_no_issues():
|
||||||
|
assert calculate_score([]) == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_critical_deduction():
|
||||||
|
issues = [Issue(issue="Test", severity="Critical", layer="Transport Layer", fix="Fix")]
|
||||||
|
assert calculate_score(issues) == 85
|
||||||
|
|
||||||
|
|
||||||
|
def test_warning_deduction():
|
||||||
|
issues = [Issue(issue="Test", severity="Warning", layer="Transport Layer", fix="Fix")]
|
||||||
|
assert calculate_score(issues) == 95
|
||||||
|
|
||||||
|
|
||||||
|
def test_info_deduction():
|
||||||
|
issues = [Issue(issue="Test", severity="Info", layer="Transport Layer", fix="Fix")]
|
||||||
|
assert calculate_score(issues) == 98
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_cannot_go_below_zero():
|
||||||
|
issues = [Issue(issue=f"Test {i}", severity="Critical", layer="Transport Layer", fix="Fix") for i in range(10)]
|
||||||
|
assert calculate_score(issues) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_layers_present():
|
||||||
|
statuses = calculate_layer_statuses([])
|
||||||
|
assert "Transport Layer" in statuses
|
||||||
|
assert "SSL/TLS Layer" in statuses
|
||||||
|
assert "Server Config Layer" in statuses
|
||||||
|
assert "Cookie Security" in statuses
|
||||||
|
assert "Exposure Layer" in statuses
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_status_green_when_no_issues():
|
||||||
|
statuses = calculate_layer_statuses([])
|
||||||
|
for layer in statuses.values():
|
||||||
|
assert layer.status == "green"
|
||||||
|
assert layer.issues == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_status_yellow_for_few_issues():
|
||||||
|
issues = [
|
||||||
|
Issue(issue="Test 1", severity="Warning", layer="SSL/TLS Layer", fix="Fix"),
|
||||||
|
Issue(issue="Test 2", severity="Warning", layer="SSL/TLS Layer", fix="Fix"),
|
||||||
|
]
|
||||||
|
statuses = calculate_layer_statuses(issues)
|
||||||
|
assert statuses["SSL/TLS Layer"].status == "yellow"
|
||||||
|
assert statuses["SSL/TLS Layer"].issues == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_status_red_for_many_issues():
|
||||||
|
issues = [
|
||||||
|
Issue(issue="Test 1", severity="Warning", layer="Cookie Security", fix="Fix"),
|
||||||
|
Issue(issue="Test 2", severity="Warning", layer="Cookie Security", fix="Fix"),
|
||||||
|
Issue(issue="Test 3", severity="Critical", layer="Cookie Security", fix="Fix"),
|
||||||
|
]
|
||||||
|
statuses = calculate_layer_statuses(issues)
|
||||||
|
assert statuses["Cookie Security"].status == "red"
|
||||||
|
assert statuses["Cookie Security"].issues == 3
|
||||||
86
tests/test_ssl_checker.py
Normal file
86
tests/test_ssl_checker.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.scanner.ssl_checker import SSLScanner, _check_ssl
|
||||||
|
|
||||||
|
scanner = SSLScanner()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_http_urls():
|
||||||
|
response = MagicMock()
|
||||||
|
issues = await scanner.scan("http://example.com", response)
|
||||||
|
assert issues == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_self_signed():
|
||||||
|
response = MagicMock()
|
||||||
|
mock_result = {
|
||||||
|
"error": "self-signed certificate",
|
||||||
|
"cert": None,
|
||||||
|
"tls_version": "TLSv1.3",
|
||||||
|
"self_signed": True,
|
||||||
|
}
|
||||||
|
with patch("app.services.scanner.ssl_checker.asyncio.to_thread", return_value=mock_result):
|
||||||
|
issues = await scanner.scan("https://self-signed.example.com", response)
|
||||||
|
assert any("self-signed" in i.issue.lower() for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_weak_tls():
|
||||||
|
response = MagicMock()
|
||||||
|
future_date = (datetime.datetime.utcnow() + datetime.timedelta(days=365)).strftime("%b %d %H:%M:%S %Y GMT")
|
||||||
|
mock_result = {
|
||||||
|
"error": None,
|
||||||
|
"cert": {
|
||||||
|
"notAfter": future_date,
|
||||||
|
"subject": ((('commonName', 'example.com'),),),
|
||||||
|
"issuer": ((('commonName', 'CA'),),),
|
||||||
|
},
|
||||||
|
"tls_version": "TLSv1.1",
|
||||||
|
"self_signed": False,
|
||||||
|
}
|
||||||
|
with patch("app.services.scanner.ssl_checker.asyncio.to_thread", return_value=mock_result):
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("weak TLS" in i.issue.lower() or "tls" in i.issue.lower() for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_expiring_soon():
|
||||||
|
response = MagicMock()
|
||||||
|
soon_date = (datetime.datetime.utcnow() + datetime.timedelta(days=15)).strftime("%b %d %H:%M:%S %Y GMT")
|
||||||
|
mock_result = {
|
||||||
|
"error": None,
|
||||||
|
"cert": {
|
||||||
|
"notAfter": soon_date,
|
||||||
|
"subject": ((('commonName', 'example.com'),),),
|
||||||
|
"issuer": ((('commonName', 'CA'),),),
|
||||||
|
},
|
||||||
|
"tls_version": "TLSv1.3",
|
||||||
|
"self_signed": False,
|
||||||
|
}
|
||||||
|
with patch("app.services.scanner.ssl_checker.asyncio.to_thread", return_value=mock_result):
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("expires in" in i.issue.lower() for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_issues_for_valid_cert():
|
||||||
|
response = MagicMock()
|
||||||
|
future_date = (datetime.datetime.utcnow() + datetime.timedelta(days=365)).strftime("%b %d %H:%M:%S %Y GMT")
|
||||||
|
mock_result = {
|
||||||
|
"error": None,
|
||||||
|
"cert": {
|
||||||
|
"notAfter": future_date,
|
||||||
|
"subject": ((('commonName', 'example.com'),),),
|
||||||
|
"issuer": ((('commonName', 'Let\'s Encrypt'),),),
|
||||||
|
},
|
||||||
|
"tls_version": "TLSv1.3",
|
||||||
|
"self_signed": False,
|
||||||
|
}
|
||||||
|
with patch("app.services.scanner.ssl_checker.asyncio.to_thread", return_value=mock_result):
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert len(issues) == 0
|
||||||
75
tests/test_transport.py
Normal file
75
tests/test_transport.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.scanner.transport import TransportScanner
|
||||||
|
|
||||||
|
scanner = TransportScanner()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(headers: dict) -> MagicMock:
|
||||||
|
response = MagicMock()
|
||||||
|
response.headers = headers
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_no_https():
|
||||||
|
response = _make_response({})
|
||||||
|
issues = await scanner.scan("http://example.com", response)
|
||||||
|
assert any("HTTPS" in i.issue for i in issues)
|
||||||
|
assert len(issues) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_hsts():
|
||||||
|
response = _make_response({})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("HSTS" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_short_hsts_max_age():
|
||||||
|
response = _make_response({
|
||||||
|
"Strict-Transport-Security": "max-age=3600; includeSubDomains; preload"
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("max-age" in i.issue.lower() for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_includesubdomains():
|
||||||
|
response = _make_response({
|
||||||
|
"Strict-Transport-Security": "max-age=31536000; preload"
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("includeSubDomains" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_preload():
|
||||||
|
response = _make_response({
|
||||||
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("preload" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detects_missing_upgrade_insecure_requests():
|
||||||
|
response = _make_response({
|
||||||
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||||
|
"Content-Security-Policy": "default-src 'self'",
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert any("upgrade-insecure-requests" in i.issue for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_good_hsts_no_transport_issues():
|
||||||
|
response = _make_response({
|
||||||
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||||
|
"Content-Security-Policy": "default-src 'self'; upgrade-insecure-requests",
|
||||||
|
})
|
||||||
|
issues = await scanner.scan("https://example.com", response)
|
||||||
|
assert len(issues) == 0
|
||||||
50
tests/test_validators.py
Normal file
50
tests/test_validators.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.utils.validators import validate_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_https_url():
|
||||||
|
result = validate_url("https://example.com")
|
||||||
|
assert result == "https://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_http_url():
|
||||||
|
result = validate_url("http://example.com")
|
||||||
|
assert result == "http://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_ftp_scheme():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("ftp://example.com")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_no_scheme():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("example.com")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_localhost():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("http://localhost")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_private_ip():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("http://192.168.1.1")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_loopback():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("http://127.0.0.1")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_unresolvable_host():
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
validate_url("http://this-domain-does-not-exist-xyz123.com")
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
Reference in New Issue
Block a user