"""Unit tests for AI Router — router, providers, models, API.""" import json import os import tempfile import pytest from unittest.mock import patch, MagicMock TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db") os.environ["AI_ROUTER_DB"] = TEST_DB os.environ["OPENAI_API_KEY"] = "sk-test-openai" os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic" os.environ["GOOGLE_API_KEY"] = "sk-test-google" os.environ["MISTRAL_API_KEY"] = "sk-test-mistral" @pytest.fixture(autouse=True) def clean_db(): yield try: os.remove(TEST_DB) except OSError: pass @pytest.fixture def app(): from ai_router_api import create_app app = create_app() app.config["TESTING"] = True return app @pytest.fixture def client(app): return app.test_client() @pytest.fixture def router(): from ai_router.router import AIRouter return AIRouter() # ───────────────────────────────────────────── # Provider Base Tests # ───────────────────────────────────────────── class TestProviderInterface: def test_provider_map_has_all(self): from ai_router.providers import PROVIDER_MAP assert "openai" in PROVIDER_MAP assert "anthropic" in PROVIDER_MAP assert "google" in PROVIDER_MAP assert "mistral" in PROVIDER_MAP def test_api_key_resolution_env(self): from ai_router.providers.base import AIProvider class TestProvider(AIProvider): @property def name(self): return "openai" def chat(self, messages, model, **kwargs): return {} def models(self): return [] def check_health(self): return {"status": "ok"} p = TestProvider() assert p.get_api_key() == "sk-test-openai" def test_api_key_resolution_db_overrides_env(self): from ai_router.providers.base import AIProvider class TestProvider(AIProvider): @property def name(self): return "openai" def chat(self, messages, model, **kwargs): return {} def models(self): return [] def check_health(self): return {"status": "ok"} p = TestProvider() assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key" # ───────────────────────────────────────────── # Router Tests # ───────────────────────────────────────────── class TestRouter: def test_resolve_known_model(self, router): info = router._resolve_model("gpt-4o") assert info is not None assert info["provider"] == "openai" assert info["real_model"] == "gpt-4o" def test_resolve_unknown_model(self, router): info = router._resolve_model("nonexistent-model") assert info is None def test_list_models_includes_defaults(self, router): models = router.list_available_models() aliases = [m["alias"] for m in models] assert "gpt-4o" in aliases assert "claude-3-opus" in aliases assert "gemini-pro" in aliases assert "mistral-large" in aliases def test_prioritized_providers_default_order(self, router): providers = router._get_prioritized_providers() names = [p[0] for p in providers] assert names == ["openai", "anthropic", "google", "mistral"] def test_chat_unknown_model(self, router): result = router.chat([{"role": "user", "content": "hi"}], "unknown-model") assert result["status"] == "error" @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") def test_chat_success_first_provider(self, mock_chat, router): mock_chat.return_value = { "content": "Hello!", "model": "gpt-4o", "provider": "openai", "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o") assert result["status"] == "success" assert result["content"] == "Hello!" assert result["provider"] == "openai" @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") @patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat") def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router): mock_openai.side_effect = Exception("OpenAI down") mock_anthropic.return_value = { "content": "Hello from Anthropic!", "model": "claude-3-sonnet-20240229", "provider": "anthropic", "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet") assert result["status"] == "success" assert result["provider"] == "anthropic" @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") @patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat") @patch("ai_router.providers.google_adapter.GoogleAdapter.chat") @patch("ai_router.providers.mistral_adapter.MistralAdapter.chat") def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router): for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral): mock.side_effect = Exception("Provider unavailable") result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o") assert result["status"] == "error" assert "All providers failed" in result["error"] def test_health_all_providers_returns_dict(self, router): health = router.check_all_providers_health() for name in ("openai", "anthropic", "google", "mistral"): assert name in health assert "status" in health[name] # ───────────────────────────────────────────── # Database Tests # ───────────────────────────────────────────── class TestModels: def test_init_db_creates_tables(self): from ai_router.models import init_db, get_db init_db() conn = get_db() tables = conn.execute( "SELECT name FROM sqlite_master WHERE type='table'" ).fetchall() names = [r[0] for r in tables] assert "ai_providers" in names assert "ai_model_mapping" in names assert "ai_router_log" in names conn.close() def test_upsert_and_get_providers(self): from ai_router.models import init_db, upsert_provider, get_providers_from_db init_db() upsert_provider("Test OpenAI", "openai", "sk-test", priority=1) providers = get_providers_from_db() assert len(providers) > 0 assert any(p["name"] == "Test OpenAI" for p in providers) def test_log_router_attempt(self): from ai_router.models import init_db, log_router_attempt, get_db init_db() log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success") conn = get_db() rows = conn.execute("SELECT * FROM ai_router_log").fetchall() assert len(rows) == 1 assert rows[0]["status"] == "success" conn.close() # ───────────────────────────────────────────── # API Blueprint Tests # ───────────────────────────────────────────── class TestAPI: def test_health_endpoint(self, client): resp = client.get("/api/v1/ai/health") assert resp.status_code in (200, 503) data = resp.get_json() assert data["service"] == "ai-router" assert "providers" in data def test_models_endpoint(self, client): resp = client.get("/api/v1/ai/models") assert resp.status_code == 200 data = resp.get_json() assert "models" in data assert len(data["models"]) > 0 def test_chat_no_auth(self, client): resp = client.post( "/api/v1/ai/chat", json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"}, ) assert resp.status_code == 401 def test_chat_no_messages(self, client): resp = client.post( "/api/v1/ai/chat", json={"model": "gpt-4o"}, headers={"X-API-Key": "test-key"}, ) assert resp.status_code == 401 @patch("ai_router.utils.verify_token_via_broker") @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") def test_chat_success_with_auth(self, mock_chat, mock_verify, client): mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]} mock_chat.return_value = { "content": "Hello!", "model": "gpt-4o", "provider": "openai", "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } resp = client.post( "/api/v1/ai/chat", json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"}, headers={"Authorization": "Bearer test-token"}, ) data = resp.get_json() assert resp.status_code == 200, f"Got {resp.status_code}: {data}" assert data["status"] == "success" def test_usage_requires_admin(self, client): resp = client.get("/api/v1/ai/usage") assert resp.status_code == 401