diff --git a/app/config.py b/app/config.py index 14c0383..5658716 100644 --- a/app/config.py +++ b/app/config.py @@ -55,6 +55,9 @@ class Settings(BaseSettings): # Leave blank for Ollama (local, no key needed). ai_api_key: str | None = None + # AI_API_BASE: Custom API base URL (e.g. for Agent Router or custom OpenAI-compatible proxies) + ai_api_base: str | None = None + # ------------------------------------------------------------------------- # Legacy Gemini key — kept for backward compatibility. # If AI_API_KEY is not set but GEMINI_API_KEY is, we use that automatically. diff --git a/app/services/ai.py b/app/services/ai.py index 4609a7d..3167677 100644 --- a/app/services/ai.py +++ b/app/services/ai.py @@ -74,6 +74,9 @@ async def call_ai( "api_key": api_key, } + if settings.ai_api_base: + kwargs["api_base"] = settings.ai_api_base + # JSON mode: supported natively by OpenAI and LiteLLM proxied Gemini. # For providers that don't support it, LiteLLM silently ignores the flag. if json_mode: diff --git a/cli/securelens/ai/__init__.py b/cli/securelens/ai/__init__.py index e148511..498f6a1 100644 --- a/cli/securelens/ai/__init__.py +++ b/cli/securelens/ai/__init__.py @@ -21,6 +21,7 @@ async def call_ai( temperature: float = 0.3, json_mode: bool = False, conversation_history: Optional[list] = None, + api_base: Optional[str] = None, ) -> str: """ Single entry-point for all AI calls in the CLI. @@ -34,6 +35,7 @@ async def call_ai( json_mode : Ask the model to respond with valid JSON only conversation_history : Optional list of {"role": ..., "content": ...} dicts for multi-turn chat sessions + api_base : Optional custom API base URL (e.g. for Agent Router) """ import litellm @@ -49,6 +51,9 @@ async def call_ai( "api_key": api_key if api_key else None, } + if api_base: + kwargs["api_base"] = api_base + if json_mode: kwargs["response_format"] = {"type": "json_object"} diff --git a/cli/securelens/cli.py b/cli/securelens/cli.py index b4a98bf..90d02b3 100644 --- a/cli/securelens/cli.py +++ b/cli/securelens/cli.py @@ -76,7 +76,8 @@ def configure(): "4": ("gpt-4o", "OpenAI GPT-4o"), "5": ("claude-3-5-haiku-20241022","Anthropic Claude 3.5 Haiku"), "6": ("ollama/llama3.1", "Ollama (local, no key needed)"), - "7": ("custom", "Custom model string"), + "7": ("agentrouter", "Agent Router (OpenAI-compatible gateway)"), + "8": ("custom", "Custom LiteLLM model / Endpoint"), } console.print("[bold]Choose AI Provider:[/bold]") for k, (_, desc) in providers.items(): @@ -85,11 +86,21 @@ def configure(): choice = Prompt.ask("Select", choices=list(providers.keys()), default="1") model_str, _ = providers[choice] + api_base = "" - if model_str == "custom": + if choice == "7": + model_str = Prompt.ask("Enter model name (must start with 'openai/')", default="openai/deepseek-chat") + if not model_str.startswith("openai/"): + model_str = "openai/" + model_str + api_base = Prompt.ask("Enter API base URL", default="https://agentrouter.org/v1").strip() + elif choice == "8": + model_str = Prompt.ask("Enter LiteLLM model string (e.g. openai/my-model-name)") + api_base = Prompt.ask("Enter custom API base URL (optional, e.g. https://my-endpoint/v1)", default="").strip() + elif model_str == "custom": model_str = Prompt.ask("Enter LiteLLM model string (e.g. openrouter/google/gemini-flash)") cfg.default_model = model_str + cfg.api_base = api_base # API key (skip for Ollama) if not model_str.startswith("ollama/"): @@ -108,6 +119,8 @@ def configure(): save_config(cfg) console.print(f"\n[bold green]✓ Config saved to {CONFIG_FILE}[/bold green]") console.print(f" Model: [cyan]{cfg.default_model}[/cyan]") + if cfg.api_base: + console.print(f" Base URL: [cyan]{cfg.api_base}[/cyan]") console.print(f" Output: [cyan]{cfg.output_format}[/cyan]\n") @@ -306,6 +319,7 @@ async def _scan_async(path, model, output, max_files, ci, fail_on, no_ai, sync): target_type="code", api_key=cfg.api_key if not no_ai else None, model=cfg.default_model, + api_base=cfg.api_base if not no_ai else None, ) await run_repl(ctx) @@ -401,6 +415,7 @@ async def _web_async(url, model, output, ci, fail_on, no_ai): target_type="web", api_key=cfg.api_key if not no_ai else None, model=cfg.default_model, + api_base=cfg.api_base if not no_ai else None, ) await run_repl(ctx) diff --git a/cli/securelens/config.py b/cli/securelens/config.py index 075495e..bc99311 100644 --- a/cli/securelens/config.py +++ b/cli/securelens/config.py @@ -20,6 +20,7 @@ class CLIConfig: # AI backend default_model: str = "gemini/gemini-2.0-flash" api_key: str = "" + api_base: str = "" # Backend Integration (for sync / auth) backend_url: str = "http://localhost:8000" @@ -61,6 +62,7 @@ def load_config() -> CLIConfig: data = yaml.safe_load(f) or {} cfg.default_model = data.get("default_model", cfg.default_model) cfg.api_key = data.get("api_key", cfg.api_key) + cfg.api_base = data.get("api_base", cfg.api_base) cfg.backend_url = data.get("backend_url", cfg.backend_url) cfg.token = data.get("token", cfg.token) cfg.output_format = data.get("output_format", cfg.output_format) @@ -82,6 +84,11 @@ def load_config() -> CLIConfig: or os.environ.get("AI_MODEL") or cfg.default_model ) + cfg.api_base = ( + os.environ.get("SECURELENS_API_BASE") + or os.environ.get("AI_API_BASE") + or cfg.api_base + ) return cfg @@ -92,6 +99,7 @@ def save_config(cfg: CLIConfig) -> None: data = { "default_model": cfg.default_model, "api_key": cfg.api_key, + "api_base": cfg.api_base, "backend_url": cfg.backend_url, "token": cfg.token, "output_format": cfg.output_format, diff --git a/cli/securelens/repl.py b/cli/securelens/repl.py index 1dcfdec..aa126c0 100644 --- a/cli/securelens/repl.py +++ b/cli/securelens/repl.py @@ -57,6 +57,7 @@ class ReplContext: target_type: str # "code" | "web" | "github" api_key: str model: str + api_base: Optional[str] = None conversation_history: list = field(default_factory=list) @@ -109,6 +110,7 @@ async def run_repl(ctx: ReplContext) -> None: model=ctx.model, temperature=0.5, conversation_history=ctx.conversation_history, + api_base=ctx.api_base, ) if response: diff --git a/tests/test_cli_api_base.py b/tests/test_cli_api_base.py new file mode 100644 index 0000000..fe82f94 --- /dev/null +++ b/tests/test_cli_api_base.py @@ -0,0 +1,54 @@ +from unittest.mock import AsyncMock, patch +import pytest +from securelens.ai import call_ai + +@pytest.fixture(autouse=True) +def setup_db(): + pass + +@pytest.mark.asyncio +async def test_call_ai_passes_api_base(): + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value.choices = [ + AsyncMock(message=AsyncMock(content="Mock response")) + ] + + await call_ai( + prompt="Hello", + api_key="mock_key", + model="openai/deepseek-chat", + api_base="https://agentrouter.org/v1" + ) + + mock_acompletion.assert_called_once() + called_kwargs = mock_acompletion.call_args[1] + assert called_kwargs["api_base"] == "https://agentrouter.org/v1" + assert called_kwargs["model"] == "openai/deepseek-chat" + assert called_kwargs["api_key"] == "mock_key" + +@pytest.mark.asyncio +async def test_backend_call_ai_passes_api_base(): + from app.services.ai import call_ai as backend_call_ai + from app.config import settings + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value.choices = [ + AsyncMock(message=AsyncMock(content="Mock response")) + ] + + original_key = settings.ai_api_key + original_base = settings.ai_api_base + + try: + settings.ai_api_key = "mock_key" + settings.ai_api_base = "https://agentrouter.org/v1" + + await backend_call_ai(prompt="Hello") + + mock_acompletion.assert_called_once() + called_kwargs = mock_acompletion.call_args[1] + assert called_kwargs["api_base"] == "https://agentrouter.org/v1" + assert called_kwargs["api_key"] == "mock_key" + finally: + settings.ai_api_key = original_key + settings.ai_api_base = original_base