diff --git a/app/prescreener.py b/app/prescreener.py index 50250bd..51a32d5 100644 --- a/app/prescreener.py +++ b/app/prescreener.py @@ -221,7 +221,11 @@ def _parse_classify_output(raw: str) -> list[dict]: async def classify_with_deepseek(live_items: list[dict]) -> list[dict]: - """Single DeepSeek call → list of {domain, niche, type}.""" + """Single DeepSeek call → list of {domain, niche, type}. + + Replicate may return 202 (async) for slow models like DeepSeek-R1. + We poll the prediction URL until it succeeds or times out. + """ if not live_items: return [] payload = { @@ -231,23 +235,49 @@ async def classify_with_deepseek(live_items: list[dict]) -> list[dict]: "temperature": 0.1, } } + auth_headers = { + "Authorization": f"Bearer {REPLICATE_TOKEN}", + "Content-Type": "application/json", + } try: - async with httpx.AsyncClient(timeout=120) as client: + async with httpx.AsyncClient(timeout=300) as client: resp = await client.post( DEEPSEEK_MODEL, - headers={ - "Authorization": f"Bearer {REPLICATE_TOKEN}", - "Content-Type": "application/json", - "Prefer": "wait", - }, + headers={**auth_headers, "Prefer": "wait=60"}, json=payload, ) resp.raise_for_status() data = resp.json() - output = data.get("output", "") + # ── Poll if Replicate accepted async (202 or status starting/processing) ── + if resp.status_code == 202 or data.get("status") in ("starting", "processing"): + poll_url = (data.get("urls") or {}).get("get") + if not poll_url: + logger.error("DeepSeek: 202 but no poll URL in response") + return [] + logger.info("DeepSeek: async prediction, polling %s", poll_url) + for attempt in range(90): # up to ~3 minutes + await asyncio.sleep(2) + pr = await client.get( + poll_url, + headers={"Authorization": f"Bearer {REPLICATE_TOKEN}"}, + ) + pdata = pr.json() + status = pdata.get("status") + logger.debug("DeepSeek poll #%d status=%s", attempt + 1, status) + if status == "succeeded": + data = pdata + break + if status in ("failed", "canceled"): + logger.error("DeepSeek prediction %s: %s", status, pdata.get("error")) + return [] + else: + logger.error("DeepSeek: prediction timed out after polling 90×2s") + return [] + + output = data.get("output") or "" if isinstance(output, list): - output = "".join(output) + output = "".join(str(t) for t in output if t is not None) logger.info("DeepSeek raw output (first 500 chars): %.500s", output) result = _parse_classify_output(output)