130 lines
4.1 KiB
JavaScript
130 lines
4.1 KiB
JavaScript
|
|
import express from 'express';
|
||
|
|
import { loadModels } from './models.js';
|
||
|
|
|
||
|
|
const router = express.Router();
|
||
|
|
const REPLICATE_BASE = 'https://api.replicate.com/v1';
|
||
|
|
|
||
|
|
function buildInput(model, prompt, systemPrompt, extraParams) {
|
||
|
|
const base = { ...model.defaultInput };
|
||
|
|
|
||
|
|
// Inject system prompt if model supports it
|
||
|
|
if (model.systemPromptParam && systemPrompt) {
|
||
|
|
base[model.systemPromptParam] = systemPrompt;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Apply extra params from user
|
||
|
|
if (extraParams) {
|
||
|
|
Object.assign(base, extraParams);
|
||
|
|
}
|
||
|
|
|
||
|
|
base.prompt = prompt;
|
||
|
|
return base;
|
||
|
|
}
|
||
|
|
|
||
|
|
async function runPrediction(model, input, token) {
|
||
|
|
const url = `${REPLICATE_BASE}/models/${model.owner}/${model.name}/predictions`;
|
||
|
|
|
||
|
|
const response = await fetch(url, {
|
||
|
|
method: 'POST',
|
||
|
|
headers: {
|
||
|
|
'Authorization': `Bearer ${token}`,
|
||
|
|
'Content-Type': 'application/json',
|
||
|
|
'Prefer': 'wait'
|
||
|
|
},
|
||
|
|
body: JSON.stringify({ input })
|
||
|
|
});
|
||
|
|
|
||
|
|
if (!response.ok) {
|
||
|
|
const err = await response.json().catch(() => ({ detail: 'Unknown error' }));
|
||
|
|
throw new Error(err.detail || `HTTP ${response.status}`);
|
||
|
|
}
|
||
|
|
|
||
|
|
return response.json();
|
||
|
|
}
|
||
|
|
|
||
|
|
function extractOutput(prediction) {
|
||
|
|
const { output } = prediction;
|
||
|
|
if (!output) return '';
|
||
|
|
if (typeof output === 'string') return output;
|
||
|
|
if (Array.isArray(output)) return output.join('');
|
||
|
|
if (typeof output === 'object') return JSON.stringify(output, null, 2);
|
||
|
|
return String(output);
|
||
|
|
}
|
||
|
|
|
||
|
|
// POST /api/chat - single model call
|
||
|
|
router.post('/', async (req, res) => {
|
||
|
|
const token = process.env.REPLICATE_API_TOKEN;
|
||
|
|
if (!token) return res.status(400).json({ error: 'REPLICATE_API_TOKEN not configured' });
|
||
|
|
|
||
|
|
const { modelId, prompt, systemPrompt, searchContext, extraParams } = req.body;
|
||
|
|
if (!modelId || !prompt) return res.status(400).json({ error: 'modelId and prompt are required' });
|
||
|
|
|
||
|
|
try {
|
||
|
|
const models = await loadModels();
|
||
|
|
const model = models.find(m => m.id === modelId);
|
||
|
|
if (!model) return res.status(404).json({ error: `Model not found: ${modelId}` });
|
||
|
|
|
||
|
|
const finalPrompt = searchContext ? `${searchContext}\n\n${prompt}` : prompt;
|
||
|
|
const input = buildInput(model, finalPrompt, systemPrompt, extraParams);
|
||
|
|
|
||
|
|
const prediction = await runPrediction(model, input, token);
|
||
|
|
const content = extractOutput(prediction);
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
id: prediction.id,
|
||
|
|
modelId: model.id,
|
||
|
|
modelTag: model.tag,
|
||
|
|
modelName: model.displayName,
|
||
|
|
content,
|
||
|
|
status: prediction.status,
|
||
|
|
metrics: prediction.metrics,
|
||
|
|
urls: prediction.urls
|
||
|
|
});
|
||
|
|
} catch (err) {
|
||
|
|
res.status(500).json({ error: err.message });
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
// POST /api/chat/multi - send to multiple models in parallel
|
||
|
|
router.post('/multi', async (req, res) => {
|
||
|
|
const token = process.env.REPLICATE_API_TOKEN;
|
||
|
|
if (!token) return res.status(400).json({ error: 'REPLICATE_API_TOKEN not configured' });
|
||
|
|
|
||
|
|
const { modelIds, prompt, systemPrompt, searchContext, extraParams } = req.body;
|
||
|
|
if (!modelIds?.length || !prompt) return res.status(400).json({ error: 'modelIds and prompt are required' });
|
||
|
|
|
||
|
|
try {
|
||
|
|
const models = await loadModels();
|
||
|
|
|
||
|
|
const tasks = modelIds.map(async (modelId) => {
|
||
|
|
const model = models.find(m => m.id === modelId);
|
||
|
|
if (!model) return { modelId, error: 'Model not found' };
|
||
|
|
|
||
|
|
try {
|
||
|
|
const finalPrompt = searchContext ? `${searchContext}\n\n${prompt}` : prompt;
|
||
|
|
const input = buildInput(model, finalPrompt, systemPrompt, extraParams);
|
||
|
|
const prediction = await runPrediction(model, input, token);
|
||
|
|
const content = extractOutput(prediction);
|
||
|
|
return {
|
||
|
|
id: prediction.id,
|
||
|
|
modelId: model.id,
|
||
|
|
modelTag: model.tag,
|
||
|
|
modelName: model.displayName,
|
||
|
|
content,
|
||
|
|
status: prediction.status,
|
||
|
|
metrics: prediction.metrics
|
||
|
|
};
|
||
|
|
} catch (err) {
|
||
|
|
return { modelId, modelTag: model.tag, modelName: model.displayName, error: err.message };
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
const results = await Promise.all(tasks);
|
||
|
|
res.json({ results });
|
||
|
|
} catch (err) {
|
||
|
|
res.status(500).json({ error: err.message });
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
export { router as chatRouter };
|