add support for custom AI api base url and Agent Router integrations

This commit is contained in:
rarebuffalo
2026-06-15 01:30:24 +05:30
parent eb657ac30a
commit 6f83412d6f
7 changed files with 92 additions and 2 deletions

View File

@@ -21,6 +21,7 @@ async def call_ai(
temperature: float = 0.3,
json_mode: bool = False,
conversation_history: Optional[list] = None,
api_base: Optional[str] = None,
) -> str:
"""
Single entry-point for all AI calls in the CLI.
@@ -34,6 +35,7 @@ async def call_ai(
json_mode : Ask the model to respond with valid JSON only
conversation_history : Optional list of {"role": ..., "content": ...} dicts
for multi-turn chat sessions
api_base : Optional custom API base URL (e.g. for Agent Router)
"""
import litellm
@@ -49,6 +51,9 @@ async def call_ai(
"api_key": api_key if api_key else None,
}
if api_base:
kwargs["api_base"] = api_base
if json_mode:
kwargs["response_format"] = {"type": "json_object"}

View File

@@ -76,7 +76,8 @@ def configure():
"4": ("gpt-4o", "OpenAI GPT-4o"),
"5": ("claude-3-5-haiku-20241022","Anthropic Claude 3.5 Haiku"),
"6": ("ollama/llama3.1", "Ollama (local, no key needed)"),
"7": ("custom", "Custom model string"),
"7": ("agentrouter", "Agent Router (OpenAI-compatible gateway)"),
"8": ("custom", "Custom LiteLLM model / Endpoint"),
}
console.print("[bold]Choose AI Provider:[/bold]")
for k, (_, desc) in providers.items():
@@ -85,11 +86,21 @@ def configure():
choice = Prompt.ask("Select", choices=list(providers.keys()), default="1")
model_str, _ = providers[choice]
api_base = ""
if model_str == "custom":
if choice == "7":
model_str = Prompt.ask("Enter model name (must start with 'openai/')", default="openai/deepseek-chat")
if not model_str.startswith("openai/"):
model_str = "openai/" + model_str
api_base = Prompt.ask("Enter API base URL", default="https://agentrouter.org/v1").strip()
elif choice == "8":
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openai/my-model-name)")
api_base = Prompt.ask("Enter custom API base URL (optional, e.g. https://my-endpoint/v1)", default="").strip()
elif model_str == "custom":
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openrouter/google/gemini-flash)")
cfg.default_model = model_str
cfg.api_base = api_base
# API key (skip for Ollama)
if not model_str.startswith("ollama/"):
@@ -108,6 +119,8 @@ def configure():
save_config(cfg)
console.print(f"\n[bold green]✓ Config saved to {CONFIG_FILE}[/bold green]")
console.print(f" Model: [cyan]{cfg.default_model}[/cyan]")
if cfg.api_base:
console.print(f" Base URL: [cyan]{cfg.api_base}[/cyan]")
console.print(f" Output: [cyan]{cfg.output_format}[/cyan]\n")
@@ -306,6 +319,7 @@ async def _scan_async(path, model, output, max_files, ci, fail_on, no_ai, sync):
target_type="code",
api_key=cfg.api_key if not no_ai else None,
model=cfg.default_model,
api_base=cfg.api_base if not no_ai else None,
)
await run_repl(ctx)
@@ -401,6 +415,7 @@ async def _web_async(url, model, output, ci, fail_on, no_ai):
target_type="web",
api_key=cfg.api_key if not no_ai else None,
model=cfg.default_model,
api_base=cfg.api_base if not no_ai else None,
)
await run_repl(ctx)

View File

@@ -20,6 +20,7 @@ class CLIConfig:
# AI backend
default_model: str = "gemini/gemini-2.0-flash"
api_key: str = ""
api_base: str = ""
# Backend Integration (for sync / auth)
backend_url: str = "http://localhost:8000"
@@ -61,6 +62,7 @@ def load_config() -> CLIConfig:
data = yaml.safe_load(f) or {}
cfg.default_model = data.get("default_model", cfg.default_model)
cfg.api_key = data.get("api_key", cfg.api_key)
cfg.api_base = data.get("api_base", cfg.api_base)
cfg.backend_url = data.get("backend_url", cfg.backend_url)
cfg.token = data.get("token", cfg.token)
cfg.output_format = data.get("output_format", cfg.output_format)
@@ -82,6 +84,11 @@ def load_config() -> CLIConfig:
or os.environ.get("AI_MODEL")
or cfg.default_model
)
cfg.api_base = (
os.environ.get("SECURELENS_API_BASE")
or os.environ.get("AI_API_BASE")
or cfg.api_base
)
return cfg
@@ -92,6 +99,7 @@ def save_config(cfg: CLIConfig) -> None:
data = {
"default_model": cfg.default_model,
"api_key": cfg.api_key,
"api_base": cfg.api_base,
"backend_url": cfg.backend_url,
"token": cfg.token,
"output_format": cfg.output_format,

View File

@@ -57,6 +57,7 @@ class ReplContext:
target_type: str # "code" | "web" | "github"
api_key: str
model: str
api_base: Optional[str] = None
conversation_history: list = field(default_factory=list)
@@ -109,6 +110,7 @@ async def run_repl(ctx: ReplContext) -> None:
model=ctx.model,
temperature=0.5,
conversation_history=ctx.conversation_history,
api_base=ctx.api_base,
)
if response: