fix: sidecar process kill was not awaiting wait() — old server held GPU VRAM
- _kill_llama_server() was sync calling an unawaited coroutine. process.wait() created a discarded coroutine object — the old llama-server was never waited on to release GPU memory before starting a new one, causing OOM on rapid model switches. Fixed with async await + 10s SIGTERM timeout + SIGKILL fallback. - Changed _switch_lock from threading.Lock to asyncio.Lock() to prevent event loop deadlock during long switch operations. - Router proxy: only trigger model switches for POST /v1/chat/completions and /v1/completions. Non-chat endpoints (GET probes, /api/show) no longer trigger unwanted model reloads. - _ollama_show_lookup: return active profile context size when model_name is empty. Previously returned 404, causing Hermes Desktop to default to 256k context. - Always drain_queue() + complete_switch() after switch failure so queued requests don't hang forever waiting on a never-set switching event.
This commit is contained in:
parent
7e9b3f43e1
commit
45dd793b69
64
main.py
64
main.py
@ -251,14 +251,41 @@ async def ollama_show_post(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
async def _ollama_show_lookup(model_name: str):
|
async def _ollama_show_lookup(model_name: str):
|
||||||
"""Shared logic for Ollama /api/show model info lookup."""
|
"""Shared logic for Ollama /api/show model info lookup.
|
||||||
|
|
||||||
|
When model_name is empty string (Hermes Desktop probe with no model field),
|
||||||
|
returns the currently-active profile's info so the desktop can determine
|
||||||
|
the correct context size. Previously returned 404, causing Hermes Desktop
|
||||||
|
to default to 256k context.
|
||||||
|
"""
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
try:
|
try:
|
||||||
resp = await client.get(f"{SIDECAR_URL}/models/available")
|
resp = await client.get(f"{SIDECAR_URL}/models/available")
|
||||||
profiles = resp.json()
|
profiles = resp.json()
|
||||||
|
status_resp = await client.get(f"{SIDECAR_URL}/models/status")
|
||||||
|
status = status_resp.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return JSONResponse(status_code=404, content={"error": "model not found"})
|
return JSONResponse(status_code=404, content={"error": "model not found"})
|
||||||
|
|
||||||
|
# If no model specified, return the currently-active profile's info
|
||||||
|
active_id = status.get("active_profile")
|
||||||
|
if not model_name and active_id:
|
||||||
|
for p in profiles:
|
||||||
|
if p.get("id") == active_id:
|
||||||
|
flags = p.get("flags", {})
|
||||||
|
ctx_size = str(flags.get("ctx-size", flags.get("n_ctx", "4096")))
|
||||||
|
return {
|
||||||
|
"modelfile": "",
|
||||||
|
"parameters": f"num_ctx {ctx_size}",
|
||||||
|
"template": "",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": p.get("name", "llm"),
|
||||||
|
"parameter_size": ctx_size,
|
||||||
|
},
|
||||||
|
"model_info": {"id": p.get("id", "")},
|
||||||
|
}
|
||||||
|
|
||||||
for p in profiles:
|
for p in profiles:
|
||||||
if p.get("id") == model_name:
|
if p.get("id") == model_name:
|
||||||
# Extract actual context size from the profile's flags
|
# Extract actual context size from the profile's flags
|
||||||
@ -428,9 +455,17 @@ async def proxy(
|
|||||||
body_data = json.loads(body) if body else {}
|
body_data = json.loads(body) if body else {}
|
||||||
requested_model = body_data.get("model")
|
requested_model = body_data.get("model")
|
||||||
|
|
||||||
|
# Only trigger model switches for actual chat/completion POST requests.
|
||||||
|
# GET probes, /api/show lookups, and other non-chat endpoints should
|
||||||
|
# never trigger a switch — they just read current state.
|
||||||
|
is_chat_request = (
|
||||||
|
request.method == "POST"
|
||||||
|
and path in ("v1/chat/completions", "v1/completions")
|
||||||
|
)
|
||||||
|
|
||||||
if requested_model and sidecar_status.get("active_profile") == requested_model:
|
if requested_model and sidecar_status.get("active_profile") == requested_model:
|
||||||
target_url = f"{MAIN_PC_BASE}/{path}"
|
target_url = f"{MAIN_PC_BASE}/{path}"
|
||||||
elif requested_model:
|
elif requested_model and is_chat_request:
|
||||||
# Trigger switch for a specific model request
|
# Trigger switch for a specific model request
|
||||||
# Check if a switch is already in progress
|
# Check if a switch is already in progress
|
||||||
current_switch = await wait_for_switch()
|
current_switch = await wait_for_switch()
|
||||||
@ -497,14 +532,25 @@ async def proxy(
|
|||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
circuit_record_failure()
|
|
||||||
error = f"switch_error: {str(e)}"
|
error = f"switch_error: {str(e)}"
|
||||||
else:
|
circuit_record_failure()
|
||||||
# No model in request body (probe/GET/non-chat request) —
|
print(
|
||||||
# route to the currently active backend when available,
|
f"SWITCH EXCEPTION: profile={requested_model}, "
|
||||||
# or fall through to the fallback chain.
|
f"error={type(e).__name__}: {e}",
|
||||||
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
|
flush=True,
|
||||||
target_url = f"{MAIN_PC_BASE}/{path}"
|
)
|
||||||
|
|
||||||
|
# Always signal queued requests on switch completion/failure
|
||||||
|
# so they don't hang forever waiting for a switch that never finishes.
|
||||||
|
complete_switch()
|
||||||
|
drain_queue()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No model in request body (probe/GET/non-chat request) —
|
||||||
|
# route to the currently active backend when available,
|
||||||
|
# or fall through to the fallback chain.
|
||||||
|
if sidecar_status.get("active_profile") and sidecar_status.get("llama_server_running"):
|
||||||
|
target_url = f"{MAIN_PC_BASE}/{path}"
|
||||||
|
|
||||||
# ── Fallback chain ────────────────────────────────────────────────────
|
# ── Fallback chain ────────────────────────────────────────────────────
|
||||||
if target_url is None:
|
if target_url is None:
|
||||||
|
|||||||
@ -5,7 +5,6 @@ Runs on the Main PC, manages llama-server subprocess, serves manifest/profile da
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal as signal_module
|
import signal as signal_module
|
||||||
import threading
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -26,7 +25,7 @@ LLAMA_STDERR_LOG = os.path.join(
|
|||||||
# Global state
|
# Global state
|
||||||
_llama_server_process: Optional[asyncio.subprocess.Process] = None
|
_llama_server_process: Optional[asyncio.subprocess.Process] = None
|
||||||
_active_profile: Optional[str] = None
|
_active_profile: Optional[str] = None
|
||||||
_switch_lock = threading.Lock() # Use threading.Lock for compatibility with TestClient
|
_switch_lock = asyncio.Lock() # Use asyncio.Lock to avoid blocking the event loop
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -37,7 +36,7 @@ async def lifespan(app: FastAPI):
|
|||||||
# Cleanup: kill llama-server if running
|
# Cleanup: kill llama-server if running
|
||||||
global _llama_server_process
|
global _llama_server_process
|
||||||
if _llama_server_process:
|
if _llama_server_process:
|
||||||
_kill_llama_server()
|
await _kill_llama_server()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
@ -55,22 +54,34 @@ def _close_stderr_log():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _kill_llama_server():
|
async def _kill_llama_server():
|
||||||
"""Kill the llama-server subprocess and close its stderr log handle."""
|
"""Kill the llama-server subprocess and wait for it to fully terminate.
|
||||||
global _llama_server_process
|
|
||||||
if _llama_server_process and _llama_server_process.returncode is None:
|
|
||||||
try:
|
|
||||||
_llama_server_process.send_signal(signal_module.SIGTERM)
|
|
||||||
try:
|
|
||||||
_llama_server_process.wait(timeout=5)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
_llama_server_process.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
_llama_server_process = None
|
|
||||||
|
|
||||||
# Close stderr log handle if still open
|
This MUST be async because process.wait() is a coroutine. The synchronous
|
||||||
_close_stderr_log()
|
version was calling .wait() without await, creating an unawaited coroutine
|
||||||
|
object — the old process was never actually waited on, so it could still
|
||||||
|
hold GPU VRAM when the new server started.
|
||||||
|
"""
|
||||||
|
global _llama_server_process
|
||||||
|
if _llama_server_process is None or _llama_server_process.returncode is not None:
|
||||||
|
_close_stderr_log()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
_llama_server_process.send_signal(signal_module.SIGTERM)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_llama_server_process.wait(), timeout=10)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_llama_server_process.kill()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_llama_server_process.wait(), timeout=5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
_llama_server_process = None
|
||||||
|
_close_stderr_log()
|
||||||
|
|
||||||
|
|
||||||
def _flag_value(value) -> str:
|
def _flag_value(value) -> str:
|
||||||
@ -105,7 +116,7 @@ async def _start_llama_server(profile: dict):
|
|||||||
global _llama_server_process
|
global _llama_server_process
|
||||||
|
|
||||||
# Kill any existing process
|
# Kill any existing process
|
||||||
_kill_llama_server()
|
await _kill_llama_server()
|
||||||
|
|
||||||
# Build command from profile flags
|
# Build command from profile flags
|
||||||
cmd = ["/home/bigt/AI/llama.cpp/build/bin/llama-server"]
|
cmd = ["/home/bigt/AI/llama.cpp/build/bin/llama-server"]
|
||||||
@ -200,7 +211,7 @@ async def switch_model(payload: SwitchRequest):
|
|||||||
"""Stop current llama-server, start new one with the given profile, wait for readiness."""
|
"""Stop current llama-server, start new one with the given profile, wait for readiness."""
|
||||||
global _active_profile
|
global _active_profile
|
||||||
|
|
||||||
with _switch_lock:
|
async with _switch_lock:
|
||||||
# Validate profile_id
|
# Validate profile_id
|
||||||
profiles = load_manifest(MANIFEST_PATH)
|
profiles = load_manifest(MANIFEST_PATH)
|
||||||
if profiles is None:
|
if profiles is None:
|
||||||
@ -229,7 +240,7 @@ async def switch_model(payload: SwitchRequest):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Start the new model
|
# Start the new model
|
||||||
_kill_llama_server()
|
await _kill_llama_server()
|
||||||
_active_profile = None
|
_active_profile = None
|
||||||
await _start_llama_server(profile)
|
await _start_llama_server(profile)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user