Files
turf_saas/ai_router/router.py
CTO H3R7Tech 4b766cb908 feat(HRT-200): AI Router — Multi-provider LLM routing with failover
- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP)
- Core router with automatic failover + exponential backoff
- Flask blueprint with /api/v1/ai/* endpoints
- Auth via token-broker verify endpoint
- DB models for ai_providers, ai_model_mapping, ai_router_log
- /health endpoint (parallel provider check), /usage stats
- 21 unit tests (all passing)
2026-05-24 10:21:36 +02:00

175 lines
6.5 KiB
Python

import logging
import time
import uuid
from typing import Optional
from .providers import PROVIDER_MAP, AIProvider
from .models import get_providers_from_db, log_router_attempt
logger = logging.getLogger("ai_router.router")
DEFAULT_MODEL_MAP = {
"gpt-4o": {"provider": "openai", "real_model": "gpt-4o"},
"gpt-4o-mini": {"provider": "openai", "real_model": "gpt-4o-mini"},
"claude-3-opus": {"provider": "anthropic", "real_model": "claude-3-opus-20240229"},
"claude-3-sonnet": {"provider": "anthropic", "real_model": "claude-3-sonnet-20240229"},
"claude-3-haiku": {"provider": "anthropic", "real_model": "claude-3-haiku-20240307"},
"gemini-pro": {"provider": "google", "real_model": "gemini-1.5-pro"},
"gemini-flash": {"provider": "google", "real_model": "gemini-1.5-flash"},
"mistral-large": {"provider": "mistral", "real_model": "mistral-large-latest"},
"mistral-small": {"provider": "mistral", "real_model": "mistral-small-latest"},
}
class AIRouter:
def __init__(self):
self._provider_instances = {}
def get_provider(self, name: str) -> Optional[AIProvider]:
if name not in self._provider_instances:
cls = PROVIDER_MAP.get(name)
if not cls:
return None
self._provider_instances[name] = cls()
return self._provider_instances[name]
def _resolve_model(self, model_alias: str) -> Optional[dict]:
mapping = self._load_model_mappings()
return mapping.get(model_alias)
def _load_model_mappings(self) -> dict:
db_mappings = []
try:
db_mappings = get_providers_from_db()
except Exception as e:
logger.warning(f"Could not load model mappings from DB: {e}")
merged = dict(DEFAULT_MODEL_MAP)
for entry in db_mappings:
alias = entry.get("model_alias")
if alias:
merged[alias] = {
"provider": entry["provider_type"],
"real_model": entry.get("real_model_id", alias),
"cost_per_1k": entry.get("cost_per_1k_tokens", 0),
"db_config": entry,
}
return merged
def _get_prioritized_providers(self):
providers = []
try:
db_providers = get_providers_from_db()
seen_names = set()
for p in sorted(db_providers, key=lambda x: x.get("priority", 99)):
name = p["provider_type"]
if name not in seen_names:
seen_names.add(name)
providers.append((name, p))
except Exception as e:
logger.warning(f"Could not load provider priority from DB: {e}")
if not providers:
default_order = ["openai", "anthropic", "google", "mistral"]
providers = [(n, None) for n in default_order]
return providers
def chat(self, messages: list, model_alias: str, user_id: Optional[int] = None, **kwargs) -> dict:
request_id = str(uuid.uuid4())
start_time = time.time()
model_info = self._resolve_model(model_alias)
if not model_info:
return {"error": f"Unknown model: {model_alias}", "status": "error"}
provider_order = self._get_prioritized_providers()
preferred_provider = model_info["provider"]
real_model = model_info["real_model"]
ordered = []
for name, db_config in provider_order:
if name == preferred_provider:
ordered.insert(0, (name, db_config))
else:
ordered.append((name, db_config))
if preferred_provider not in [p[0] for p in ordered]:
ordered.insert(0, (preferred_provider, None))
last_error = None
for attempt, (provider_name, db_config) in enumerate(ordered):
provider = self.get_provider(provider_name)
if not provider:
continue
try:
if attempt > 0:
backoff = min(2 ** (attempt - 1), 30)
logger.info(f"Failover to {provider_name} after {backoff}s backoff (attempt {attempt})")
time.sleep(backoff)
api_key = provider.get_api_key(db_config)
if not api_key:
logger.warning(f"No API key configured for {provider_name}, skipping")
continue
result = provider.chat(messages, model=real_model, api_key=api_key, **kwargs)
elapsed = int((time.time() - start_time) * 1000)
log_router_attempt(
request_id=request_id,
user_id=user_id,
model_alias=model_alias,
provider_used=provider_name,
tokens_in=result.get("usage", {}).get("prompt_tokens", 0),
tokens_out=result.get("usage", {}).get("completion_tokens", 0),
duration_ms=elapsed,
status="success",
)
result["request_id"] = request_id
result["status"] = "success"
return result
except Exception as e:
last_error = str(e)
elapsed = int((time.time() - start_time) * 1000)
logger.warning(f"Provider {provider_name} failed: {e}")
log_router_attempt(
request_id=request_id,
user_id=user_id,
model_alias=model_alias,
provider_used=provider_name,
tokens_in=0,
tokens_out=0,
duration_ms=elapsed,
status="error",
error_message=last_error,
)
elapsed = int((time.time() - start_time) * 1000)
return {
"error": f"All providers failed. Last error: {last_error}",
"status": "error",
"request_id": request_id,
"duration_ms": elapsed,
}
def check_all_providers_health(self) -> dict:
results = {}
for name in PROVIDER_MAP:
provider = self.get_provider(name)
results[name] = provider.check_health()
return results
def list_available_models(self) -> list:
model_map = self._load_model_mappings()
return [
{
"alias": alias,
"provider": info["provider"],
"real_model": info["real_model"],
"cost_per_1k_tokens": info.get("cost_per_1k", 0),
}
for alias, info in model_map.items()
]