Files
iqAI/backend/routes/chat.js

130 lines
4.1 KiB
JavaScript
Raw Normal View History

import express from 'express';
import { loadModels } from './models.js';
const router = express.Router();
const REPLICATE_BASE = 'https://api.replicate.com/v1';
function buildInput(model, prompt, systemPrompt, extraParams) {
const base = { ...model.defaultInput };
// Inject system prompt if model supports it
if (model.systemPromptParam && systemPrompt) {
base[model.systemPromptParam] = systemPrompt;
}
// Apply extra params from user
if (extraParams) {
Object.assign(base, extraParams);
}
base.prompt = prompt;
return base;
}
async function runPrediction(model, input, token) {
const url = `${REPLICATE_BASE}/models/${model.owner}/${model.name}/predictions`;
const response = await fetch(url, {
method: 'POST',
headers: {
'Authorization': `Bearer ${token}`,
'Content-Type': 'application/json',
'Prefer': 'wait'
},
body: JSON.stringify({ input })
});
if (!response.ok) {
const err = await response.json().catch(() => ({ detail: 'Unknown error' }));
throw new Error(err.detail || `HTTP ${response.status}`);
}
return response.json();
}
function extractOutput(prediction) {
const { output } = prediction;
if (!output) return '';
if (typeof output === 'string') return output;
if (Array.isArray(output)) return output.join('');
if (typeof output === 'object') return JSON.stringify(output, null, 2);
return String(output);
}
// POST /api/chat - single model call
router.post('/', async (req, res) => {
const token = process.env.REPLICATE_API_TOKEN;
if (!token) return res.status(400).json({ error: 'REPLICATE_API_TOKEN not configured' });
const { modelId, prompt, systemPrompt, searchContext, extraParams } = req.body;
if (!modelId || !prompt) return res.status(400).json({ error: 'modelId and prompt are required' });
try {
const models = await loadModels();
const model = models.find(m => m.id === modelId);
if (!model) return res.status(404).json({ error: `Model not found: ${modelId}` });
const finalPrompt = searchContext ? `${searchContext}\n\n${prompt}` : prompt;
const input = buildInput(model, finalPrompt, systemPrompt, extraParams);
const prediction = await runPrediction(model, input, token);
const content = extractOutput(prediction);
res.json({
id: prediction.id,
modelId: model.id,
modelTag: model.tag,
modelName: model.displayName,
content,
status: prediction.status,
metrics: prediction.metrics,
urls: prediction.urls
});
} catch (err) {
res.status(500).json({ error: err.message });
}
});
// POST /api/chat/multi - send to multiple models in parallel
router.post('/multi', async (req, res) => {
const token = process.env.REPLICATE_API_TOKEN;
if (!token) return res.status(400).json({ error: 'REPLICATE_API_TOKEN not configured' });
const { modelIds, prompt, systemPrompt, searchContext, extraParams } = req.body;
if (!modelIds?.length || !prompt) return res.status(400).json({ error: 'modelIds and prompt are required' });
try {
const models = await loadModels();
const tasks = modelIds.map(async (modelId) => {
const model = models.find(m => m.id === modelId);
if (!model) return { modelId, error: 'Model not found' };
try {
const finalPrompt = searchContext ? `${searchContext}\n\n${prompt}` : prompt;
const input = buildInput(model, finalPrompt, systemPrompt, extraParams);
const prediction = await runPrediction(model, input, token);
const content = extractOutput(prediction);
return {
id: prediction.id,
modelId: model.id,
modelTag: model.tag,
modelName: model.displayName,
content,
status: prediction.status,
metrics: prediction.metrics
};
} catch (err) {
return { modelId, modelTag: model.tag, modelName: model.displayName, error: err.message };
}
});
const results = await Promise.all(tasks);
res.json({ results });
} catch (err) {
res.status(500).json({ error: err.message });
}
});
export { router as chatRouter };