fix: sidecar process kill was not awaiting wait() — old server held GPU VRAM

- _kill_llama_server() was sync calling an unawaited coroutine. process.wait() created
  a discarded coroutine object — the old llama-server was never waited on to release
  GPU memory before starting a new one, causing OOM on rapid model switches.
  Fixed with async await + 10s SIGTERM timeout + SIGKILL fallback.

- Changed _switch_lock from threading.Lock to asyncio.Lock() to prevent event loop
  deadlock during long switch operations.

- Router proxy: only trigger model switches for POST /v1/chat/completions and
  /v1/completions. Non-chat endpoints (GET probes, /api/show) no longer trigger
  unwanted model reloads.

- _ollama_show_lookup: return active profile context size when model_name is empty.
  Previously returned 404, causing Hermes Desktop to default to 256k context.

- Always drain_queue() + complete_switch() after switch failure so queued requests
  don't hang forever waiting on a never-set switching event.
This commit is contained in:
root 2026-06-17 23:49:57 +00:00
parent 7e9b3f43e1
commit 45dd793b69
2 changed files with 87 additions and 30 deletions

64
main.py
View File

@ -251,14 +251,41 @@ async def ollama_show_post(request: Request):
async def _ollama_show_lookup(model_name: str):
"""Shared logic for Ollama /api/show model info lookup."""
"""Shared logic for Ollama /api/show model info lookup.
When model_name is empty string (Hermes Desktop probe with no model field),
returns the currently-active profile's info so the desktop can determine
the correct context size. Previously returned 404, causing Hermes Desktop
to default to 256k context.
"""
async with httpx.AsyncClient(timeout=5.0) as client:
try:
resp = await client.get(f"{SIDECAR_URL}/models/available")
profiles = resp.json()
status_resp = await client.get(f"{SIDECAR_URL}/models/status")
status = status_resp.json()
except Exception:
return JSONResponse(status_code=404, content={"error": "model not found"})
# If no model specified, return the currently-active profile's info
active_id = status.get("active_profile")
if not model_name and active_id:
for p in profiles:
if p.get("id") == active_id:
flags = p.get("flags", {})
ctx_size = str(flags.get("ctx-size", flags.get("n_ctx", "4096")))
return {
"modelfile": "",
"parameters": f"num_ctx {ctx_size}",
"template": "",
"details": {
"format": "gguf",
"family": p.get("name", "llm"),
"parameter_size": ctx_size,
},
"model_info": {"id": p.get("id", "")},
}
for p in profiles:
if p.get("id") == model_name:
# Extract actual context size from the profile's flags
@ -428,9 +455,17 @@ async def proxy(
body_data = json.loads(body) if body else {}
requested_model = body_data.get("model")
# Only trigger model switches for actual chat/completion POST requests.
# GET probes, /api/show lookups, and other non-chat endpoints should
# never trigger a switch — they just read current state.
is_chat_request = (
request.method == "POST"
and path in ("v1/chat/completions", "v1/completions")
)
if requested_model and sidecar_status.get("active_profile") == requested_model:
target_url = f"{MAIN_PC_BASE}/{path}"
elif requested_model:
elif requested_model and is_chat_request:
# Trigger switch for a specific model request
# Check if a switch is already in progress
current_switch = await wait_for_switch()
@ -497,14 +532,25 @@ async def proxy(
flush=True,
)
except Exception as e:
circuit_record_failure()
error = f"switch_error: {str(e)}"
else:
# No model in request body (probe/GET/non-chat request) —
# route to the currently active backend when available,
# or fall through to the fallback chain.
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
target_url = f"{MAIN_PC_BASE}/{path}"
circuit_record_failure()
print(
f"SWITCH EXCEPTION: profile={requested_model}, "
f"error={type(e).__name__}: {e}",
flush=True,
)
# Always signal queued requests on switch completion/failure
# so they don't hang forever waiting for a switch that never finishes.
complete_switch()
drain_queue()
else:
# No model in request body (probe/GET/non-chat request) —
# route to the currently active backend when available,
# or fall through to the fallback chain.
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
target_url = f"{MAIN_PC_BASE}/{path}"
# ── Fallback chain ────────────────────────────────────────────────────
if target_url is None:

View File

@ -5,7 +5,6 @@ Runs on the Main PC, manages llama-server subprocess, serves manifest/profile da
import os
import asyncio
import signal as signal_module
import threading
from contextlib import asynccontextmanager
from typing import Optional
@ -26,7 +25,7 @@ LLAMA_STDERR_LOG = os.path.join(
# Global state
_llama_server_process: Optional[asyncio.subprocess.Process] = None
_active_profile: Optional[str] = None
_switch_lock = threading.Lock() # Use threading.Lock for compatibility with TestClient
_switch_lock = asyncio.Lock() # Use asyncio.Lock to avoid blocking the event loop
@asynccontextmanager
@ -37,7 +36,7 @@ async def lifespan(app: FastAPI):
# Cleanup: kill llama-server if running
global _llama_server_process
if _llama_server_process:
_kill_llama_server()
await _kill_llama_server()
app = FastAPI(lifespan=lifespan)
@ -55,22 +54,34 @@ def _close_stderr_log():
pass
def _kill_llama_server():
"""Kill the llama-server subprocess and close its stderr log handle."""
global _llama_server_process
if _llama_server_process and _llama_server_process.returncode is None:
try:
_llama_server_process.send_signal(signal_module.SIGTERM)
try:
_llama_server_process.wait(timeout=5)
except asyncio.TimeoutError:
_llama_server_process.kill()
except Exception:
pass
_llama_server_process = None
async def _kill_llama_server():
"""Kill the llama-server subprocess and wait for it to fully terminate.
# Close stderr log handle if still open
_close_stderr_log()
This MUST be async because process.wait() is a coroutine. The synchronous
version was calling .wait() without await, creating an unawaited coroutine
object the old process was never actually waited on, so it could still
hold GPU VRAM when the new server started.
"""
global _llama_server_process
if _llama_server_process is None or _llama_server_process.returncode is not None:
_close_stderr_log()
return
try:
_llama_server_process.send_signal(signal_module.SIGTERM)
try:
await asyncio.wait_for(_llama_server_process.wait(), timeout=10)
except asyncio.TimeoutError:
_llama_server_process.kill()
try:
await asyncio.wait_for(_llama_server_process.wait(), timeout=5)
except asyncio.TimeoutError:
pass
except Exception:
pass
finally:
_llama_server_process = None
_close_stderr_log()
def _flag_value(value) -> str:
@ -105,7 +116,7 @@ async def _start_llama_server(profile: dict):
global _llama_server_process
# Kill any existing process
_kill_llama_server()
await _kill_llama_server()
# Build command from profile flags
cmd = ["/home/bigt/AI/llama.cpp/build/bin/llama-server"]
@ -200,7 +211,7 @@ async def switch_model(payload: SwitchRequest):
"""Stop current llama-server, start new one with the given profile, wait for readiness."""
global _active_profile
with _switch_lock:
async with _switch_lock:
# Validate profile_id
profiles = load_manifest(MANIFEST_PATH)
if profiles is None:
@ -229,7 +240,7 @@ async def switch_model(payload: SwitchRequest):
}
# Start the new model
_kill_llama_server()
await _kill_llama_server()
_active_profile = None
await _start_llama_server(profile)