fix: poll Replicate for DeepSeek-R1 async predictions (202 Accepted)
DeepSeek-R1 is too slow for synchronous Replicate wait; it returns 202 with a prediction URL instead of the completed output. Added polling loop: - POST with Prefer: wait=60 - If 202 or status=starting/processing, poll urls.get every 2s up to 90× (~3 min ceiling) - On succeeded, use the final response data as normal - On failed/canceled/timeout, log and return [] Also guards against output=None before calling str.join(). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -221,7 +221,11 @@ def _parse_classify_output(raw: str) -> list[dict]:
|
||||
|
||||
|
||||
async def classify_with_deepseek(live_items: list[dict]) -> list[dict]:
|
||||
"""Single DeepSeek call → list of {domain, niche, type}."""
|
||||
"""Single DeepSeek call → list of {domain, niche, type}.
|
||||
|
||||
Replicate may return 202 (async) for slow models like DeepSeek-R1.
|
||||
We poll the prediction URL until it succeeds or times out.
|
||||
"""
|
||||
if not live_items:
|
||||
return []
|
||||
payload = {
|
||||
@@ -231,23 +235,49 @@ async def classify_with_deepseek(live_items: list[dict]) -> list[dict]:
|
||||
"temperature": 0.1,
|
||||
}
|
||||
}
|
||||
auth_headers = {
|
||||
"Authorization": f"Bearer {REPLICATE_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
resp = await client.post(
|
||||
DEEPSEEK_MODEL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {REPLICATE_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
"Prefer": "wait",
|
||||
},
|
||||
headers={**auth_headers, "Prefer": "wait=60"},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
output = data.get("output", "")
|
||||
# ── Poll if Replicate accepted async (202 or status starting/processing) ──
|
||||
if resp.status_code == 202 or data.get("status") in ("starting", "processing"):
|
||||
poll_url = (data.get("urls") or {}).get("get")
|
||||
if not poll_url:
|
||||
logger.error("DeepSeek: 202 but no poll URL in response")
|
||||
return []
|
||||
logger.info("DeepSeek: async prediction, polling %s", poll_url)
|
||||
for attempt in range(90): # up to ~3 minutes
|
||||
await asyncio.sleep(2)
|
||||
pr = await client.get(
|
||||
poll_url,
|
||||
headers={"Authorization": f"Bearer {REPLICATE_TOKEN}"},
|
||||
)
|
||||
pdata = pr.json()
|
||||
status = pdata.get("status")
|
||||
logger.debug("DeepSeek poll #%d status=%s", attempt + 1, status)
|
||||
if status == "succeeded":
|
||||
data = pdata
|
||||
break
|
||||
if status in ("failed", "canceled"):
|
||||
logger.error("DeepSeek prediction %s: %s", status, pdata.get("error"))
|
||||
return []
|
||||
else:
|
||||
logger.error("DeepSeek: prediction timed out after polling 90×2s")
|
||||
return []
|
||||
|
||||
output = data.get("output") or ""
|
||||
if isinstance(output, list):
|
||||
output = "".join(output)
|
||||
output = "".join(str(t) for t in output if t is not None)
|
||||
|
||||
logger.info("DeepSeek raw output (first 500 chars): %.500s", output)
|
||||
result = _parse_classify_output(output)
|
||||
|
||||
Reference in New Issue
Block a user