Issue #4: Automatic model detection and switch - Router extracts model from chat body, queries sidecar, triggers switch on mismatch - Matching active model routes directly to Main PC - No active model triggers cold start switch - Tests: 4 test_router_model_detection.py Issue #5: SSE switch progress feedback - _sse_format() correctly serializes SSE events - sse_progress_stream() generates phase progression events - Proxy yields SSE events then actual response - Tests: 3 test_router_sse_progress.py Issue #6: Circuit breaker + OpenRouter fallback - Circuit tracks Sidecar failures, opens after MAX_RECOVERY_ATTEMPTS (3) - OpenRouter API key from env, no longer uses x-intelligence-level header - Fixes: OPENROUTER_BASE, SSE format, circuit state isolation - Tests: 7 test_router_circuit_breaker.py Issue #7: LXC fallback chain completion - Full fallback: Main PC → OpenRouter → LXC - Each backend health-checked via /v1/models before routing - All backends down → 503 response - Fixed: execute() wrapped in try/except to trigger fallback chain - Tests: 3 test_router_fallback_lxc.py Issue #8: Systemd service deployment - deploy/llm-sidecar.service: systemd unit with Restart=always - deploy/manifest.yaml: example manifest with 3 profiles - deploy/README.md: deployment instructions - Updated: docker-compose.yml, requirements.txt, Dockerfile Test framework improvements: - tests/conftest.py: shared URL patches for all router tests - Fixed global state pollution in circuit breaker tests - Fixed test sidecar switch test (AsyncMock for async function) Total: 42 tests passing
436 lines
18 KiB
Python
436 lines
18 KiB
Python
import os
|
||
import asyncio
|
||
import json
|
||
import threading
|
||
from contextlib import asynccontextmanager
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
from fastapi import FastAPI, Request, Response, Header, HTTPException
|
||
from fastapi.responses import StreamingResponse, JSONResponse
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
# ─── Configuration ───────────────────────────────────────────────────────────
|
||
SIDECAR_URL = os.getenv("SIDECAR_URL", "http://10.0.4.11:8081")
|
||
MAIN_PC_BASE = os.getenv("MAIN_PC_URL", "http://10.0.4.11:8080/v1").removesuffix("/v1")
|
||
FALLBACK_SLM_URL = os.getenv("FALLBACK_SLM_URL", "http://10.0.4.200:8080/v1").removesuffix("/v1")
|
||
OPENROUTER_API_KEY=os.getenv("OPENROUTER_API_KEY", "")
|
||
OPENROUTER_BASE = "https://openrouter.ai"
|
||
|
||
print(f"SIDECAR_URL={SIDECAR_URL}")
|
||
print(f"MAIN_PC_BASE={MAIN_PC_BASE}")
|
||
print(f"FALLBACK_SLM_URL={FALLBACK_SLM_URL}")
|
||
|
||
# ─── Request Queue ───────────────────────────────────────────────────────────
|
||
_MAX_QUEUE_SIZE = 10
|
||
_QUEUE_TIMEOUT = 120 # seconds
|
||
|
||
_queue_lock = asyncio.Lock()
|
||
_queue: list = []
|
||
|
||
|
||
async def queue_request() -> asyncio.Event:
|
||
"""Add a request to the queue. Raises 429 if full."""
|
||
global _queue
|
||
async with _queue_lock:
|
||
if len(_queue) >= _MAX_QUEUE_SIZE:
|
||
raise HTTPException(status_code=429, detail="Server is busy, too many queued requests")
|
||
event = asyncio.Event()
|
||
_queue.append(event)
|
||
|
||
try:
|
||
await asyncio.wait_for(event.wait(), timeout=_QUEUE_TIMEOUT)
|
||
return event
|
||
except asyncio.TimeoutError:
|
||
async with _queue_lock:
|
||
if event in _queue:
|
||
_queue.remove(event)
|
||
raise HTTPException(status_code=429, detail="Request timed out waiting for model switch")
|
||
|
||
|
||
def drain_queue():
|
||
"""Signal all queued requests that the model is ready."""
|
||
lock = threading.Lock()
|
||
with lock:
|
||
for event in _queue:
|
||
event.set()
|
||
_queue.clear()
|
||
|
||
|
||
# ─── Circuit Breaker ────────────────────────────────────────────────────────
|
||
MAX_RECOVERY_ATTEMPTS = 3
|
||
_recovery_attempts = 0
|
||
_circuit_open = False
|
||
_circuit_lock = asyncio.Lock()
|
||
|
||
|
||
async def circuit_breaker_check() -> bool:
|
||
"""Check if the circuit allows a Sidecar request."""
|
||
global _circuit_open
|
||
async with _circuit_lock:
|
||
return not _circuit_open
|
||
|
||
|
||
def circuit_reset():
|
||
"""Reset circuit breaker after a successful Sidecar interaction."""
|
||
global _circuit_open, _recovery_attempts
|
||
_circuit_open = False
|
||
_recovery_attempts = 0
|
||
|
||
|
||
def circuit_record_failure():
|
||
"""Record a Sidecar failure. Opens circuit after MAX_RECOVERY_ATTEMPTS."""
|
||
global _circuit_open, _recovery_attempts
|
||
_recovery_attempts += 1
|
||
if _recovery_attempts >= MAX_RECOVERY_ATTEMPTS:
|
||
_circuit_open = True
|
||
print(f"Circuit breaker OPENED after {_recovery_attempts} failures")
|
||
|
||
|
||
# ─── SSE Helpers ─────────────────────────────────────────────────────────────
|
||
def _sse_format(event: str, data: dict) -> str:
|
||
lines = [f"event: {event}"]
|
||
lines.append(f"data: {json.dumps(data)}")
|
||
lines.append("")
|
||
lines.append("")
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ─── Router State ───────────────────────────────────────────────────────────
|
||
_switching_event: Optional[asyncio.Event] = None
|
||
_switching_lock = threading.Lock()
|
||
|
||
|
||
async def start_switch():
|
||
"""Signal that a switch has started. Creates an unset event to track the switch."""
|
||
global _switching_event
|
||
with _switching_lock:
|
||
if _switching_event is None or _switching_event.is_set():
|
||
_switching_event = asyncio.Event()
|
||
|
||
|
||
async def wait_for_switch():
|
||
"""Wait for the current switch to complete. Returns None if no active switch.
|
||
|
||
Returns None immediately if no switch is in progress (event is None or set).
|
||
If a switch IS in progress, waits for it to complete and then clears the event.
|
||
"""
|
||
global _switching_event
|
||
with _switching_lock:
|
||
if _switching_event is None or _switching_event.is_set():
|
||
# No switch happening, or already done
|
||
return None
|
||
evt = _switching_event
|
||
|
||
# A switch IS in progress — wait for it
|
||
await evt.wait()
|
||
|
||
# Switch is done — clear for next time
|
||
with _switching_lock:
|
||
if _switching_event is not None and _switching_event.is_set():
|
||
_switching_event = None
|
||
|
||
|
||
def complete_switch():
|
||
"""Mark the current switch as complete. Signals waiting requests."""
|
||
global _switching_event
|
||
with _switching_lock:
|
||
if _switching_event is not None and not _switching_event.is_set():
|
||
_switching_event.set()
|
||
|
||
|
||
# ─── App ─────────────────────────────────────────────────────────────────────
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
print("Intelligence Router starting")
|
||
yield
|
||
print("Intelligence Router shutting down")
|
||
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
|
||
# ─── GET /v1/models — Issue #2 ──────────────────────────────────────────────
|
||
@app.get("/v1/models")
|
||
async def get_models():
|
||
"""OpenAI-compatible /v1/models endpoint proxying to Sidecar."""
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
try:
|
||
resp = await client.get(f"{SIDECAR_URL}/models/available")
|
||
profiles = resp.json()
|
||
except Exception:
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"error": "Sidecar unavailable", "data": []},
|
||
)
|
||
|
||
models_data = [
|
||
{"id": p["id"], "object": "model", "owned_by": "sidecar"}
|
||
for p in profiles
|
||
]
|
||
return {"object": "list", "data": models_data}
|
||
|
||
|
||
# ─── GET /health ─────────────────────────────────────────────────────────────
|
||
@app.get("/health")
|
||
async def health():
|
||
return {"status": "router_online"}
|
||
|
||
|
||
# ─── GET /models/status ──────────────────────────────────────────────────────
|
||
@app.get("/models/status")
|
||
async def router_model_status():
|
||
"""Proxy to Sidecar /models/status."""
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
try:
|
||
resp = await client.get(f"{SIDECAR_URL}/models/status")
|
||
return resp.json()
|
||
except Exception:
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"error": "Sidecar unavailable"},
|
||
)
|
||
|
||
|
||
# ─── POST /models/switch — Issue #3 ──────────────────────────────────────────
|
||
@app.post("/models/switch")
|
||
async def router_model_switch(request: dict):
|
||
"""Proxy to Sidecar /models/switch."""
|
||
profile_id = request.get("profile_id")
|
||
if not profile_id:
|
||
raise HTTPException(status_code=400, detail="profile_id is required")
|
||
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
try:
|
||
resp = await client.post(
|
||
f"{SIDECAR_URL}/models/switch",
|
||
json={"profile_id": profile_id},
|
||
)
|
||
return resp.json()
|
||
except Exception as e:
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"status": "error", "message": f"Sidecar error: {str(e)}"},
|
||
)
|
||
|
||
|
||
# ─── SSE Progress Stream Generator ───────────────────────────────────────────
|
||
async def sse_progress_stream(event: asyncio.Event):
|
||
"""Generate SSE events while a model switch is in progress."""
|
||
phases = [
|
||
("stopping", "Stopping current model..."),
|
||
("starting", "Loading new model..."),
|
||
("waiting", "Waiting for model to be ready..."),
|
||
]
|
||
for phase, msg in phases:
|
||
if event.is_set():
|
||
yield _sse_format("model_switching", {"phase": phase, "message": msg})
|
||
yield _sse_format("model_switching", {"phase": "complete", "message": "Switch complete"})
|
||
return
|
||
yield _sse_format("model_switching", {"phase": phase, "message": msg})
|
||
await asyncio.sleep(2)
|
||
|
||
yield _sse_format("model_switching", {"phase": "complete", "message": "Processing your request..."})
|
||
|
||
|
||
# ─── Proxy Endpoint — Issues #2–#7 ───────────────────────────────────────────
|
||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||
async def proxy(
|
||
request: Request,
|
||
path: str,
|
||
x_intelligence_level: str = Header(None),
|
||
):
|
||
"""
|
||
Smart Proxy with full fallback chain.
|
||
Sidecar → Main PC → OpenRouter → LXC
|
||
"""
|
||
# Issue #6: Remove deprecated x-intelligence-level routing
|
||
del x_intelligence_level # type: ignore[unused-coroutine]
|
||
|
||
# Skip proxy for known sidecar admin endpoints
|
||
if path.startswith("models/available") or \
|
||
path.startswith("models/switch") or \
|
||
path.startswith("models/status"):
|
||
raise HTTPException(status_code=404, detail="Use the appropriate endpoint")
|
||
|
||
# ── Determine target URL ──────────────────────────────────────────────
|
||
target_url: Optional[str] = None
|
||
error: Optional[str] = None
|
||
|
||
# Circuit breaker check
|
||
if not await circuit_breaker_check():
|
||
error = "circuit_open"
|
||
else:
|
||
# Query Sidecar for active model
|
||
sidecar_status = None
|
||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||
try:
|
||
resp = await client.get(f"{SIDECAR_URL}/models/status")
|
||
if resp.status_code == 200:
|
||
sidecar_status = resp.json()
|
||
circuit_reset()
|
||
except Exception:
|
||
error = "sidecar_down"
|
||
|
||
if sidecar_status is None:
|
||
circuit_record_failure()
|
||
error = "sidecar_down"
|
||
else:
|
||
# Extract requested model from request body
|
||
body = await request.body()
|
||
body_data = json.loads(body) if body else {}
|
||
requested_model = body_data.get("model")
|
||
|
||
if requested_model and sidecar_status.get("active_profile") == requested_model:
|
||
target_url = f"{MAIN_PC_BASE}/{path}"
|
||
else:
|
||
# Trigger switch
|
||
if requested_model:
|
||
# Check if a switch is already in progress
|
||
current_switch = await wait_for_switch()
|
||
|
||
if current_switch is not None and not current_switch.is_set():
|
||
# Another request started the switch — queue this one
|
||
try:
|
||
wait_evt = await queue_request()
|
||
except HTTPException as he:
|
||
raise
|
||
|
||
# SSE progress while waiting
|
||
async def stream_with_sse():
|
||
sse_gen = sse_progress_stream(wait_evt)
|
||
try:
|
||
await wait_evt.wait()
|
||
async for sse_chunk in sse_gen:
|
||
yield sse_chunk
|
||
complete_switch()
|
||
drain_queue()
|
||
async with httpx.AsyncClient(timeout=60.0) as c:
|
||
req_headers = dict(request.headers)
|
||
req_headers.pop("host", None)
|
||
async with c.stream(
|
||
request.method,
|
||
f"{MAIN_PC_BASE}/{path}",
|
||
content=body,
|
||
headers=req_headers,
|
||
) as resp:
|
||
async for chunk in resp.aiter_bytes():
|
||
yield chunk
|
||
finally:
|
||
# Clean up sse_gen
|
||
try:
|
||
await sse_gen.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
return StreamingResponse(
|
||
stream_with_sse(),
|
||
media_type="text/event-stream",
|
||
)
|
||
|
||
# First request triggers the switch
|
||
await start_switch() # Create event for tracking
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
switch_resp = await client.post(
|
||
f"{SIDECAR_URL}/models/switch",
|
||
json={"profile_id": requested_model},
|
||
)
|
||
switch_result = switch_resp.json()
|
||
if switch_result.get("status") == "ready":
|
||
complete_switch()
|
||
drain_queue()
|
||
target_url = f"{MAIN_PC_BASE}/{path}"
|
||
else:
|
||
error = "switch_failed"
|
||
except Exception as e:
|
||
circuit_record_failure()
|
||
error = f"switch_error: {str(e)}"
|
||
|
||
# ── Fallback chain ────────────────────────────────────────────────────
|
||
if target_url is None:
|
||
if error in ("sidecar_down", "circuit_open", "switch_failed"):
|
||
if OPENROUTER_API_KEY:
|
||
target_url = f"{OPENROUTER_BASE}/{path}"
|
||
else:
|
||
target_url = f"{FALLBACK_SLM_URL}/{path}"
|
||
|
||
if not target_url:
|
||
return Response(content="No valid target available", status_code=503)
|
||
|
||
# ── Prepare request ───────────────────────────────────────────────────
|
||
body = await request.body()
|
||
headers = dict(request.headers)
|
||
headers.pop("host", None)
|
||
headers.pop("content-length", None)
|
||
|
||
# ── Execute request with fallback ─────────────────────────────────────
|
||
exec_error: Optional[Exception] = None
|
||
|
||
async def execute(target: str) -> Optional[Response]:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
accept_header = request.headers.get("accept", "")
|
||
if "text/event-stream" in accept_header or "application/x-ndjson" in accept_header:
|
||
async def gen():
|
||
async with client.stream(
|
||
request.method, target,
|
||
content=body, headers=headers,
|
||
) as resp:
|
||
async for chunk in resp.aiter_bytes():
|
||
yield chunk
|
||
return StreamingResponse(gen(), status_code=200)
|
||
|
||
resp = await client.request(
|
||
method=request.method,
|
||
url=target,
|
||
content=body,
|
||
headers=headers,
|
||
)
|
||
return Response(
|
||
content=resp.content,
|
||
status_code=resp.status_code,
|
||
headers=dict(resp.headers),
|
||
)
|
||
|
||
primary_result = None
|
||
try:
|
||
primary_result = await execute(target_url)
|
||
except Exception:
|
||
pass # Falls through to fallback chain
|
||
if primary_result is not None:
|
||
return primary_result
|
||
|
||
# ── Fallback chain: Main PC → OpenRouter → LXC ──────────────────────
|
||
fallback_order = []
|
||
|
||
# Determine which backends are still viable
|
||
if target_url.startswith(MAIN_PC_BASE):
|
||
if OPENROUTER_API_KEY:
|
||
fallback_order.append((OPENROUTER_BASE, OPENROUTER_API_KEY))
|
||
fallback_order.append((FALLBACK_SLM_URL, None))
|
||
elif target_url.startswith(OPENROUTER_BASE):
|
||
fallback_order.append((FALLBACK_SLM_URL, None))
|
||
|
||
for fb_base, fb_key in fallback_order:
|
||
# Check health before routing
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
try:
|
||
resp = await client.get(f"{fb_base}/v1/models")
|
||
if resp.status_code != 200:
|
||
continue
|
||
fb_url = f"{fb_base}/{path}"
|
||
if fb_key:
|
||
headers["Authorization"] = f"Bearer {fb_key}"
|
||
result = await execute(fb_url)
|
||
if result is not None:
|
||
return result
|
||
except Exception:
|
||
continue
|
||
|
||
return Response(
|
||
content="No valid target available (all backends down)",
|
||
status_code=503,
|
||
)
|