Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b766cb908 | ||
|
|
837cddb406 |
4
ai_router/__init__.py
Normal file
4
ai_router/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .router import AIRouter
|
||||||
|
from .api import ai_router_bp, register_ai_router
|
||||||
|
|
||||||
|
__all__ = ["AIRouter", "ai_router_bp", "register_ai_router"]
|
||||||
172
ai_router/api.py
Normal file
172
ai_router/api.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Flask Blueprint for AI Router — chat, health, models, admin."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
|
from .router import AIRouter
|
||||||
|
from .models import init_db, upsert_provider, upsert_model_mapping
|
||||||
|
from .utils import require_auth, admin_required
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.api")
|
||||||
|
|
||||||
|
ai_router_bp = Blueprint("ai_router", __name__, url_prefix="/api/v1/ai")
|
||||||
|
|
||||||
|
|
||||||
|
def register_ai_router(app):
|
||||||
|
app.register_blueprint(ai_router_bp)
|
||||||
|
|
||||||
|
|
||||||
|
_router = AIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/health", methods=["GET"])
|
||||||
|
def health():
|
||||||
|
health_data = _router.check_all_providers_health()
|
||||||
|
all_ok = all(v["status"] == "ok" for v in health_data.values())
|
||||||
|
return jsonify({
|
||||||
|
"status": "ok" if all_ok else "degraded",
|
||||||
|
"service": "ai-router",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"providers": health_data,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}), 200 if all_ok else 503
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/models", methods=["GET"])
|
||||||
|
def list_models():
|
||||||
|
models = _router.list_available_models()
|
||||||
|
return jsonify({"models": models})
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/chat", methods=["POST"])
|
||||||
|
@require_auth
|
||||||
|
def chat():
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
messages = data.get("messages", [])
|
||||||
|
model = data.get("model", "gpt-4o-mini")
|
||||||
|
user_id = (request.current_user or {}).get("user_id")
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return jsonify({"error": "messages field is required"}), 400
|
||||||
|
|
||||||
|
kwargs = {k: data[k] for k in ("temperature", "max_tokens", "top_p", "stream") if k in data}
|
||||||
|
|
||||||
|
result = _router.chat(messages=messages, model_alias=model, user_id=user_id, **kwargs)
|
||||||
|
|
||||||
|
if result.get("status") == "error":
|
||||||
|
code = 503 if "All providers failed" in result.get("error", "") else 400
|
||||||
|
return jsonify(result), code
|
||||||
|
|
||||||
|
return jsonify(result), 200
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/admin/providers", methods=["GET"])
|
||||||
|
@admin_required
|
||||||
|
def list_providers():
|
||||||
|
from .models import get_db
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, name, provider_type, base_url, priority, is_active, created_at, updated_at "
|
||||||
|
"FROM ai_providers ORDER BY priority"
|
||||||
|
).fetchall()
|
||||||
|
return jsonify({"providers": [dict(r) for r in rows]})
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/admin/providers", methods=["POST"])
|
||||||
|
@admin_required
|
||||||
|
def upsert_provider_endpoint():
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
name = data.get("name", "")
|
||||||
|
provider_type = data.get("provider_type", "")
|
||||||
|
api_key = data.get("api_key", "")
|
||||||
|
base_url = data.get("base_url", "")
|
||||||
|
priority = data.get("priority", 99)
|
||||||
|
|
||||||
|
if not name or provider_type not in ("openai", "anthropic", "google", "mistral"):
|
||||||
|
return jsonify({"error": "Valid name and provider_type required"}), 400
|
||||||
|
|
||||||
|
ok = upsert_provider(name, provider_type, api_key, base_url, priority=priority)
|
||||||
|
if ok:
|
||||||
|
return jsonify({"status": "ok", "message": f"Provider {name} saved"}), 200
|
||||||
|
return jsonify({"error": "Failed to save provider"}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/admin/model-mappings", methods=["POST"])
|
||||||
|
@admin_required
|
||||||
|
def upsert_model_mapping_endpoint():
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
model_alias = data.get("model_alias", "")
|
||||||
|
provider_id = data.get("provider_id")
|
||||||
|
real_model_id = data.get("real_model_id", "")
|
||||||
|
cost = data.get("cost_per_1k_tokens", 0)
|
||||||
|
|
||||||
|
if not model_alias or not provider_id or not real_model_id:
|
||||||
|
return jsonify({"error": "model_alias, provider_id, real_model_id required"}), 400
|
||||||
|
|
||||||
|
ok = upsert_model_mapping(model_alias, provider_id, real_model_id, cost)
|
||||||
|
if ok:
|
||||||
|
return jsonify({"status": "ok", "message": f"Mapping for {model_alias} saved"}), 200
|
||||||
|
return jsonify({"error": "Failed to save model mapping"}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/admin/providers/<int:provider_id>", methods=["DELETE"])
|
||||||
|
@admin_required
|
||||||
|
def delete_provider(provider_id):
|
||||||
|
from .models import get_db
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
conn.execute("DELETE FROM ai_model_mapping WHERE provider_id = ?", (provider_id,))
|
||||||
|
conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,))
|
||||||
|
conn.commit()
|
||||||
|
return jsonify({"status": "deleted", "provider_id": provider_id})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/usage", methods=["GET"])
|
||||||
|
@admin_required
|
||||||
|
def usage_stats():
|
||||||
|
from .models import get_db
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
limit = request.args.get("limit", 50, type=int)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT * FROM ai_router_log ORDER BY created_at DESC LIMIT ?", (limit,)
|
||||||
|
).fetchall()
|
||||||
|
return jsonify({"usage": [dict(r) for r in rows]})
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ai_router_bp.route("/usage/summary", methods=["GET"])
|
||||||
|
@admin_required
|
||||||
|
def usage_summary():
|
||||||
|
from .models import get_db
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
agg = conn.execute("""
|
||||||
|
SELECT provider_used, status, COUNT(*) as count,
|
||||||
|
SUM(duration_ms) as total_ms, SUM(tokens_in + tokens_out) as total_tokens
|
||||||
|
FROM ai_router_log
|
||||||
|
GROUP BY provider_used, status
|
||||||
|
ORDER BY provider_used
|
||||||
|
""").fetchall()
|
||||||
|
totals = conn.execute("""
|
||||||
|
SELECT COUNT(*) as total_requests,
|
||||||
|
SUM(CASE WHEN status='success' THEN 1 ELSE 0 END) as success_count,
|
||||||
|
SUM(tokens_in + tokens_out) as total_tokens
|
||||||
|
FROM ai_router_log
|
||||||
|
""").fetchone()
|
||||||
|
return jsonify({
|
||||||
|
"by_provider": [dict(r) for r in agg],
|
||||||
|
"totals": dict(totals) if totals else {},
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
167
ai_router/models.py
Normal file
167
ai_router/models.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.models")
|
||||||
|
|
||||||
|
DB_PATH = os.environ.get("AI_ROUTER_DB", "/home/h3r7/turf_saas/ai_router.db")
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
conn.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS ai_providers (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
provider_type TEXT NOT NULL CHECK(provider_type IN ('openai','anthropic','google','mistral')),
|
||||||
|
api_key TEXT NOT NULL DEFAULT '',
|
||||||
|
base_url TEXT DEFAULT '',
|
||||||
|
config TEXT DEFAULT '{}',
|
||||||
|
priority INTEGER NOT NULL DEFAULT 99,
|
||||||
|
is_active INTEGER NOT NULL DEFAULT 1,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ai_model_mapping (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
model_alias TEXT NOT NULL UNIQUE,
|
||||||
|
provider_id INTEGER NOT NULL REFERENCES ai_providers(id),
|
||||||
|
real_model_id TEXT NOT NULL,
|
||||||
|
cost_per_1k_tokens REAL NOT NULL DEFAULT 0,
|
||||||
|
is_active INTEGER NOT NULL DEFAULT 1,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ai_router_log (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
request_id TEXT NOT NULL,
|
||||||
|
user_id INTEGER,
|
||||||
|
model_alias TEXT NOT NULL,
|
||||||
|
provider_used TEXT NOT NULL,
|
||||||
|
tokens_in INTEGER NOT NULL DEFAULT 0,
|
||||||
|
tokens_out INTEGER NOT NULL DEFAULT 0,
|
||||||
|
duration_ms INTEGER NOT NULL DEFAULT 0,
|
||||||
|
status TEXT NOT NULL CHECK(status IN ('success','error')),
|
||||||
|
error_message TEXT DEFAULT '',
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_ai_router_log_request_id ON ai_router_log(request_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_ai_router_log_created_at ON ai_router_log(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_ai_model_mapping_alias ON ai_model_mapping(model_alias);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("AI Router database tables initialized")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize AI Router DB: {e}")
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_providers_from_db():
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
rows = conn.execute("""
|
||||||
|
SELECT p.id, p.name, p.provider_type, p.api_key, p.base_url, p.config,
|
||||||
|
p.priority, p.is_active, m.model_alias, m.real_model_id, m.cost_per_1k_tokens
|
||||||
|
FROM ai_providers p
|
||||||
|
LEFT JOIN ai_model_mapping m ON m.provider_id = p.id
|
||||||
|
WHERE p.is_active = 1
|
||||||
|
""").fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not query providers: {e}")
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def log_router_attempt(request_id, user_id, model_alias, provider_used,
|
||||||
|
tokens_in, tokens_out, duration_ms, status,
|
||||||
|
error_message=""):
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO ai_router_log
|
||||||
|
(request_id, user_id, model_alias, provider_used,
|
||||||
|
tokens_in, tokens_out, duration_ms, status, error_message)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
|
(request_id, user_id, model_alias, provider_used,
|
||||||
|
tokens_in, tokens_out, duration_ms, status, error_message),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to log router attempt: {e}")
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_provider(name, provider_type, api_key="", base_url="",
|
||||||
|
config=None, priority=99, is_active=1):
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
existing = conn.execute(
|
||||||
|
"SELECT id FROM ai_providers WHERE name = ?", (name,)
|
||||||
|
).fetchone()
|
||||||
|
if existing:
|
||||||
|
conn.execute(
|
||||||
|
"""UPDATE ai_providers SET provider_type=?, api_key=?, base_url=?,
|
||||||
|
config=?, priority=?, is_active=?, updated_at=datetime('now')
|
||||||
|
WHERE name=?""",
|
||||||
|
(provider_type, api_key, base_url,
|
||||||
|
config or "{}", priority, is_active, name),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO ai_providers
|
||||||
|
(name, provider_type, api_key, base_url, config, priority, is_active)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||||
|
(name, provider_type, api_key, base_url,
|
||||||
|
config or "{}", priority, is_active),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upsert provider: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_model_mapping(model_alias, provider_id, real_model_id, cost_per_1k=0):
|
||||||
|
conn = get_db()
|
||||||
|
try:
|
||||||
|
existing = conn.execute(
|
||||||
|
"SELECT id FROM ai_model_mapping WHERE model_alias = ?", (model_alias,)
|
||||||
|
).fetchone()
|
||||||
|
if existing:
|
||||||
|
conn.execute(
|
||||||
|
"""UPDATE ai_model_mapping SET provider_id=?, real_model_id=?,
|
||||||
|
cost_per_1k_tokens=? WHERE model_alias=?""",
|
||||||
|
(provider_id, real_model_id, cost_per_1k, model_alias),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO ai_model_mapping
|
||||||
|
(model_alias, provider_id, real_model_id, cost_per_1k_tokens)
|
||||||
|
VALUES (?, ?, ?, ?)""",
|
||||||
|
(model_alias, provider_id, real_model_id, cost_per_1k),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upsert model mapping: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
14
ai_router/providers/__init__.py
Normal file
14
ai_router/providers/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from .base import AIProvider
|
||||||
|
from .openai_adapter import OpenAIAdapter
|
||||||
|
from .anthropic_adapter import AnthropicAdapter
|
||||||
|
from .google_adapter import GoogleAdapter
|
||||||
|
from .mistral_adapter import MistralAdapter
|
||||||
|
|
||||||
|
PROVIDER_MAP = {
|
||||||
|
"openai": OpenAIAdapter,
|
||||||
|
"anthropic": AnthropicAdapter,
|
||||||
|
"google": GoogleAdapter,
|
||||||
|
"mistral": MistralAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
__all__ = ["AIProvider", "PROVIDER_MAP", "OpenAIAdapter", "AnthropicAdapter", "GoogleAdapter", "MistralAdapter"]
|
||||||
57
ai_router/providers/anthropic_adapter.py
Normal file
57
ai_router/providers/anthropic_adapter.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import AIProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.anthropic")
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicAdapter(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "anthropic"
|
||||||
|
|
||||||
|
def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict:
|
||||||
|
from anthropic import Anthropic
|
||||||
|
|
||||||
|
key = api_key or self.get_api_key()
|
||||||
|
client = Anthropic(api_key=key)
|
||||||
|
|
||||||
|
system_msg = None
|
||||||
|
chat_messages = messages
|
||||||
|
if messages and messages[0].get("role") == "system":
|
||||||
|
system_msg = messages[0]["content"]
|
||||||
|
chat_messages = messages[1:]
|
||||||
|
|
||||||
|
resp = client.messages.create(
|
||||||
|
model=model,
|
||||||
|
system=system_msg,
|
||||||
|
messages=[{"role": m["role"], "content": m["content"]} for m in chat_messages],
|
||||||
|
**{k: v for k, v in kwargs.items() if k in ("temperature", "max_tokens", "top_p")},
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"content": resp.content[0].text if resp.content else "",
|
||||||
|
"model": resp.model,
|
||||||
|
"provider": self.name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": resp.usage.input_tokens if resp.usage else 0,
|
||||||
|
"completion_tokens": resp.usage.output_tokens if resp.usage else 0,
|
||||||
|
"total_tokens": (resp.usage.input_tokens + resp.usage.output_tokens) if resp.usage else 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def models(self) -> list:
|
||||||
|
from anthropic import Anthropic
|
||||||
|
|
||||||
|
client = Anthropic(api_key=self.get_api_key())
|
||||||
|
return [m.id for m in client.models.list()]
|
||||||
|
|
||||||
|
def check_health(self) -> dict:
|
||||||
|
try:
|
||||||
|
from anthropic import Anthropic
|
||||||
|
client = Anthropic(api_key=self.get_api_key())
|
||||||
|
client.models.list()
|
||||||
|
return {"status": "ok", "details": "API reachable"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Anthropic health check failed: {e}")
|
||||||
|
return {"status": "error", "details": str(e)}
|
||||||
44
ai_router/providers/base.py
Normal file
44
ai_router/providers/base.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Provider identifier (openai, anthropic, google, mistral)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat(self, messages: list, model: str, **kwargs) -> dict:
|
||||||
|
"""Send a chat completion request. Returns dict with at least:
|
||||||
|
{
|
||||||
|
"content": str,
|
||||||
|
"model": str,
|
||||||
|
"provider": self.name,
|
||||||
|
"usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def models(self) -> list:
|
||||||
|
"""Return list of available models from this provider."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check_health(self) -> dict:
|
||||||
|
"""Check provider connectivity. Returns {"status": "ok"|"error", "details": str}"""
|
||||||
|
|
||||||
|
def get_api_key(self, db_config: Optional[dict] = None) -> Optional[str]:
|
||||||
|
"""Resolve API key: DB override > env var."""
|
||||||
|
provider_env_map = {
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
|
"google": "GOOGLE_API_KEY",
|
||||||
|
"mistral": "MISTRAL_API_KEY",
|
||||||
|
}
|
||||||
|
if db_config and db_config.get("api_key"):
|
||||||
|
return db_config["api_key"]
|
||||||
|
import os
|
||||||
|
env_var = provider_env_map.get(self.name)
|
||||||
|
if env_var:
|
||||||
|
return os.environ.get(env_var)
|
||||||
|
return None
|
||||||
57
ai_router/providers/google_adapter.py
Normal file
57
ai_router/providers/google_adapter.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import AIProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.google")
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleAdapter(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "google"
|
||||||
|
|
||||||
|
def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict:
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
key = api_key or self.get_api_key()
|
||||||
|
client = genai.Client(api_key=key)
|
||||||
|
|
||||||
|
system_instruction = None
|
||||||
|
chat_messages = messages
|
||||||
|
if messages and messages[0].get("role") == "system":
|
||||||
|
system_instruction = messages[0]["content"]
|
||||||
|
chat_messages = messages[1:]
|
||||||
|
|
||||||
|
contents = []
|
||||||
|
for m in chat_messages:
|
||||||
|
role = "user" if m["role"] in ("user", "system") else "model"
|
||||||
|
contents.append({"role": role, "parts": [{"text": m["content"]}]})
|
||||||
|
|
||||||
|
resp = client.models.generate_content(
|
||||||
|
model=model,
|
||||||
|
contents=contents,
|
||||||
|
config={"system_instruction": system_instruction} if system_instruction else None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"content": resp.text or "",
|
||||||
|
"model": model,
|
||||||
|
"provider": self.name,
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
def models(self) -> list:
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
client = genai.Client(api_key=self.get_api_key())
|
||||||
|
return [m.name for m in client.models.list()]
|
||||||
|
|
||||||
|
def check_health(self) -> dict:
|
||||||
|
try:
|
||||||
|
from google import genai
|
||||||
|
client = genai.Client(api_key=self.get_api_key())
|
||||||
|
client.models.list()
|
||||||
|
return {"status": "ok", "details": "API reachable"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Google health check failed: {e}")
|
||||||
|
return {"status": "error", "details": str(e)}
|
||||||
70
ai_router/providers/mistral_adapter.py
Normal file
70
ai_router/providers/mistral_adapter.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .base import AIProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.mistral")
|
||||||
|
|
||||||
|
MISTRAL_API_BASE = "https://api.mistral.ai/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAdapter(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "mistral"
|
||||||
|
|
||||||
|
def _headers(self, api_key: Optional[str] = None) -> dict:
|
||||||
|
key = api_key or self.get_api_key()
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict:
|
||||||
|
key = api_key or self.get_api_key()
|
||||||
|
headers = self._headers(key)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
for k in ("temperature", "max_tokens", "top_p", "stream"):
|
||||||
|
if k in kwargs:
|
||||||
|
payload[k] = kwargs[k]
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{MISTRAL_API_BASE}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
choice = data["choices"][0]
|
||||||
|
return {
|
||||||
|
"content": choice["message"]["content"] or "",
|
||||||
|
"model": data.get("model", model),
|
||||||
|
"provider": self.name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": data.get("usage", {}).get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": data.get("usage", {}).get("completion_tokens", 0),
|
||||||
|
"total_tokens": data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def models(self) -> list:
|
||||||
|
resp = requests.get(f"{MISTRAL_API_BASE}/models", headers=self._headers(), timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["id"] for m in resp.json().get("data", [])]
|
||||||
|
|
||||||
|
def check_health(self) -> dict:
|
||||||
|
try:
|
||||||
|
resp = requests.get(f"{MISTRAL_API_BASE}/models", headers=self._headers(), timeout=10)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return {"status": "ok", "details": "API reachable"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mistral health check failed: {e}")
|
||||||
|
return {"status": "error", "details": str(e)}
|
||||||
50
ai_router/providers/openai_adapter.py
Normal file
50
ai_router/providers/openai_adapter.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import AIProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.openai")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAdapter(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
def chat(self, messages: list, model: str, api_key: Optional[str] = None, **kwargs) -> dict:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
key = api_key or self.get_api_key()
|
||||||
|
client = OpenAI(api_key=key)
|
||||||
|
resp = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
**{k: v for k, v in kwargs.items() if k in ("temperature", "max_tokens", "top_p", "stream")},
|
||||||
|
)
|
||||||
|
choice = resp.choices[0]
|
||||||
|
return {
|
||||||
|
"content": choice.message.content or "",
|
||||||
|
"model": resp.model,
|
||||||
|
"provider": self.name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": resp.usage.prompt_tokens if resp.usage else 0,
|
||||||
|
"completion_tokens": resp.usage.completion_tokens if resp.usage else 0,
|
||||||
|
"total_tokens": resp.usage.total_tokens if resp.usage else 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def models(self) -> list:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(api_key=self.get_api_key())
|
||||||
|
return [m.id for m in client.models.list()]
|
||||||
|
|
||||||
|
def check_health(self) -> dict:
|
||||||
|
try:
|
||||||
|
from openai import OpenAI
|
||||||
|
client = OpenAI(api_key=self.get_api_key())
|
||||||
|
client.models.list()
|
||||||
|
return {"status": "ok", "details": "API reachable"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"OpenAI health check failed: {e}")
|
||||||
|
return {"status": "error", "details": str(e)}
|
||||||
174
ai_router/router.py
Normal file
174
ai_router/router.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .providers import PROVIDER_MAP, AIProvider
|
||||||
|
from .models import get_providers_from_db, log_router_attempt
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router.router")
|
||||||
|
|
||||||
|
DEFAULT_MODEL_MAP = {
|
||||||
|
"gpt-4o": {"provider": "openai", "real_model": "gpt-4o"},
|
||||||
|
"gpt-4o-mini": {"provider": "openai", "real_model": "gpt-4o-mini"},
|
||||||
|
"claude-3-opus": {"provider": "anthropic", "real_model": "claude-3-opus-20240229"},
|
||||||
|
"claude-3-sonnet": {"provider": "anthropic", "real_model": "claude-3-sonnet-20240229"},
|
||||||
|
"claude-3-haiku": {"provider": "anthropic", "real_model": "claude-3-haiku-20240307"},
|
||||||
|
"gemini-pro": {"provider": "google", "real_model": "gemini-1.5-pro"},
|
||||||
|
"gemini-flash": {"provider": "google", "real_model": "gemini-1.5-flash"},
|
||||||
|
"mistral-large": {"provider": "mistral", "real_model": "mistral-large-latest"},
|
||||||
|
"mistral-small": {"provider": "mistral", "real_model": "mistral-small-latest"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AIRouter:
|
||||||
|
def __init__(self):
|
||||||
|
self._provider_instances = {}
|
||||||
|
|
||||||
|
def get_provider(self, name: str) -> Optional[AIProvider]:
|
||||||
|
if name not in self._provider_instances:
|
||||||
|
cls = PROVIDER_MAP.get(name)
|
||||||
|
if not cls:
|
||||||
|
return None
|
||||||
|
self._provider_instances[name] = cls()
|
||||||
|
return self._provider_instances[name]
|
||||||
|
|
||||||
|
def _resolve_model(self, model_alias: str) -> Optional[dict]:
|
||||||
|
mapping = self._load_model_mappings()
|
||||||
|
return mapping.get(model_alias)
|
||||||
|
|
||||||
|
def _load_model_mappings(self) -> dict:
|
||||||
|
db_mappings = []
|
||||||
|
try:
|
||||||
|
db_mappings = get_providers_from_db()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load model mappings from DB: {e}")
|
||||||
|
|
||||||
|
merged = dict(DEFAULT_MODEL_MAP)
|
||||||
|
for entry in db_mappings:
|
||||||
|
alias = entry.get("model_alias")
|
||||||
|
if alias:
|
||||||
|
merged[alias] = {
|
||||||
|
"provider": entry["provider_type"],
|
||||||
|
"real_model": entry.get("real_model_id", alias),
|
||||||
|
"cost_per_1k": entry.get("cost_per_1k_tokens", 0),
|
||||||
|
"db_config": entry,
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _get_prioritized_providers(self):
|
||||||
|
providers = []
|
||||||
|
try:
|
||||||
|
db_providers = get_providers_from_db()
|
||||||
|
seen_names = set()
|
||||||
|
for p in sorted(db_providers, key=lambda x: x.get("priority", 99)):
|
||||||
|
name = p["provider_type"]
|
||||||
|
if name not in seen_names:
|
||||||
|
seen_names.add(name)
|
||||||
|
providers.append((name, p))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load provider priority from DB: {e}")
|
||||||
|
|
||||||
|
if not providers:
|
||||||
|
default_order = ["openai", "anthropic", "google", "mistral"]
|
||||||
|
providers = [(n, None) for n in default_order]
|
||||||
|
return providers
|
||||||
|
|
||||||
|
def chat(self, messages: list, model_alias: str, user_id: Optional[int] = None, **kwargs) -> dict:
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
model_info = self._resolve_model(model_alias)
|
||||||
|
if not model_info:
|
||||||
|
return {"error": f"Unknown model: {model_alias}", "status": "error"}
|
||||||
|
|
||||||
|
provider_order = self._get_prioritized_providers()
|
||||||
|
preferred_provider = model_info["provider"]
|
||||||
|
real_model = model_info["real_model"]
|
||||||
|
|
||||||
|
ordered = []
|
||||||
|
for name, db_config in provider_order:
|
||||||
|
if name == preferred_provider:
|
||||||
|
ordered.insert(0, (name, db_config))
|
||||||
|
else:
|
||||||
|
ordered.append((name, db_config))
|
||||||
|
|
||||||
|
if preferred_provider not in [p[0] for p in ordered]:
|
||||||
|
ordered.insert(0, (preferred_provider, None))
|
||||||
|
|
||||||
|
last_error = None
|
||||||
|
for attempt, (provider_name, db_config) in enumerate(ordered):
|
||||||
|
provider = self.get_provider(provider_name)
|
||||||
|
if not provider:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if attempt > 0:
|
||||||
|
backoff = min(2 ** (attempt - 1), 30)
|
||||||
|
logger.info(f"Failover to {provider_name} after {backoff}s backoff (attempt {attempt})")
|
||||||
|
time.sleep(backoff)
|
||||||
|
|
||||||
|
api_key = provider.get_api_key(db_config)
|
||||||
|
if not api_key:
|
||||||
|
logger.warning(f"No API key configured for {provider_name}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = provider.chat(messages, model=real_model, api_key=api_key, **kwargs)
|
||||||
|
elapsed = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
log_router_attempt(
|
||||||
|
request_id=request_id,
|
||||||
|
user_id=user_id,
|
||||||
|
model_alias=model_alias,
|
||||||
|
provider_used=provider_name,
|
||||||
|
tokens_in=result.get("usage", {}).get("prompt_tokens", 0),
|
||||||
|
tokens_out=result.get("usage", {}).get("completion_tokens", 0),
|
||||||
|
duration_ms=elapsed,
|
||||||
|
status="success",
|
||||||
|
)
|
||||||
|
result["request_id"] = request_id
|
||||||
|
result["status"] = "success"
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
elapsed = int((time.time() - start_time) * 1000)
|
||||||
|
logger.warning(f"Provider {provider_name} failed: {e}")
|
||||||
|
log_router_attempt(
|
||||||
|
request_id=request_id,
|
||||||
|
user_id=user_id,
|
||||||
|
model_alias=model_alias,
|
||||||
|
provider_used=provider_name,
|
||||||
|
tokens_in=0,
|
||||||
|
tokens_out=0,
|
||||||
|
duration_ms=elapsed,
|
||||||
|
status="error",
|
||||||
|
error_message=last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = int((time.time() - start_time) * 1000)
|
||||||
|
return {
|
||||||
|
"error": f"All providers failed. Last error: {last_error}",
|
||||||
|
"status": "error",
|
||||||
|
"request_id": request_id,
|
||||||
|
"duration_ms": elapsed,
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_all_providers_health(self) -> dict:
|
||||||
|
results = {}
|
||||||
|
for name in PROVIDER_MAP:
|
||||||
|
provider = self.get_provider(name)
|
||||||
|
results[name] = provider.check_health()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def list_available_models(self) -> list:
|
||||||
|
model_map = self._load_model_mappings()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"alias": alias,
|
||||||
|
"provider": info["provider"],
|
||||||
|
"real_model": info["real_model"],
|
||||||
|
"cost_per_1k_tokens": info.get("cost_per_1k", 0),
|
||||||
|
}
|
||||||
|
for alias, info in model_map.items()
|
||||||
|
]
|
||||||
93
ai_router/utils.py
Normal file
93
ai_router/utils.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask import request, jsonify
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai_router")
|
||||||
|
|
||||||
|
TOKEN_BROKER_URL = os.environ.get(
|
||||||
|
"TOKEN_BROKER_URL", "http://localhost:8783"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token_via_broker(token: str) -> dict:
|
||||||
|
"""Verify an API token via the token-broker /verify endpoint."""
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
resp = requests.post(
|
||||||
|
f"{TOKEN_BROKER_URL}/api/v1/tokens/verify",
|
||||||
|
json={"token": token},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if data.get("valid"):
|
||||||
|
return data
|
||||||
|
return {}
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.warning(f"Token broker unreachable: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(f):
|
||||||
|
"""Decorator: validate Bearer or X-API-Key via token-broker."""
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
api_key = request.headers.get("X-API-Key", "")
|
||||||
|
|
||||||
|
raw_token = ""
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
raw_token = auth_header.split(" ", 1)[1]
|
||||||
|
elif api_key:
|
||||||
|
raw_token = api_key
|
||||||
|
|
||||||
|
if not raw_token:
|
||||||
|
return jsonify({"error": "Authentication required"}), 401
|
||||||
|
|
||||||
|
payload = verify_token_via_broker(raw_token)
|
||||||
|
if not payload or not payload.get("valid"):
|
||||||
|
return jsonify({"error": "Invalid or expired token"}), 401
|
||||||
|
|
||||||
|
request.current_user = {
|
||||||
|
"user_id": payload.get("user_id"),
|
||||||
|
"token_id": payload.get("token_id"),
|
||||||
|
"scopes": payload.get("scopes", []),
|
||||||
|
}
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def admin_required(f):
|
||||||
|
"""Decorator: require admin scope on the authenticated token."""
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
api_key = request.headers.get("X-API-Key", "")
|
||||||
|
|
||||||
|
raw_token = ""
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
raw_token = auth_header.split(" ", 1)[1]
|
||||||
|
elif api_key:
|
||||||
|
raw_token = api_key
|
||||||
|
|
||||||
|
if not raw_token:
|
||||||
|
return jsonify({"error": "Authentication required"}), 401
|
||||||
|
|
||||||
|
payload = verify_token_via_broker(raw_token)
|
||||||
|
if not payload or not payload.get("valid"):
|
||||||
|
return jsonify({"error": "Invalid or expired token"}), 401
|
||||||
|
|
||||||
|
scopes = payload.get("scopes", [])
|
||||||
|
if "admin" not in scopes and "ai_router_admin" not in scopes:
|
||||||
|
return jsonify({"error": "Admin access required"}), 403
|
||||||
|
|
||||||
|
request.current_user = {
|
||||||
|
"user_id": payload.get("user_id"),
|
||||||
|
"token_id": payload.get("token_id"),
|
||||||
|
"scopes": scopes,
|
||||||
|
}
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
77
ai_router_api.py
Normal file
77
ai_router_api.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
AI Router API — Multi-provider LLM routing with failover
|
||||||
|
Port: 8783 | DB: SQLite ai_router.db
|
||||||
|
HRT-200 — AI Router (Multi-provider + failover)
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET /api/v1/ai/health — Check all providers health
|
||||||
|
GET /api/v1/ai/models — List available models
|
||||||
|
POST /api/v1/ai/chat — Chat completion with auto-failover
|
||||||
|
GET /api/v1/ai/admin/providers — List configured providers
|
||||||
|
POST /api/v1/ai/admin/providers — Upsert a provider
|
||||||
|
POST /api/v1/ai/admin/model-mappings— Upsert a model mapping
|
||||||
|
DELETE /api/v1/ai/admin/providers/:id — Remove a provider
|
||||||
|
GET /api/v1/ai/usage — Usage logs
|
||||||
|
GET /api/v1/ai/usage/summary — Aggregated usage stats
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from flask import Flask, jsonify
|
||||||
|
from flask_cors import CORS
|
||||||
|
|
||||||
|
LOG_DIR = os.path.join(os.path.dirname(__file__), "ai_router", "logs")
|
||||||
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] ai-router: %(name)s: %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
logging.handlers.RotatingFileHandler(
|
||||||
|
os.path.join(LOG_DIR, "ai_router.log"),
|
||||||
|
maxBytes=5 * 1024 * 1024,
|
||||||
|
backupCount=3,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("ai_router")
|
||||||
|
|
||||||
|
PORT = int(os.environ.get("AI_ROUTER_PORT", "8783"))
|
||||||
|
|
||||||
|
|
||||||
|
def create_app():
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app)
|
||||||
|
|
||||||
|
from ai_router.api import register_ai_router
|
||||||
|
from ai_router.models import init_db
|
||||||
|
|
||||||
|
init_db()
|
||||||
|
register_ai_router(app)
|
||||||
|
|
||||||
|
@app.errorhandler(404)
|
||||||
|
def not_found(e):
|
||||||
|
return jsonify({"error": "not_found", "message": "Route not found"}), 404
|
||||||
|
|
||||||
|
@app.errorhandler(500)
|
||||||
|
def internal_error(e):
|
||||||
|
logger.error(f"Internal error: {e}")
|
||||||
|
return jsonify({"error": "internal_error", "message": "Internal server error"}), 500
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("AI Router API starting...")
|
||||||
|
logger.info(f"Port: {PORT}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
debug = os.environ.get("FLASK_ENV", "production") == "development"
|
||||||
|
app.run(host="0.0.0.0", port=PORT, debug=debug)
|
||||||
@@ -41,6 +41,7 @@ from .routes.user_tokens import user_tokens_bp
|
|||||||
from .routes.history import history_bp
|
from .routes.history import history_bp
|
||||||
from .routes.org import org_bp
|
from .routes.org import org_bp
|
||||||
from .routes.ml_feedback import ml_feedback_bp
|
from .routes.ml_feedback import ml_feedback_bp
|
||||||
|
from .routes.admin import admin_bp
|
||||||
|
|
||||||
# Master blueprint that aggregates all sub-routes under /api/v1
|
# Master blueprint that aggregates all sub-routes under /api/v1
|
||||||
api_v1_bp = Blueprint("api_v1", __name__, url_prefix="/api/v1")
|
api_v1_bp = Blueprint("api_v1", __name__, url_prefix="/api/v1")
|
||||||
@@ -61,3 +62,4 @@ def register_api_v1(app):
|
|||||||
app.register_blueprint(history_bp)
|
app.register_blueprint(history_bp)
|
||||||
app.register_blueprint(org_bp)
|
app.register_blueprint(org_bp)
|
||||||
app.register_blueprint(ml_feedback_bp)
|
app.register_blueprint(ml_feedback_bp)
|
||||||
|
app.register_blueprint(admin_bp)
|
||||||
|
|||||||
587
api_v1/routes/admin.py
Normal file
587
api_v1/routes/admin.py
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Admin Blueprint — Client CRUD + Subscription management
|
||||||
|
HRT-199 — Foundation (Client CRUD + Auth + Subscription)
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
POST /api/v1/admin/setup — init first admin (no auth, 1 call only)
|
||||||
|
GET /api/v1/admin/clients — list all clients (paginated, filterable)
|
||||||
|
GET /api/v1/admin/clients/<id> — client detail + subscription
|
||||||
|
PUT /api/v1/admin/clients/<id> — update client (plan, name, email)
|
||||||
|
DELETE /api/v1/admin/clients/<id> — delete client + tokens + subscription
|
||||||
|
POST /api/v1/admin/clients/<id>/suspend — suspend client (set plan=suspended)
|
||||||
|
POST /api/v1/admin/clients/<id>/activate — reactivate client (restore plan)
|
||||||
|
GET /api/v1/admin/stats — client stats (total, by plan, new/30d)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
|
from saas_auth import require_auth
|
||||||
|
from api_v1.utils import get_db, paginate_query, get_pagination_params, not_found, bad_request, internal_error
|
||||||
|
|
||||||
|
logger = logging.getLogger("turf_saas.admin")
|
||||||
|
|
||||||
|
admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
|
||||||
|
|
||||||
|
DB_PATH = os.environ.get("TURF_SAAS_DB", "/home/h3r7/turf_saas/turf_saas.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_saas_db():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_admin_tables():
|
||||||
|
"""Idempotent: create admin_users table."""
|
||||||
|
conn = _get_saas_db()
|
||||||
|
conn.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS admin_users (
|
||||||
|
user_id TEXT PRIMARY KEY REFERENCES saas_users(id),
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
created_by TEXT
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
migrate_admin_tables()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("admin DB init warning: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_admin(user_id: str, db=None) -> bool:
|
||||||
|
if not user_id:
|
||||||
|
return False
|
||||||
|
close = False
|
||||||
|
if db is None:
|
||||||
|
db = _get_saas_db()
|
||||||
|
close = True
|
||||||
|
try:
|
||||||
|
row = db.execute(
|
||||||
|
"SELECT 1 FROM admin_users WHERE user_id = ?", (user_id,)
|
||||||
|
).fetchone()
|
||||||
|
return row is not None
|
||||||
|
finally:
|
||||||
|
if close:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
user = getattr(request, "current_user", None)
|
||||||
|
if not user:
|
||||||
|
return jsonify({"error": "Non authentifié"}), 401
|
||||||
|
if not _is_admin(user["id"]):
|
||||||
|
return jsonify({"error": "Accès administrateur requis"}), 403
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def _user_to_client(row) -> dict:
|
||||||
|
return {
|
||||||
|
"id": row["id"],
|
||||||
|
"email": row["email"],
|
||||||
|
"firstname": row.get("firstname", ""),
|
||||||
|
"lastname": row.get("lastname", ""),
|
||||||
|
"plan": row.get("plan", "free"),
|
||||||
|
"telegram_chat_id": row.get("telegram_chat_id"),
|
||||||
|
"alert_value_bets": bool(row.get("alert_value_bets", 1)),
|
||||||
|
"alert_top1": bool(row.get("alert_top1", 1)),
|
||||||
|
"alert_quinte_only": bool(row.get("alert_quinte_only", 0)),
|
||||||
|
"created_at": row.get("created_at"),
|
||||||
|
"updated_at": row.get("updated_at"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_subscription(db, user_id: str):
|
||||||
|
return db.execute(
|
||||||
|
"""SELECT * FROM saas_subscriptions
|
||||||
|
WHERE user_id = ? ORDER BY start_date DESC LIMIT 1""",
|
||||||
|
(user_id,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── POST /api/v1/admin/setup ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/setup", methods=["POST"])
|
||||||
|
def admin_setup():
|
||||||
|
"""Init first admin (no auth). Only works once — when admin_users is empty."""
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
email = (data.get("email") or "").strip().lower()
|
||||||
|
if not email or "@" not in email:
|
||||||
|
return jsonify({"error": "Email valide requis"}), 400
|
||||||
|
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
existing = db.execute("SELECT 1 FROM admin_users LIMIT 1").fetchone()
|
||||||
|
if existing:
|
||||||
|
return jsonify({"error": "Admin déjà configuré"}), 409
|
||||||
|
|
||||||
|
user = db.execute(
|
||||||
|
"SELECT id, email FROM saas_users WHERE email = ?", (email,)
|
||||||
|
).fetchone()
|
||||||
|
if not user:
|
||||||
|
return jsonify({"error": "Utilisateur introuvable avec cet email"}), 404
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO admin_users (user_id, created_by) VALUES (?, 'setup')",
|
||||||
|
(user["id"],),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
logger.info("Admin setup: user %s (%s) promoted to admin", user["id"], email)
|
||||||
|
return jsonify({"ok": True, "user_id": user["id"], "email": email}), 201
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error("admin_setup error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── GET /api/v1/admin/clients ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients", methods=["GET"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def list_clients():
|
||||||
|
"""List all clients with pagination and filters.
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
parameters:
|
||||||
|
- in: query
|
||||||
|
name: page
|
||||||
|
type: integer
|
||||||
|
- in: query
|
||||||
|
name: per_page
|
||||||
|
type: integer
|
||||||
|
- in: query
|
||||||
|
name: search
|
||||||
|
type: string
|
||||||
|
description: Search by email or name
|
||||||
|
- in: query
|
||||||
|
name: plan
|
||||||
|
type: string
|
||||||
|
description: Filter by plan (free, premium, pro, suspended)
|
||||||
|
- in: query
|
||||||
|
name: sort_by
|
||||||
|
type: string
|
||||||
|
enum: [created_at, email, plan, updated_at]
|
||||||
|
- in: query
|
||||||
|
name: sort_order
|
||||||
|
type: string
|
||||||
|
enum: [asc, desc]
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Paginated client list
|
||||||
|
403:
|
||||||
|
description: Admin access required
|
||||||
|
"""
|
||||||
|
page = request.args.get("page", 1, type=int)
|
||||||
|
per_page = request.args.get("per_page", 20, type=int)
|
||||||
|
search = request.args.get("search", "").strip()
|
||||||
|
plan_filter = request.args.get("plan", "").strip()
|
||||||
|
sort_by = request.args.get("sort_by", "created_at").strip()
|
||||||
|
sort_order = request.args.get("sort_order", "desc").strip()
|
||||||
|
|
||||||
|
if sort_by not in ("created_at", "email", "plan", "updated_at"):
|
||||||
|
sort_by = "created_at"
|
||||||
|
if sort_order not in ("asc", "desc"):
|
||||||
|
sort_order = "desc"
|
||||||
|
if per_page < 1 or per_page > 100:
|
||||||
|
per_page = 20
|
||||||
|
if page < 1:
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
if search:
|
||||||
|
conditions.append("(email LIKE ? OR firstname LIKE ? OR lastname LIKE ?)")
|
||||||
|
p = f"%{search}%"
|
||||||
|
params.extend([p, p, p])
|
||||||
|
if plan_filter:
|
||||||
|
conditions.append("plan = ?")
|
||||||
|
params.append(plan_filter)
|
||||||
|
|
||||||
|
where = (" WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||||
|
|
||||||
|
total = db.execute(
|
||||||
|
f"SELECT COUNT(*) FROM saas_users{where}", params
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
f"SELECT * FROM saas_users{where} ORDER BY {sort_by} {sort_order} LIMIT ? OFFSET ?",
|
||||||
|
params + [per_page, offset],
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for row in rows:
|
||||||
|
client = _user_to_client(row)
|
||||||
|
sub = _fetch_subscription(db, row["id"])
|
||||||
|
if sub:
|
||||||
|
client["subscription"] = {
|
||||||
|
"plan": sub["plan"],
|
||||||
|
"status": sub["status"],
|
||||||
|
"start_date": sub["start_date"],
|
||||||
|
"current_period_end": sub["current_period_end"],
|
||||||
|
"stripe_customer_id": sub["stripe_customer_id"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
client["subscription"] = None
|
||||||
|
result.append(client)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"clients": result,
|
||||||
|
"pagination": {
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"total": total,
|
||||||
|
"total_pages": (total + per_page - 1) // per_page,
|
||||||
|
},
|
||||||
|
}), 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("list_clients error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── GET /api/v1/admin/clients/<id> ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients/<string:client_id>", methods=["GET"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def get_client(client_id: str):
|
||||||
|
"""Get client details with subscription info.
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: id
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Client details
|
||||||
|
404:
|
||||||
|
description: Client not found
|
||||||
|
"""
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
row = db.execute(
|
||||||
|
"SELECT * FROM saas_users WHERE id = ?", (client_id,)
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return jsonify({"error": "Client introuvable"}), 404
|
||||||
|
|
||||||
|
client = _user_to_client(row)
|
||||||
|
sub = _fetch_subscription(db, client_id)
|
||||||
|
if sub:
|
||||||
|
client["subscription"] = {
|
||||||
|
"id": sub["id"],
|
||||||
|
"plan": sub["plan"],
|
||||||
|
"status": sub["status"],
|
||||||
|
"start_date": sub["start_date"],
|
||||||
|
"end_date": sub["end_date"],
|
||||||
|
"current_period_end": sub["current_period_end"],
|
||||||
|
"grace_period_end": sub["grace_period_end"],
|
||||||
|
"stripe_customer_id": sub["stripe_customer_id"],
|
||||||
|
"stripe_subscription_id": sub["stripe_subscription_id"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
client["subscription"] = None
|
||||||
|
|
||||||
|
return jsonify({"client": client}), 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("get_client error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── PUT /api/v1/admin/clients/<id> ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients/<string:client_id>", methods=["PUT"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def update_client(client_id: str):
|
||||||
|
"""Update client fields (plan, firstname, lastname, email).
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: id
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
firstname: { type: string }
|
||||||
|
lastname: { type: string }
|
||||||
|
email: { type: string }
|
||||||
|
plan: { type: string, enum: [free, premium, pro, suspended] }
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Client updated
|
||||||
|
400:
|
||||||
|
description: Invalid parameters
|
||||||
|
404:
|
||||||
|
description: Client not found
|
||||||
|
"""
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
if not data:
|
||||||
|
return jsonify({"error": "Corps JSON requis"}), 400
|
||||||
|
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
existing = db.execute(
|
||||||
|
"SELECT id FROM saas_users WHERE id = ?", (client_id,)
|
||||||
|
).fetchone()
|
||||||
|
if not existing:
|
||||||
|
return jsonify({"error": "Client introuvable"}), 404
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
if "firstname" in data:
|
||||||
|
fields["firstname"] = data["firstname"].strip()
|
||||||
|
if "lastname" in data:
|
||||||
|
fields["lastname"] = data["lastname"].strip()
|
||||||
|
if "email" in data:
|
||||||
|
email = data["email"].strip().lower()
|
||||||
|
if "@" not in email:
|
||||||
|
return jsonify({"error": "Email invalide"}), 400
|
||||||
|
fields["email"] = email
|
||||||
|
if "plan" in data:
|
||||||
|
plan = data["plan"].strip().lower()
|
||||||
|
if plan not in ("free", "premium", "pro", "suspended"):
|
||||||
|
return jsonify({"error": "Plan invalide. free|premium|pro|suspended"}), 400
|
||||||
|
fields["plan"] = plan
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return jsonify({"ok": True}), 200
|
||||||
|
|
||||||
|
set_clause = ", ".join(f"{k}=?" for k in fields)
|
||||||
|
values = list(fields.values()) + [datetime.now(timezone.utc).isoformat(), client_id]
|
||||||
|
db.execute(
|
||||||
|
f"UPDATE saas_users SET {set_clause}, updated_at=? WHERE id=?", values
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info("Admin %s updated client %s: %s",
|
||||||
|
request.current_user["id"], client_id, fields)
|
||||||
|
return jsonify({"ok": True, "updated": list(fields.keys())}), 200
|
||||||
|
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
return jsonify({"error": "Cet email est déjà utilisé"}), 409
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error("update_client error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── DELETE /api/v1/admin/clients/<id> ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients/<string:client_id>", methods=["DELETE"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def delete_client(client_id: str):
|
||||||
|
"""Delete client and all associated data (tokens, subscriptions).
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: id
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Client deleted
|
||||||
|
404:
|
||||||
|
description: Client not found
|
||||||
|
"""
|
||||||
|
admin_id = request.current_user["id"]
|
||||||
|
if client_id == admin_id:
|
||||||
|
return jsonify({"error": "Impossible de supprimer votre propre compte"}), 400
|
||||||
|
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
existing = db.execute(
|
||||||
|
"SELECT id FROM saas_users WHERE id = ?", (client_id,)
|
||||||
|
).fetchone()
|
||||||
|
if not existing:
|
||||||
|
return jsonify({"error": "Client introuvable"}), 404
|
||||||
|
|
||||||
|
db.execute("DELETE FROM saas_tokens WHERE user_id = ?", (client_id,))
|
||||||
|
db.execute("DELETE FROM saas_subscriptions WHERE user_id = ?", (client_id,))
|
||||||
|
db.execute("DELETE FROM admin_users WHERE user_id = ?", (client_id,))
|
||||||
|
db.execute("DELETE FROM saas_users WHERE id = ?", (client_id,))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info("Admin %s deleted client %s", admin_id, client_id)
|
||||||
|
return jsonify({"ok": True, "deleted_id": client_id}), 200
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error("delete_client error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── POST /api/v1/admin/clients/<id>/suspend ───────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients/<string:client_id>/suspend", methods=["POST"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def suspend_client(client_id: str):
|
||||||
|
"""Suspend a client by setting plan to 'suspended'.
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Client suspended
|
||||||
|
404:
|
||||||
|
description: Client not found
|
||||||
|
"""
|
||||||
|
return _set_client_plan(client_id, "suspended")
|
||||||
|
|
||||||
|
|
||||||
|
# ─── POST /api/v1/admin/clients/<id>/activate ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/clients/<string:client_id>/activate", methods=["POST"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def activate_client(client_id: str):
|
||||||
|
"""Reactivate a suspended client to 'free' plan.
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Client activated
|
||||||
|
404:
|
||||||
|
description: Client not found
|
||||||
|
"""
|
||||||
|
return _set_client_plan(client_id, "free")
|
||||||
|
|
||||||
|
|
||||||
|
def _set_client_plan(client_id: str, plan: str):
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
existing = db.execute(
|
||||||
|
"SELECT id, plan FROM saas_users WHERE id = ?", (client_id,)
|
||||||
|
).fetchone()
|
||||||
|
if not existing:
|
||||||
|
return jsonify({"error": "Client introuvable"}), 404
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"UPDATE saas_users SET plan=?, updated_at=? WHERE id=?",
|
||||||
|
(plan, datetime.now(timezone.utc).isoformat(), client_id),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
action = "suspendu" if plan == "suspended" else "réactivé"
|
||||||
|
logger.info("Client %s %s par admin %s", client_id, action,
|
||||||
|
request.current_user["id"])
|
||||||
|
return jsonify({"ok": True, "client_id": client_id, "plan": plan, "action": action}), 200
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error("_set_client_plan error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── GET /api/v1/admin/stats ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/stats", methods=["GET"])
|
||||||
|
@require_auth
|
||||||
|
@require_admin
|
||||||
|
def admin_stats():
|
||||||
|
"""Client stats: totals by plan, new this month/30d.
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- Admin
|
||||||
|
security:
|
||||||
|
- Bearer: []
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Admin stats
|
||||||
|
"""
|
||||||
|
db = _get_saas_db()
|
||||||
|
try:
|
||||||
|
total = db.execute("SELECT COUNT(*) FROM saas_users").fetchone()[0]
|
||||||
|
|
||||||
|
by_plan = {}
|
||||||
|
for row in db.execute(
|
||||||
|
"SELECT plan, COUNT(*) AS cnt FROM saas_users GROUP BY plan"
|
||||||
|
).fetchall():
|
||||||
|
by_plan[row["plan"]] = row["cnt"]
|
||||||
|
|
||||||
|
new_30d = db.execute(
|
||||||
|
"SELECT COUNT(*) FROM saas_users WHERE created_at >= datetime('now', '-30 days')"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
new_7d = db.execute(
|
||||||
|
"SELECT COUNT(*) FROM saas_users WHERE created_at >= datetime('now', '-7 days')"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
active_subs = db.execute(
|
||||||
|
"SELECT COUNT(DISTINCT user_id) FROM saas_subscriptions WHERE status = 'active'"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"total_clients": total,
|
||||||
|
"clients_by_plan": by_plan,
|
||||||
|
"new_last_30d": new_30d,
|
||||||
|
"new_last_7d": new_7d,
|
||||||
|
"active_subscriptions": active_subs,
|
||||||
|
}), 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("admin_stats error: %s", e)
|
||||||
|
return jsonify({"error": "Erreur interne"}), 500
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
@@ -21,10 +21,7 @@ from api_v1.utils import (
|
|||||||
paginate_query,
|
paginate_query,
|
||||||
)
|
)
|
||||||
# Auth: try flask_jwt_extended (app_v1) first, fall back to saas_auth (portal_server)
|
# Auth: try flask_jwt_extended (app_v1) first, fall back to saas_auth (portal_server)
|
||||||
try:
|
from saas_auth import require_auth as jwt_required_middleware
|
||||||
from auth import jwt_required_middleware
|
|
||||||
except ImportError:
|
|
||||||
from saas_auth import require_auth as jwt_required_middleware
|
|
||||||
|
|
||||||
history_bp = Blueprint("v1_history", __name__, url_prefix="/api/v1/history")
|
history_bp = Blueprint("v1_history", __name__, url_prefix="/api/v1/history")
|
||||||
|
|
||||||
|
|||||||
255
tests/test_ai_router.py
Normal file
255
tests/test_ai_router.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""Unit tests for AI Router — router, providers, models, API."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
TEST_DB = os.path.join(tempfile.mkdtemp(), "test_ai_router.db")
|
||||||
|
os.environ["AI_ROUTER_DB"] = TEST_DB
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-test-openai"
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = "sk-test-anthropic"
|
||||||
|
os.environ["GOOGLE_API_KEY"] = "sk-test-google"
|
||||||
|
os.environ["MISTRAL_API_KEY"] = "sk-test-mistral"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clean_db():
|
||||||
|
yield
|
||||||
|
try:
|
||||||
|
os.remove(TEST_DB)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app():
|
||||||
|
from ai_router_api import create_app
|
||||||
|
app = create_app()
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router():
|
||||||
|
from ai_router.router import AIRouter
|
||||||
|
return AIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Provider Base Tests
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderInterface:
|
||||||
|
def test_provider_map_has_all(self):
|
||||||
|
from ai_router.providers import PROVIDER_MAP
|
||||||
|
assert "openai" in PROVIDER_MAP
|
||||||
|
assert "anthropic" in PROVIDER_MAP
|
||||||
|
assert "google" in PROVIDER_MAP
|
||||||
|
assert "mistral" in PROVIDER_MAP
|
||||||
|
|
||||||
|
def test_api_key_resolution_env(self):
|
||||||
|
from ai_router.providers.base import AIProvider
|
||||||
|
class TestProvider(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self): return "openai"
|
||||||
|
def chat(self, messages, model, **kwargs): return {}
|
||||||
|
def models(self): return []
|
||||||
|
def check_health(self): return {"status": "ok"}
|
||||||
|
|
||||||
|
p = TestProvider()
|
||||||
|
assert p.get_api_key() == "sk-test-openai"
|
||||||
|
|
||||||
|
def test_api_key_resolution_db_overrides_env(self):
|
||||||
|
from ai_router.providers.base import AIProvider
|
||||||
|
class TestProvider(AIProvider):
|
||||||
|
@property
|
||||||
|
def name(self): return "openai"
|
||||||
|
def chat(self, messages, model, **kwargs): return {}
|
||||||
|
def models(self): return []
|
||||||
|
def check_health(self): return {"status": "ok"}
|
||||||
|
|
||||||
|
p = TestProvider()
|
||||||
|
assert p.get_api_key({"api_key": "sk-db-key"}) == "sk-db-key"
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Router Tests
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouter:
|
||||||
|
def test_resolve_known_model(self, router):
|
||||||
|
info = router._resolve_model("gpt-4o")
|
||||||
|
assert info is not None
|
||||||
|
assert info["provider"] == "openai"
|
||||||
|
assert info["real_model"] == "gpt-4o"
|
||||||
|
|
||||||
|
def test_resolve_unknown_model(self, router):
|
||||||
|
info = router._resolve_model("nonexistent-model")
|
||||||
|
assert info is None
|
||||||
|
|
||||||
|
def test_list_models_includes_defaults(self, router):
|
||||||
|
models = router.list_available_models()
|
||||||
|
aliases = [m["alias"] for m in models]
|
||||||
|
assert "gpt-4o" in aliases
|
||||||
|
assert "claude-3-opus" in aliases
|
||||||
|
assert "gemini-pro" in aliases
|
||||||
|
assert "mistral-large" in aliases
|
||||||
|
|
||||||
|
def test_prioritized_providers_default_order(self, router):
|
||||||
|
providers = router._get_prioritized_providers()
|
||||||
|
names = [p[0] for p in providers]
|
||||||
|
assert names == ["openai", "anthropic", "google", "mistral"]
|
||||||
|
|
||||||
|
def test_chat_unknown_model(self, router):
|
||||||
|
result = router.chat([{"role": "user", "content": "hi"}], "unknown-model")
|
||||||
|
assert result["status"] == "error"
|
||||||
|
|
||||||
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||||
|
def test_chat_success_first_provider(self, mock_chat, router):
|
||||||
|
mock_chat.return_value = {
|
||||||
|
"content": "Hello!",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"provider": "openai",
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["content"] == "Hello!"
|
||||||
|
assert result["provider"] == "openai"
|
||||||
|
|
||||||
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||||
|
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
||||||
|
def test_chat_failover_to_second_provider(self, mock_anthropic, mock_openai, router):
|
||||||
|
mock_openai.side_effect = Exception("OpenAI down")
|
||||||
|
mock_anthropic.return_value = {
|
||||||
|
"content": "Hello from Anthropic!",
|
||||||
|
"model": "claude-3-sonnet-20240229",
|
||||||
|
"provider": "anthropic",
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
result = router.chat([{"role": "user", "content": "hi"}], "claude-3-sonnet")
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["provider"] == "anthropic"
|
||||||
|
|
||||||
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||||
|
@patch("ai_router.providers.anthropic_adapter.AnthropicAdapter.chat")
|
||||||
|
@patch("ai_router.providers.google_adapter.GoogleAdapter.chat")
|
||||||
|
@patch("ai_router.providers.mistral_adapter.MistralAdapter.chat")
|
||||||
|
def test_chat_all_providers_fail(self, mock_mistral, mock_google, mock_anthropic, mock_openai, router):
|
||||||
|
for mock in (mock_openai, mock_anthropic, mock_google, mock_mistral):
|
||||||
|
mock.side_effect = Exception("Provider unavailable")
|
||||||
|
result = router.chat([{"role": "user", "content": "hi"}], "gpt-4o")
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "All providers failed" in result["error"]
|
||||||
|
|
||||||
|
def test_health_all_providers_returns_dict(self, router):
|
||||||
|
health = router.check_all_providers_health()
|
||||||
|
for name in ("openai", "anthropic", "google", "mistral"):
|
||||||
|
assert name in health
|
||||||
|
assert "status" in health[name]
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Database Tests
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels:
|
||||||
|
def test_init_db_creates_tables(self):
|
||||||
|
from ai_router.models import init_db, get_db
|
||||||
|
init_db()
|
||||||
|
conn = get_db()
|
||||||
|
tables = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
).fetchall()
|
||||||
|
names = [r[0] for r in tables]
|
||||||
|
assert "ai_providers" in names
|
||||||
|
assert "ai_model_mapping" in names
|
||||||
|
assert "ai_router_log" in names
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def test_upsert_and_get_providers(self):
|
||||||
|
from ai_router.models import init_db, upsert_provider, get_providers_from_db
|
||||||
|
init_db()
|
||||||
|
upsert_provider("Test OpenAI", "openai", "sk-test", priority=1)
|
||||||
|
providers = get_providers_from_db()
|
||||||
|
assert len(providers) > 0
|
||||||
|
assert any(p["name"] == "Test OpenAI" for p in providers)
|
||||||
|
|
||||||
|
def test_log_router_attempt(self):
|
||||||
|
from ai_router.models import init_db, log_router_attempt, get_db
|
||||||
|
init_db()
|
||||||
|
log_router_attempt("req-1", 42, "gpt-4o", "openai", 10, 5, 200, "success")
|
||||||
|
conn = get_db()
|
||||||
|
rows = conn.execute("SELECT * FROM ai_router_log").fetchall()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["status"] == "success"
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# API Blueprint Tests
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPI:
|
||||||
|
def test_health_endpoint(self, client):
|
||||||
|
resp = client.get("/api/v1/ai/health")
|
||||||
|
assert resp.status_code in (200, 503)
|
||||||
|
data = resp.get_json()
|
||||||
|
assert data["service"] == "ai-router"
|
||||||
|
assert "providers" in data
|
||||||
|
|
||||||
|
def test_models_endpoint(self, client):
|
||||||
|
resp = client.get("/api/v1/ai/models")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.get_json()
|
||||||
|
assert "models" in data
|
||||||
|
assert len(data["models"]) > 0
|
||||||
|
|
||||||
|
def test_chat_no_auth(self, client):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/ai/chat",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_chat_no_messages(self, client):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/ai/chat",
|
||||||
|
json={"model": "gpt-4o"},
|
||||||
|
headers={"X-API-Key": "test-key"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
@patch("ai_router.utils.verify_token_via_broker")
|
||||||
|
@patch("ai_router.providers.openai_adapter.OpenAIAdapter.chat")
|
||||||
|
def test_chat_success_with_auth(self, mock_chat, mock_verify, client):
|
||||||
|
mock_verify.return_value = {"valid": True, "user_id": 1, "scopes": ["user"]}
|
||||||
|
mock_chat.return_value = {
|
||||||
|
"content": "Hello!",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"provider": "openai",
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/ai/chat",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "model": "gpt-4o"},
|
||||||
|
headers={"Authorization": "Bearer test-token"},
|
||||||
|
)
|
||||||
|
data = resp.get_json()
|
||||||
|
assert resp.status_code == 200, f"Got {resp.status_code}: {data}"
|
||||||
|
assert data["status"] == "success"
|
||||||
|
|
||||||
|
def test_usage_requires_admin(self, client):
|
||||||
|
resp = client.get("/api/v1/ai/usage")
|
||||||
|
assert resp.status_code == 401
|
||||||
Reference in New Issue
Block a user