From 4b766cb90811e102b49ea285de6a7cee6ee0514c Mon Sep 17 00:00:00 2001 From: CTO H3R7Tech Date: Sun, 24 May 2026 10:21:36 +0200 Subject: [PATCH] =?UTF-8?q?feat(HRT-200):=20AI=20Router=20=E2=80=94=20Mult?= =?UTF-8?q?i-provider=20LLM=20routing=20with=20failover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP) - Core router with automatic failover + exponential backoff - Flask blueprint with /api/v1/ai/* endpoints - Auth via token-broker verify endpoint - DB models for ai_providers, ai_model_mapping, ai_router_log - /health endpoint (parallel provider check), /usage stats - 21 unit tests (all passing) --- ai_router/__init__.py | 4 + ai_router/api.py | 172 +++++++++++++++ ai_router/models.py | 167 +++++++++++++++ ai_router/providers/__init__.py | 14 ++ ai_router/providers/anthropic_adapter.py | 57 +++++ ai_router/providers/base.py | 44 ++++ ai_router/providers/google_adapter.py | 57 +++++ ai_router/providers/mistral_adapter.py | 70 +++++++ ai_router/providers/openai_adapter.py | 50 +++++ ai_router/router.py | 174 ++++++++++++++++ ai_router/utils.py | 93 +++++++++ ai_router_api.py | 77 +++++++ tests/test_ai_router.py | 255 +++++++++++++++++++++++ 13 files changed, 1234 insertions(+) create mode 100644 ai_router/__init__.py create mode 100644 ai_router/api.py create mode 100644 ai_router/models.py create mode 100644 ai_router/providers/__init__.py create mode 100644 ai_router/providers/anthropic_adapter.py create mode 100644 ai_router/providers/base.py create mode 100644 ai_router/providers/google_adapter.py create mode 100644 ai_router/providers/mistral_adapter.py create mode 100644 ai_router/providers/openai_adapter.py create mode 100644 ai_router/router.py create mode 100644 ai_router/utils.py create mode 100644 ai_router_api.py create mode 100644 tests/test_ai_router.py diff --git a/ai_router/__init__.py b/ai_router/__init__.py new file mode 100644 index 0000000..f97cd52 --- /dev/null +++ b/ai_router/__init__.py @@ -0,0 +1,4 @@ +from .router import AIRouter +from .api import ai_router_bp, register_ai_router + +__all__ = ["AIRouter", "ai_router_bp", "register_ai_router"] diff --git a/ai_router/api.py b/ai_router/api.py new file mode 100644 index 0000000..27ebc67 --- /dev/null +++ b/ai_router/api.py @@ -0,0 +1,172 @@ +"""Flask Blueprint for AI Router — chat, health, models, admin.""" + +import logging +from datetime import datetime, timezone + +from flask import Blueprint, jsonify, request + +from .router import AIRouter +from .models import init_db, upsert_provider, upsert_model_mapping +from .utils import require_auth, admin_required + +logger = logging.getLogger("ai_router.api") + +ai_router_bp = Blueprint("ai_router", __name__, url_prefix="/api/v1/ai") + + +def register_ai_router(app): + app.register_blueprint(ai_router_bp) + + +_router = AIRouter() + + +@ai_router_bp.route("/health", methods=["GET"]) +def health(): + health_data = _router.check_all_providers_health() + all_ok = all(v["status"] == "ok" for v in health_data.values()) + return jsonify({ + "status": "ok" if all_ok else "degraded", + "service": "ai-router", + "version": "1.0.0", + "providers": health_data, + "timestamp": datetime.now(timezone.utc).isoformat(), + }), 200 if all_ok else 503 + + +@ai_router_bp.route("/models", methods=["GET"]) +def list_models(): + models = _router.list_available_models() + return jsonify({"models": models}) + + +@ai_router_bp.route("/chat", methods=["POST"]) +@require_auth +def chat(): + data = request.get_json(silent=True) or {} + messages = data.get("messages", []) + model = data.get("model", "gpt-4o-mini") + user_id = (request.current_user or {}).get("user_id") + + if not messages: + return jsonify({"error": "messages field is required"}), 400 + + kwargs = {k: data[k] for k in ("temperature", "max_tokens", "top_p", "stream") if k in data} + + result = _router.chat(messages=messages, model_alias=model, user_id=user_id, **kwargs) + + if result.get("status") == "error": + code = 503 if "All providers failed" in result.get("error", "") else 400 + return jsonify(result), code + + return jsonify(result), 200 + + +@ai_router_bp.route("/admin/providers", methods=["GET"]) +@admin_required +def list_providers(): + from .models import get_db + conn = get_db() + try: + rows = conn.execute( + "SELECT id, name, provider_type, base_url, priority, is_active, created_at, updated_at " + "FROM ai_providers ORDER BY priority" + ).fetchall() + return jsonify({"providers": [dict(r) for r in rows]}) + finally: + conn.close() + + +@ai_router_bp.route("/admin/providers", methods=["POST"]) +@admin_required +def upsert_provider_endpoint(): + data = request.get_json(silent=True) or {} + name = data.get("name", "") + provider_type = data.get("provider_type", "") + api_key = data.get("api_key", "") + base_url = data.get("base_url", "") + priority = data.get("priority", 99) + + if not name or provider_type not in ("openai", "anthropic", "google", "mistral"): + return jsonify({"error": "Valid name and provider_type required"}), 400 + + ok = upsert_provider(name, provider_type, api_key, base_url, priority=priority) + if ok: + return jsonify({"status": "ok", "message": f"Provider {name} saved"}), 200 + return jsonify({"error": "Failed to save provider"}), 500 + + +@ai_router_bp.route("/admin/model-mappings", methods=["POST"]) +@admin_required +def upsert_model_mapping_endpoint(): + data = request.get_json(silent=True) or {} + model_alias = data.get("model_alias", "") + provider_id = data.get("provider_id") + real_model_id = data.get("real_model_id", "") + cost = data.get("cost_per_1k_tokens", 0) + + if not model_alias or not provider_id or not real_model_id: + return jsonify({"error": "model_alias, provider_id, real_model_id required"}), 400 + + ok = upsert_model_mapping(model_alias, provider_id, real_model_id, cost) + if ok: + return jsonify({"status": "ok", "message": f"Mapping for {model_alias} saved"}), 200 + return jsonify({"error": "Failed to save model mapping"}), 500 + + +@ai_router_bp.route("/admin/providers/", methods=["DELETE"]) +@admin_required +def delete_provider(provider_id): + from .models import get_db + conn = get_db() + try: + conn.execute("DELETE FROM ai_model_mapping WHERE provider_id = ?", (provider_id,)) + conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,)) + conn.commit() + return jsonify({"status": "deleted", "provider_id": provider_id}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@ai_router_bp.route("/usage", methods=["GET"]) +@admin_required +def usage_stats(): + from .models import get_db + conn = get_db() + try: + limit = request.args.get("limit", 50, type=int) + rows = conn.execute( + "SELECT * FROM ai_router_log ORDER BY created_at DESC LIMIT ?", (limit,) + ).fetchall() + return jsonify({"usage": [dict(r) for r in rows]}) + finally: + conn.close() + + +@ai_router_bp.route("/usage/summary", methods=["GET"]) +@admin_required +def usage_summary(): + from .models import get_db + conn = get_db() + try: + agg = conn.execute(""" + SELECT provider_used, status, COUNT(*) as count, + SUM(duration_ms) as total_ms, SUM(tokens_in + tokens_out) as total_tokens + FROM ai_router_log + GROUP BY provider_used, status + ORDER BY provider_used + """).fetchall() + totals = conn.execute(""" + SELECT COUNT(*) as total_requests, + SUM(CASE WHEN status='success' THEN 1 ELSE 0 END) as success_count, + SUM(tokens_in + tokens_out) as total_tokens + FROM ai_router_log + """).fetchone() + return jsonify({ + "by_provider": [dict(r) for r in agg], + "totals": dict(totals) if totals else {}, + }) + finally: + conn.close() diff --git a/ai_router/models.py b/ai_router/models.py new file mode 100644 index 0000000..b78493f --- /dev/null +++ b/ai_router/models.py @@ -0,0 +1,167 @@ +import logging +import os +import sqlite3 +from datetime import datetime, timezone + +logger = logging.getLogger("ai_router.models") + +DB_PATH = os.environ.get("AI_ROUTER_DB", "/home/h3r7/turf_saas/ai_router.db") + + +def get_db(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +def init_db(): + conn = get_db() + try: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS ai_providers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + provider_type TEXT NOT NULL CHECK(provider_type IN ('openai','anthropic','google','mistral')), + api_key TEXT NOT NULL DEFAULT '', + base_url TEXT DEFAULT '', + config TEXT DEFAULT '{}', + priority INTEGER NOT NULL DEFAULT 99, + is_active INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + + CREATE TABLE IF NOT EXISTS ai_model_mapping ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_alias TEXT NOT NULL UNIQUE, + provider_id INTEGER NOT NULL REFERENCES ai_providers(id), + real_model_id TEXT NOT NULL, + cost_per_1k_tokens REAL NOT NULL DEFAULT 0, + is_active INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + + CREATE TABLE IF NOT EXISTS ai_router_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + request_id TEXT NOT NULL, + user_id INTEGER, + model_alias TEXT NOT NULL, + provider_used TEXT NOT NULL, + tokens_in INTEGER NOT NULL DEFAULT 0, + tokens_out INTEGER NOT NULL DEFAULT 0, + duration_ms INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL CHECK(status IN ('success','error')), + error_message TEXT DEFAULT '', + created_at TEXT NOT NULL DEFAULT (datetime('now')) + ); + + CREATE INDEX IF NOT EXISTS idx_ai_router_log_request_id ON ai_router_log(request_id); + CREATE INDEX IF NOT EXISTS idx_ai_router_log_created_at ON ai_router_log(created_at); + CREATE INDEX IF NOT EXISTS idx_ai_model_mapping_alias ON ai_model_mapping(model_alias); + """) + conn.commit() + logger.info("AI Router database tables initialized") + except Exception as e: + logger.error(f"Failed to initialize AI Router DB: {e}") + finally: + conn.close() + + +def get_providers_from_db(): + conn = get_db() + try: + rows = conn.execute(""" + SELECT p.id, p.name, p.provider_type, p.api_key, p.base_url, p.config, + p.priority, p.is_active, m.model_alias, m.real_model_id, m.cost_per_1k_tokens + FROM ai_providers p + LEFT JOIN ai_model_mapping m ON m.provider_id = p.id + WHERE p.is_active = 1 + """).fetchall() + return [dict(r) for r in rows] + except Exception as e: + logger.warning(f"Could not query providers: {e}") + return [] + finally: + conn.close() + + +def log_router_attempt(request_id, user_id, model_alias, provider_used, + tokens_in, tokens_out, duration_ms, status, + error_message=""): + conn = get_db() + try: + conn.execute( + """INSERT INTO ai_router_log + (request_id, user_id, model_alias, provider_used, + tokens_in, tokens_out, duration_ms, status, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (request_id, user_id, model_alias, provider_used, + tokens_in, tokens_out, duration_ms, status, error_message), + ) + conn.commit() + except Exception as e: + logger.warning(f"Failed to log router attempt: {e}") + finally: + conn.close() + + +def upsert_provider(name, provider_type, api_key="", base_url="", + config=None, priority=99, is_active=1): + conn = get_db() + try: + existing = conn.execute( + "SELECT id FROM ai_providers WHERE name = ?", (name,) + ).fetchone() + if existing: + conn.execute( + """UPDATE ai_providers SET provider_type=?, api_key=?, base_url=?, + config=?, priority=?, is_active=?, updated_at=datetime('now') + WHERE name=?""", + (provider_type, api_key, base_url, + config or "{}", priority, is_active, name), + ) + else: + conn.execute( + """INSERT INTO ai_providers + (name, provider_type, api_key, base_url, config, priority, is_active) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (name, provider_type, api_key, base_url, + config or "{}", priority, is_active), + ) + conn.commit() + return True + except Exception as e: + logger.error(f"Failed to upsert provider: {e}") + return False + finally: + conn.close() + + +def upsert_model_mapping(model_alias, provider_id, real_model_id, cost_per_1k=0): + conn = get_db() + try: + existing = conn.execute( + "SELECT id FROM ai_model_mapping WHERE model_alias = ?", (model_alias,) + ).fetchone() + if existing: + conn.execute( + """UPDATE ai_model_mapping SET provider_id=?, real_model_id=?, + cost_per_1k_tokens=? WHERE model_alias=?""", + (provider_id, real_model_id, cost_per_1k, model_alias), + ) + else: + conn.execute( + """INSERT INTO ai_model_mapping + (model_alias, provider_id, real_model_id, cost_per_1k_tokens) + VALUES (?, ?, ?, ?)""", + (model_alias, provider_id, real_model_id, cost_per_1k), + ) + conn.commit() + return True + except Exception as e: + logger.error(f"Failed to upsert model mapping: {e}") + return False + finally: + conn.close() diff --git a/ai_router/providers/__init__.py b/ai_router/providers/__init__.py new file mode 100644 index 0000000..3387ded --- /dev/null +++ b/ai_router/providers/__init__.py @@ -0,0 +1,14 @@ +from .base import AIProvider +from .openai_adapter import OpenAIAdapter +from .anthropic_adapter import AnthropicAdapter +from .google_adapter import GoogleAdapter +from .mistral_adapter import MistralAdapter + +PROVIDER_MAP = { + "openai": OpenAIAdapter, + "anthropic": AnthropicAdapter, + "google": GoogleAdapter, + "mistral": MistralAdapter, +} + +__all__ = ["AIProvider", "PROVIDER_MAP", "OpenAIAdapter", "AnthropicAdapter", "GoogleAdapter", "MistralAdapter"] diff --git a/ai_router/providers/anthropic_adapter.py b/ai_router/providers/anthropic_adapter.py new file mode 100644 index 0000000..045889e --- /dev/null +++ b/ai_router/providers/anthropic_adapter.py @@ -0,0 +1,57 @@ +import logging +from typing import Optional + +from .base import AIProvider + +logger = logging.getLogger("ai_router.anthropic") + + +class AnthropicAdapter(AIProvider): + @property + def name(self) -> str: + return "anthropic" + + def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict: + from anthropic import Anthropic + + key = api_key or self.get_api_key() + client = Anthropic(api_key=key) + + system_msg = None + chat_messages = messages + if messages and messages[0].get("role") == "system": + system_msg = messages[0]["content"] + chat_messages = messages[1:] + + resp = client.messages.create( + model=model, + system=system_msg, + messages=[{"role": m["role"], "content": m["content"]} for m in chat_messages], + **{k: v for k, v in kwargs.items() if k in ("temperature", "max_tokens", "top_p")}, + ) + return { + "content": resp.content[0].text if resp.content else "", + "model": resp.model, + "provider": self.name, + "usage": { + "prompt_tokens": resp.usage.input_tokens if resp.usage else 0, + "completion_tokens": resp.usage.output_tokens if resp.usage else 0, + "total_tokens": (resp.usage.input_tokens + resp.usage.output_tokens) if resp.usage else 0, + }, + } + + def models(self) -> list: + from anthropic import Anthropic + + client = Anthropic(api_key=self.get_api_key()) + return [m.id for m in client.models.list()] + + def check_health(self) -> dict: + try: + from anthropic import Anthropic + client = Anthropic(api_key=self.get_api_key()) + client.models.list() + return {"status": "ok", "details": "API reachable"} + except Exception as e: + logger.warning(f"Anthropic health check failed: {e}") + return {"status": "error", "details": str(e)} diff --git a/ai_router/providers/base.py b/ai_router/providers/base.py new file mode 100644 index 0000000..8856d88 --- /dev/null +++ b/ai_router/providers/base.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class AIProvider(ABC): + @property + @abstractmethod + def name(self) -> str: + """Provider identifier (openai, anthropic, google, mistral).""" + + @abstractmethod + def chat(self, messages: list, model: str, **kwargs) -> dict: + """Send a chat completion request. Returns dict with at least: + { + "content": str, + "model": str, + "provider": self.name, + "usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int} + } + """ + + @abstractmethod + def models(self) -> list: + """Return list of available models from this provider.""" + + @abstractmethod + def check_health(self) -> dict: + """Check provider connectivity. Returns {"status": "ok"|"error", "details": str}""" + + def get_api_key(self, db_config: Optional[dict] = None) -> Optional[str]: + """Resolve API key: DB override > env var.""" + provider_env_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "mistral": "MISTRAL_API_KEY", + } + if db_config and db_config.get("api_key"): + return db_config["api_key"] + import os + env_var = provider_env_map.get(self.name) + if env_var: + return os.environ.get(env_var) + return None diff --git a/ai_router/providers/google_adapter.py b/ai_router/providers/google_adapter.py new file mode 100644 index 0000000..8036d54 --- /dev/null +++ b/ai_router/providers/google_adapter.py @@ -0,0 +1,57 @@ +import logging +from typing import Optional + +from .base import AIProvider + +logger = logging.getLogger("ai_router.google") + + +class GoogleAdapter(AIProvider): + @property + def name(self) -> str: + return "google" + + def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict: + from google import genai + + key = api_key or self.get_api_key() + client = genai.Client(api_key=key) + + system_instruction = None + chat_messages = messages + if messages and messages[0].get("role") == "system": + system_instruction = messages[0]["content"] + chat_messages = messages[1:] + + contents = [] + for m in chat_messages: + role = "user" if m["role"] in ("user", "system") else "model" + contents.append({"role": role, "parts": [{"text": m["content"]}]}) + + resp = client.models.generate_content( + model=model, + contents=contents, + config={"system_instruction": system_instruction} if system_instruction else None, + ) + return { + "content": resp.text or "", + "model": model, + "provider": self.name, + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + def models(self) -> list: + from google import genai + + client = genai.Client(api_key=self.get_api_key()) + return [m.name for m in client.models.list()] + + def check_health(self) -> dict: + try: + from google import genai + client = genai.Client(api_key=self.get_api_key()) + client.models.list() + return {"status": "ok", "details": "API reachable"} + except Exception as e: + logger.warning(f"Google health check failed: {e}") + return {"status": "error", "details": str(e)} diff --git a/ai_router/providers/mistral_adapter.py b/ai_router/providers/mistral_adapter.py new file mode 100644 index 0000000..48796e2 --- /dev/null +++ b/ai_router/providers/mistral_adapter.py @@ -0,0 +1,70 @@ +import json +import logging +from typing import Optional + +import requests + +from .base import AIProvider + +logger = logging.getLogger("ai_router.mistral") + +MISTRAL_API_BASE = "https://api.mistral.ai/v1" + + +class MistralAdapter(AIProvider): + @property + def name(self) -> str: + return "mistral" + + def _headers(self, api_key: Optional[str] = None) -> dict: + key = api_key or self.get_api_key() + return { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict: + key = api_key or self.get_api_key() + headers = self._headers(key) + + payload = { + "model": model, + "messages": messages, + } + for k in ("temperature", "max_tokens", "top_p", "stream"): + if k in kwargs: + payload[k] = kwargs[k] + + resp = requests.post( + f"{MISTRAL_API_BASE}/chat/completions", + headers=headers, + json=payload, + timeout=120, + ) + resp.raise_for_status() + data = resp.json() + choice = data["choices"][0] + return { + "content": choice["message"]["content"] or "", + "model": data.get("model", model), + "provider": self.name, + "usage": { + "prompt_tokens": data.get("usage", {}).get("prompt_tokens", 0), + "completion_tokens": data.get("usage", {}).get("completion_tokens", 0), + "total_tokens": data.get("usage", {}).get("total_tokens", 0), + }, + } + + def models(self) -> list: + resp = requests.get(f"{MISTRAL_API_BASE}/models", headers=self._headers(), timeout=30) + resp.raise_for_status() + return [m["id"] for m in resp.json().get("data", [])] + + def check_health(self) -> dict: + try: + resp = requests.get(f"{MISTRAL_API_BASE}/models", headers=self._headers(), timeout=10) + resp.raise_for_status() + return {"status": "ok", "details": "API reachable"} + except Exception as e: + logger.warning(f"Mistral health check failed: {e}") + return {"status": "error", "details": str(e)} diff --git a/ai_router/providers/openai_adapter.py b/ai_router/providers/openai_adapter.py new file mode 100644 index 0000000..c5163d7 --- /dev/null +++ b/ai_router/providers/openai_adapter.py @@ -0,0 +1,50 @@ +import logging +from typing import Optional + +from .base import AIProvider + +logger = logging.getLogger("ai_router.openai") + + +class OpenAIAdapter(AIProvider): + @property + def name(self) -> str: + return "openai" + + def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict: + from openai import OpenAI + + key = api_key or self.get_api_key() + client = OpenAI(api_key=key) + resp = client.chat.completions.create( + model=model, + messages=messages, + **{k: v for k, v in kwargs.items() if k in ("temperature", "max_tokens", "top_p", "stream")}, + ) + choice = resp.choices[0] + return { + "content": choice.message.content or "", + "model": resp.model, + "provider": self.name, + "usage": { + "prompt_tokens": resp.usage.prompt_tokens if resp.usage else 0, + "completion_tokens": resp.usage.completion_tokens if resp.usage else 0, + "total_tokens": resp.usage.total_tokens if resp.usage else 0, + }, + } + + def models(self) -> list: + from openai import OpenAI + + client = OpenAI(api_key=self.get_api_key()) + return [m.id for m in client.models.list()] + + def check_health(self) -> dict: + try: + from openai import OpenAI + client = OpenAI(api_key=self.get_api_key()) + client.models.list() + return {"status": "ok", "details": "API reachable"} + except Exception as e: + logger.warning(f"OpenAI health check failed: {e}") + return {"status": "error", "details": str(e)} diff --git a/ai_router/router.py b/ai_router/router.py new file mode 100644 index 0000000..1977750 --- /dev/null +++ b/ai_router/router.py @@ -0,0 +1,174 @@ +import logging +import time +import uuid +from typing import Optional + +from .providers import PROVIDER_MAP, AIProvider +from .models import get_providers_from_db, log_router_attempt + +logger = logging.getLogger("ai_router.router") + +DEFAULT_MODEL_MAP = { + "gpt-4o": {"provider": "openai", "real_model": "gpt-4o"}, + "gpt-4o-mini": {"provider": "openai", "real_model": "gpt-4o-mini"}, + "claude-3-opus": {"provider": "anthropic", "real_model": "claude-3-opus-20240229"}, + "claude-3-sonnet": {"provider": "anthropic", "real_model": "claude-3-sonnet-20240229"}, + "claude-3-haiku": {"provider": "anthropic", "real_model": "claude-3-haiku-20240307"}, + "gemini-pro": {"provider": "google", "real_model": "gemini-1.5-pro"}, + "gemini-flash": {"provider": "google", "real_model": "gemini-1.5-flash"}, + "mistral-large": {"provider": "mistral", "real_model": "mistral-large-latest"}, + "mistral-small": {"provider": "mistral", "real_model": "mistral-small-latest"}, +} + + +class AIRouter: + def __init__(self): + self._provider_instances = {} + + def get_provider(self, name: str) -> Optional[AIProvider]: + if name not in self._provider_instances: + cls = PROVIDER_MAP.get(name) + if not cls: + return None + self._provider_instances[name] = cls() + return self._provider_instances[name] + + def _resolve_model(self, model_alias: str) -> Optional[dict]: + mapping = self._load_model_mappings() + return mapping.get(model_alias) + + def _load_model_mappings(self) -> dict: + db_mappings = [] + try: + db_mappings = get_providers_from_db() + except Exception as e: + logger.warning(f"Could not load model mappings from DB: {e}") + + merged = dict(DEFAULT_MODEL_MAP) + for entry in db_mappings: + alias = entry.get("model_alias") + if alias: + merged[alias] = { + "provider": entry["provider_type"], + "real_model": entry.get("real_model_id", alias), + "cost_per_1k": entry.get("cost_per_1k_tokens", 0), + "db_config": entry, + } + return merged + + def _get_prioritized_providers(self): + providers = [] + try: + db_providers = get_providers_from_db() + seen_names = set() + for p in sorted(db_providers, key=lambda x: x.get("priority", 99)): + name = p["provider_type"] + if name not in seen_names: + seen_names.add(name) + providers.append((name, p)) + except Exception as e: + logger.warning(f"Could not load provider priority from DB: {e}") + + if not providers: + default_order = ["openai", "anthropic", "google", "mistral"] + providers = [(n, None) for n in default_order] + return providers + + def chat(self, messages: list, model_alias: str, user_id: Optional[int] = None, **kwargs) -> dict: + request_id = str(uuid.uuid4()) + start_time = time.time() + + model_info = self._resolve_model(model_alias) + if not model_info: + return {"error": f"Unknown model: {model_alias}", "status": "error"} + + provider_order = self._get_prioritized_providers() + preferred_provider = model_info["provider"] + real_model = model_info["real_model"] + + ordered = [] + for name, db_config in provider_order: + if name == preferred_provider: + ordered.insert(0, (name, db_config)) + else: + ordered.append((name, db_config)) + + if preferred_provider not in [p[0] for p in ordered]: + ordered.insert(0, (preferred_provider, None)) + + last_error = None + for attempt, (provider_name, db_config) in enumerate(ordered): + provider = self.get_provider(provider_name) + if not provider: + continue + + try: + if attempt > 0: + backoff = min(2 ** (attempt - 1), 30) + logger.info(f"Failover to {provider_name} after {backoff}s backoff (attempt {attempt})") + time.sleep(backoff) + + api_key = provider.get_api_key(db_config) + if not api_key: + logger.warning(f"No API key configured for {provider_name}, skipping") + continue + + result = provider.chat(messages, model=real_model, api_key=api_key, **kwargs) + elapsed = int((time.time() - start_time) * 1000) + + log_router_attempt( + request_id=request_id, + user_id=user_id, + model_alias=model_alias, + provider_used=provider_name, + tokens_in=result.get("usage", {}).get("prompt_tokens", 0), + tokens_out=result.get("usage", {}).get("completion_tokens", 0), + duration_ms=elapsed, + status="success", + ) + result["request_id"] = request_id + result["status"] = "success" + return result + + except Exception as e: + last_error = str(e) + elapsed = int((time.time() - start_time) * 1000) + logger.warning(f"Provider {provider_name} failed: {e}") + log_router_attempt( + request_id=request_id, + user_id=user_id, + model_alias=model_alias, + provider_used=provider_name, + tokens_in=0, + tokens_out=0, + duration_ms=elapsed, + status="error", + error_message=last_error, + ) + + elapsed = int((time.time() - start_time) * 1000) + return { + "error": f"All providers failed. Last error: {last_error}", + "status": "error", + "request_id": request_id, + "duration_ms": elapsed, + } + + def check_all_providers_health(self) -> dict: + results = {} + for name in PROVIDER_MAP: + provider = self.get_provider(name) + results[name] = provider.check_health() + return results + + def list_available_models(self) -> list: + model_map = self._load_model_mappings() + return [ + { + "alias": alias, + "provider": info["provider"], + "real_model": info["real_model"], + "cost_per_1k_tokens": info.get("cost_per_1k", 0), + } + for alias, info in model_map.items() + ] diff --git a/ai_router/utils.py b/ai_router/utils.py new file mode 100644 index 0000000..a7d83af --- /dev/null +++ b/ai_router/utils.py @@ -0,0 +1,93 @@ +import logging +import os +import sys +from functools import wraps + +from flask import request, jsonify + +logger = logging.getLogger("ai_router") + +TOKEN_BROKER_URL = os.environ.get( + "TOKEN_BROKER_URL", "http://localhost:8783" +) + + +def verify_token_via_broker(token: str) -> dict: + """Verify an API token via the token-broker /verify endpoint.""" + import requests + try: + resp = requests.post( + f"{TOKEN_BROKER_URL}/api/v1/tokens/verify", + json={"token": token}, + timeout=10, + ) + if resp.status_code == 200: + data = resp.json() + if data.get("valid"): + return data + return {} + except requests.RequestException as e: + logger.warning(f"Token broker unreachable: {e}") + return {} + + +def require_auth(f): + """Decorator: validate Bearer or X-API-Key via token-broker.""" + @wraps(f) + def decorated(*args, **kwargs): + auth_header = request.headers.get("Authorization", "") + api_key = request.headers.get("X-API-Key", "") + + raw_token = "" + if auth_header.startswith("Bearer "): + raw_token = auth_header.split(" ", 1)[1] + elif api_key: + raw_token = api_key + + if not raw_token: + return jsonify({"error": "Authentication required"}), 401 + + payload = verify_token_via_broker(raw_token) + if not payload or not payload.get("valid"): + return jsonify({"error": "Invalid or expired token"}), 401 + + request.current_user = { + "user_id": payload.get("user_id"), + "token_id": payload.get("token_id"), + "scopes": payload.get("scopes", []), + } + return f(*args, **kwargs) + return decorated + + +def admin_required(f): + """Decorator: require admin scope on the authenticated token.""" + @wraps(f) + def decorated(*args, **kwargs): + auth_header = request.headers.get("Authorization", "") + api_key = request.headers.get("X-API-Key", "") + + raw_token = "" + if auth_header.startswith("Bearer "): + raw_token = auth_header.split(" ", 1)[1] + elif api_key: + raw_token = api_key + + if not raw_token: + return jsonify({"error": "Authentication required"}), 401 + + payload = verify_token_via_broker(raw_token) + if not payload or not payload.get("valid"): + return jsonify({"error": "Invalid or expired token"}), 401 + + scopes = payload.get("scopes", []) + if "admin" not in scopes and "ai_router_admin" not in scopes: + return jsonify({"error": "Admin access required"}), 403 + + request.current_user = { + "user_id": payload.get("user_id"), + "token_id": payload.get("token_id"), + "scopes": scopes, + } + return f(*args, **kwargs) + return decorated diff --git a/ai_router_api.py b/ai_router_api.py new file mode 100644 index 0000000..2221322 --- /dev/null +++ b/ai_router_api.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +AI Router API — Multi-provider LLM routing with failover +Port: 8783 | DB: SQLite ai_router.db +HRT-200 — AI Router (Multi-provider + failover) + +Endpoints: + GET /api/v1/ai/health — Check all providers health + GET /api/v1/ai/models — List available models + POST /api/v1/ai/chat — Chat completion with auto-failover + GET /api/v1/ai/admin/providers — List configured providers + POST /api/v1/ai/admin/providers — Upsert a provider + POST /api/v1/ai/admin/model-mappings— Upsert a model mapping + DELETE /api/v1/ai/admin/providers/:id — Remove a provider + GET /api/v1/ai/usage — Usage logs + GET /api/v1/ai/usage/summary — Aggregated usage stats +""" + +import logging +import logging.handlers +import os +import sys + +from flask import Flask, jsonify +from flask_cors import CORS + +LOG_DIR = os.path.join(os.path.dirname(__file__), "ai_router", "logs") +os.makedirs(LOG_DIR, exist_ok=True) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] ai-router: %(name)s: %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout), + logging.handlers.RotatingFileHandler( + os.path.join(LOG_DIR, "ai_router.log"), + maxBytes=5 * 1024 * 1024, + backupCount=3, + ), + ], +) +logger = logging.getLogger("ai_router") + +PORT = int(os.environ.get("AI_ROUTER_PORT", "8783")) + + +def create_app(): + app = Flask(__name__) + CORS(app) + + from ai_router.api import register_ai_router + from ai_router.models import init_db + + init_db() + register_ai_router(app) + + @app.errorhandler(404) + def not_found(e): + return jsonify({"error": "not_found", "message": "Route not found"}), 404 + + @app.errorhandler(500) + def internal_error(e): + logger.error(f"Internal error: {e}") + return jsonify({"error": "internal_error", "message": "Internal server error"}), 500 + + return app + + +if __name__ == "__main__": + logger.info("=" * 60) + logger.info("AI Router API starting...") + logger.info(f"Port: {PORT}") + logger.info("=" * 60) + + app = create_app() + debug = os.environ.get("FLASK_ENV", "production") == "development" + app.run(host="0.0.0.0", port=PORT, debug=debug) diff --git a/tests/test_ai_router.py b/tests/test_ai_router.py new file mode 100644 index 0000000..e42c911 --- /dev/null +++ b/tests/test_ai_router.py @@ -0,0 +1,255 @@ +"""Unit tests for AI Router — router, providers, models, API.""" + +import json +import os +import tempfile +import pytest +from unittest.mock import patch, MagicMock + +TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db") +os.environ["AI_ROUTER_DB"] = TEST_DB +os.environ["OPENAI_API_KEY"] = "sk-test-openai" +os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic" +os.environ["GOOGLE_API_KEY"] = "sk-test-google" +os.environ["MISTRAL_API_KEY"] = "sk-test-mistral" + + +@pytest.fixture(autouse=True) +def clean_db(): + yield + try: + os.remove(TEST_DB) + except OSError: + pass + + +@pytest.fixture +def app(): + from ai_router_api import create_app + app = create_app() + app.config["TESTING"] = True + return app + + +@pytest.fixture +def client(app): + return app.test_client() + + +@pytest.fixture +def router(): + from ai_router.router import AIRouter + return AIRouter() + + +# ───────────────────────────────────────────── +# Provider Base Tests +# ───────────────────────────────────────────── + + +class TestProviderInterface: + def test_provider_map_has_all(self): + from ai_router.providers import PROVIDER_MAP + assert "openai" in PROVIDER_MAP + assert "anthropic" in PROVIDER_MAP + assert "google" in PROVIDER_MAP + assert "mistral" in PROVIDER_MAP + + def test_api_key_resolution_env(self): + from ai_router.providers.base import AIProvider + class TestProvider(AIProvider): + @property + def name(self): return "openai" + def chat(self, messages, model, **kwargs): return {} + def models(self): return [] + def check_health(self): return {"status": "ok"} + + p = TestProvider() + assert p.get_api_key() == "sk-test-openai" + + def test_api_key_resolution_db_overrides_env(self): + from ai_router.providers.base import AIProvider + class TestProvider(AIProvider): + @property + def name(self): return "openai" + def chat(self, messages, model, **kwargs): return {} + def models(self): return [] + def check_health(self): return {"status": "ok"} + + p = TestProvider() + assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key" + + +# ───────────────────────────────────────────── +# Router Tests +# ───────────────────────────────────────────── + + +class TestRouter: + def test_resolve_known_model(self, router): + info = router._resolve_model("gpt-4o") + assert info is not None + assert info["provider"] == "openai" + assert info["real_model"] == "gpt-4o" + + def test_resolve_unknown_model(self, router): + info = router._resolve_model("nonexistent-model") + assert info is None + + def test_list_models_includes_defaults(self, router): + models = router.list_available_models() + aliases = [m["alias"] for m in models] + assert "gpt-4o" in aliases + assert "claude-3-opus" in aliases + assert "gemini-pro" in aliases + assert "mistral-large" in aliases + + def test_prioritized_providers_default_order(self, router): + providers = router._get_prioritized_providers() + names = [p[0] for p in providers] + assert names == ["openai", "anthropic", "google", "mistral"] + + def test_chat_unknown_model(self, router): + result = router.chat([{"role": "user", "content": "hi"}], "unknown-model") + assert result["status"] == "error" + + @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") + def test_chat_success_first_provider(self, mock_chat, router): + mock_chat.return_value = { + "content": "Hello!", + "model": "gpt-4o", + "provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o") + assert result["status"] == "success" + assert result["content"] == "Hello!" + assert result["provider"] == "openai" + + @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") + @patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat") + def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router): + mock_openai.side_effect = Exception("OpenAI down") + mock_anthropic.return_value = { + "content": "Hello from Anthropic!", + "model": "claude-3-sonnet-20240229", + "provider": "anthropic", + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet") + assert result["status"] == "success" + assert result["provider"] == "anthropic" + + @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") + @patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat") + @patch("ai_router.providers.google_adapter.GoogleAdapter.chat") + @patch("ai_router.providers.mistral_adapter.MistralAdapter.chat") + def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router): + for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral): + mock.side_effect = Exception("Provider unavailable") + result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o") + assert result["status"] == "error" + assert "All providers failed" in result["error"] + + def test_health_all_providers_returns_dict(self, router): + health = router.check_all_providers_health() + for name in ("openai", "anthropic", "google", "mistral"): + assert name in health + assert "status" in health[name] + + +# ───────────────────────────────────────────── +# Database Tests +# ───────────────────────────────────────────── + + +class TestModels: + def test_init_db_creates_tables(self): + from ai_router.models import init_db, get_db + init_db() + conn = get_db() + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + names = [r[0] for r in tables] + assert "ai_providers" in names + assert "ai_model_mapping" in names + assert "ai_router_log" in names + conn.close() + + def test_upsert_and_get_providers(self): + from ai_router.models import init_db, upsert_provider, get_providers_from_db + init_db() + upsert_provider("Test OpenAI", "openai", "sk-test", priority=1) + providers = get_providers_from_db() + assert len(providers) > 0 + assert any(p["name"] == "Test OpenAI" for p in providers) + + def test_log_router_attempt(self): + from ai_router.models import init_db, log_router_attempt, get_db + init_db() + log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success") + conn = get_db() + rows = conn.execute("SELECT * FROM ai_router_log").fetchall() + assert len(rows) == 1 + assert rows[0]["status"] == "success" + conn.close() + + +# ───────────────────────────────────────────── +# API Blueprint Tests +# ───────────────────────────────────────────── + + +class TestAPI: + def test_health_endpoint(self, client): + resp = client.get("/api/v1/ai/health") + assert resp.status_code in (200, 503) + data = resp.get_json() + assert data["service"] == "ai-router" + assert "providers" in data + + def test_models_endpoint(self, client): + resp = client.get("/api/v1/ai/models") + assert resp.status_code == 200 + data = resp.get_json() + assert "models" in data + assert len(data["models"]) > 0 + + def test_chat_no_auth(self, client): + resp = client.post( + "/api/v1/ai/chat", + json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"}, + ) + assert resp.status_code == 401 + + def test_chat_no_messages(self, client): + resp = client.post( + "/api/v1/ai/chat", + json={"model": "gpt-4o"}, + headers={"X-API-Key": "test-key"}, + ) + assert resp.status_code == 401 + + @patch("ai_router.utils.verify_token_via_broker") + @patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat") + def test_chat_success_with_auth(self, mock_chat, mock_verify, client): + mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]} + mock_chat.return_value = { + "content": "Hello!", + "model": "gpt-4o", + "provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + resp = client.post( + "/api/v1/ai/chat", + json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"}, + headers={"Authorization": "Bearer test-token"}, + ) + data = resp.get_json() + assert resp.status_code == 200, f"Got {resp.status_code}: {data}" + assert data["status"] == "success" + + def test_usage_requires_admin(self, client): + resp = client.get("/api/v1/ai/usage") + assert resp.status_code == 401