- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP) - Core router with automatic failover + exponential backoff - Flask blueprint with /api/v1/ai/* endpoints - Auth via token-broker verify endpoint - DB models for ai_providers, ai_model_mapping, ai_router_log - /health endpoint (parallel provider check), /usage stats - 21 unit tests (all passing)
175 lines
6.5 KiB
Python
175 lines
6.5 KiB
Python
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from .providers import PROVIDER_MAP, AIProvider
|
|
from .models import get_providers_from_db, log_router_attempt
|
|
|
|
logger = logging.getLogger("ai_router.router")
|
|
|
|
DEFAULT_MODEL_MAP = {
|
|
"gpt-4o": {"provider": "openai", "real_model": "gpt-4o"},
|
|
"gpt-4o-mini": {"provider": "openai", "real_model": "gpt-4o-mini"},
|
|
"claude-3-opus": {"provider": "anthropic", "real_model": "claude-3-opus-20240229"},
|
|
"claude-3-sonnet": {"provider": "anthropic", "real_model": "claude-3-sonnet-20240229"},
|
|
"claude-3-haiku": {"provider": "anthropic", "real_model": "claude-3-haiku-20240307"},
|
|
"gemini-pro": {"provider": "google", "real_model": "gemini-1.5-pro"},
|
|
"gemini-flash": {"provider": "google", "real_model": "gemini-1.5-flash"},
|
|
"mistral-large": {"provider": "mistral", "real_model": "mistral-large-latest"},
|
|
"mistral-small": {"provider": "mistral", "real_model": "mistral-small-latest"},
|
|
}
|
|
|
|
|
|
class AIRouter:
|
|
def __init__(self):
|
|
self._provider_instances = {}
|
|
|
|
def get_provider(self, name: str) -> Optional[AIProvider]:
|
|
if name not in self._provider_instances:
|
|
cls = PROVIDER_MAP.get(name)
|
|
if not cls:
|
|
return None
|
|
self._provider_instances[name] = cls()
|
|
return self._provider_instances[name]
|
|
|
|
def _resolve_model(self, model_alias: str) -> Optional[dict]:
|
|
mapping = self._load_model_mappings()
|
|
return mapping.get(model_alias)
|
|
|
|
def _load_model_mappings(self) -> dict:
|
|
db_mappings = []
|
|
try:
|
|
db_mappings = get_providers_from_db()
|
|
except Exception as e:
|
|
logger.warning(f"Could not load model mappings from DB: {e}")
|
|
|
|
merged = dict(DEFAULT_MODEL_MAP)
|
|
for entry in db_mappings:
|
|
alias = entry.get("model_alias")
|
|
if alias:
|
|
merged[alias] = {
|
|
"provider": entry["provider_type"],
|
|
"real_model": entry.get("real_model_id", alias),
|
|
"cost_per_1k": entry.get("cost_per_1k_tokens", 0),
|
|
"db_config": entry,
|
|
}
|
|
return merged
|
|
|
|
def _get_prioritized_providers(self):
|
|
providers = []
|
|
try:
|
|
db_providers = get_providers_from_db()
|
|
seen_names = set()
|
|
for p in sorted(db_providers, key=lambda x: x.get("priority", 99)):
|
|
name = p["provider_type"]
|
|
if name not in seen_names:
|
|
seen_names.add(name)
|
|
providers.append((name, p))
|
|
except Exception as e:
|
|
logger.warning(f"Could not load provider priority from DB: {e}")
|
|
|
|
if not providers:
|
|
default_order = ["openai", "anthropic", "google", "mistral"]
|
|
providers = [(n, None) for n in default_order]
|
|
return providers
|
|
|
|
def chat(self, messages: list, model_alias: str, user_id: Optional[int] = None, **kwargs) -> dict:
|
|
request_id = str(uuid.uuid4())
|
|
start_time = time.time()
|
|
|
|
model_info = self._resolve_model(model_alias)
|
|
if not model_info:
|
|
return {"error": f"Unknown model: {model_alias}", "status": "error"}
|
|
|
|
provider_order = self._get_prioritized_providers()
|
|
preferred_provider = model_info["provider"]
|
|
real_model = model_info["real_model"]
|
|
|
|
ordered = []
|
|
for name, db_config in provider_order:
|
|
if name == preferred_provider:
|
|
ordered.insert(0, (name, db_config))
|
|
else:
|
|
ordered.append((name, db_config))
|
|
|
|
if preferred_provider not in [p[0] for p in ordered]:
|
|
ordered.insert(0, (preferred_provider, None))
|
|
|
|
last_error = None
|
|
for attempt, (provider_name, db_config) in enumerate(ordered):
|
|
provider = self.get_provider(provider_name)
|
|
if not provider:
|
|
continue
|
|
|
|
try:
|
|
if attempt > 0:
|
|
backoff = min(2 ** (attempt - 1), 30)
|
|
logger.info(f"Failover to {provider_name} after {backoff}s backoff (attempt {attempt})")
|
|
time.sleep(backoff)
|
|
|
|
api_key = provider.get_api_key(db_config)
|
|
if not api_key:
|
|
logger.warning(f"No API key configured for {provider_name}, skipping")
|
|
continue
|
|
|
|
result = provider.chat(messages, model=real_model, api_key=api_key, **kwargs)
|
|
elapsed = int((time.time() - start_time) * 1000)
|
|
|
|
log_router_attempt(
|
|
request_id=request_id,
|
|
user_id=user_id,
|
|
model_alias=model_alias,
|
|
provider_used=provider_name,
|
|
tokens_in=result.get("usage", {}).get("prompt_tokens", 0),
|
|
tokens_out=result.get("usage", {}).get("completion_tokens", 0),
|
|
duration_ms=elapsed,
|
|
status="success",
|
|
)
|
|
result["request_id"] = request_id
|
|
result["status"] = "success"
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
elapsed = int((time.time() - start_time) * 1000)
|
|
logger.warning(f"Provider {provider_name} failed: {e}")
|
|
log_router_attempt(
|
|
request_id=request_id,
|
|
user_id=user_id,
|
|
model_alias=model_alias,
|
|
provider_used=provider_name,
|
|
tokens_in=0,
|
|
tokens_out=0,
|
|
duration_ms=elapsed,
|
|
status="error",
|
|
error_message=last_error,
|
|
)
|
|
|
|
elapsed = int((time.time() - start_time) * 1000)
|
|
return {
|
|
"error": f"All providers failed. Last error: {last_error}",
|
|
"status": "error",
|
|
"request_id": request_id,
|
|
"duration_ms": elapsed,
|
|
}
|
|
|
|
def check_all_providers_health(self) -> dict:
|
|
results = {}
|
|
for name in PROVIDER_MAP:
|
|
provider = self.get_provider(name)
|
|
results[name] = provider.check_health()
|
|
return results
|
|
|
|
def list_available_models(self) -> list:
|
|
model_map = self._load_model_mappings()
|
|
return [
|
|
{
|
|
"alias": alias,
|
|
"provider": info["provider"],
|
|
"real_model": info["real_model"],
|
|
"cost_per_1k_tokens": info.get("cost_per_1k", 0),
|
|
}
|
|
for alias, info in model_map.items()
|
|
]
|