Files
turf_saas/tests/test_ai_router.py
CTO H3R7Tech 4b766cb908 feat(HRT-200): AI Router — Multi-provider LLM routing with failover
- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP)
- Core router with automatic failover + exponential backoff
- Flask blueprint with /api/v1/ai/* endpoints
- Auth via token-broker verify endpoint
- DB models for ai_providers, ai_model_mapping, ai_router_log
- /health endpoint (parallel provider check), /usage stats
- 21 unit tests (all passing)
2026-05-24 10:21:36 +02:00

256 lines
9.7 KiB
Python

"""Unit tests for AI Router — router, providers, models, API."""
import json
import os
import tempfile
import pytest
from unittest.mock import patch, MagicMock
TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db")
os.environ["AI_ROUTER_DB"] = TEST_DB
os.environ["OPENAI_API_KEY"] = "sk-test-openai"
os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic"
os.environ["GOOGLE_API_KEY"] = "sk-test-google"
os.environ["MISTRAL_API_KEY"] = "sk-test-mistral"
@pytest.fixture(autouse=True)
def clean_db():
yield
try:
os.remove(TEST_DB)
except OSError:
pass
@pytest.fixture
def app():
from ai_router_api import create_app
app = create_app()
app.config["TESTING"] = True
return app
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def router():
from ai_router.router import AIRouter
return AIRouter()
# ─────────────────────────────────────────────
# Provider Base Tests
# ─────────────────────────────────────────────
class TestProviderInterface:
def test_provider_map_has_all(self):
from ai_router.providers import PROVIDER_MAP
assert "openai" in PROVIDER_MAP
assert "anthropic" in PROVIDER_MAP
assert "google" in PROVIDER_MAP
assert "mistral" in PROVIDER_MAP
def test_api_key_resolution_env(self):
from ai_router.providers.base import AIProvider
class TestProvider(AIProvider):
@property
def name(self): return "openai"
def chat(self, messages, model, **kwargs): return {}
def models(self): return []
def check_health(self): return {"status": "ok"}
p = TestProvider()
assert p.get_api_key() == "sk-test-openai"
def test_api_key_resolution_db_overrides_env(self):
from ai_router.providers.base import AIProvider
class TestProvider(AIProvider):
@property
def name(self): return "openai"
def chat(self, messages, model, **kwargs): return {}
def models(self): return []
def check_health(self): return {"status": "ok"}
p = TestProvider()
assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key"
# ─────────────────────────────────────────────
# Router Tests
# ─────────────────────────────────────────────
class TestRouter:
def test_resolve_known_model(self, router):
info = router._resolve_model("gpt-4o")
assert info is not None
assert info["provider"] == "openai"
assert info["real_model"] == "gpt-4o"
def test_resolve_unknown_model(self, router):
info = router._resolve_model("nonexistent-model")
assert info is None
def test_list_models_includes_defaults(self, router):
models = router.list_available_models()
aliases = [m["alias"] for m in models]
assert "gpt-4o" in aliases
assert "claude-3-opus" in aliases
assert "gemini-pro" in aliases
assert "mistral-large" in aliases
def test_prioritized_providers_default_order(self, router):
providers = router._get_prioritized_providers()
names = [p[0] for p in providers]
assert names == ["openai", "anthropic", "google", "mistral"]
def test_chat_unknown_model(self, router):
result = router.chat([{"role": "user", "content": "hi"}], "unknown-model")
assert result["status"] == "error"
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
def test_chat_success_first_provider(self, mock_chat, router):
mock_chat.return_value = {
"content": "Hello!",
"model": "gpt-4o",
"provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
assert result["status"] == "success"
assert result["content"] == "Hello!"
assert result["provider"] == "openai"
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router):
mock_openai.side_effect = Exception("OpenAI down")
mock_anthropic.return_value = {
"content": "Hello from Anthropic!",
"model": "claude-3-sonnet-20240229",
"provider": "anthropic",
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet")
assert result["status"] == "success"
assert result["provider"] == "anthropic"
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
@patch("ai_router.providers.google_adapter.GoogleAdapter.chat")
@patch("ai_router.providers.mistral_adapter.MistralAdapter.chat")
def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router):
for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral):
mock.side_effect = Exception("Provider unavailable")
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
assert result["status"] == "error"
assert "All providers failed" in result["error"]
def test_health_all_providers_returns_dict(self, router):
health = router.check_all_providers_health()
for name in ("openai", "anthropic", "google", "mistral"):
assert name in health
assert "status" in health[name]
# ─────────────────────────────────────────────
# Database Tests
# ─────────────────────────────────────────────
class TestModels:
def test_init_db_creates_tables(self):
from ai_router.models import init_db, get_db
init_db()
conn = get_db()
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
names = [r[0] for r in tables]
assert "ai_providers" in names
assert "ai_model_mapping" in names
assert "ai_router_log" in names
conn.close()
def test_upsert_and_get_providers(self):
from ai_router.models import init_db, upsert_provider, get_providers_from_db
init_db()
upsert_provider("Test OpenAI", "openai", "sk-test", priority=1)
providers = get_providers_from_db()
assert len(providers) > 0
assert any(p["name"] == "Test OpenAI" for p in providers)
def test_log_router_attempt(self):
from ai_router.models import init_db, log_router_attempt, get_db
init_db()
log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success")
conn = get_db()
rows = conn.execute("SELECT * FROM ai_router_log").fetchall()
assert len(rows) == 1
assert rows[0]["status"] == "success"
conn.close()
# ─────────────────────────────────────────────
# API Blueprint Tests
# ─────────────────────────────────────────────
class TestAPI:
def test_health_endpoint(self, client):
resp = client.get("/api/v1/ai/health")
assert resp.status_code in (200, 503)
data = resp.get_json()
assert data["service"] == "ai-router"
assert "providers" in data
def test_models_endpoint(self, client):
resp = client.get("/api/v1/ai/models")
assert resp.status_code == 200
data = resp.get_json()
assert "models" in data
assert len(data["models"]) > 0
def test_chat_no_auth(self, client):
resp = client.post(
"/api/v1/ai/chat",
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
)
assert resp.status_code == 401
def test_chat_no_messages(self, client):
resp = client.post(
"/api/v1/ai/chat",
json={"model": "gpt-4o"},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 401
@patch("ai_router.utils.verify_token_via_broker")
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
def test_chat_success_with_auth(self, mock_chat, mock_verify, client):
mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]}
mock_chat.return_value = {
"content": "Hello!",
"model": "gpt-4o",
"provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
resp = client.post(
"/api/v1/ai/chat",
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
headers={"Authorization": "Bearer test-token"},
)
data = resp.get_json()
assert resp.status_code == 200, f"Got {resp.status_code}: {data}"
assert data["status"] == "success"
def test_usage_requires_admin(self, client):
resp = client.get("/api/v1/ai/usage")
assert resp.status_code == 401