104 lines
4.0 KiB
Python
104 lines
4.0 KiB
Python
|
|
"""Tests for automatic model detection — Issue #4.
|
||
|
|
|
||
|
|
Router extracts model from chat body, queries sidecar, triggers switch on mismatch.
|
||
|
|
"""
|
||
|
|
import asyncio
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import patch
|
||
|
|
from httpx import Response, ASGITransport, AsyncClient
|
||
|
|
|
||
|
|
from main import app as router_app
|
||
|
|
|
||
|
|
SIDECAR_URL = "http://localhost:8081"
|
||
|
|
MAIN_PC_URL = "http://localhost:8080"
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def setup():
|
||
|
|
"""Setup test environment."""
|
||
|
|
import main
|
||
|
|
main._circuit_open = False
|
||
|
|
main._recovery_attempts = 0
|
||
|
|
with patch("main.SIDECAR_URL", SIDECAR_URL), \
|
||
|
|
patch("main.MAIN_PC_BASE", MAIN_PC_URL), \
|
||
|
|
patch("main.FALLBACK_SLM_URL", "http://localhost:9999"), \
|
||
|
|
patch("main.OPENROUTER_API_KEY", ""):
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
def test_active_model_match_routes_directly():
|
||
|
|
"""Matching active model → routes to Main PC without switch."""
|
||
|
|
import respx
|
||
|
|
|
||
|
|
async def run_test():
|
||
|
|
with respx.mock:
|
||
|
|
respx.get(f"{SIDECAR_URL}/models/status").mock(
|
||
|
|
return_value=Response(200, json={"active_profile": "qwen-3-8b", "llama_server_running": True})
|
||
|
|
)
|
||
|
|
respx.post(f"{MAIN_PC_URL}/v1/chat/completions").mock(
|
||
|
|
return_value=Response(200, json={"choices": [{"message": {"content": "Hello"}}]})
|
||
|
|
)
|
||
|
|
transport = ASGITransport(app=router_app)
|
||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
|
|
resp = await ac.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={"model": "qwen-3-8b", "messages": [{"role": "user", "content": "hi"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
switch_calls = [r for r in respx.calls if "switch" in r[0].url.path]
|
||
|
|
assert len(switch_calls) == 0
|
||
|
|
|
||
|
|
asyncio.run(run_test())
|
||
|
|
|
||
|
|
|
||
|
|
def test_mismatch_triggers_switch():
|
||
|
|
"""Mismatching model → triggers switch via sidecar."""
|
||
|
|
import respx
|
||
|
|
|
||
|
|
async def run_test():
|
||
|
|
with respx.mock:
|
||
|
|
respx.get(f"{SIDECAR_URL}/models/status").mock(
|
||
|
|
return_value=Response(200, json={"active_profile": "llama-4-maverick", "llama_server_running": True})
|
||
|
|
)
|
||
|
|
respx.post(f"{SIDECAR_URL}/models/switch").mock(
|
||
|
|
return_value=Response(200, json={"status": "ready", "active_profile": "qwen-3-8b"})
|
||
|
|
)
|
||
|
|
respx.post(f"{MAIN_PC_URL}/v1/chat/completions").mock(
|
||
|
|
return_value=Response(200, json={"choices": [{"message": {"content": "Hello"}}]})
|
||
|
|
)
|
||
|
|
transport = ASGITransport(app=router_app)
|
||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
|
|
resp = await ac.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={"model": "qwen-3-8b", "messages": [{"role": "user", "content": "hi"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
asyncio.run(run_test())
|
||
|
|
|
||
|
|
|
||
|
|
def test_no_active_model_triggers_cold_start():
|
||
|
|
"""No active model → triggers cold start switch."""
|
||
|
|
import respx
|
||
|
|
|
||
|
|
async def run_test():
|
||
|
|
with respx.mock:
|
||
|
|
respx.get(f"{SIDECAR_URL}/models/status").mock(
|
||
|
|
return_value=Response(200, json={"active_profile": None, "llama_server_running": False})
|
||
|
|
)
|
||
|
|
respx.post(f"{SIDECAR_URL}/models/switch").mock(
|
||
|
|
return_value=Response(200, json={"status": "ready", "active_profile": "qwen-3-8b"})
|
||
|
|
)
|
||
|
|
respx.post(f"{MAIN_PC_URL}/v1/chat/completions").mock(
|
||
|
|
return_value=Response(200, json={"choices": [{"message": {"content": "Hello"}}]})
|
||
|
|
)
|
||
|
|
transport = ASGITransport(app=router_app)
|
||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
|
|
resp = await ac.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={"model": "qwen-3-8b", "messages": [{"role": "user", "content": "hi"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
asyncio.run(run_test())
|