feat(HRT-200): AI Router — Multi-provider LLM routing with failover
- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP) - Core router with automatic failover + exponential backoff - Flask blueprint with /api/v1/ai/* endpoints - Auth via token-broker verify endpoint - DB models for ai_providers, ai_model_mapping, ai_router_log - /health endpoint (parallel provider check), /usage stats - 21 unit tests (all passing)
This commit is contained in:
255
tests/test_ai_router.py
Normal file
255
tests/test_ai_router.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for AI Router — router, providers, models, API."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db")
|
||||
os.environ["AI_ROUTER_DB"] = TEST_DB
|
||||
os.environ["OPENAI_API_KEY"] = "sk-test-openai"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic"
|
||||
os.environ["GOOGLE_API_KEY"] = "sk-test-google"
|
||||
os.environ["MISTRAL_API_KEY"] = "sk-test-mistral"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_db():
|
||||
yield
|
||||
try:
|
||||
os.remove(TEST_DB)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
from ai_router_api import create_app
|
||||
app = create_app()
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router():
|
||||
from ai_router.router import AIRouter
|
||||
return AIRouter()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Provider Base Tests
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProviderInterface:
|
||||
def test_provider_map_has_all(self):
|
||||
from ai_router.providers import PROVIDER_MAP
|
||||
assert "openai" in PROVIDER_MAP
|
||||
assert "anthropic" in PROVIDER_MAP
|
||||
assert "google" in PROVIDER_MAP
|
||||
assert "mistral" in PROVIDER_MAP
|
||||
|
||||
def test_api_key_resolution_env(self):
|
||||
from ai_router.providers.base import AIProvider
|
||||
class TestProvider(AIProvider):
|
||||
@property
|
||||
def name(self): return "openai"
|
||||
def chat(self, messages, model, **kwargs): return {}
|
||||
def models(self): return []
|
||||
def check_health(self): return {"status": "ok"}
|
||||
|
||||
p = TestProvider()
|
||||
assert p.get_api_key() == "sk-test-openai"
|
||||
|
||||
def test_api_key_resolution_db_overrides_env(self):
|
||||
from ai_router.providers.base import AIProvider
|
||||
class TestProvider(AIProvider):
|
||||
@property
|
||||
def name(self): return "openai"
|
||||
def chat(self, messages, model, **kwargs): return {}
|
||||
def models(self): return []
|
||||
def check_health(self): return {"status": "ok"}
|
||||
|
||||
p = TestProvider()
|
||||
assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key"
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Router Tests
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRouter:
|
||||
def test_resolve_known_model(self, router):
|
||||
info = router._resolve_model("gpt-4o")
|
||||
assert info is not None
|
||||
assert info["provider"] == "openai"
|
||||
assert info["real_model"] == "gpt-4o"
|
||||
|
||||
def test_resolve_unknown_model(self, router):
|
||||
info = router._resolve_model("nonexistent-model")
|
||||
assert info is None
|
||||
|
||||
def test_list_models_includes_defaults(self, router):
|
||||
models = router.list_available_models()
|
||||
aliases = [m["alias"] for m in models]
|
||||
assert "gpt-4o" in aliases
|
||||
assert "claude-3-opus" in aliases
|
||||
assert "gemini-pro" in aliases
|
||||
assert "mistral-large" in aliases
|
||||
|
||||
def test_prioritized_providers_default_order(self, router):
|
||||
providers = router._get_prioritized_providers()
|
||||
names = [p[0] for p in providers]
|
||||
assert names == ["openai", "anthropic", "google", "mistral"]
|
||||
|
||||
def test_chat_unknown_model(self, router):
|
||||
result = router.chat([{"role": "user", "content": "hi"}], "unknown-model")
|
||||
assert result["status"] == "error"
|
||||
|
||||
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||
def test_chat_success_first_provider(self, mock_chat, router):
|
||||
mock_chat.return_value = {
|
||||
"content": "Hello!",
|
||||
"model": "gpt-4o",
|
||||
"provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
||||
assert result["status"] == "success"
|
||||
assert result["content"] == "Hello!"
|
||||
assert result["provider"] == "openai"
|
||||
|
||||
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
||||
def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router):
|
||||
mock_openai.side_effect = Exception("OpenAI down")
|
||||
mock_anthropic.return_value = {
|
||||
"content": "Hello from Anthropic!",
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"provider": "anthropic",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet")
|
||||
assert result["status"] == "success"
|
||||
assert result["provider"] == "anthropic"
|
||||
|
||||
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
||||
@patch("ai_router.providers.google_adapter.GoogleAdapter.chat")
|
||||
@patch("ai_router.providers.mistral_adapter.MistralAdapter.chat")
|
||||
def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router):
|
||||
for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral):
|
||||
mock.side_effect = Exception("Provider unavailable")
|
||||
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
||||
assert result["status"] == "error"
|
||||
assert "All providers failed" in result["error"]
|
||||
|
||||
def test_health_all_providers_returns_dict(self, router):
|
||||
health = router.check_all_providers_health()
|
||||
for name in ("openai", "anthropic", "google", "mistral"):
|
||||
assert name in health
|
||||
assert "status" in health[name]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Database Tests
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModels:
|
||||
def test_init_db_creates_tables(self):
|
||||
from ai_router.models import init_db, get_db
|
||||
init_db()
|
||||
conn = get_db()
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
names = [r[0] for r in tables]
|
||||
assert "ai_providers" in names
|
||||
assert "ai_model_mapping" in names
|
||||
assert "ai_router_log" in names
|
||||
conn.close()
|
||||
|
||||
def test_upsert_and_get_providers(self):
|
||||
from ai_router.models import init_db, upsert_provider, get_providers_from_db
|
||||
init_db()
|
||||
upsert_provider("Test OpenAI", "openai", "sk-test", priority=1)
|
||||
providers = get_providers_from_db()
|
||||
assert len(providers) > 0
|
||||
assert any(p["name"] == "Test OpenAI" for p in providers)
|
||||
|
||||
def test_log_router_attempt(self):
|
||||
from ai_router.models import init_db, log_router_attempt, get_db
|
||||
init_db()
|
||||
log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success")
|
||||
conn = get_db()
|
||||
rows = conn.execute("SELECT * FROM ai_router_log").fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["status"] == "success"
|
||||
conn.close()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# API Blueprint Tests
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAPI:
|
||||
def test_health_endpoint(self, client):
|
||||
resp = client.get("/api/v1/ai/health")
|
||||
assert resp.status_code in (200, 503)
|
||||
data = resp.get_json()
|
||||
assert data["service"] == "ai-router"
|
||||
assert "providers" in data
|
||||
|
||||
def test_models_endpoint(self, client):
|
||||
resp = client.get("/api/v1/ai/models")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "models" in data
|
||||
assert len(data["models"]) > 0
|
||||
|
||||
def test_chat_no_auth(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/ai/chat",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_chat_no_messages(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/ai/chat",
|
||||
json={"model": "gpt-4o"},
|
||||
headers={"X-API-Key": "test-key"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@patch("ai_router.utils.verify_token_via_broker")
|
||||
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||
def test_chat_success_with_auth(self, mock_chat, mock_verify, client):
|
||||
mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]}
|
||||
mock_chat.return_value = {
|
||||
"content": "Hello!",
|
||||
"model": "gpt-4o",
|
||||
"provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
resp = client.post(
|
||||
"/api/v1/ai/chat",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
||||
headers={"Authorization": "Bearer test-token"},
|
||||
)
|
||||
data = resp.get_json()
|
||||
assert resp.status_code == 200, f"Got {resp.status_code}: {data}"
|
||||
assert data["status"] == "success"
|
||||
|
||||
def test_usage_requires_admin(self, client):
|
||||
resp = client.get("/api/v1/ai/usage")
|
||||
assert resp.status_code == 401
|
||||
Reference in New Issue
Block a user