mirror of
https://github.com/Rarebuffalo/securelens-backend.git
synced 2026-06-19 07:00:30 +00:00
add support for custom AI api base url and Agent Router integrations
This commit is contained in:
@@ -21,6 +21,7 @@ async def call_ai(
|
||||
temperature: float = 0.3,
|
||||
json_mode: bool = False,
|
||||
conversation_history: Optional[list] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry-point for all AI calls in the CLI.
|
||||
@@ -34,6 +35,7 @@ async def call_ai(
|
||||
json_mode : Ask the model to respond with valid JSON only
|
||||
conversation_history : Optional list of {"role": ..., "content": ...} dicts
|
||||
for multi-turn chat sessions
|
||||
api_base : Optional custom API base URL (e.g. for Agent Router)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
@@ -49,6 +51,9 @@ async def call_ai(
|
||||
"api_key": api_key if api_key else None,
|
||||
}
|
||||
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if json_mode:
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
|
||||
@@ -76,7 +76,8 @@ def configure():
|
||||
"4": ("gpt-4o", "OpenAI GPT-4o"),
|
||||
"5": ("claude-3-5-haiku-20241022","Anthropic Claude 3.5 Haiku"),
|
||||
"6": ("ollama/llama3.1", "Ollama (local, no key needed)"),
|
||||
"7": ("custom", "Custom model string"),
|
||||
"7": ("agentrouter", "Agent Router (OpenAI-compatible gateway)"),
|
||||
"8": ("custom", "Custom LiteLLM model / Endpoint"),
|
||||
}
|
||||
console.print("[bold]Choose AI Provider:[/bold]")
|
||||
for k, (_, desc) in providers.items():
|
||||
@@ -85,11 +86,21 @@ def configure():
|
||||
|
||||
choice = Prompt.ask("Select", choices=list(providers.keys()), default="1")
|
||||
model_str, _ = providers[choice]
|
||||
api_base = ""
|
||||
|
||||
if model_str == "custom":
|
||||
if choice == "7":
|
||||
model_str = Prompt.ask("Enter model name (must start with 'openai/')", default="openai/deepseek-chat")
|
||||
if not model_str.startswith("openai/"):
|
||||
model_str = "openai/" + model_str
|
||||
api_base = Prompt.ask("Enter API base URL", default="https://agentrouter.org/v1").strip()
|
||||
elif choice == "8":
|
||||
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openai/my-model-name)")
|
||||
api_base = Prompt.ask("Enter custom API base URL (optional, e.g. https://my-endpoint/v1)", default="").strip()
|
||||
elif model_str == "custom":
|
||||
model_str = Prompt.ask("Enter LiteLLM model string (e.g. openrouter/google/gemini-flash)")
|
||||
|
||||
cfg.default_model = model_str
|
||||
cfg.api_base = api_base
|
||||
|
||||
# API key (skip for Ollama)
|
||||
if not model_str.startswith("ollama/"):
|
||||
@@ -108,6 +119,8 @@ def configure():
|
||||
save_config(cfg)
|
||||
console.print(f"\n[bold green]✓ Config saved to {CONFIG_FILE}[/bold green]")
|
||||
console.print(f" Model: [cyan]{cfg.default_model}[/cyan]")
|
||||
if cfg.api_base:
|
||||
console.print(f" Base URL: [cyan]{cfg.api_base}[/cyan]")
|
||||
console.print(f" Output: [cyan]{cfg.output_format}[/cyan]\n")
|
||||
|
||||
|
||||
@@ -306,6 +319,7 @@ async def _scan_async(path, model, output, max_files, ci, fail_on, no_ai, sync):
|
||||
target_type="code",
|
||||
api_key=cfg.api_key if not no_ai else None,
|
||||
model=cfg.default_model,
|
||||
api_base=cfg.api_base if not no_ai else None,
|
||||
)
|
||||
await run_repl(ctx)
|
||||
|
||||
@@ -401,6 +415,7 @@ async def _web_async(url, model, output, ci, fail_on, no_ai):
|
||||
target_type="web",
|
||||
api_key=cfg.api_key if not no_ai else None,
|
||||
model=cfg.default_model,
|
||||
api_base=cfg.api_base if not no_ai else None,
|
||||
)
|
||||
await run_repl(ctx)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ class CLIConfig:
|
||||
# AI backend
|
||||
default_model: str = "gemini/gemini-2.0-flash"
|
||||
api_key: str = ""
|
||||
api_base: str = ""
|
||||
|
||||
# Backend Integration (for sync / auth)
|
||||
backend_url: str = "http://localhost:8000"
|
||||
@@ -61,6 +62,7 @@ def load_config() -> CLIConfig:
|
||||
data = yaml.safe_load(f) or {}
|
||||
cfg.default_model = data.get("default_model", cfg.default_model)
|
||||
cfg.api_key = data.get("api_key", cfg.api_key)
|
||||
cfg.api_base = data.get("api_base", cfg.api_base)
|
||||
cfg.backend_url = data.get("backend_url", cfg.backend_url)
|
||||
cfg.token = data.get("token", cfg.token)
|
||||
cfg.output_format = data.get("output_format", cfg.output_format)
|
||||
@@ -82,6 +84,11 @@ def load_config() -> CLIConfig:
|
||||
or os.environ.get("AI_MODEL")
|
||||
or cfg.default_model
|
||||
)
|
||||
cfg.api_base = (
|
||||
os.environ.get("SECURELENS_API_BASE")
|
||||
or os.environ.get("AI_API_BASE")
|
||||
or cfg.api_base
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -92,6 +99,7 @@ def save_config(cfg: CLIConfig) -> None:
|
||||
data = {
|
||||
"default_model": cfg.default_model,
|
||||
"api_key": cfg.api_key,
|
||||
"api_base": cfg.api_base,
|
||||
"backend_url": cfg.backend_url,
|
||||
"token": cfg.token,
|
||||
"output_format": cfg.output_format,
|
||||
|
||||
@@ -57,6 +57,7 @@ class ReplContext:
|
||||
target_type: str # "code" | "web" | "github"
|
||||
api_key: str
|
||||
model: str
|
||||
api_base: Optional[str] = None
|
||||
conversation_history: list = field(default_factory=list)
|
||||
|
||||
|
||||
@@ -109,6 +110,7 @@ async def run_repl(ctx: ReplContext) -> None:
|
||||
model=ctx.model,
|
||||
temperature=0.5,
|
||||
conversation_history=ctx.conversation_history,
|
||||
api_base=ctx.api_base,
|
||||
)
|
||||
|
||||
if response:
|
||||
|
||||
Reference in New Issue
Block a user