mirror of
https://github.com/Rarebuffalo/securelens-backend.git
synced 2026-06-19 07:00:30 +00:00
add support for custom AI api base url and Agent Router integrations
This commit is contained in:
@@ -55,6 +55,9 @@ class Settings(BaseSettings):
|
||||
# Leave blank for Ollama (local, no key needed).
|
||||
ai_api_key: str | None = None
|
||||
|
||||
# AI_API_BASE: Custom API base URL (e.g. for Agent Router or custom OpenAI-compatible proxies)
|
||||
ai_api_base: str | None = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Legacy Gemini key — kept for backward compatibility.
|
||||
# If AI_API_KEY is not set but GEMINI_API_KEY is, we use that automatically.
|
||||
|
||||
@@ -74,6 +74,9 @@ async def call_ai(
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if settings.ai_api_base:
|
||||
kwargs["api_base"] = settings.ai_api_base
|
||||
|
||||
# JSON mode: supported natively by OpenAI and LiteLLM proxied Gemini.
|
||||
# For providers that don't support it, LiteLLM silently ignores the flag.
|
||||
if json_mode:
|
||||
|
||||
@@ -21,6 +21,7 @@ async def call_ai(
|
||||
temperature: float = 0.3,
|
||||
json_mode: bool = False,
|
||||
conversation_history: Optional[list] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry-point for all AI calls in the CLI.
|
||||
@@ -34,6 +35,7 @@ async def call_ai(
|
||||
json_mode : Ask the model to respond with valid JSON only
|
||||
conversation_history : Optional list of {"role": ..., "content": ...} dicts
|
||||
for multi-turn chat sessions
|
||||
api_base : Optional custom API base URL (e.g. for Agent Router)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
@@ -49,6 +51,9 @@ async def call_ai(
|
||||
"api_key": api_key if api_key else None,
|
||||
}
|
||||
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if json_mode:
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
|
||||
@@ -76,7 +76,8 @@ def configure():
|
||||
"4": ("gpt-4o", "OpenAI GPT-4o"),
|
||||
"5": ("claude-3-5-haiku-20241022","Anthropic Claude 3.5 Haiku"),
|
||||
"6": ("ollama/llama3.1", "Ollama (local, no key needed)"),
|
||||
"7": ("custom", "Custom model string"),
|
||||
"7": ("agentrouter", "Agent Router (OpenAI-compatible gateway)"),
|
||||
"8": ("custom", "Custom LiteLLM model / Endpoint"),
|
||||
}
|
||||
console.print("[bold]Choose AI Provider:[/bold]")
|
||||
for k, (_, desc) in providers.items():
|
||||
@@ -85,11 +86,21 @@ def configure():
|
||||
|
||||
choice = Prompt.ask("Select", choices=list(providers.keys()), default="1")
|
||||
model_str, _ = providers[choice]
|
||||
api_base = ""
|
||||
|
||||
if model_str == "custom":
|
||||
if choice == "7":
|
||||
model_str = Prompt.ask("Enter model name (must start with 'openai/')", default="openai/deepseek-chat")
|
||||
if not model_str.startswith("openai/"):
|
||||
model_str = "openai/" + model_str
|
||||
api_base = Prompt.ask("Enter API base URL", default="https://agentrouter.org/v1").strip()
|
||||
elif choice == "8":
|
||||
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openai/my-model-name)")
|
||||
api_base = Prompt.ask("Enter custom API base URL (optional, e.g. https://my-endpoint/v1)", default="").strip()
|
||||
elif model_str == "custom":
|
||||
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openrouter/google/gemini-flash)")
|
||||
|
||||
cfg.default_model = model_str
|
||||
cfg.api_base = api_base
|
||||
|
||||
# API key (skip for Ollama)
|
||||
if not model_str.startswith("ollama/"):
|
||||
@@ -108,6 +119,8 @@ def configure():
|
||||
save_config(cfg)
|
||||
console.print(f"\n[bold green]✓ Config saved to {CONFIG_FILE}[/bold green]")
|
||||
console.print(f" Model: [cyan]{cfg.default_model}[/cyan]")
|
||||
if cfg.api_base:
|
||||
console.print(f" Base URL: [cyan]{cfg.api_base}[/cyan]")
|
||||
console.print(f" Output: [cyan]{cfg.output_format}[/cyan]\n")
|
||||
|
||||
|
||||
@@ -306,6 +319,7 @@ async def _scan_async(path, model, output, max_files, ci, fail_on, no_ai, sync):
|
||||
target_type="code",
|
||||
api_key=cfg.api_key if not no_ai else None,
|
||||
model=cfg.default_model,
|
||||
api_base=cfg.api_base if not no_ai else None,
|
||||
)
|
||||
await run_repl(ctx)
|
||||
|
||||
@@ -401,6 +415,7 @@ async def _web_async(url, model, output, ci, fail_on, no_ai):
|
||||
target_type="web",
|
||||
api_key=cfg.api_key if not no_ai else None,
|
||||
model=cfg.default_model,
|
||||
api_base=cfg.api_base if not no_ai else None,
|
||||
)
|
||||
await run_repl(ctx)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ class CLIConfig:
|
||||
# AI backend
|
||||
default_model: str = "gemini/gemini-2.0-flash"
|
||||
api_key: str = ""
|
||||
api_base: str = ""
|
||||
|
||||
# Backend Integration (for sync / auth)
|
||||
backend_url: str = "http://localhost:8000"
|
||||
@@ -61,6 +62,7 @@ def load_config() -> CLIConfig:
|
||||
data = yaml.safe_load(f) or {}
|
||||
cfg.default_model = data.get("default_model", cfg.default_model)
|
||||
cfg.api_key = data.get("api_key", cfg.api_key)
|
||||
cfg.api_base = data.get("api_base", cfg.api_base)
|
||||
cfg.backend_url = data.get("backend_url", cfg.backend_url)
|
||||
cfg.token = data.get("token", cfg.token)
|
||||
cfg.output_format = data.get("output_format", cfg.output_format)
|
||||
@@ -82,6 +84,11 @@ def load_config() -> CLIConfig:
|
||||
or os.environ.get("AI_MODEL")
|
||||
or cfg.default_model
|
||||
)
|
||||
cfg.api_base = (
|
||||
os.environ.get("SECURELENS_API_BASE")
|
||||
or os.environ.get("AI_API_BASE")
|
||||
or cfg.api_base
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -92,6 +99,7 @@ def save_config(cfg: CLIConfig) -> None:
|
||||
data = {
|
||||
"default_model": cfg.default_model,
|
||||
"api_key": cfg.api_key,
|
||||
"api_base": cfg.api_base,
|
||||
"backend_url": cfg.backend_url,
|
||||
"token": cfg.token,
|
||||
"output_format": cfg.output_format,
|
||||
|
||||
@@ -57,6 +57,7 @@ class ReplContext:
|
||||
target_type: str # "code" | "web" | "github"
|
||||
api_key: str
|
||||
model: str
|
||||
api_base: Optional[str] = None
|
||||
conversation_history: list = field(default_factory=list)
|
||||
|
||||
|
||||
@@ -109,6 +110,7 @@ async def run_repl(ctx: ReplContext) -> None:
|
||||
model=ctx.model,
|
||||
temperature=0.5,
|
||||
conversation_history=ctx.conversation_history,
|
||||
api_base=ctx.api_base,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
||||
54
tests/test_cli_api_base.py
Normal file
54
tests/test_cli_api_base.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
from securelens.ai import call_ai
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_db():
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_ai_passes_api_base():
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value.choices = [
|
||||
AsyncMock(message=AsyncMock(content="Mock response"))
|
||||
]
|
||||
|
||||
await call_ai(
|
||||
prompt="Hello",
|
||||
api_key="mock_key",
|
||||
model="openai/deepseek-chat",
|
||||
api_base="https://agentrouter.org/v1"
|
||||
)
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
called_kwargs = mock_acompletion.call_args[1]
|
||||
assert called_kwargs["api_base"] == "https://agentrouter.org/v1"
|
||||
assert called_kwargs["model"] == "openai/deepseek-chat"
|
||||
assert called_kwargs["api_key"] == "mock_key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_call_ai_passes_api_base():
|
||||
from app.services.ai import call_ai as backend_call_ai
|
||||
from app.config import settings
|
||||
|
||||
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
|
||||
mock_acompletion.return_value.choices = [
|
||||
AsyncMock(message=AsyncMock(content="Mock response"))
|
||||
]
|
||||
|
||||
original_key = settings.ai_api_key
|
||||
original_base = settings.ai_api_base
|
||||
|
||||
try:
|
||||
settings.ai_api_key = "mock_key"
|
||||
settings.ai_api_base = "https://agentrouter.org/v1"
|
||||
|
||||
await backend_call_ai(prompt="Hello")
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
called_kwargs = mock_acompletion.call_args[1]
|
||||
assert called_kwargs["api_base"] == "https://agentrouter.org/v1"
|
||||
assert called_kwargs["api_key"] == "mock_key"
|
||||
finally:
|
||||
settings.ai_api_key = original_key
|
||||
settings.ai_api_base = original_base
|
||||
Reference in New Issue
Block a user