- 4 provider adapters: OpenAI (SDK), Anthropic (SDK), Google (google-genai), Mistral (direct HTTP) - Core router with automatic failover + exponential backoff - Flask blueprint with /api/v1/ai/* endpoints - Auth via token-broker verify endpoint - DB models for ai_providers, ai_model_mapping, ai_router_log - /health endpoint (parallel provider check), /usage stats - 21 unit tests (all passing)
173 lines
5.6 KiB
Python
173 lines
5.6 KiB
Python
"""Flask Blueprint for AI Router — chat, health, models, admin."""
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from flask import Blueprint, jsonify, request
|
|
|
|
from .router import AIRouter
|
|
from .models import init_db, upsert_provider, upsert_model_mapping
|
|
from .utils import require_auth, admin_required
|
|
|
|
logger = logging.getLogger("ai_router.api")
|
|
|
|
ai_router_bp = Blueprint("ai_router", __name__, url_prefix="/api/v1/ai")
|
|
|
|
|
|
def register_ai_router(app):
|
|
app.register_blueprint(ai_router_bp)
|
|
|
|
|
|
_router = AIRouter()
|
|
|
|
|
|
@ai_router_bp.route("/health", methods=["GET"])
|
|
def health():
|
|
health_data = _router.check_all_providers_health()
|
|
all_ok = all(v["status"] == "ok" for v in health_data.values())
|
|
return jsonify({
|
|
"status": "ok" if all_ok else "degraded",
|
|
"service": "ai-router",
|
|
"version": "1.0.0",
|
|
"providers": health_data,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}), 200 if all_ok else 503
|
|
|
|
|
|
@ai_router_bp.route("/models", methods=["GET"])
|
|
def list_models():
|
|
models = _router.list_available_models()
|
|
return jsonify({"models": models})
|
|
|
|
|
|
@ai_router_bp.route("/chat", methods=["POST"])
|
|
@require_auth
|
|
def chat():
|
|
data = request.get_json(silent=True) or {}
|
|
messages = data.get("messages", [])
|
|
model = data.get("model", "gpt-4o-mini")
|
|
user_id = (request.current_user or {}).get("user_id")
|
|
|
|
if not messages:
|
|
return jsonify({"error": "messages field is required"}), 400
|
|
|
|
kwargs = {k: data[k] for k in ("temperature", "max_tokens", "top_p", "stream") if k in data}
|
|
|
|
result = _router.chat(messages=messages, model_alias=model, user_id=user_id, **kwargs)
|
|
|
|
if result.get("status") == "error":
|
|
code = 503 if "All providers failed" in result.get("error", "") else 400
|
|
return jsonify(result), code
|
|
|
|
return jsonify(result), 200
|
|
|
|
|
|
@ai_router_bp.route("/admin/providers", methods=["GET"])
|
|
@admin_required
|
|
def list_providers():
|
|
from .models import get_db
|
|
conn = get_db()
|
|
try:
|
|
rows = conn.execute(
|
|
"SELECT id, name, provider_type, base_url, priority, is_active, created_at, updated_at "
|
|
"FROM ai_providers ORDER BY priority"
|
|
).fetchall()
|
|
return jsonify({"providers": [dict(r) for r in rows]})
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@ai_router_bp.route("/admin/providers", methods=["POST"])
|
|
@admin_required
|
|
def upsert_provider_endpoint():
|
|
data = request.get_json(silent=True) or {}
|
|
name = data.get("name", "")
|
|
provider_type = data.get("provider_type", "")
|
|
api_key = data.get("api_key", "")
|
|
base_url = data.get("base_url", "")
|
|
priority = data.get("priority", 99)
|
|
|
|
if not name or provider_type not in ("openai", "anthropic", "google", "mistral"):
|
|
return jsonify({"error": "Valid name and provider_type required"}), 400
|
|
|
|
ok = upsert_provider(name, provider_type, api_key, base_url, priority=priority)
|
|
if ok:
|
|
return jsonify({"status": "ok", "message": f"Provider {name} saved"}), 200
|
|
return jsonify({"error": "Failed to save provider"}), 500
|
|
|
|
|
|
@ai_router_bp.route("/admin/model-mappings", methods=["POST"])
|
|
@admin_required
|
|
def upsert_model_mapping_endpoint():
|
|
data = request.get_json(silent=True) or {}
|
|
model_alias = data.get("model_alias", "")
|
|
provider_id = data.get("provider_id")
|
|
real_model_id = data.get("real_model_id", "")
|
|
cost = data.get("cost_per_1k_tokens", 0)
|
|
|
|
if not model_alias or not provider_id or not real_model_id:
|
|
return jsonify({"error": "model_alias, provider_id, real_model_id required"}), 400
|
|
|
|
ok = upsert_model_mapping(model_alias, provider_id, real_model_id, cost)
|
|
if ok:
|
|
return jsonify({"status": "ok", "message": f"Mapping for {model_alias} saved"}), 200
|
|
return jsonify({"error": "Failed to save model mapping"}), 500
|
|
|
|
|
|
@ai_router_bp.route("/admin/providers/<int:provider_id>", methods=["DELETE"])
|
|
@admin_required
|
|
def delete_provider(provider_id):
|
|
from .models import get_db
|
|
conn = get_db()
|
|
try:
|
|
conn.execute("DELETE FROM ai_model_mapping WHERE provider_id = ?", (provider_id,))
|
|
conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,))
|
|
conn.commit()
|
|
return jsonify({"status": "deleted", "provider_id": provider_id})
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@ai_router_bp.route("/usage", methods=["GET"])
|
|
@admin_required
|
|
def usage_stats():
|
|
from .models import get_db
|
|
conn = get_db()
|
|
try:
|
|
limit = request.args.get("limit", 50, type=int)
|
|
rows = conn.execute(
|
|
"SELECT * FROM ai_router_log ORDER BY created_at DESC LIMIT ?", (limit,)
|
|
).fetchall()
|
|
return jsonify({"usage": [dict(r) for r in rows]})
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@ai_router_bp.route("/usage/summary", methods=["GET"])
|
|
@admin_required
|
|
def usage_summary():
|
|
from .models import get_db
|
|
conn = get_db()
|
|
try:
|
|
agg = conn.execute("""
|
|
SELECT provider_used, status, COUNT(*) as count,
|
|
SUM(duration_ms) as total_ms, SUM(tokens_in + tokens_out) as total_tokens
|
|
FROM ai_router_log
|
|
GROUP BY provider_used, status
|
|
ORDER BY provider_used
|
|
""").fetchall()
|
|
totals = conn.execute("""
|
|
SELECT COUNT(*) as total_requests,
|
|
SUM(CASE WHEN status='success' THEN 1 ELSE 0 END) as success_count,
|
|
SUM(tokens_in + tokens_out) as total_tokens
|
|
FROM ai_router_log
|
|
""").fetchone()
|
|
return jsonify({
|
|
"by_provider": [dict(r) for r in agg],
|
|
"totals": dict(totals) if totals else {},
|
|
})
|
|
finally:
|
|
conn.close()
|