add support for custom AI api base url and Agent Router integrations

This commit is contained in:
rarebuffalo
2026-06-15 01:30:24 +05:30
parent eb657ac30a
commit 6f83412d6f
7 changed files with 92 additions and 2 deletions

View File

@@ -55,6 +55,9 @@ class Settings(BaseSettings):
# Leave blank for Ollama (local, no key needed).
ai_api_key: str | None = None
# AI_API_BASE: Custom API base URL (e.g. for Agent Router or custom OpenAI-compatible proxies)
ai_api_base: str | None = None
# -------------------------------------------------------------------------
# Legacy Gemini key — kept for backward compatibility.
# If AI_API_KEY is not set but GEMINI_API_KEY is, we use that automatically.

View File

@@ -74,6 +74,9 @@ async def call_ai(
"api_key": api_key,
}
if settings.ai_api_base:
kwargs["api_base"] = settings.ai_api_base
# JSON mode: supported natively by OpenAI and LiteLLM proxied Gemini.
# For providers that don't support it, LiteLLM silently ignores the flag.
if json_mode:

View File

@@ -21,6 +21,7 @@ async def call_ai(
temperature: float = 0.3,
json_mode: bool = False,
conversation_history: Optional[list] = None,
api_base: Optional[str] = None,
) -> str:
"""
Single entry-point for all AI calls in the CLI.
@@ -34,6 +35,7 @@ async def call_ai(
json_mode : Ask the model to respond with valid JSON only
conversation_history : Optional list of {"role": ..., "content": ...} dicts
for multi-turn chat sessions
api_base : Optional custom API base URL (e.g. for Agent Router)
"""
import litellm
@@ -49,6 +51,9 @@ async def call_ai(
"api_key": api_key if api_key else None,
}
if api_base:
kwargs["api_base"] = api_base
if json_mode:
kwargs["response_format"] = {"type": "json_object"}

View File

@@ -76,7 +76,8 @@ def configure():
"4": ("gpt-4o", "OpenAI GPT-4o"),
"5": ("claude-3-5-haiku-20241022","Anthropic Claude 3.5 Haiku"),
"6": ("ollama/llama3.1", "Ollama (local, no key needed)"),
"7": ("custom", "Custom model string"),
"7": ("agentrouter", "Agent Router (OpenAI-compatible gateway)"),
"8": ("custom", "Custom LiteLLM model / Endpoint"),
}
console.print("[bold]Choose AI Provider:[/bold]")
for k, (_, desc) in providers.items():
@@ -85,11 +86,21 @@ def configure():
choice = Prompt.ask("Select", choices=list(providers.keys()), default="1")
model_str, _ = providers[choice]
api_base = ""
if model_str == "custom":
if choice == "7":
model_str = Prompt.ask("Enter model name (must start with 'openai/')", default="openai/deepseek-chat")
if not model_str.startswith("openai/"):
model_str = "openai/" + model_str
api_base = Prompt.ask("Enter API base URL", default="https://agentrouter.org/v1").strip()
elif choice == "8":
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openai/my-model-name)")
api_base = Prompt.ask("Enter custom API base URL (optional, e.g. https://my-endpoint/v1)", default="").strip()
elif model_str == "custom":
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openrouter/google/gemini-flash)")
cfg.default_model = model_str
cfg.api_base = api_base
# API key (skip for Ollama)
if not model_str.startswith("ollama/"):
@@ -108,6 +119,8 @@ def configure():
save_config(cfg)
console.print(f"\n[bold green]✓ Config saved to {CONFIG_FILE}[/bold green]")
console.print(f" Model: [cyan]{cfg.default_model}[/cyan]")
if cfg.api_base:
console.print(f" Base URL: [cyan]{cfg.api_base}[/cyan]")
console.print(f" Output: [cyan]{cfg.output_format}[/cyan]\n")
@@ -306,6 +319,7 @@ async def _scan_async(path, model, output, max_files, ci, fail_on, no_ai, sync):
target_type="code",
api_key=cfg.api_key if not no_ai else None,
model=cfg.default_model,
api_base=cfg.api_base if not no_ai else None,
)
await run_repl(ctx)
@@ -401,6 +415,7 @@ async def _web_async(url, model, output, ci, fail_on, no_ai):
target_type="web",
api_key=cfg.api_key if not no_ai else None,
model=cfg.default_model,
api_base=cfg.api_base if not no_ai else None,
)
await run_repl(ctx)

View File

@@ -20,6 +20,7 @@ class CLIConfig:
# AI backend
default_model: str = "gemini/gemini-2.0-flash"
api_key: str = ""
api_base: str = ""
# Backend Integration (for sync / auth)
backend_url: str = "http://localhost:8000"
@@ -61,6 +62,7 @@ def load_config() -> CLIConfig:
data = yaml.safe_load(f) or {}
cfg.default_model = data.get("default_model", cfg.default_model)
cfg.api_key = data.get("api_key", cfg.api_key)
cfg.api_base = data.get("api_base", cfg.api_base)
cfg.backend_url = data.get("backend_url", cfg.backend_url)
cfg.token = data.get("token", cfg.token)
cfg.output_format = data.get("output_format", cfg.output_format)
@@ -82,6 +84,11 @@ def load_config() -> CLIConfig:
or os.environ.get("AI_MODEL")
or cfg.default_model
)
cfg.api_base = (
os.environ.get("SECURELENS_API_BASE")
or os.environ.get("AI_API_BASE")
or cfg.api_base
)
return cfg
@@ -92,6 +99,7 @@ def save_config(cfg: CLIConfig) -> None:
data = {
"default_model": cfg.default_model,
"api_key": cfg.api_key,
"api_base": cfg.api_base,
"backend_url": cfg.backend_url,
"token": cfg.token,
"output_format": cfg.output_format,

View File

@@ -57,6 +57,7 @@ class ReplContext:
target_type: str # "code" | "web" | "github"
api_key: str
model: str
api_base: Optional[str] = None
conversation_history: list = field(default_factory=list)
@@ -109,6 +110,7 @@ async def run_repl(ctx: ReplContext) -> None:
model=ctx.model,
temperature=0.5,
conversation_history=ctx.conversation_history,
api_base=ctx.api_base,
)
if response:

View File

@@ -0,0 +1,54 @@
from unittest.mock import AsyncMock, patch
import pytest
from securelens.ai import call_ai
@pytest.fixture(autouse=True)
def setup_db():
pass
@pytest.mark.asyncio
async def test_call_ai_passes_api_base():
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value.choices = [
AsyncMock(message=AsyncMock(content="Mock response"))
]
await call_ai(
prompt="Hello",
api_key="mock_key",
model="openai/deepseek-chat",
api_base="https://agentrouter.org/v1"
)
mock_acompletion.assert_called_once()
called_kwargs = mock_acompletion.call_args[1]
assert called_kwargs["api_base"] == "https://agentrouter.org/v1"
assert called_kwargs["model"] == "openai/deepseek-chat"
assert called_kwargs["api_key"] == "mock_key"
@pytest.mark.asyncio
async def test_backend_call_ai_passes_api_base():
from app.services.ai import call_ai as backend_call_ai
from app.config import settings
with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion:
mock_acompletion.return_value.choices = [
AsyncMock(message=AsyncMock(content="Mock response"))
]
original_key = settings.ai_api_key
original_base = settings.ai_api_base
try:
settings.ai_api_key = "mock_key"
settings.ai_api_base = "https://agentrouter.org/v1"
await backend_call_ai(prompt="Hello")
mock_acompletion.assert_called_once()
called_kwargs = mock_acompletion.call_args[1]
assert called_kwargs["api_base"] == "https://agentrouter.org/v1"
assert called_kwargs["api_key"] == "mock_key"
finally:
settings.ai_api_key = original_key
settings.ai_api_base = original_base