- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP) - Core router with automatic failover + exponential backoff - Flask blueprint with /api/v1/ai/* endpoints - Auth via token-broker verify endpoint - DB models for ai_providers, ai_model_mapping, ai_router_log - /health endpoint (parallel provider check), /usage stats - 21 unit tests (all passing)
256 lines
9.7 KiB
Python
256 lines
9.7 KiB
Python
"""Unit tests for AI Router — router, providers, models, API."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db")
|
|
os.environ["AI_ROUTER_DB"] = TEST_DB
|
|
os.environ["OPENAI_API_KEY"] = "sk-test-openai"
|
|
os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic"
|
|
os.environ["GOOGLE_API_KEY"] = "sk-test-google"
|
|
os.environ["MISTRAL_API_KEY"] = "sk-test-mistral"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clean_db():
|
|
yield
|
|
try:
|
|
os.remove(TEST_DB)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
from ai_router_api import create_app
|
|
app = create_app()
|
|
app.config["TESTING"] = True
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return app.test_client()
|
|
|
|
|
|
@pytest.fixture
|
|
def router():
|
|
from ai_router.router import AIRouter
|
|
return AIRouter()
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Provider Base Tests
|
|
# ─────────────────────────────────────────────
|
|
|
|
|
|
class TestProviderInterface:
|
|
def test_provider_map_has_all(self):
|
|
from ai_router.providers import PROVIDER_MAP
|
|
assert "openai" in PROVIDER_MAP
|
|
assert "anthropic" in PROVIDER_MAP
|
|
assert "google" in PROVIDER_MAP
|
|
assert "mistral" in PROVIDER_MAP
|
|
|
|
def test_api_key_resolution_env(self):
|
|
from ai_router.providers.base import AIProvider
|
|
class TestProvider(AIProvider):
|
|
@property
|
|
def name(self): return "openai"
|
|
def chat(self, messages, model, **kwargs): return {}
|
|
def models(self): return []
|
|
def check_health(self): return {"status": "ok"}
|
|
|
|
p = TestProvider()
|
|
assert p.get_api_key() == "sk-test-openai"
|
|
|
|
def test_api_key_resolution_db_overrides_env(self):
|
|
from ai_router.providers.base import AIProvider
|
|
class TestProvider(AIProvider):
|
|
@property
|
|
def name(self): return "openai"
|
|
def chat(self, messages, model, **kwargs): return {}
|
|
def models(self): return []
|
|
def check_health(self): return {"status": "ok"}
|
|
|
|
p = TestProvider()
|
|
assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key"
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Router Tests
|
|
# ─────────────────────────────────────────────
|
|
|
|
|
|
class TestRouter:
|
|
def test_resolve_known_model(self, router):
|
|
info = router._resolve_model("gpt-4o")
|
|
assert info is not None
|
|
assert info["provider"] == "openai"
|
|
assert info["real_model"] == "gpt-4o"
|
|
|
|
def test_resolve_unknown_model(self, router):
|
|
info = router._resolve_model("nonexistent-model")
|
|
assert info is None
|
|
|
|
def test_list_models_includes_defaults(self, router):
|
|
models = router.list_available_models()
|
|
aliases = [m["alias"] for m in models]
|
|
assert "gpt-4o" in aliases
|
|
assert "claude-3-opus" in aliases
|
|
assert "gemini-pro" in aliases
|
|
assert "mistral-large" in aliases
|
|
|
|
def test_prioritized_providers_default_order(self, router):
|
|
providers = router._get_prioritized_providers()
|
|
names = [p[0] for p in providers]
|
|
assert names == ["openai", "anthropic", "google", "mistral"]
|
|
|
|
def test_chat_unknown_model(self, router):
|
|
result = router.chat([{"role": "user", "content": "hi"}], "unknown-model")
|
|
assert result["status"] == "error"
|
|
|
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
|
def test_chat_success_first_provider(self, mock_chat, router):
|
|
mock_chat.return_value = {
|
|
"content": "Hello!",
|
|
"model": "gpt-4o",
|
|
"provider": "openai",
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
}
|
|
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
|
assert result["status"] == "success"
|
|
assert result["content"] == "Hello!"
|
|
assert result["provider"] == "openai"
|
|
|
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
|
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
|
def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router):
|
|
mock_openai.side_effect = Exception("OpenAI down")
|
|
mock_anthropic.return_value = {
|
|
"content": "Hello from Anthropic!",
|
|
"model": "claude-3-sonnet-20240229",
|
|
"provider": "anthropic",
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
}
|
|
result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet")
|
|
assert result["status"] == "success"
|
|
assert result["provider"] == "anthropic"
|
|
|
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
|
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
|
@patch("ai_router.providers.google_adapter.GoogleAdapter.chat")
|
|
@patch("ai_router.providers.mistral_adapter.MistralAdapter.chat")
|
|
def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router):
|
|
for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral):
|
|
mock.side_effect = Exception("Provider unavailable")
|
|
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
|
assert result["status"] == "error"
|
|
assert "All providers failed" in result["error"]
|
|
|
|
def test_health_all_providers_returns_dict(self, router):
|
|
health = router.check_all_providers_health()
|
|
for name in ("openai", "anthropic", "google", "mistral"):
|
|
assert name in health
|
|
assert "status" in health[name]
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Database Tests
|
|
# ─────────────────────────────────────────────
|
|
|
|
|
|
class TestModels:
|
|
def test_init_db_creates_tables(self):
|
|
from ai_router.models import init_db, get_db
|
|
init_db()
|
|
conn = get_db()
|
|
tables = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
|
).fetchall()
|
|
names = [r[0] for r in tables]
|
|
assert "ai_providers" in names
|
|
assert "ai_model_mapping" in names
|
|
assert "ai_router_log" in names
|
|
conn.close()
|
|
|
|
def test_upsert_and_get_providers(self):
|
|
from ai_router.models import init_db, upsert_provider, get_providers_from_db
|
|
init_db()
|
|
upsert_provider("Test OpenAI", "openai", "sk-test", priority=1)
|
|
providers = get_providers_from_db()
|
|
assert len(providers) > 0
|
|
assert any(p["name"] == "Test OpenAI" for p in providers)
|
|
|
|
def test_log_router_attempt(self):
|
|
from ai_router.models import init_db, log_router_attempt, get_db
|
|
init_db()
|
|
log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success")
|
|
conn = get_db()
|
|
rows = conn.execute("SELECT * FROM ai_router_log").fetchall()
|
|
assert len(rows) == 1
|
|
assert rows[0]["status"] == "success"
|
|
conn.close()
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# API Blueprint Tests
|
|
# ─────────────────────────────────────────────
|
|
|
|
|
|
class TestAPI:
|
|
def test_health_endpoint(self, client):
|
|
resp = client.get("/api/v1/ai/health")
|
|
assert resp.status_code in (200, 503)
|
|
data = resp.get_json()
|
|
assert data["service"] == "ai-router"
|
|
assert "providers" in data
|
|
|
|
def test_models_endpoint(self, client):
|
|
resp = client.get("/api/v1/ai/models")
|
|
assert resp.status_code == 200
|
|
data = resp.get_json()
|
|
assert "models" in data
|
|
assert len(data["models"]) > 0
|
|
|
|
def test_chat_no_auth(self, client):
|
|
resp = client.post(
|
|
"/api/v1/ai/chat",
|
|
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
def test_chat_no_messages(self, client):
|
|
resp = client.post(
|
|
"/api/v1/ai/chat",
|
|
json={"model": "gpt-4o"},
|
|
headers={"X-API-Key": "test-key"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
@patch("ai_router.utils.verify_token_via_broker")
|
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
|
def test_chat_success_with_auth(self, mock_chat, mock_verify, client):
|
|
mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]}
|
|
mock_chat.return_value = {
|
|
"content": "Hello!",
|
|
"model": "gpt-4o",
|
|
"provider": "openai",
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
}
|
|
resp = client.post(
|
|
"/api/v1/ai/chat",
|
|
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
|
headers={"Authorization": "Bearer test-token"},
|
|
)
|
|
data = resp.get_json()
|
|
assert resp.status_code == 200, f"Got {resp.status_code}: {data}"
|
|
assert data["status"] == "success"
|
|
|
|
def test_usage_requires_admin(self, client):
|
|
resp = client.get("/api/v1/ai/usage")
|
|
assert resp.status_code == 401
|