fix: sidecar process kill was not awaiting wait() — old server held GPU VRAM
- _kill_llama_server() was sync calling an unawaited coroutine. process.wait() created a discarded coroutine object — the old llama-server was never waited on to release GPU memory before starting a new one, causing OOM on rapid model switches. Fixed with async await + 10s SIGTERM timeout + SIGKILL fallback. - Changed _switch_lock from threading.Lock to asyncio.Lock() to prevent event loop deadlock during long switch operations. - Router proxy: only trigger model switches for POST /v1/chat/completions and /v1/completions. Non-chat endpoints (GET probes, /api/show) no longer trigger unwanted model reloads. - _ollama_show_lookup: return active profile context size when model_name is empty. Previously returned 404, causing Hermes Desktop to default to 256k context. - Always drain_queue() + complete_switch() after switch failure so queued requests don't hang forever waiting on a never-set switching event.
This commit is contained in:
parent
7e9b3f43e1
commit
45dd793b69
64
main.py
64
main.py
@ -251,14 +251,41 @@ async def ollama_show_post(request: Request):
|
||||
|
||||
|
||||
async def _ollama_show_lookup(model_name: str):
|
||||
"""Shared logic for Ollama /api/show model info lookup."""
|
||||
"""Shared logic for Ollama /api/show model info lookup.
|
||||
|
||||
When model_name is empty string (Hermes Desktop probe with no model field),
|
||||
returns the currently-active profile's info so the desktop can determine
|
||||
the correct context size. Previously returned 404, causing Hermes Desktop
|
||||
to default to 256k context.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
try:
|
||||
resp = await client.get(f"{SIDECAR_URL}/models/available")
|
||||
profiles = resp.json()
|
||||
status_resp = await client.get(f"{SIDECAR_URL}/models/status")
|
||||
status = status_resp.json()
|
||||
except Exception:
|
||||
return JSONResponse(status_code=404, content={"error": "model not found"})
|
||||
|
||||
# If no model specified, return the currently-active profile's info
|
||||
active_id = status.get("active_profile")
|
||||
if not model_name and active_id:
|
||||
for p in profiles:
|
||||
if p.get("id") == active_id:
|
||||
flags = p.get("flags", {})
|
||||
ctx_size = str(flags.get("ctx-size", flags.get("n_ctx", "4096")))
|
||||
return {
|
||||
"modelfile": "",
|
||||
"parameters": f"num_ctx {ctx_size}",
|
||||
"template": "",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": p.get("name", "llm"),
|
||||
"parameter_size": ctx_size,
|
||||
},
|
||||
"model_info": {"id": p.get("id", "")},
|
||||
}
|
||||
|
||||
for p in profiles:
|
||||
if p.get("id") == model_name:
|
||||
# Extract actual context size from the profile's flags
|
||||
@ -428,9 +455,17 @@ async def proxy(
|
||||
body_data = json.loads(body) if body else {}
|
||||
requested_model = body_data.get("model")
|
||||
|
||||
# Only trigger model switches for actual chat/completion POST requests.
|
||||
# GET probes, /api/show lookups, and other non-chat endpoints should
|
||||
# never trigger a switch — they just read current state.
|
||||
is_chat_request = (
|
||||
request.method == "POST"
|
||||
and path in ("v1/chat/completions", "v1/completions")
|
||||
)
|
||||
|
||||
if requested_model and sidecar_status.get("active_profile") == requested_model:
|
||||
target_url = f"{MAIN_PC_BASE}/{path}"
|
||||
elif requested_model:
|
||||
elif requested_model and is_chat_request:
|
||||
# Trigger switch for a specific model request
|
||||
# Check if a switch is already in progress
|
||||
current_switch = await wait_for_switch()
|
||||
@ -497,14 +532,25 @@ async def proxy(
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
circuit_record_failure()
|
||||
error = f"switch_error: {str(e)}"
|
||||
else:
|
||||
# No model in request body (probe/GET/non-chat request) —
|
||||
# route to the currently active backend when available,
|
||||
# or fall through to the fallback chain.
|
||||
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
|
||||
target_url = f"{MAIN_PC_BASE}/{path}"
|
||||
circuit_record_failure()
|
||||
print(
|
||||
f"SWITCH EXCEPTION: profile={requested_model}, "
|
||||
f"error={type(e).__name__}: {e}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Always signal queued requests on switch completion/failure
|
||||
# so they don't hang forever waiting for a switch that never finishes.
|
||||
complete_switch()
|
||||
drain_queue()
|
||||
|
||||
else:
|
||||
# No model in request body (probe/GET/non-chat request) —
|
||||
# route to the currently active backend when available,
|
||||
# or fall through to the fallback chain.
|
||||
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
|
||||
target_url = f"{MAIN_PC_BASE}/{path}"
|
||||
|
||||
# ── Fallback chain ────────────────────────────────────────────────────
|
||||
if target_url is None:
|
||||
|
||||
@ -5,7 +5,6 @@ Runs on the Main PC, manages llama-server subprocess, serves manifest/profile da
|
||||
import os
|
||||
import asyncio
|
||||
import signal as signal_module
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@ -26,7 +25,7 @@ LLAMA_STDERR_LOG = os.path.join(
|
||||
# Global state
|
||||
_llama_server_process: Optional[asyncio.subprocess.Process] = None
|
||||
_active_profile: Optional[str] = None
|
||||
_switch_lock = threading.Lock() # Use threading.Lock for compatibility with TestClient
|
||||
_switch_lock = asyncio.Lock() # Use asyncio.Lock to avoid blocking the event loop
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -37,7 +36,7 @@ async def lifespan(app: FastAPI):
|
||||
# Cleanup: kill llama-server if running
|
||||
global _llama_server_process
|
||||
if _llama_server_process:
|
||||
_kill_llama_server()
|
||||
await _kill_llama_server()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@ -55,22 +54,34 @@ def _close_stderr_log():
|
||||
pass
|
||||
|
||||
|
||||
def _kill_llama_server():
|
||||
"""Kill the llama-server subprocess and close its stderr log handle."""
|
||||
global _llama_server_process
|
||||
if _llama_server_process and _llama_server_process.returncode is None:
|
||||
try:
|
||||
_llama_server_process.send_signal(signal_module.SIGTERM)
|
||||
try:
|
||||
_llama_server_process.wait(timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
_llama_server_process.kill()
|
||||
except Exception:
|
||||
pass
|
||||
_llama_server_process = None
|
||||
async def _kill_llama_server():
|
||||
"""Kill the llama-server subprocess and wait for it to fully terminate.
|
||||
|
||||
# Close stderr log handle if still open
|
||||
_close_stderr_log()
|
||||
This MUST be async because process.wait() is a coroutine. The synchronous
|
||||
version was calling .wait() without await, creating an unawaited coroutine
|
||||
object — the old process was never actually waited on, so it could still
|
||||
hold GPU VRAM when the new server started.
|
||||
"""
|
||||
global _llama_server_process
|
||||
if _llama_server_process is None or _llama_server_process.returncode is not None:
|
||||
_close_stderr_log()
|
||||
return
|
||||
|
||||
try:
|
||||
_llama_server_process.send_signal(signal_module.SIGTERM)
|
||||
try:
|
||||
await asyncio.wait_for(_llama_server_process.wait(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
_llama_server_process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(_llama_server_process.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
_llama_server_process = None
|
||||
_close_stderr_log()
|
||||
|
||||
|
||||
def _flag_value(value) -> str:
|
||||
@ -105,7 +116,7 @@ async def _start_llama_server(profile: dict):
|
||||
global _llama_server_process
|
||||
|
||||
# Kill any existing process
|
||||
_kill_llama_server()
|
||||
await _kill_llama_server()
|
||||
|
||||
# Build command from profile flags
|
||||
cmd = ["/home/bigt/AI/llama.cpp/build/bin/llama-server"]
|
||||
@ -200,7 +211,7 @@ async def switch_model(payload: SwitchRequest):
|
||||
"""Stop current llama-server, start new one with the given profile, wait for readiness."""
|
||||
global _active_profile
|
||||
|
||||
with _switch_lock:
|
||||
async with _switch_lock:
|
||||
# Validate profile_id
|
||||
profiles = load_manifest(MANIFEST_PATH)
|
||||
if profiles is None:
|
||||
@ -229,7 +240,7 @@ async def switch_model(payload: SwitchRequest):
|
||||
}
|
||||
|
||||
# Start the new model
|
||||
_kill_llama_server()
|
||||
await _kill_llama_server()
|
||||
_active_profile = None
|
||||
await _start_llama_server(profile)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user