diff --git a/main.py b/main.py index 2f2101e..099c32e 100644 --- a/main.py +++ b/main.py @@ -141,6 +141,49 @@ def complete_switch(): _switching_event.set() +async def _background_switch(requested_model: str): + """Run a model switch in the background. + + The sidecar POST is awaited but the caller gets an immediate SSE stream + so Hermes Desktop doesn't timeout waiting for the first response. + + Called via asyncio.create_task() so it runs concurrently with the + SSE stream being sent to the client. + """ + try: + async with httpx.AsyncClient(timeout=120.0) as client: + switch_resp = await client.post( + f"{SIDECAR_URL}/models/switch", + json={"profile_id": requested_model}, + ) + switch_result = switch_resp.json() + if switch_result.get("status") == "ready": + print( + f"SWITCH SUCCESS: profile={requested_model}", + flush=True, + ) + else: + circuit_record_failure() + print( + f"SWITCH FAILED: profile={requested_model}, " + f"status={switch_result.get('status')}, " + f"message={switch_result.get('message', '(no message)')}", + flush=True, + ) + except Exception as e: + circuit_record_failure() + print( + f"SWITCH EXCEPTION: profile={requested_model}, " + f"error={type(e).__name__}: {e}", + flush=True, + ) + finally: + # Signal all queued requests so they can proceed (and fall + # through to the fallback chain if the switch failed). + complete_switch() + drain_queue() + + # ─── App ───────────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): @@ -466,84 +509,93 @@ async def proxy( if requested_model and sidecar_status.get("active_profile") == requested_model: target_url = f"{MAIN_PC_BASE}/{path}" elif requested_model and is_chat_request: - # Trigger switch for a specific model request - # Check if a switch is already in progress + # All requests during a model switch get an immediate SSE streaming + # response so clients (Hermes Desktop) don't timeout while waiting + # for the model to load (10-30s). The switch runs in a background + # task; the SSE stream yields progress events, then pipes through + # the actual response once the backend model is ready. current_switch = await wait_for_switch() + if current_switch is None: + # No switch in progress — start one in the background + await start_switch() + asyncio.create_task(_background_switch(requested_model)) - if current_switch is not None and not current_switch.is_set(): - # Another request started the switch — queue this one + # Queue this request — signals when switch completes + try: + wait_evt = await queue_request() + except HTTPException as he: + raise + + # Build request headers once + req_headers = dict(request.headers) + req_headers.pop("host", None) + + async def stream_with_sse(): + sse_gen = sse_progress_stream(wait_evt) try: - wait_evt = await queue_request() - except HTTPException as he: - raise - - # SSE progress while waiting - async def stream_with_sse(): - sse_gen = sse_progress_stream(wait_evt) + await wait_evt.wait() + async for sse_chunk in sse_gen: + yield sse_chunk + # Send actual request to Main PC + async with httpx.AsyncClient(timeout=60.0) as c: + async with c.stream( + request.method, + f"{MAIN_PC_BASE}/{path}", + content=body, + headers=req_headers, + ) as resp: + async for chunk in resp.aiter_bytes(): + yield chunk + except Exception: + # Main PC unreachable (switch failed or server died) — + # try fallback chain + yield _sse_format( + "error", + {"message": "Backend unreachable, trying fallback..."}, + ) + # Try OpenRouter + if OPENROUTER_API_KEY: + try: + fb_headers = dict(req_headers) + fb_headers["Authorization"] = f"Bearer {OPENROUTER_API_KEY}" + async with httpx.AsyncClient(timeout=60.0) as c: + async with c.stream( + request.method, + f"{OPENROUTER_BASE}/{path}", + content=body, + headers=fb_headers, + ) as resp: + async for chunk in resp.aiter_bytes(): + yield chunk + return + except Exception: + pass + # Fallback to LXC SLM try: - await wait_evt.wait() - async for sse_chunk in sse_gen: - yield sse_chunk - complete_switch() - drain_queue() async with httpx.AsyncClient(timeout=60.0) as c: - req_headers = dict(request.headers) - req_headers.pop("host", None) async with c.stream( request.method, - f"{MAIN_PC_BASE}/{path}", + f"{FALLBACK_SLM_URL}/{path}", content=body, headers=req_headers, ) as resp: async for chunk in resp.aiter_bytes(): yield chunk - finally: - # Clean up sse_gen - try: - await sse_gen.aclose() - except Exception: - pass - - return StreamingResponse( - stream_with_sse(), - media_type="text/event-stream", - ) - - # First request triggers the switch - await start_switch() # Create event for tracking - try: - async with httpx.AsyncClient(timeout=120.0) as client: - switch_resp = await client.post( - f"{SIDECAR_URL}/models/switch", - json={"profile_id": requested_model}, - ) - switch_result = switch_resp.json() - if switch_result.get("status") == "ready": - complete_switch() - drain_queue() - target_url = f"{MAIN_PC_BASE}/{path}" - else: - error = "switch_failed" - circuit_record_failure() - print( - f"SWITCH FAILED: profile={requested_model}, " - f"sidecar_status={switch_result.get('status')}, " - f"message={switch_result.get('message', '(no message)')}", - flush=True, + except Exception: + yield _sse_format( + "error", + {"message": "All backends unavailable"}, ) - except Exception as e: - error = f"switch_error: {str(e)}" - circuit_record_failure() - print( - f"SWITCH EXCEPTION: profile={requested_model}, " - f"error={type(e).__name__}: {e}", - flush=True, - ) + finally: + try: + await sse_gen.aclose() + except Exception: + pass - # Always signal queued requests on switch completion/failure - # so they don't hang forever waiting for a switch that never finishes. - complete_switch() - drain_queue() + return StreamingResponse( + stream_with_sse(), + media_type="text/event-stream", + ) else: # No model in request body (probe/GET/non-chat request) —