Compare commits
4 Commits
feature/ml
...
feature/ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce0ee150ec | ||
|
|
b8ef1ed35d | ||
|
|
c8f1bfd478 | ||
|
|
5a23692ad1 |
132
API_AUTH.md
Normal file
132
API_AUTH.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# API Auth JWT — Documentation
|
||||
## Sprint 2-3 (HRT-28)
|
||||
|
||||
Base URL: `http://localhost:8792`
|
||||
|
||||
---
|
||||
|
||||
## Endpoints d'authentification
|
||||
|
||||
### `POST /api/v1/auth/register`
|
||||
Inscription d'un nouvel utilisateur (plan free par défaut).
|
||||
|
||||
**Body JSON:**
|
||||
```json
|
||||
{ "email": "user@example.com", "password": "motdepasse123" }
|
||||
```
|
||||
**Réponse 201:**
|
||||
```json
|
||||
{ "message": "Compte créé avec succès", "user_id": 1 }
|
||||
```
|
||||
**Erreurs:** `400` (email invalide / mot de passe < 8 car.), `409` (email déjà utilisé)
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/v1/auth/login`
|
||||
Connexion — retourne access_token (15min) + refresh_token (30j).
|
||||
|
||||
**Body JSON:**
|
||||
```json
|
||||
{ "email": "user@example.com", "password": "motdepasse123" }
|
||||
```
|
||||
**Réponse 200:**
|
||||
```json
|
||||
{
|
||||
"access_token": "<JWT>",
|
||||
"refresh_token": "<refresh_JWT>",
|
||||
"token_type": "Bearer",
|
||||
"plan": "free"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/v1/auth/refresh`
|
||||
Rotation du refresh token — invalide l'ancien, émet un nouveau.
|
||||
|
||||
**Body JSON:**
|
||||
```json
|
||||
{ "refresh_token": "<refresh_JWT>" }
|
||||
```
|
||||
**Réponse 200:** identique à `/login`
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/v1/auth/logout`
|
||||
Révocation du refresh token.
|
||||
|
||||
**Body JSON:**
|
||||
```json
|
||||
{ "refresh_token": "<refresh_JWT>" }
|
||||
```
|
||||
**Réponse 200:**
|
||||
```json
|
||||
{ "message": "Déconnexion réussie" }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Routes protégées
|
||||
|
||||
Toutes les routes protégées nécessitent le header:
|
||||
```
|
||||
Authorization: Bearer <access_token>
|
||||
```
|
||||
|
||||
### `GET /api/v1/predictions`
|
||||
| Plan | Accès |
|
||||
|---------|---------------------------------------------|
|
||||
| free | Top 3 uniquement, 1 course/jour |
|
||||
| premium | Toutes les courses + alertes Telegram |
|
||||
| pro | API complète + lien export CSV |
|
||||
|
||||
### `GET /api/v1/predictions/export`
|
||||
Export CSV — **plan pro uniquement** (`403` pour free/premium).
|
||||
|
||||
### `GET /api/v1/subscription/upgrade`
|
||||
Infos sur les plans disponibles et plan courant de l'utilisateur.
|
||||
|
||||
### `GET /api/v1/health`
|
||||
Vérification d'état du service (pas d'auth requise).
|
||||
|
||||
---
|
||||
|
||||
## Sécurité
|
||||
|
||||
- **Passwords:** hashés avec bcrypt (saltRounds=12)
|
||||
- **JWT access:** expiration 15 minutes (HS256)
|
||||
- **JWT refresh:** expiration 30 jours, stocké hashé (SHA-256) en DB, rotation à chaque usage
|
||||
- **Rate limiting:** 100 requêtes/min par IP — header `X-RateLimit-Remaining`
|
||||
- **CORS:** configuré pour `https://turf-ia.h3r7.tech` + localhost dev
|
||||
- **Logs d'accès:** horodatés ISO 8601 dans `logs/saas_api.log`
|
||||
|
||||
---
|
||||
|
||||
## Lancement
|
||||
|
||||
```bash
|
||||
JWT_SECRET_KEY="votre_cle_secrete" \
|
||||
CORS_ORIGINS="https://turf-ia.h3r7.tech" \
|
||||
./venv/bin/python saas_api.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
./venv/bin/pytest tests/test_auth.py -v
|
||||
# Avec couverture:
|
||||
./venv/bin/pytest tests/test_auth.py --cov=auth --cov=auth_db --cov=middleware --cov=saas_api --cov-report=term-missing
|
||||
# Résultat: 27 tests OK, couverture globale 83%
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Structure des tables DB
|
||||
|
||||
```sql
|
||||
-- users: id, email, password_hash, plan(free/premium/pro), created_at, is_active, daily_usage, last_usage_date
|
||||
-- subscriptions: id, user_id, plan, start_date, end_date, stripe_customer_id
|
||||
-- refresh_tokens: id, user_id, token_hash, created_at, expires_at, revoked
|
||||
```
|
||||
156
README_API_V1.md
Normal file
156
README_API_V1.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Turf SaaS — API v1 Reference
|
||||
|
||||
Sprint 3-4 · HRT-29 — Refacto API /v1/
|
||||
|
||||
## Base URL
|
||||
|
||||
```
|
||||
http://<host>:8792
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
All endpoints (except `/api/v1/health` and `/api/v1/auth/*`) require a **Bearer JWT** token.
|
||||
|
||||
```
|
||||
Authorization: Bearer <access_token>
|
||||
```
|
||||
|
||||
### Get a token
|
||||
|
||||
```bash
|
||||
# Register
|
||||
curl -X POST http://localhost:8792/api/v1/auth/register \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"email": "user@example.com", "password": "mypassword"}'
|
||||
|
||||
# Login → returns access_token + refresh_token
|
||||
curl -X POST http://localhost:8792/api/v1/auth/login \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"email": "user@example.com", "password": "mypassword"}'
|
||||
```
|
||||
|
||||
## Plans & Access Control
|
||||
|
||||
| Plan | Inclus |
|
||||
|-----------|----------------------------------------------------|
|
||||
| `free` | health, auth, courses/today, predictions/top3 (1/j)|
|
||||
| `premium` | + predictions/all, valuebets, metrics |
|
||||
| `pro` | + backtest, export/csv |
|
||||
|
||||
## Endpoints
|
||||
|
||||
### System
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|--------|------------------|------|----------------------|
|
||||
| GET | `/api/v1/health` | Non | Healthcheck public |
|
||||
| GET | `/api/v1/docs` | Non | Swagger UI |
|
||||
|
||||
### Auth
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|---------------------------|--------------------------------|
|
||||
| POST | `/api/v1/auth/register` | Créer un compte (plan=free) |
|
||||
| POST | `/api/v1/auth/login` | Login → JWT tokens |
|
||||
| POST | `/api/v1/auth/refresh` | Renouveler l'access token |
|
||||
| POST | `/api/v1/auth/logout` | Révoquer le refresh token |
|
||||
|
||||
### Courses
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|---------------------------------------|---------|------------------------------------|
|
||||
| GET | `/api/v1/courses/today` | free+ | Courses du jour (paginé) |
|
||||
| GET | `/api/v1/courses/{id}/predictions` | free+ | Prédictions ML pour une course |
|
||||
|
||||
Query params `courses/today`: `filter=[all|quinte|trot|plat]`, `limit`, `offset`
|
||||
|
||||
`{id}` format: `{num_reunion}-{num_course}` ex: `1-3`
|
||||
|
||||
### Prédictions
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|---------------------------|-----------|------------------------------|
|
||||
| GET | `/api/v1/predictions/top3`| free+ | Top 3 chevaux du jour |
|
||||
| GET | `/api/v1/predictions/all` | premium+ | Toutes les prédictions ML |
|
||||
|
||||
Query params: `date=YYYY-MM-DD`, `limit`, `offset`
|
||||
|
||||
### Value Bets
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|---------------------|-----------|--------------------------|
|
||||
| GET | `/api/v1/valuebets` | premium+ | Value bets du jour |
|
||||
|
||||
Query params: `date`, `min_odds` (défaut 2.0), `limit`, `offset`
|
||||
|
||||
### Backtest
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|---------------------|------|----------------------------------|
|
||||
| GET | `/api/v1/backtest` | pro | Résultats historiques des paris |
|
||||
|
||||
Query params: `start`, `end` (YYYY-MM-DD), `limit`, `offset`
|
||||
|
||||
### Export
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|-------------------------|------|----------------------|
|
||||
| GET | `/api/v1/export/csv` | pro | Export CSV |
|
||||
|
||||
Query params: `type=[predictions|bets]`, `date`, `start`, `end`
|
||||
|
||||
### Métriques
|
||||
|
||||
| Method | Path | Plan | Description |
|
||||
|--------|---------------------|----------|-----------------------|
|
||||
| GET | `/api/v1/metrics` | premium+ | Métriques ML et paris |
|
||||
|
||||
Query params: `days` (int, défaut 30)
|
||||
|
||||
## Réponse uniforme
|
||||
|
||||
Toutes les erreurs retournent :
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Description de l'erreur",
|
||||
"code": 400
|
||||
}
|
||||
```
|
||||
|
||||
Les listes paginées incluent :
|
||||
|
||||
```json
|
||||
{
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"limit": 20,
|
||||
"offset": 0,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Démarrage
|
||||
|
||||
```bash
|
||||
cd /home/h3r7/turf_saas
|
||||
source venv/bin/activate
|
||||
python app_v1.py
|
||||
# ou
|
||||
gunicorn -w 2 -b 0.0.0.0:8792 app_v1:app
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
cd /home/h3r7/turf_saas
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/test_api_v1.py -v
|
||||
```
|
||||
|
||||
## Documentation Swagger
|
||||
|
||||
Accessible sur : `http://localhost:8792/api/v1/docs`
|
||||
43
api_v1/__init__.py
Normal file
43
api_v1/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
API v1 Blueprint package — Turf SaaS
|
||||
Sprint 3-4: HRT-29 — Refacto API /v1/
|
||||
Sprint 5-6: HRT-31 — Billing Stripe
|
||||
|
||||
Registers sub-blueprints:
|
||||
/api/v1/health — public health-check
|
||||
/api/v1/courses/ — courses du jour
|
||||
/api/v1/predictions/— predictions ML
|
||||
/api/v1/valuebets — value bets (premium+)
|
||||
/api/v1/backtest — backtest historique (pro)
|
||||
/api/v1/export/ — export CSV (pro)
|
||||
/api/v1/metrics — métriques perf ML (premium+)
|
||||
/api/v1/billing/ — Stripe checkout, portal, webhook, status
|
||||
/api/v1/docs — Swagger UI (via flasgger, registered on app)
|
||||
"""
|
||||
|
||||
from flask import Blueprint
|
||||
|
||||
from .routes.health import health_bp
|
||||
from .routes.courses import courses_bp
|
||||
from .routes.predictions import predictions_bp
|
||||
from .routes.valuebets import valuebets_bp
|
||||
from .routes.backtest import backtest_bp
|
||||
from .routes.export import export_bp
|
||||
from .routes.metrics import metrics_bp
|
||||
from .routes.billing import billing_bp
|
||||
|
||||
# Master blueprint that aggregates all sub-routes under /api/v1
|
||||
api_v1_bp = Blueprint("api_v1", __name__, url_prefix="/api/v1")
|
||||
|
||||
|
||||
def register_api_v1(app):
|
||||
"""Register all API v1 blueprints onto the Flask app."""
|
||||
app.register_blueprint(health_bp)
|
||||
app.register_blueprint(courses_bp)
|
||||
app.register_blueprint(predictions_bp)
|
||||
app.register_blueprint(valuebets_bp)
|
||||
app.register_blueprint(backtest_bp)
|
||||
app.register_blueprint(export_bp)
|
||||
app.register_blueprint(metrics_bp)
|
||||
app.register_blueprint(billing_bp)
|
||||
195
api_v1/routes/backtest.py
Normal file
195
api_v1/routes/backtest.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Backtest route for API v1.
|
||||
|
||||
GET /api/v1/backtest — Résultats backtest historiques (pro)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
internal_error,
|
||||
bad_request,
|
||||
get_pagination_params,
|
||||
paginate_query,
|
||||
)
|
||||
from auth import jwt_required_middleware, plan_required
|
||||
|
||||
backtest_bp = Blueprint("v1_backtest", __name__, url_prefix="/api/v1")
|
||||
|
||||
|
||||
@backtest_bp.route("/backtest", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("pro")
|
||||
def backtest():
|
||||
"""
|
||||
Backtest historique
|
||||
---
|
||||
tags:
|
||||
- Backtest
|
||||
summary: Résultats backtest historiques des paris simulés — accès pro uniquement
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: start
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date de début (YYYY-MM-DD), défaut = -30j
|
||||
- name: end
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date de fin (YYYY-MM-DD), défaut = aujourd'hui
|
||||
- name: limit
|
||||
in: query
|
||||
type: integer
|
||||
default: 50
|
||||
- name: offset
|
||||
in: query
|
||||
type: integer
|
||||
default: 0
|
||||
responses:
|
||||
200:
|
||||
description: Résultats backtest
|
||||
401:
|
||||
description: Token invalide
|
||||
403:
|
||||
description: Plan insuffisant (pro requis)
|
||||
"""
|
||||
start = request.args.get("start")
|
||||
end = request.args.get("end")
|
||||
|
||||
# Validate date formats
|
||||
for label, val in [("start", start), ("end", end)]:
|
||||
if val:
|
||||
try:
|
||||
datetime.strptime(val, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
return bad_request(
|
||||
f"Paramètre '{label}' invalide, format attendu: YYYY-MM-DD"
|
||||
)
|
||||
|
||||
if not start:
|
||||
start = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d")
|
||||
if not end:
|
||||
end = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
limit, offset = get_pagination_params(default_limit=50, max_limit=200)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
if not table_exists(conn, "bet_results"):
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"period": {"start": start, "end": end},
|
||||
"summary": {
|
||||
"total_bets": 0,
|
||||
"message": "Aucune donnée bet_results",
|
||||
},
|
||||
"by_type": {},
|
||||
"details": [],
|
||||
"pagination": {
|
||||
"total": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False,
|
||||
},
|
||||
}
|
||||
), 200
|
||||
|
||||
# Summary
|
||||
summary_row = conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN resultat='GAGNE' THEN 1 ELSE 0 END) AS gagne,
|
||||
SUM(mise) AS mise,
|
||||
SUM(gain) AS gain
|
||||
FROM bet_results
|
||||
WHERE date BETWEEN ? AND ?""",
|
||||
(start, end),
|
||||
).fetchone()
|
||||
|
||||
total_bets = summary_row["total"] or 0
|
||||
gagne = summary_row["gagne"] or 0
|
||||
mise = float(summary_row["mise"] or 0)
|
||||
gain = float(summary_row["gain"] or 0)
|
||||
roi = round((gain - mise) / mise * 100, 1) if mise > 0 else 0.0
|
||||
precision = round(gagne / total_bets * 100, 1) if total_bets > 0 else 0.0
|
||||
|
||||
# By type
|
||||
by_type_rows = conn.execute(
|
||||
"""SELECT
|
||||
type_pari,
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN resultat='GAGNE' THEN 1 ELSE 0 END) AS gagne,
|
||||
SUM(mise) AS mise,
|
||||
SUM(gain) AS gain
|
||||
FROM bet_results
|
||||
WHERE date BETWEEN ? AND ?
|
||||
GROUP BY type_pari""",
|
||||
(start, end),
|
||||
).fetchall()
|
||||
|
||||
by_type = {}
|
||||
for row in by_type_rows:
|
||||
t = row["total"] or 0
|
||||
g = row["gagne"] or 0
|
||||
m = float(row["mise"] or 0)
|
||||
gn = float(row["gain"] or 0)
|
||||
by_type[row["type_pari"]] = {
|
||||
"count": t,
|
||||
"gagne": g,
|
||||
"mise": round(m, 2),
|
||||
"gain": round(gn, 2),
|
||||
"roi": round((gn - m) / m * 100, 1) if m > 0 else 0.0,
|
||||
"precision": round(g / t * 100, 1) if t > 0 else 0.0,
|
||||
}
|
||||
|
||||
# Paginated details
|
||||
count_row = conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM bet_results WHERE date BETWEEN ? AND ?",
|
||||
(start, end),
|
||||
).fetchone()
|
||||
detail_total = count_row["cnt"] if count_row else 0
|
||||
|
||||
detail_rows = conn.execute(
|
||||
"""SELECT date, race_name, type_pari, horse_name, horse_number,
|
||||
COALESCE(cote, 0) AS cote, mise, resultat, gain
|
||||
FROM bet_results
|
||||
WHERE date BETWEEN ? AND ?
|
||||
ORDER BY date DESC, id DESC
|
||||
LIMIT ? OFFSET ?""",
|
||||
(start, end, limit, offset),
|
||||
).fetchall()
|
||||
|
||||
details = [dict(r) for r in detail_rows]
|
||||
pagination = paginate_query(details, detail_total, limit, offset)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"period": {"start": start, "end": end},
|
||||
"summary": {
|
||||
"total_bets": total_bets,
|
||||
"gagne": gagne,
|
||||
"perdu": total_bets - gagne,
|
||||
"precision": precision,
|
||||
"mise_totale": round(mise, 2),
|
||||
"gain_total": round(gain, 2),
|
||||
"roi": roi,
|
||||
},
|
||||
"by_type": by_type,
|
||||
"details": details,
|
||||
**pagination,
|
||||
}
|
||||
), 200
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
664
api_v1/routes/billing.py
Normal file
664
api_v1/routes/billing.py
Normal file
@@ -0,0 +1,664 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Billing Blueprint — Stripe integration
|
||||
Sprint 5-6: HRT-31
|
||||
|
||||
Endpoints:
|
||||
POST /api/v1/billing/checkout — create Stripe Checkout session (auth required)
|
||||
POST /api/v1/billing/portal — create Stripe Customer Portal session (auth required)
|
||||
POST /api/v1/billing/webhook — Stripe webhook handler (public, signature-verified)
|
||||
GET /api/v1/billing/status — current subscription status (auth required)
|
||||
|
||||
Environment variables required:
|
||||
STRIPE_SECRET_KEY — Stripe secret key (sk_live_... or sk_test_...)
|
||||
STRIPE_PUBLISHABLE_KEY — Stripe publishable key (pk_...)
|
||||
STRIPE_WEBHOOK_SECRET — webhook signing secret (whsec_...)
|
||||
STRIPE_PRICE_PREMIUM — Stripe Price ID for Premium plan (price_...)
|
||||
STRIPE_PRICE_PRO — Stripe Price ID for Pro plan (price_...)
|
||||
APP_BASE_URL — e.g. https://turf-ia.h3r7.tech (default http://localhost:8793)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import stripe
|
||||
from flask import Blueprint, g, jsonify, request
|
||||
|
||||
from auth import jwt_required_middleware
|
||||
from billing_db import get_db, migrate_billing_tables
|
||||
|
||||
logger = logging.getLogger("turf_saas.billing")
|
||||
|
||||
billing_bp = Blueprint("billing", __name__, url_prefix="/api/v1/billing")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Stripe configuration
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
stripe.api_key = os.environ.get("STRIPE_SECRET_KEY", "")
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get("STRIPE_WEBHOOK_SECRET", "")
|
||||
STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY", "")
|
||||
APP_BASE_URL = os.environ.get("APP_BASE_URL", "http://localhost:8793")
|
||||
|
||||
# Plan → Stripe Price ID mapping
|
||||
PLAN_PRICE_IDS = {
|
||||
"premium": os.environ.get("STRIPE_PRICE_PREMIUM", ""),
|
||||
"pro": os.environ.get("STRIPE_PRICE_PRO", ""),
|
||||
}
|
||||
|
||||
# Plan display names
|
||||
PLAN_NAMES = {
|
||||
"free": "Free",
|
||||
"premium": "Premium",
|
||||
"pro": "Pro",
|
||||
}
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# DB helpers
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sget(obj, key, default=None):
|
||||
"""Safely get a value from a dict OR a Stripe StripeObject.
|
||||
|
||||
Stripe v7+ uses attribute-style access; plain dicts use [] / .get().
|
||||
"""
|
||||
try:
|
||||
# StripeObject supports [] but not .get(); dict supports both
|
||||
val = obj[key]
|
||||
return val if val is not None else default
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
def _get_active_subscription(db, user_id: int):
|
||||
"""Return the most recent active subscription row for a user."""
|
||||
return db.execute(
|
||||
"""SELECT * FROM subscriptions
|
||||
WHERE user_id = ?
|
||||
ORDER BY start_date DESC
|
||||
LIMIT 1""",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
|
||||
|
||||
def _upsert_subscription(db, user_id: int, **fields):
|
||||
"""
|
||||
Update existing subscription or insert a new one.
|
||||
fields: plan, stripe_customer_id, stripe_subscription_id,
|
||||
status, current_period_end, grace_period_end, end_date
|
||||
"""
|
||||
existing = _get_active_subscription(db, user_id)
|
||||
if existing:
|
||||
# Build SET clause dynamically from provided fields
|
||||
set_parts = ", ".join(f"{k} = ?" for k in fields)
|
||||
values = list(fields.values()) + [existing["id"]]
|
||||
db.execute(f"UPDATE subscriptions SET {set_parts} WHERE id = ?", values)
|
||||
else:
|
||||
cols = ", ".join(["user_id"] + list(fields.keys()))
|
||||
placeholders = ", ".join(["?"] * (1 + len(fields)))
|
||||
values = [user_id] + list(fields.values())
|
||||
db.execute(
|
||||
f"INSERT INTO subscriptions ({cols}) VALUES ({placeholders})", values
|
||||
)
|
||||
|
||||
|
||||
def _update_user_plan(db, user_id: int, plan: str):
|
||||
"""Sync users.plan field to match active subscription."""
|
||||
db.execute("UPDATE users SET plan = ? WHERE id = ?", (plan, user_id))
|
||||
|
||||
|
||||
def _get_or_create_stripe_customer(user, db) -> str:
|
||||
"""Return existing stripe_customer_id or create a new Stripe Customer."""
|
||||
sub = _get_active_subscription(db, user["id"])
|
||||
if sub and sub["stripe_customer_id"]:
|
||||
return sub["stripe_customer_id"]
|
||||
|
||||
# Create new customer in Stripe
|
||||
customer = stripe.Customer.create(
|
||||
email=user["email"],
|
||||
metadata={"user_id": str(user["id"])},
|
||||
)
|
||||
return customer["id"]
|
||||
|
||||
|
||||
def _record_billing_event(
|
||||
db, stripe_event_id: str, event_type: str, user_id=None, payload=None
|
||||
):
|
||||
"""Insert a billing_events audit row (idempotent on stripe_event_id)."""
|
||||
try:
|
||||
db.execute(
|
||||
"""INSERT OR IGNORE INTO billing_events
|
||||
(stripe_event_id, event_type, user_id, payload)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(
|
||||
stripe_event_id,
|
||||
event_type,
|
||||
user_id,
|
||||
json.dumps(payload) if payload else None,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Could not record billing event %s: %s", stripe_event_id, e)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# POST /api/v1/billing/checkout
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@billing_bp.route("/checkout", methods=["POST"])
|
||||
@jwt_required_middleware
|
||||
def create_checkout():
|
||||
"""
|
||||
Create a Stripe Checkout session for upgrading to Premium or Pro.
|
||||
---
|
||||
tags:
|
||||
- Billing
|
||||
security:
|
||||
- Bearer: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [plan]
|
||||
properties:
|
||||
plan:
|
||||
type: string
|
||||
enum: [premium, pro]
|
||||
responses:
|
||||
200:
|
||||
description: Checkout session URL
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
checkout_url:
|
||||
type: string
|
||||
session_id:
|
||||
type: string
|
||||
400:
|
||||
description: Invalid plan or Stripe not configured
|
||||
503:
|
||||
description: Stripe API error
|
||||
"""
|
||||
if not stripe.api_key:
|
||||
return jsonify({"error": "Stripe non configuré"}), 503
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
plan = body.get("plan", "").lower()
|
||||
|
||||
if plan not in ("premium", "pro"):
|
||||
return jsonify({"error": "Plan invalide. Choisir 'premium' ou 'pro'"}), 400
|
||||
|
||||
price_id = PLAN_PRICE_IDS.get(plan)
|
||||
if not price_id:
|
||||
return jsonify({"error": f"Prix Stripe non configuré pour le plan {plan}"}), 503
|
||||
|
||||
user = g.current_user
|
||||
if user["plan"] == plan:
|
||||
return jsonify({"error": f"Vous êtes déjà sur le plan {plan}"}), 400
|
||||
|
||||
db = get_db()
|
||||
try:
|
||||
customer_id = _get_or_create_stripe_customer(user, db)
|
||||
# Persist customer_id early to prevent duplicates
|
||||
_upsert_subscription(
|
||||
db, user["id"], stripe_customer_id=customer_id, plan=user["plan"]
|
||||
)
|
||||
db.commit()
|
||||
|
||||
session = stripe.checkout.Session.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
mode="subscription",
|
||||
success_url=f"{APP_BASE_URL}/billing/success?session_id={{CHECKOUT_SESSION_ID}}",
|
||||
cancel_url=f"{APP_BASE_URL}/billing/cancel",
|
||||
metadata={"user_id": str(user["id"]), "plan": plan},
|
||||
subscription_data={"metadata": {"user_id": str(user["id"]), "plan": plan}},
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
logger.error("Stripe checkout error for user %s: %s", user["id"], e)
|
||||
return jsonify({"error": "Erreur Stripe", "detail": str(e)}), 503
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
"plan": plan,
|
||||
"publishable_key": STRIPE_PUBLISHABLE_KEY,
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# POST /api/v1/billing/portal
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@billing_bp.route("/portal", methods=["POST"])
|
||||
@jwt_required_middleware
|
||||
def create_portal():
|
||||
"""
|
||||
Create a Stripe Customer Portal session for managing subscription.
|
||||
---
|
||||
tags:
|
||||
- Billing
|
||||
security:
|
||||
- Bearer: []
|
||||
responses:
|
||||
200:
|
||||
description: Portal session URL
|
||||
400:
|
||||
description: No Stripe customer found
|
||||
503:
|
||||
description: Stripe not configured or API error
|
||||
"""
|
||||
if not stripe.api_key:
|
||||
return jsonify({"error": "Stripe non configuré"}), 503
|
||||
|
||||
user = g.current_user
|
||||
db = get_db()
|
||||
try:
|
||||
sub = _get_active_subscription(db, user["id"])
|
||||
customer_id = sub["stripe_customer_id"] if sub else None
|
||||
|
||||
if not customer_id:
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Aucun abonnement Stripe trouvé. "
|
||||
"Souscrivez d'abord à un plan payant."
|
||||
}
|
||||
), 400
|
||||
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=f"{APP_BASE_URL}/account",
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
logger.error("Stripe portal error for user %s: %s", user["id"], e)
|
||||
return jsonify({"error": "Erreur Stripe", "detail": str(e)}), 503
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return jsonify({"portal_url": session.url}), 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# GET /api/v1/billing/status
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@billing_bp.route("/status", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
def billing_status():
|
||||
"""
|
||||
Return current subscription status for the authenticated user.
|
||||
---
|
||||
tags:
|
||||
- Billing
|
||||
security:
|
||||
- Bearer: []
|
||||
responses:
|
||||
200:
|
||||
description: Subscription status
|
||||
"""
|
||||
user = g.current_user
|
||||
db = get_db()
|
||||
try:
|
||||
sub = _get_active_subscription(db, user["id"])
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if not sub:
|
||||
return jsonify(
|
||||
{
|
||||
"plan": "free",
|
||||
"status": "active",
|
||||
"stripe_customer_id": None,
|
||||
"stripe_subscription_id": None,
|
||||
"current_period_end": None,
|
||||
"grace_period_end": None,
|
||||
}
|
||||
), 200
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"plan": sub["plan"],
|
||||
"status": sub["status"] or "active",
|
||||
"stripe_customer_id": sub["stripe_customer_id"],
|
||||
"stripe_subscription_id": sub["stripe_subscription_id"],
|
||||
"start_date": sub["start_date"],
|
||||
"end_date": sub["end_date"],
|
||||
"current_period_end": sub["current_period_end"],
|
||||
"grace_period_end": sub["grace_period_end"],
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# POST /api/v1/billing/webhook
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@billing_bp.route("/webhook", methods=["POST"])
|
||||
def stripe_webhook():
|
||||
"""
|
||||
Stripe webhook handler — no auth, signature-verified.
|
||||
|
||||
Handled events:
|
||||
checkout.session.completed → activate subscription
|
||||
customer.subscription.updated → sync plan/status
|
||||
customer.subscription.deleted → downgrade to free
|
||||
invoice.payment_failed → set past_due + 3-day grace period
|
||||
invoice.payment_succeeded → clear grace period
|
||||
"""
|
||||
payload = request.get_data()
|
||||
sig_header = request.headers.get("Stripe-Signature", "")
|
||||
|
||||
# Verify webhook signature (required in production)
|
||||
if STRIPE_WEBHOOK_SECRET:
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
logger.warning("Stripe webhook signature invalid: %s", e)
|
||||
return jsonify({"error": "Signature invalide"}), 400
|
||||
except ValueError as e:
|
||||
logger.warning("Stripe webhook payload invalid: %s", e)
|
||||
return jsonify({"error": "Payload invalide"}), 400
|
||||
else:
|
||||
# Dev/test: accept without verification (log a warning)
|
||||
logger.warning("STRIPE_WEBHOOK_SECRET not set — skipping signature check!")
|
||||
try:
|
||||
event = stripe.Event.construct_from(json.loads(payload), stripe.api_key)
|
||||
except Exception as e:
|
||||
return jsonify({"error": "Payload invalide", "detail": str(e)}), 400
|
||||
|
||||
event_type = event["type"]
|
||||
event_id = event["id"]
|
||||
logger.info("Stripe webhook received: %s (%s)", event_type, event_id)
|
||||
|
||||
db = get_db()
|
||||
try:
|
||||
if event_type == "checkout.session.completed":
|
||||
_handle_checkout_completed(db, event)
|
||||
|
||||
elif event_type in (
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.created",
|
||||
):
|
||||
_handle_subscription_updated(db, event)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
_handle_subscription_deleted(db, event)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
_handle_payment_failed(db, event)
|
||||
|
||||
elif event_type == "invoice.payment_succeeded":
|
||||
_handle_payment_succeeded(db, event)
|
||||
|
||||
else:
|
||||
logger.debug("Unhandled Stripe event type: %s", event_type)
|
||||
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error("Error processing Stripe webhook %s: %s", event_id, e)
|
||||
return jsonify({"error": "Erreur interne"}), 500
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return jsonify({"status": "ok"}), 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Webhook handlers
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _resolve_user_from_customer(db, customer_id: str):
|
||||
"""Look up user_id via subscriptions.stripe_customer_id."""
|
||||
row = db.execute(
|
||||
"SELECT user_id FROM subscriptions WHERE stripe_customer_id = ? LIMIT 1",
|
||||
(customer_id,),
|
||||
).fetchone()
|
||||
if row:
|
||||
return row["user_id"]
|
||||
|
||||
# Fallback: query Stripe for user_id metadata
|
||||
try:
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
meta = _sget(customer, "metadata") or {}
|
||||
uid = _sget(meta, "user_id")
|
||||
if uid:
|
||||
return int(uid)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_plan_from_price(price_id: str) -> str:
|
||||
"""Map Stripe price ID to internal plan name."""
|
||||
for plan, pid in PLAN_PRICE_IDS.items():
|
||||
if pid and pid == price_id:
|
||||
return plan
|
||||
# Unknown price — default to premium (safer than pro)
|
||||
return "premium"
|
||||
|
||||
|
||||
def _handle_checkout_completed(db, event):
|
||||
"""checkout.session.completed → activate subscription for the user."""
|
||||
session = event["data"]["object"]
|
||||
customer_id = _sget(session, "customer")
|
||||
subscription_id = _sget(session, "subscription")
|
||||
metadata = _sget(session, "metadata") or {}
|
||||
plan = _sget(metadata, "plan") or "premium"
|
||||
user_id = _sget(metadata, "user_id")
|
||||
|
||||
if user_id:
|
||||
user_id = int(user_id)
|
||||
else:
|
||||
user_id = _resolve_user_from_customer(db, customer_id)
|
||||
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"checkout.session.completed: cannot resolve user for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch subscription details from Stripe
|
||||
current_period_end = None
|
||||
if subscription_id:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
current_period_end = datetime.fromtimestamp(
|
||||
sub["current_period_end"], tz=timezone.utc
|
||||
).isoformat()
|
||||
# Sync plan from price if metadata plan is missing
|
||||
if sub["items"]["data"]:
|
||||
price_id = sub["items"]["data"][0]["price"]["id"]
|
||||
plan = _resolve_plan_from_price(price_id)
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch subscription %s: %s", subscription_id, e)
|
||||
|
||||
_upsert_subscription(
|
||||
db,
|
||||
user_id,
|
||||
plan=plan,
|
||||
stripe_customer_id=customer_id,
|
||||
stripe_subscription_id=subscription_id,
|
||||
status="active",
|
||||
current_period_end=current_period_end,
|
||||
grace_period_end=None,
|
||||
)
|
||||
_update_user_plan(db, user_id, plan)
|
||||
_record_billing_event(db, event["id"], event["type"], user_id=user_id)
|
||||
logger.info("checkout.session.completed: user %s upgraded to %s", user_id, plan)
|
||||
|
||||
|
||||
def _handle_subscription_updated(db, event):
|
||||
"""customer.subscription.updated → sync status and plan."""
|
||||
sub_obj = event["data"]["object"]
|
||||
customer_id = _sget(sub_obj, "customer")
|
||||
subscription_id = _sget(sub_obj, "id")
|
||||
stripe_status = _sget(sub_obj, "status") or "active"
|
||||
current_period_end = None
|
||||
|
||||
cpe = _sget(sub_obj, "current_period_end")
|
||||
if cpe:
|
||||
current_period_end = datetime.fromtimestamp(cpe, tz=timezone.utc).isoformat()
|
||||
|
||||
# Resolve plan from price
|
||||
plan = "premium"
|
||||
items_data = _sget(_sget(sub_obj, "items") or {}, "data")
|
||||
if items_data:
|
||||
price_id = items_data[0]["price"]["id"]
|
||||
plan = _resolve_plan_from_price(price_id)
|
||||
|
||||
user_id = _resolve_user_from_customer(db, customer_id)
|
||||
if not user_id:
|
||||
# Try metadata
|
||||
meta = _sget(sub_obj, "metadata") or {}
|
||||
meta_uid = _sget(meta, "user_id")
|
||||
if meta_uid:
|
||||
user_id = int(meta_uid)
|
||||
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"subscription.updated: cannot resolve user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
|
||||
_upsert_subscription(
|
||||
db,
|
||||
user_id,
|
||||
plan=plan,
|
||||
stripe_customer_id=customer_id,
|
||||
stripe_subscription_id=subscription_id,
|
||||
status=stripe_status,
|
||||
current_period_end=current_period_end,
|
||||
)
|
||||
_update_user_plan(db, user_id, plan)
|
||||
_record_billing_event(db, event["id"], event["type"], user_id=user_id)
|
||||
logger.info(
|
||||
"subscription.updated: user %s plan=%s status=%s", user_id, plan, stripe_status
|
||||
)
|
||||
|
||||
|
||||
def _handle_subscription_deleted(db, event):
|
||||
"""customer.subscription.deleted → downgrade to free."""
|
||||
sub_obj = event["data"]["object"]
|
||||
customer_id = _sget(sub_obj, "customer")
|
||||
|
||||
user_id = _resolve_user_from_customer(db, customer_id)
|
||||
if not user_id:
|
||||
meta = _sget(sub_obj, "metadata") or {}
|
||||
meta_uid = _sget(meta, "user_id")
|
||||
if meta_uid:
|
||||
user_id = int(meta_uid)
|
||||
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"subscription.deleted: cannot resolve user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
|
||||
_upsert_subscription(
|
||||
db,
|
||||
user_id,
|
||||
plan="free",
|
||||
stripe_subscription_id=None,
|
||||
status="canceled",
|
||||
end_date=datetime.now(timezone.utc).isoformat(),
|
||||
current_period_end=None,
|
||||
grace_period_end=None,
|
||||
)
|
||||
_update_user_plan(db, user_id, "free")
|
||||
_record_billing_event(db, event["id"], event["type"], user_id=user_id)
|
||||
logger.info("subscription.deleted: user %s downgraded to free", user_id)
|
||||
|
||||
|
||||
def _handle_payment_failed(db, event):
|
||||
"""invoice.payment_failed → mark past_due + 3-day grace period."""
|
||||
invoice = event["data"]["object"]
|
||||
customer_id = _sget(invoice, "customer")
|
||||
subscription_id = _sget(invoice, "subscription")
|
||||
|
||||
user_id = _resolve_user_from_customer(db, customer_id)
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"invoice.payment_failed: cannot resolve user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
|
||||
grace_end = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat()
|
||||
|
||||
_upsert_subscription(db, user_id, status="past_due", grace_period_end=grace_end)
|
||||
_record_billing_event(
|
||||
db,
|
||||
event["id"],
|
||||
event["type"],
|
||||
user_id=user_id,
|
||||
payload={"subscription_id": subscription_id},
|
||||
)
|
||||
|
||||
# TODO: send notification email via /api/notifications
|
||||
logger.warning(
|
||||
"invoice.payment_failed: user %s past_due, grace period until %s",
|
||||
user_id,
|
||||
grace_end,
|
||||
)
|
||||
|
||||
|
||||
def _handle_payment_succeeded(db, event):
|
||||
"""invoice.payment_succeeded → clear past_due / grace period."""
|
||||
invoice = event["data"]["object"]
|
||||
customer_id = _sget(invoice, "customer")
|
||||
|
||||
user_id = _resolve_user_from_customer(db, customer_id)
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
# Refresh subscription period end
|
||||
current_period_end = None
|
||||
lines = _sget(invoice, "lines") or {}
|
||||
lines_data = _sget(lines, "data") or []
|
||||
if lines_data:
|
||||
period = lines_data[0].get("period") or {}
|
||||
period_end = (
|
||||
period.get("end") if isinstance(period, dict) else _sget(period, "end")
|
||||
)
|
||||
if period_end:
|
||||
current_period_end = datetime.fromtimestamp(
|
||||
period_end, tz=timezone.utc
|
||||
).isoformat()
|
||||
|
||||
_upsert_subscription(
|
||||
db,
|
||||
user_id,
|
||||
status="active",
|
||||
grace_period_end=None,
|
||||
current_period_end=current_period_end,
|
||||
)
|
||||
_record_billing_event(db, event["id"], event["type"], user_id=user_id)
|
||||
logger.info("invoice.payment_succeeded: user %s payment cleared", user_id)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# On-import: ensure DB migration ran
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
try:
|
||||
migrate_billing_tables()
|
||||
except Exception as _e:
|
||||
logger.warning("billing_db migration skipped (test env?): %s", _e)
|
||||
277
api_v1/routes/courses.py
Normal file
277
api_v1/routes/courses.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Courses routes for API v1.
|
||||
|
||||
GET /api/v1/courses/today — liste des courses du jour (public, paginated)
|
||||
GET /api/v1/courses/{id}/predictions — prédictions ML pour une course (free tier, 1/day limit)
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, jsonify, request, g
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
error_response,
|
||||
bad_request,
|
||||
not_found,
|
||||
internal_error,
|
||||
get_pagination_params,
|
||||
paginate_query,
|
||||
)
|
||||
from auth import jwt_required_middleware, free_daily_limit_check
|
||||
|
||||
courses_bp = Blueprint("v1_courses", __name__, url_prefix="/api/v1/courses")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# GET /api/v1/courses/today
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@courses_bp.route("/today", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
def courses_today():
|
||||
"""
|
||||
Courses du jour
|
||||
---
|
||||
tags:
|
||||
- Courses
|
||||
summary: Liste toutes les courses du jour avec info course
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: filter
|
||||
in: query
|
||||
type: string
|
||||
enum: [all, quinte, trot, plat]
|
||||
default: all
|
||||
description: Filtre par type de course
|
||||
- name: limit
|
||||
in: query
|
||||
type: integer
|
||||
default: 20
|
||||
- name: offset
|
||||
in: query
|
||||
type: integer
|
||||
default: 0
|
||||
responses:
|
||||
200:
|
||||
description: Liste des courses du jour
|
||||
401:
|
||||
description: Token manquant ou invalide
|
||||
"""
|
||||
race_filter = request.args.get("filter", "all").lower()
|
||||
limit, offset = get_pagination_params(default_limit=50, max_limit=200)
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Build SQL condition
|
||||
if race_filter == "quinte":
|
||||
cond = "AND (c.libelle LIKE '%Quinté%' OR c.libelle LIKE '%Quinte%')"
|
||||
elif race_filter == "trot":
|
||||
cond = "AND c.discipline LIKE '%Trot%'"
|
||||
elif race_filter == "plat":
|
||||
cond = "AND c.discipline LIKE '%Plat%'"
|
||||
else:
|
||||
cond = ""
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
# Graceful handling if pmu_courses table doesn't exist yet
|
||||
if not table_exists(conn, "pmu_courses"):
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": today,
|
||||
"filter": race_filter,
|
||||
"courses": [],
|
||||
"pagination": {
|
||||
"total": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False,
|
||||
},
|
||||
}
|
||||
), 200
|
||||
|
||||
# Count total
|
||||
count_row = conn.execute(
|
||||
f"""SELECT COUNT(*) as cnt
|
||||
FROM pmu_courses c
|
||||
WHERE c.date_programme = ? {cond}""",
|
||||
(today,),
|
||||
).fetchone()
|
||||
total = count_row["cnt"] if count_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
f"""SELECT
|
||||
c.date_programme,
|
||||
c.num_reunion,
|
||||
c.num_course,
|
||||
c.libelle,
|
||||
c.discipline,
|
||||
c.distance,
|
||||
c.hippodrome,
|
||||
c.px_type,
|
||||
COUNT(p.id_cheval) as nb_partants
|
||||
FROM pmu_courses c
|
||||
LEFT JOIN pmu_partants p
|
||||
ON p.date_programme = c.date_programme
|
||||
AND p.num_reunion = c.num_reunion
|
||||
AND p.num_course = c.num_course
|
||||
WHERE c.date_programme = ? {cond}
|
||||
GROUP BY c.date_programme, c.num_reunion, c.num_course
|
||||
ORDER BY c.num_reunion ASC, c.num_course ASC
|
||||
LIMIT ? OFFSET ?""",
|
||||
(today, limit, offset),
|
||||
).fetchall()
|
||||
|
||||
courses = []
|
||||
for r in rows:
|
||||
course_id = f"{r['num_reunion']}-{r['num_course']}"
|
||||
courses.append(
|
||||
{
|
||||
"id": course_id,
|
||||
"date": r["date_programme"],
|
||||
"num_reunion": r["num_reunion"],
|
||||
"num_course": r["num_course"],
|
||||
"libelle": r["libelle"],
|
||||
"discipline": r["discipline"],
|
||||
"distance": r["distance"],
|
||||
"hippodrome": r["hippodrome"],
|
||||
"type_pari": r["px_type"],
|
||||
"nb_partants": r["nb_partants"],
|
||||
}
|
||||
)
|
||||
|
||||
pagination = paginate_query(courses, total, limit, offset)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": today,
|
||||
"filter": race_filter,
|
||||
"courses": courses,
|
||||
**pagination,
|
||||
}
|
||||
), 200
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# GET /api/v1/courses/<course_id>/predictions
|
||||
# course_id format: "{num_reunion}-{num_course}" e.g. "1-3"
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@courses_bp.route("/<course_id>/predictions", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@free_daily_limit_check
|
||||
def course_predictions(course_id):
|
||||
"""
|
||||
Prédictions pour une course
|
||||
---
|
||||
tags:
|
||||
- Courses
|
||||
summary: Prédictions ML pour une course identifiée par {num_reunion}-{num_course}
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: course_id
|
||||
in: path
|
||||
type: string
|
||||
required: true
|
||||
description: Identifiant de la course (format num_reunion-num_course, ex "1-3")
|
||||
- name: date
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date de la course (YYYY-MM-DD), défaut = aujourd'hui
|
||||
responses:
|
||||
200:
|
||||
description: Prédictions ML pour la course
|
||||
400:
|
||||
description: Paramètres invalides
|
||||
404:
|
||||
description: Course introuvable
|
||||
429:
|
||||
description: Limite quotidienne free tier atteinte
|
||||
"""
|
||||
# Parse course_id
|
||||
parts = course_id.split("-")
|
||||
if len(parts) != 2:
|
||||
return bad_request(
|
||||
"course_id doit être au format {num_reunion}-{num_course}, ex: 1-3"
|
||||
)
|
||||
|
||||
try:
|
||||
num_reunion = int(parts[0])
|
||||
num_course = int(parts[1])
|
||||
except ValueError:
|
||||
return bad_request("num_reunion et num_course doivent être des entiers")
|
||||
|
||||
date_param = request.args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
# Fetch course info
|
||||
course_row = conn.execute(
|
||||
"""SELECT libelle, discipline, distance, hippodrome, px_type
|
||||
FROM pmu_courses
|
||||
WHERE date_programme = ? AND num_reunion = ? AND num_course = ?""",
|
||||
(date_param, num_reunion, num_course),
|
||||
).fetchone()
|
||||
|
||||
if not course_row:
|
||||
return not_found(
|
||||
f"Course R{num_reunion}C{num_course} introuvable pour le {date_param}"
|
||||
)
|
||||
|
||||
# Fetch ML predictions from cache
|
||||
preds = []
|
||||
if table_exists(conn, "ml_predictions_cache"):
|
||||
preds = conn.execute(
|
||||
"""SELECT horse_name, horse_number, odds, prob_top1, prob_top3,
|
||||
ml_score, recommendation, is_value_bet, risque_label, risque_score
|
||||
FROM ml_predictions_cache
|
||||
WHERE date = ? AND num_reunion = ? AND num_course = ?
|
||||
ORDER BY ml_score DESC""",
|
||||
(date_param, num_reunion, num_course),
|
||||
).fetchall()
|
||||
|
||||
# Fetch partants
|
||||
partants = conn.execute(
|
||||
"""SELECT nom, num_pmu, cote_direct, cote_reference, tendance_cote, favoris,
|
||||
tx_victoire, tx_place, forme_recente, driver, entraineur, musique
|
||||
FROM pmu_partants
|
||||
WHERE date_programme = ? AND num_reunion = ? AND num_course = ?
|
||||
ORDER BY num_pmu ASC""",
|
||||
(date_param, num_reunion, num_course),
|
||||
).fetchall()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": date_param,
|
||||
"course": {
|
||||
"id": course_id,
|
||||
"libelle": course_row["libelle"],
|
||||
"discipline": course_row["discipline"],
|
||||
"distance": course_row["distance"],
|
||||
"hippodrome": course_row["hippodrome"],
|
||||
"type_pari": course_row["px_type"],
|
||||
},
|
||||
"predictions": [dict(p) for p in preds],
|
||||
"partants": [dict(p) for p in partants],
|
||||
}
|
||||
), 200
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
185
api_v1/routes/export.py
Normal file
185
api_v1/routes/export.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Export route for API v1.
|
||||
|
||||
GET /api/v1/export/csv — Export CSV des prédictions ou paris (pro)
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, jsonify, request, Response
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
internal_error,
|
||||
bad_request,
|
||||
forbidden,
|
||||
)
|
||||
from auth import jwt_required_middleware, plan_required
|
||||
|
||||
export_bp = Blueprint("v1_export", __name__, url_prefix="/api/v1/export")
|
||||
|
||||
# Maximum rows exportable in one request
|
||||
EXPORT_MAX_ROWS = 5000
|
||||
|
||||
|
||||
@export_bp.route("/csv", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("pro")
|
||||
def export_csv():
|
||||
"""
|
||||
Export CSV
|
||||
---
|
||||
tags:
|
||||
- Export
|
||||
summary: Export CSV des prédictions ML ou des paris historiques — accès pro uniquement
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: type
|
||||
in: query
|
||||
type: string
|
||||
enum: [predictions, bets]
|
||||
default: predictions
|
||||
description: Type de données à exporter
|
||||
- name: start
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date de début (YYYY-MM-DD)
|
||||
- name: end
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date de fin (YYYY-MM-DD)
|
||||
- name: date
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date unique (YYYY-MM-DD), ignoré si start/end fournis
|
||||
responses:
|
||||
200:
|
||||
description: Fichier CSV
|
||||
content:
|
||||
text/csv:
|
||||
schema:
|
||||
type: string
|
||||
400:
|
||||
description: Paramètre invalide
|
||||
401:
|
||||
description: Token invalide
|
||||
403:
|
||||
description: Plan insuffisant (pro requis)
|
||||
"""
|
||||
export_type = request.args.get("type", "predictions").lower()
|
||||
if export_type not in ("predictions", "bets"):
|
||||
return bad_request(
|
||||
"Paramètre 'type' invalide. Valeurs acceptées: predictions, bets"
|
||||
)
|
||||
|
||||
start = request.args.get("start")
|
||||
end = request.args.get("end")
|
||||
date = request.args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
for label, val in [("start", start), ("end", end), ("date", date)]:
|
||||
if val:
|
||||
try:
|
||||
datetime.strptime(val, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
return bad_request(
|
||||
f"Paramètre '{label}' invalide, format attendu: YYYY-MM-DD"
|
||||
)
|
||||
|
||||
# Build date range
|
||||
if start and end:
|
||||
date_cond = "date BETWEEN ? AND ?"
|
||||
date_params = [start, end]
|
||||
elif start:
|
||||
date_cond = "date >= ?"
|
||||
date_params = [start]
|
||||
else:
|
||||
date_cond = "date = ?"
|
||||
date_params = [date]
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
output = io.StringIO()
|
||||
|
||||
if export_type == "predictions":
|
||||
if not table_exists(conn, "ml_predictions_cache"):
|
||||
return bad_request("Table ml_predictions_cache introuvable")
|
||||
|
||||
rows = conn.execute(
|
||||
f"""SELECT date, race_label, hippodrome, discipline, distance, heure,
|
||||
horse_name, horse_number, odds, prob_top1, prob_top3,
|
||||
ml_score, recommendation, is_value_bet, risque_label
|
||||
FROM ml_predictions_cache
|
||||
WHERE {date_cond}
|
||||
ORDER BY date DESC, ml_score DESC
|
||||
LIMIT {EXPORT_MAX_ROWS}""",
|
||||
date_params,
|
||||
).fetchall()
|
||||
|
||||
fieldnames = [
|
||||
"date",
|
||||
"race_label",
|
||||
"hippodrome",
|
||||
"discipline",
|
||||
"distance",
|
||||
"heure",
|
||||
"horse_name",
|
||||
"horse_number",
|
||||
"odds",
|
||||
"prob_top1",
|
||||
"prob_top3",
|
||||
"ml_score",
|
||||
"recommendation",
|
||||
"is_value_bet",
|
||||
"risque_label",
|
||||
]
|
||||
|
||||
else: # bets
|
||||
if not table_exists(conn, "bet_results"):
|
||||
return bad_request("Table bet_results introuvable")
|
||||
|
||||
rows = conn.execute(
|
||||
f"""SELECT date, race_name, type_pari, horse_name, horse_number,
|
||||
COALESCE(cote, 0) AS cote, mise, resultat, gain
|
||||
FROM bet_results
|
||||
WHERE {date_cond}
|
||||
ORDER BY date DESC
|
||||
LIMIT {EXPORT_MAX_ROWS}""",
|
||||
date_params,
|
||||
).fetchall()
|
||||
|
||||
fieldnames = [
|
||||
"date",
|
||||
"race_name",
|
||||
"type_pari",
|
||||
"horse_name",
|
||||
"horse_number",
|
||||
"cote",
|
||||
"mise",
|
||||
"resultat",
|
||||
"gain",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(dict(row))
|
||||
|
||||
filename = f"turf_{export_type}_{date_params[0]}.csv"
|
||||
return Response(
|
||||
output.getvalue(),
|
||||
status=200,
|
||||
mimetype="text/csv",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
44
api_v1/routes/health.py
Normal file
44
api_v1/routes/health.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
GET /api/v1/health — public healthcheck endpoint.
|
||||
No authentication required.
|
||||
"""
|
||||
|
||||
from flask import Blueprint, jsonify
|
||||
from datetime import datetime, timezone
|
||||
|
||||
health_bp = Blueprint("v1_health", __name__, url_prefix="/api/v1")
|
||||
|
||||
|
||||
@health_bp.route("/health", methods=["GET"])
|
||||
def health():
|
||||
"""
|
||||
Health check
|
||||
---
|
||||
tags:
|
||||
- System
|
||||
summary: Public healthcheck — returns API status and timestamp
|
||||
responses:
|
||||
200:
|
||||
description: API is healthy
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
example: ok
|
||||
version:
|
||||
type: string
|
||||
example: "1.0"
|
||||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
"""
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"version": "1.0",
|
||||
"api": "Turf SaaS API v1",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
), 200
|
||||
144
api_v1/routes/metrics.py
Normal file
144
api_v1/routes/metrics.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Metrics route for API v1.
|
||||
|
||||
GET /api/v1/metrics — Métriques performances ML (premium+)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
internal_error,
|
||||
bad_request,
|
||||
)
|
||||
from auth import jwt_required_middleware, plan_required
|
||||
|
||||
metrics_bp = Blueprint("v1_metrics", __name__, url_prefix="/api/v1")
|
||||
|
||||
|
||||
@metrics_bp.route("/metrics", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("premium", "pro")
|
||||
def metrics():
|
||||
"""
|
||||
Métriques ML
|
||||
---
|
||||
tags:
|
||||
- Métriques
|
||||
summary: Métriques de performance du modèle ML (precision, ROI, top-3 rate) — premium+
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: days
|
||||
in: query
|
||||
type: integer
|
||||
default: 30
|
||||
description: Nombre de jours à analyser (max 365)
|
||||
responses:
|
||||
200:
|
||||
description: Métriques de performance ML
|
||||
401:
|
||||
description: Token invalide
|
||||
403:
|
||||
description: Plan insuffisant (premium ou pro requis)
|
||||
"""
|
||||
try:
|
||||
days = int(request.args.get("days", 30))
|
||||
except (ValueError, TypeError):
|
||||
return bad_request("Paramètre 'days' doit être un entier")
|
||||
|
||||
days = max(1, min(days, 365))
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
# ── Bet-level metrics from bet_results ──
|
||||
bet_metrics = {
|
||||
"available": False,
|
||||
"period": {"start": start_date, "end": end_date, "days": days},
|
||||
}
|
||||
ml_metrics = {"available": False}
|
||||
daily_stats = []
|
||||
|
||||
if table_exists(conn, "bet_results"):
|
||||
row = conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN resultat='GAGNE' THEN 1 ELSE 0 END) AS gagne,
|
||||
SUM(mise) AS mise,
|
||||
SUM(gain) AS gain
|
||||
FROM bet_results
|
||||
WHERE date BETWEEN ? AND ?""",
|
||||
(start_date, end_date),
|
||||
).fetchone()
|
||||
|
||||
total = row["total"] or 0
|
||||
gagne = row["gagne"] or 0
|
||||
mise = float(row["mise"] or 0)
|
||||
gain = float(row["gain"] or 0)
|
||||
|
||||
bet_metrics = {
|
||||
"available": True,
|
||||
"period": {"start": start_date, "end": end_date, "days": days},
|
||||
"total_bets": total,
|
||||
"precision_pct": round(gagne / total * 100, 2) if total > 0 else 0.0,
|
||||
"roi_pct": round((gain - mise) / mise * 100, 2) if mise > 0 else 0.0,
|
||||
"mise_totale": round(mise, 2),
|
||||
"gain_total": round(gain, 2),
|
||||
}
|
||||
|
||||
# ── ML predictions cache metrics ──
|
||||
if table_exists(conn, "ml_predictions_cache"):
|
||||
cache_row = conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) AS total,
|
||||
SUM(is_value_bet) AS value_bets,
|
||||
AVG(prob_top1) AS avg_prob_top1,
|
||||
AVG(prob_top3) AS avg_prob_top3,
|
||||
AVG(ml_score) AS avg_ml_score
|
||||
FROM ml_predictions_cache
|
||||
WHERE date BETWEEN ? AND ?""",
|
||||
(start_date, end_date),
|
||||
).fetchone()
|
||||
|
||||
if cache_row and cache_row["total"]:
|
||||
ml_metrics = {
|
||||
"available": True,
|
||||
"total_predictions": cache_row["total"],
|
||||
"value_bets": cache_row["value_bets"] or 0,
|
||||
"avg_prob_top1": round(float(cache_row["avg_prob_top1"] or 0), 4),
|
||||
"avg_prob_top3": round(float(cache_row["avg_prob_top3"] or 0), 4),
|
||||
"avg_ml_score": round(float(cache_row["avg_ml_score"] or 0), 4),
|
||||
}
|
||||
|
||||
# ── Daily breakdown ──
|
||||
if table_exists(conn, "daily_stats"):
|
||||
daily_rows = conn.execute(
|
||||
"""SELECT date, total_bets, bets_gagne, precision_pct, roi_pct,
|
||||
mise_totale, gain_total
|
||||
FROM daily_stats
|
||||
WHERE date BETWEEN ? AND ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 60""",
|
||||
(start_date, end_date),
|
||||
).fetchall()
|
||||
daily_stats = [dict(r) for r in daily_rows]
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"period": {"start": start_date, "end": end_date, "days": days},
|
||||
"bet_metrics": bet_metrics,
|
||||
"ml_metrics": ml_metrics,
|
||||
"daily": daily_stats,
|
||||
}
|
||||
), 200
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
163
api_v1/routes/predictions.py
Normal file
163
api_v1/routes/predictions.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Predictions routes for API v1.
|
||||
|
||||
GET /api/v1/predictions/top3 — Top 3 global du jour (free tier, 1/day limit)
|
||||
GET /api/v1/predictions/all — Toutes prédictions (premium+)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
internal_error,
|
||||
not_found,
|
||||
get_pagination_params,
|
||||
paginate_query,
|
||||
)
|
||||
from auth import jwt_required_middleware, plan_required, free_daily_limit_check
|
||||
|
||||
predictions_bp = Blueprint("v1_predictions", __name__, url_prefix="/api/v1/predictions")
|
||||
|
||||
|
||||
def _fetch_ml_predictions(conn, date: str, limit: int = None, offset: int = 0):
|
||||
"""Shared helper — returns rows from ml_predictions_cache."""
|
||||
if not table_exists(conn, "ml_predictions_cache"):
|
||||
return [], 0
|
||||
|
||||
count_row = conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM ml_predictions_cache WHERE date = ?",
|
||||
(date,),
|
||||
).fetchone()
|
||||
total = count_row["cnt"] if count_row else 0
|
||||
|
||||
sql = """SELECT
|
||||
race_label, hippodrome, discipline, distance, heure,
|
||||
horse_name, horse_number, odds, prob_top1, prob_top3,
|
||||
ml_score, recommendation, is_value_bet, risque_label, risque_score
|
||||
FROM ml_predictions_cache
|
||||
WHERE date = ?
|
||||
ORDER BY ml_score DESC"""
|
||||
params = [date]
|
||||
|
||||
if limit is not None:
|
||||
sql += " LIMIT ? OFFSET ?"
|
||||
params += [limit, offset]
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
return [dict(r) for r in rows], total
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# GET /api/v1/predictions/top3
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@predictions_bp.route("/top3", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@free_daily_limit_check
|
||||
def predictions_top3():
|
||||
"""
|
||||
Top 3 prédictions du jour
|
||||
---
|
||||
tags:
|
||||
- Prédictions
|
||||
summary: Top 3 chevaux avec le meilleur score ML du jour (free tier inclus)
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: date
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date au format YYYY-MM-DD (défaut aujourd'hui)
|
||||
responses:
|
||||
200:
|
||||
description: Top 3 prédictions ML du jour
|
||||
401:
|
||||
description: Token invalide
|
||||
429:
|
||||
description: Limite quotidienne free tier atteinte
|
||||
"""
|
||||
date_param = request.args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
predictions, _ = _fetch_ml_predictions(conn, date_param, limit=3, offset=0)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": date_param,
|
||||
"top3": predictions,
|
||||
}
|
||||
), 200
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# GET /api/v1/predictions/all
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@predictions_bp.route("/all", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("premium", "pro")
|
||||
def predictions_all():
|
||||
"""
|
||||
Toutes les prédictions du jour
|
||||
---
|
||||
tags:
|
||||
- Prédictions
|
||||
summary: Toutes les prédictions ML du jour — accès premium et pro uniquement
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: date
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date au format YYYY-MM-DD (défaut aujourd'hui)
|
||||
- name: limit
|
||||
in: query
|
||||
type: integer
|
||||
default: 20
|
||||
- name: offset
|
||||
in: query
|
||||
type: integer
|
||||
default: 0
|
||||
responses:
|
||||
200:
|
||||
description: Toutes les prédictions ML
|
||||
401:
|
||||
description: Token invalide
|
||||
403:
|
||||
description: Plan insuffisant (premium ou pro requis)
|
||||
"""
|
||||
date_param = request.args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||||
limit, offset = get_pagination_params(default_limit=50, max_limit=500)
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
predictions, total = _fetch_ml_predictions(
|
||||
conn, date_param, limit=limit, offset=offset
|
||||
)
|
||||
pagination = paginate_query(predictions, total, limit, offset)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": date_param,
|
||||
"predictions": predictions,
|
||||
**pagination,
|
||||
}
|
||||
), 200
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
111
api_v1/routes/valuebets.py
Normal file
111
api_v1/routes/valuebets.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Value bets route for API v1.
|
||||
|
||||
GET /api/v1/valuebets — Value bets du jour (premium+)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from api_v1.utils import (
|
||||
get_db,
|
||||
table_exists,
|
||||
internal_error,
|
||||
get_pagination_params,
|
||||
paginate_query,
|
||||
)
|
||||
from auth import jwt_required_middleware, plan_required
|
||||
|
||||
valuebets_bp = Blueprint("v1_valuebets", __name__, url_prefix="/api/v1")
|
||||
|
||||
|
||||
@valuebets_bp.route("/valuebets", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("premium", "pro")
|
||||
def valuebets():
|
||||
"""
|
||||
Value bets du jour
|
||||
---
|
||||
tags:
|
||||
- Value Bets
|
||||
summary: Value bets du jour — chevaux à cote surévaluée par le marché (premium+)
|
||||
security:
|
||||
- Bearer: []
|
||||
parameters:
|
||||
- name: date
|
||||
in: query
|
||||
type: string
|
||||
format: date
|
||||
description: Date YYYY-MM-DD (défaut aujourd'hui)
|
||||
- name: min_odds
|
||||
in: query
|
||||
type: number
|
||||
default: 2.0
|
||||
description: Cote minimale pour filtrer les value bets
|
||||
- name: limit
|
||||
in: query
|
||||
type: integer
|
||||
default: 20
|
||||
- name: offset
|
||||
in: query
|
||||
type: integer
|
||||
default: 0
|
||||
responses:
|
||||
200:
|
||||
description: Value bets du jour
|
||||
401:
|
||||
description: Token invalide
|
||||
403:
|
||||
description: Plan insuffisant (premium ou pro requis)
|
||||
"""
|
||||
date_param = request.args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||||
limit, offset = get_pagination_params(default_limit=20, max_limit=100)
|
||||
|
||||
try:
|
||||
min_odds = float(request.args.get("min_odds", 2.0))
|
||||
except (ValueError, TypeError):
|
||||
min_odds = 2.0
|
||||
|
||||
conn = get_db()
|
||||
try:
|
||||
rows = []
|
||||
total = 0
|
||||
|
||||
if table_exists(conn, "ml_predictions_cache"):
|
||||
count_row = conn.execute(
|
||||
"""SELECT COUNT(*) as cnt
|
||||
FROM ml_predictions_cache
|
||||
WHERE date = ? AND is_value_bet = 1 AND odds >= ?""",
|
||||
(date_param, min_odds),
|
||||
).fetchone()
|
||||
total = count_row["cnt"] if count_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
"""SELECT race_label, hippodrome, discipline, distance, heure,
|
||||
horse_name, horse_number, odds, prob_top1, prob_top3,
|
||||
ml_score, recommendation, risque_label, risque_score
|
||||
FROM ml_predictions_cache
|
||||
WHERE date = ? AND is_value_bet = 1 AND odds >= ?
|
||||
ORDER BY ml_score DESC
|
||||
LIMIT ? OFFSET ?""",
|
||||
(date_param, min_odds, limit, offset),
|
||||
).fetchall()
|
||||
|
||||
valuebets_list = [dict(r) for r in rows]
|
||||
pagination = paginate_query(valuebets_list, total, limit, offset)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"status": "ok",
|
||||
"date": date_param,
|
||||
"min_odds": min_odds,
|
||||
"valuebets": valuebets_list,
|
||||
**pagination,
|
||||
}
|
||||
), 200
|
||||
|
||||
except Exception as e:
|
||||
return internal_error(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
98
api_v1/utils.py
Normal file
98
api_v1/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shared utilities for API v1 — error helpers, pagination, DB access.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
from flask import jsonify, request
|
||||
|
||||
DB_PATH = os.environ.get("TURF_SAAS_DB", "/home/h3r7/turf_saas/turf_saas.db")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Database
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Return a SQLite connection with Row factory."""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def table_exists(conn, table_name: str) -> bool:
|
||||
row = conn.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", (table_name,)
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Uniform error responses
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def error_response(message: str, code: int, status: str = "error"):
|
||||
"""Return a JSON error envelope consistent with the API contract.
|
||||
|
||||
Shape: {"status": "error", "message": "...", "code": 400}
|
||||
"""
|
||||
return jsonify({"status": status, "message": message, "code": code}), code
|
||||
|
||||
|
||||
def not_found(message: str = "Resource not found"):
|
||||
return error_response(message, 404)
|
||||
|
||||
|
||||
def bad_request(message: str = "Bad request"):
|
||||
return error_response(message, 400)
|
||||
|
||||
|
||||
def forbidden(message: str = "Forbidden", required_plans=None, current_plan=None):
|
||||
payload = {"status": "error", "message": message, "code": 403}
|
||||
if required_plans:
|
||||
payload["required_plans"] = required_plans
|
||||
if current_plan:
|
||||
payload["current_plan"] = current_plan
|
||||
payload["upgrade_url"] = "/api/v1/subscription/upgrade"
|
||||
return jsonify(payload), 403
|
||||
|
||||
|
||||
def internal_error(message: str = "Internal server error"):
|
||||
return error_response(message, 500)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Pagination helpers
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_pagination_params(default_limit: int = 20, max_limit: int = 100):
|
||||
"""Extract and validate limit/offset from query-string."""
|
||||
try:
|
||||
limit = int(request.args.get("limit", default_limit))
|
||||
except (ValueError, TypeError):
|
||||
limit = default_limit
|
||||
|
||||
try:
|
||||
offset = int(request.args.get("offset", 0))
|
||||
except (ValueError, TypeError):
|
||||
offset = 0
|
||||
|
||||
limit = max(1, min(limit, max_limit))
|
||||
offset = max(0, offset)
|
||||
return limit, offset
|
||||
|
||||
|
||||
def paginate_query(rows, total: int, limit: int, offset: int):
|
||||
"""Wrap a list of rows in a pagination envelope."""
|
||||
return {
|
||||
"pagination": {
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total,
|
||||
}
|
||||
}
|
||||
138
app_v1.py
Normal file
138
app_v1.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
app_v1.py — Turf SaaS Flask application with versioned API /v1/
|
||||
|
||||
This module creates the Flask app, registers:
|
||||
- Auth JWT (from Sprint 2-3)
|
||||
- API v1 blueprints
|
||||
- Swagger/OpenAPI documentation at /api/v1/docs
|
||||
|
||||
Usage:
|
||||
python app_v1.py
|
||||
# or via gunicorn:
|
||||
gunicorn -w 2 -b 0.0.0.0:8792 app_v1:app
|
||||
|
||||
Sprint 3-4: HRT-29 — Refacto API /v1/
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask, jsonify
|
||||
from flask_cors import CORS
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flasgger import Swagger
|
||||
|
||||
from auth_db import init_auth_tables
|
||||
from auth import auth_bp
|
||||
from api_v1 import register_api_v1
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("turf_saas.app_v1")
|
||||
|
||||
|
||||
def create_app() -> Flask:
|
||||
"""Application factory."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# ── CORS ──
|
||||
CORS(app, origins=["*"], methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
|
||||
# ── JWT config ──
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||
"JWT_SECRET_KEY", "change-me-in-production-use-strong-random-secret"
|
||||
)
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
|
||||
JWTManager(app)
|
||||
|
||||
# ── Swagger / OpenAPI ──
|
||||
swagger_config = {
|
||||
"headers": [],
|
||||
"specs": [
|
||||
{
|
||||
"endpoint": "apispec_v1",
|
||||
"route": "/api/v1/apispec.json",
|
||||
"rule_filter": lambda rule: str(rule).startswith("/api/v1"),
|
||||
"model_filter": lambda tag: True,
|
||||
}
|
||||
],
|
||||
"static_url_path": "/flasgger_static",
|
||||
"swagger_ui": True,
|
||||
"specs_route": "/api/v1/docs",
|
||||
}
|
||||
|
||||
swagger_template = {
|
||||
"swagger": "2.0",
|
||||
"info": {
|
||||
"title": "Turf SaaS API",
|
||||
"description": (
|
||||
"API v1 — Prédictions turf IA, value bets, backtest & métriques.\n\n"
|
||||
"**Plans:** `free` | `premium` | `pro`\n\n"
|
||||
"**Auth:** Bearer JWT — obtenir un token via `POST /api/v1/auth/login`"
|
||||
),
|
||||
"version": "1.0.0",
|
||||
"contact": {"name": "H3R7 Tech"},
|
||||
},
|
||||
"basePath": "/",
|
||||
"schemes": ["http", "https"],
|
||||
"securityDefinitions": {
|
||||
"Bearer": {
|
||||
"type": "apiKey",
|
||||
"name": "Authorization",
|
||||
"in": "header",
|
||||
"description": "Entrer: **Bearer <token>**",
|
||||
}
|
||||
},
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
}
|
||||
|
||||
Swagger(app, config=swagger_config, template=swagger_template)
|
||||
|
||||
# ── Auth DB init ──
|
||||
with app.app_context():
|
||||
try:
|
||||
init_auth_tables()
|
||||
except Exception as e:
|
||||
logger.warning("init_auth_tables warning: %s", e)
|
||||
|
||||
# ── Register auth blueprint ──
|
||||
app.register_blueprint(auth_bp)
|
||||
|
||||
# ── Register API v1 blueprints ──
|
||||
register_api_v1(app)
|
||||
|
||||
# ── Global error handlers ──
|
||||
@app.errorhandler(404)
|
||||
def not_found_handler(e):
|
||||
return jsonify(
|
||||
{"status": "error", "message": "Route introuvable", "code": 404}
|
||||
), 404
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed_handler(e):
|
||||
return jsonify(
|
||||
{"status": "error", "message": "Méthode non autorisée", "code": 405}
|
||||
), 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error_handler(e):
|
||||
logger.exception("Unhandled 500 error")
|
||||
return jsonify(
|
||||
{"status": "error", "message": "Erreur serveur interne", "code": 500}
|
||||
), 500
|
||||
|
||||
logger.info("Turf SaaS API v1 ready — docs at /api/v1/docs")
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.environ.get("PORT", 8792))
|
||||
app.run(host="0.0.0.0", port=port, debug=False)
|
||||
362
auth.py
Normal file
362
auth.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auth Blueprint — JWT authentication + multi-tenant plan enforcement
|
||||
Sprint 2-3: HRT-28
|
||||
|
||||
Endpoints:
|
||||
POST /api/v1/auth/register — email/password registration
|
||||
POST /api/v1/auth/login — returns access_token (15min) + refresh_token (30d)
|
||||
POST /api/v1/auth/refresh — rotate refresh token, issue new access_token
|
||||
POST /api/v1/auth/logout — revoke refresh token
|
||||
|
||||
Middleware exposed:
|
||||
jwt_required_middleware() — decorator: valid access JWT required
|
||||
plan_required(plans) — decorator: user plan must be in given list
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import wraps
|
||||
|
||||
import bcrypt
|
||||
from flask import Blueprint, request, jsonify, g, current_app
|
||||
from flask_jwt_extended import (
|
||||
JWTManager,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_jwt_identity,
|
||||
verify_jwt_in_request,
|
||||
)
|
||||
from flask_jwt_extended.exceptions import JWTExtendedException
|
||||
from jwt.exceptions import PyJWTError
|
||||
|
||||
from auth_db import get_db
|
||||
|
||||
logger = logging.getLogger("turf_saas.auth")
|
||||
|
||||
auth_bp = Blueprint("auth", __name__, url_prefix="/api/v1/auth")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_token(raw_token: str) -> str:
|
||||
"""SHA-256 hash of a token string for secure DB storage."""
|
||||
return hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_user_by_email(email: str):
|
||||
db = get_db()
|
||||
user = db.execute(
|
||||
"SELECT * FROM users WHERE email = ? AND is_active = 1", (email.lower(),)
|
||||
).fetchone()
|
||||
db.close()
|
||||
return user
|
||||
|
||||
|
||||
def _get_user_by_id(user_id: int):
|
||||
db = get_db()
|
||||
user = db.execute(
|
||||
"SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,)
|
||||
).fetchone()
|
||||
db.close()
|
||||
return user
|
||||
|
||||
|
||||
def _store_refresh_token(user_id: int, raw_token: str, expires_at: datetime):
|
||||
token_hash = _hash_token(raw_token)
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO refresh_tokens (user_id, token_hash, expires_at) VALUES (?,?,?)",
|
||||
(user_id, token_hash, expires_at.isoformat()),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
|
||||
def _revoke_refresh_token(raw_token: str):
|
||||
token_hash = _hash_token(raw_token)
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"UPDATE refresh_tokens SET revoked = 1 WHERE token_hash = ?", (token_hash,)
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
|
||||
def _is_refresh_token_valid(raw_token: str, user_id: int) -> bool:
|
||||
token_hash = _hash_token(raw_token)
|
||||
db = get_db()
|
||||
row = db.execute(
|
||||
"""SELECT id FROM refresh_tokens
|
||||
WHERE token_hash = ? AND user_id = ? AND revoked = 0
|
||||
AND expires_at > datetime('now')""",
|
||||
(token_hash, user_id),
|
||||
).fetchone()
|
||||
db.close()
|
||||
return row is not None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Auth endpoints
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@auth_bp.route("/register", methods=["POST"])
|
||||
def register():
|
||||
"""POST /api/v1/auth/register — create a new user account (plan=free)."""
|
||||
data = request.get_json(silent=True) or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
password = data.get("password") or ""
|
||||
|
||||
if not email or "@" not in email:
|
||||
return jsonify({"error": "Email invalide"}), 400
|
||||
if len(password) < 8:
|
||||
return jsonify({"error": "Mot de passe trop court (min 8 caractères)"}), 400
|
||||
|
||||
# Check uniqueness
|
||||
existing = _get_user_by_email(email)
|
||||
if existing:
|
||||
return jsonify({"error": "Email déjà enregistré"}), 409
|
||||
|
||||
password_hash = bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
db = get_db()
|
||||
try:
|
||||
cursor = db.execute(
|
||||
"INSERT INTO users (email, password_hash, plan) VALUES (?,?,?)",
|
||||
(email, password_hash, "free"),
|
||||
)
|
||||
user_id = cursor.lastrowid
|
||||
# Create initial subscription record
|
||||
db.execute(
|
||||
"INSERT INTO subscriptions (user_id, plan) VALUES (?,?)",
|
||||
(user_id, "free"),
|
||||
)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error("register error: %s", e)
|
||||
return jsonify({"error": "Erreur interne"}), 500
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info("New user registered: %s (id=%s)", email, user_id)
|
||||
return jsonify({"message": "Compte créé avec succès", "user_id": user_id}), 201
|
||||
|
||||
|
||||
@auth_bp.route("/login", methods=["POST"])
|
||||
def login():
|
||||
"""POST /api/v1/auth/login — returns JWT access_token + refresh_token."""
|
||||
data = request.get_json(silent=True) or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
password = data.get("password") or ""
|
||||
|
||||
if not email or not password:
|
||||
return jsonify({"error": "Email et mot de passe requis"}), 400
|
||||
|
||||
user = _get_user_by_email(email)
|
||||
if not user:
|
||||
return jsonify({"error": "Identifiants invalides"}), 401
|
||||
|
||||
if not bcrypt.checkpw(password.encode(), user["password_hash"].encode()):
|
||||
logger.warning("Failed login attempt for %s", email)
|
||||
return jsonify({"error": "Identifiants invalides"}), 401
|
||||
|
||||
# Create tokens
|
||||
identity = str(user["id"])
|
||||
additional_claims = {"plan": user["plan"], "email": user["email"]}
|
||||
|
||||
access_token = create_access_token(
|
||||
identity=identity,
|
||||
additional_claims=additional_claims,
|
||||
)
|
||||
raw_refresh = create_refresh_token(identity=identity)
|
||||
|
||||
refresh_expires = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
_store_refresh_token(user["id"], raw_refresh, refresh_expires)
|
||||
|
||||
logger.info("User %s logged in (plan=%s)", email, user["plan"])
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"refresh_token": raw_refresh,
|
||||
"token_type": "Bearer",
|
||||
"plan": user["plan"],
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@auth_bp.route("/refresh", methods=["POST"])
|
||||
def refresh():
|
||||
"""POST /api/v1/auth/refresh — rotate refresh token, issue new access_token."""
|
||||
data = request.get_json(silent=True) or {}
|
||||
raw_refresh = (data.get("refresh_token") or "").strip()
|
||||
|
||||
if not raw_refresh:
|
||||
return jsonify({"error": "refresh_token manquant"}), 400
|
||||
|
||||
# Decode without verifying in DB first (to get user_id)
|
||||
try:
|
||||
decoded = decode_token(raw_refresh)
|
||||
except Exception:
|
||||
return jsonify({"error": "Refresh token invalide ou expiré"}), 401
|
||||
|
||||
user_id = int(decoded.get("sub", 0))
|
||||
|
||||
if not _is_refresh_token_valid(raw_refresh, user_id):
|
||||
return jsonify({"error": "Refresh token invalide, révoqué ou expiré"}), 401
|
||||
|
||||
user = _get_user_by_id(user_id)
|
||||
if not user:
|
||||
return jsonify({"error": "Utilisateur introuvable"}), 401
|
||||
|
||||
# Revoke old refresh token (rotation)
|
||||
_revoke_refresh_token(raw_refresh)
|
||||
|
||||
# Issue new tokens
|
||||
identity = str(user["id"])
|
||||
additional_claims = {"plan": user["plan"], "email": user["email"]}
|
||||
new_access = create_access_token(
|
||||
identity=identity, additional_claims=additional_claims
|
||||
)
|
||||
new_refresh = create_refresh_token(identity=identity)
|
||||
|
||||
refresh_expires = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
_store_refresh_token(user["id"], new_refresh, refresh_expires)
|
||||
|
||||
logger.info("Token refreshed for user_id=%s", user_id)
|
||||
return jsonify(
|
||||
{
|
||||
"access_token": new_access,
|
||||
"refresh_token": new_refresh,
|
||||
"token_type": "Bearer",
|
||||
"plan": user["plan"],
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@auth_bp.route("/logout", methods=["POST"])
|
||||
def logout():
|
||||
"""POST /api/v1/auth/logout — revoke refresh token."""
|
||||
data = request.get_json(silent=True) or {}
|
||||
raw_refresh = (data.get("refresh_token") or "").strip()
|
||||
|
||||
if raw_refresh:
|
||||
_revoke_refresh_token(raw_refresh)
|
||||
|
||||
return jsonify({"message": "Déconnexion réussie"}), 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# JWT-protected middleware
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def jwt_required_middleware(fn):
|
||||
"""Decorator: require a valid Bearer JWT access token."""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
verify_jwt_in_request()
|
||||
user_id = int(get_jwt_identity())
|
||||
user = _get_user_by_id(user_id)
|
||||
if not user:
|
||||
return jsonify({"error": "Utilisateur introuvable"}), 401
|
||||
g.current_user = dict(user)
|
||||
g.current_user_id = user_id
|
||||
except (JWTExtendedException, PyJWTError) as e:
|
||||
logger.debug("JWT auth failed: %s", e)
|
||||
return jsonify({"error": "Token invalide ou expiré", "detail": str(e)}), 401
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def plan_required(*allowed_plans):
|
||||
"""
|
||||
Decorator factory: user's plan must be in allowed_plans.
|
||||
Must be applied AFTER @jwt_required_middleware.
|
||||
|
||||
Example:
|
||||
@app.route("/api/v1/predictions")
|
||||
@jwt_required_middleware
|
||||
@plan_required("premium", "pro")
|
||||
def premium_predictions():
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
user = getattr(g, "current_user", None)
|
||||
if not user:
|
||||
return jsonify({"error": "Non authentifié"}), 401
|
||||
if user["plan"] not in allowed_plans:
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Plan insuffisant",
|
||||
"required": list(allowed_plans),
|
||||
"current_plan": user["plan"],
|
||||
"upgrade_url": "/api/v1/subscription/upgrade",
|
||||
}
|
||||
), 403
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def free_daily_limit_check(fn):
|
||||
"""
|
||||
Decorator: enforce free plan daily limit (1 course/jour).
|
||||
Must be applied AFTER @jwt_required_middleware.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
user = getattr(g, "current_user", None)
|
||||
if not user or user["plan"] != "free":
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
today = datetime.now(timezone.utc).date().isoformat()
|
||||
db = get_db()
|
||||
row = db.execute(
|
||||
"SELECT daily_usage, last_usage_date FROM users WHERE id = ?",
|
||||
(user["id"],),
|
||||
).fetchone()
|
||||
db.close()
|
||||
|
||||
if row and row["last_usage_date"] == today and row["daily_usage"] >= 1:
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Limite quotidienne atteinte (plan free: 1 course/jour)",
|
||||
"upgrade_url": "/api/v1/subscription/upgrade",
|
||||
}
|
||||
), 429
|
||||
|
||||
# Increment usage
|
||||
db = get_db()
|
||||
if row and row["last_usage_date"] == today:
|
||||
db.execute(
|
||||
"UPDATE users SET daily_usage = daily_usage + 1 WHERE id = ?",
|
||||
(user["id"],),
|
||||
)
|
||||
else:
|
||||
db.execute(
|
||||
"UPDATE users SET daily_usage = 1, last_usage_date = ? WHERE id = ?",
|
||||
(today, user["id"]),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
68
auth_db.py
Normal file
68
auth_db.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auth DB — users and subscriptions schema for turf_saas.db
|
||||
Sprint 2-3: Auth JWT + Multi-tenant (HRT-28)
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
DB_PATH = os.environ.get("TURF_SAAS_DB", "/home/h3r7/turf_saas/turf_saas.db")
|
||||
|
||||
|
||||
def get_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def init_auth_tables():
|
||||
"""Create users and subscriptions tables if they don't exist."""
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
c.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
plan TEXT NOT NULL DEFAULT 'free'
|
||||
CHECK(plan IN ('free','premium','pro')),
|
||||
created_at DATETIME NOT NULL DEFAULT (datetime('now')),
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
daily_usage INTEGER NOT NULL DEFAULT 0,
|
||||
last_usage_date TEXT DEFAULT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
plan TEXT NOT NULL CHECK(plan IN ('free','premium','pro')),
|
||||
start_date DATETIME NOT NULL DEFAULT (datetime('now')),
|
||||
end_date DATETIME,
|
||||
stripe_customer_id TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
created_at DATETIME NOT NULL DEFAULT (datetime('now')),
|
||||
expires_at DATETIME NOT NULL,
|
||||
revoked INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user ON refresh_tokens(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("[auth_db] Tables users, subscriptions, refresh_tokens created/verified.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_auth_tables()
|
||||
127
billing_db.py
Normal file
127
billing_db.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DB Migration — Billing Stripe
|
||||
Sprint 5-6: HRT-31
|
||||
|
||||
Adds stripe_subscription_id and status columns to subscriptions table,
|
||||
and an invoices / grace-period tracking table.
|
||||
|
||||
Run once:
|
||||
./venv/bin/python billing_db.py
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
import logging
|
||||
|
||||
DB_PATH = os.environ.get("TURF_SAAS_DB", "/home/h3r7/turf_saas/turf_saas.db")
|
||||
logger = logging.getLogger("turf_saas.billing_db")
|
||||
|
||||
|
||||
def get_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def migrate_billing_tables():
|
||||
"""Idempotent migration: add billing columns and billing_events table.
|
||||
|
||||
Requires auth tables (users, subscriptions) to exist first.
|
||||
Calls init_auth_tables() automatically if subscriptions is absent.
|
||||
"""
|
||||
from auth_db import init_auth_tables as _init_auth
|
||||
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Ensure base auth tables exist
|
||||
tables = {
|
||||
row[0] for row in c.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
}
|
||||
conn.close()
|
||||
|
||||
if "subscriptions" not in tables:
|
||||
_init_auth()
|
||||
|
||||
conn = get_db()
|
||||
c = conn.cursor()
|
||||
|
||||
# Add stripe_subscription_id if missing
|
||||
columns = {row[1] for row in c.execute("PRAGMA table_info(subscriptions)")}
|
||||
|
||||
if "stripe_subscription_id" not in columns:
|
||||
c.execute("ALTER TABLE subscriptions ADD COLUMN stripe_subscription_id TEXT")
|
||||
logger.info("[billing_db] Added stripe_subscription_id column to subscriptions")
|
||||
|
||||
if "status" not in columns:
|
||||
c.execute(
|
||||
"ALTER TABLE subscriptions ADD COLUMN "
|
||||
"status TEXT NOT NULL DEFAULT 'active' "
|
||||
"CHECK(status IN ('active','past_due','canceled','trialing','incomplete'))"
|
||||
)
|
||||
logger.info("[billing_db] Added status column to subscriptions")
|
||||
|
||||
if "grace_period_end" not in columns:
|
||||
c.execute("ALTER TABLE subscriptions ADD COLUMN grace_period_end DATETIME")
|
||||
logger.info("[billing_db] Added grace_period_end column to subscriptions")
|
||||
|
||||
if "current_period_end" not in columns:
|
||||
c.execute("ALTER TABLE subscriptions ADD COLUMN current_period_end DATETIME")
|
||||
logger.info("[billing_db] Added current_period_end column to subscriptions")
|
||||
|
||||
# billing_events table — audit trail for all webhook events
|
||||
c.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS billing_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
stripe_event_id TEXT NOT NULL UNIQUE,
|
||||
event_type TEXT NOT NULL,
|
||||
user_id INTEGER REFERENCES users(id),
|
||||
payload TEXT,
|
||||
processed_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_events_user ON billing_events(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_events_type ON billing_events(event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe ON subscriptions(stripe_subscription_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_customer ON subscriptions(stripe_customer_id);
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(
|
||||
"[billing_db] Migration complete: subscriptions + billing_events tables ready."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
migrate_billing_tables()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Re-exported helpers for test usage
|
||||
# (primary implementations live in api_v1/routes/billing.py)
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _upsert_subscription(db, user_id: int, **fields):
|
||||
"""
|
||||
Update existing subscription row or insert a new one.
|
||||
Convenience re-export for test helpers.
|
||||
"""
|
||||
existing = db.execute(
|
||||
"SELECT id FROM subscriptions WHERE user_id = ? ORDER BY start_date DESC LIMIT 1",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
if existing:
|
||||
set_parts = ", ".join(f"{k} = ?" for k in fields)
|
||||
values = list(fields.values()) + [existing["id"]]
|
||||
db.execute(f"UPDATE subscriptions SET {set_parts} WHERE id = ?", values)
|
||||
else:
|
||||
cols = ", ".join(["user_id"] + list(fields.keys()))
|
||||
placeholders = ", ".join(["?"] * (1 + len(fields)))
|
||||
values = [user_id] + list(fields.values())
|
||||
db.execute(
|
||||
f"INSERT INTO subscriptions ({cols}) VALUES ({placeholders})", values
|
||||
)
|
||||
241
combined_api.py
241
combined_api.py
@@ -3519,6 +3519,7 @@ def brave_search():
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
|
||||
@app.route("/turf/api/predictions_analysis", methods=["GET"])
|
||||
def api_predictions_analysis():
|
||||
"""Analyse des predictions vs resultats reels"""
|
||||
@@ -3532,25 +3533,13 @@ def api_predictions_analysis():
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {
|
||||
"canalturf": {
|
||||
"total": 0,
|
||||
"top1_pct": 0,
|
||||
"top3_pct": 0,
|
||||
"top5_pct": 0,
|
||||
"ze2_pct": 0,
|
||||
},
|
||||
"scoring": {
|
||||
"total": 0,
|
||||
"top1_pct": 0,
|
||||
"top3_pct": 0,
|
||||
"top5_pct": 0,
|
||||
"ze2_pct": 0,
|
||||
},
|
||||
"canalturf": {"total": 0, "top1_pct": 0, "top3_pct": 0, "top5_pct": 0, "ze2_pct": 0},
|
||||
"scoring": {"total": 0, "top1_pct": 0, "top3_pct": 0, "top5_pct": 0, "ze2_pct": 0},
|
||||
}
|
||||
|
||||
for source in ["canalturf", "scoring"]:
|
||||
pred_table = "predictions" if source == "canalturf" else "scoring"
|
||||
pred_col = "predicted_1" if source == "canalturf" else "horse_number"
|
||||
pred_col = "predicted_1" if source == "canalturf" else "horse_number"
|
||||
try:
|
||||
cursor.execute(
|
||||
f"""
|
||||
@@ -3577,16 +3566,16 @@ def api_predictions_analysis():
|
||||
top1_hit = top3_hit = 0
|
||||
total = len(races)
|
||||
for race, data in races.items():
|
||||
actual = set(data["actual"][:3])
|
||||
pred_top1 = data["predicted"][0] if data["predicted"] else None
|
||||
actual_top1 = data["actual"][0] if data["actual"] else None
|
||||
actual = set(data["actual"][:3])
|
||||
pred_top1 = data["predicted"][0] if data["predicted"] else None
|
||||
actual_top1 = data["actual"][0] if data["actual"] else None
|
||||
if pred_top1 and actual_top1 and pred_top1 == actual_top1:
|
||||
top1_hit += 1
|
||||
if len(set(data["predicted"][:3]) & actual) >= 1:
|
||||
top3_hit += 1
|
||||
|
||||
if total > 0:
|
||||
stats[source]["total"] = total
|
||||
stats[source]["total"] = total
|
||||
stats[source]["top1_pct"] = round(top1_hit / total * 100, 1)
|
||||
stats[source]["top3_pct"] = round(top3_hit / total * 100, 1)
|
||||
except Exception as e:
|
||||
@@ -3596,219 +3585,5 @@ def api_predictions_analysis():
|
||||
return jsonify({"stats": stats, "period": {"start": start_date, "end": end_date}})
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# /api/v1/predictions — Ensemble model endpoint (Sprint 6-7 ML Upgrade)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
_predict_v2 = None
|
||||
|
||||
|
||||
def _load_predict_v2():
|
||||
"""Lazy import of predict_v2 module (ensemble model)."""
|
||||
global _predict_v2
|
||||
if _predict_v2 is None:
|
||||
try:
|
||||
import importlib.util, sys
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"predict_v2", "/home/h3r7/turf_saas/predict_v2.py"
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
_predict_v2 = mod
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.error(f"[v1/predictions] predict_v2 import failed: {e}")
|
||||
return _predict_v2
|
||||
|
||||
|
||||
@app.route("/api/v1/predictions", methods=["GET"])
|
||||
@app.route("/turf/api/v1/predictions", methods=["GET"])
|
||||
def api_v1_predictions():
|
||||
"""
|
||||
Ensemble ML predictions using XGBoost + LightGBM + MLP (Optuna-tuned).
|
||||
Query params:
|
||||
- date: YYYY-MM-DD (default: today / latest available)
|
||||
- reunion: int (default: all)
|
||||
- course: int (default: all)
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
t0 = _time.perf_counter()
|
||||
|
||||
mod = _load_predict_v2()
|
||||
if mod is None:
|
||||
# Graceful fallback: redirect to legacy ml_predictions
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Ensemble model not available yet",
|
||||
"fallback": "/api/ml_predictions",
|
||||
"message": "Model is still training. Use /api/ml_predictions for legacy XGBoost predictions.",
|
||||
}
|
||||
), 503
|
||||
|
||||
ensemble = mod.load_ensemble()
|
||||
if ensemble is None:
|
||||
return jsonify(
|
||||
{
|
||||
"error": "Ensemble model file not found",
|
||||
"model_path": str(mod.ENSEMBLE_PATH),
|
||||
"message": "Run train_ensemble.py to generate the model.",
|
||||
"fallback": "/api/ml_predictions",
|
||||
}
|
||||
), 503
|
||||
|
||||
date_param = request.args.get("date", None)
|
||||
reunion_param = request.args.get("reunion", None)
|
||||
course_param = request.args.get("course", None)
|
||||
|
||||
conn = sqlite3.connect("/home/h3r7/turf_saas/turf.db")
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Determine date to use
|
||||
if date_param:
|
||||
date_used = date_param
|
||||
else:
|
||||
row = conn.execute(
|
||||
"SELECT MAX(date_programme) as d FROM pmu_partants"
|
||||
).fetchone()
|
||||
date_used = (
|
||||
row["d"] if row and row["d"] else datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
# Build query
|
||||
where_clauses = ["p.date_programme = ?"]
|
||||
params = [date_used]
|
||||
if reunion_param:
|
||||
where_clauses.append("p.num_reunion = ?")
|
||||
params.append(int(reunion_param))
|
||||
if course_param:
|
||||
where_clauses.append("p.num_course = ?")
|
||||
params.append(int(course_param))
|
||||
|
||||
query = f"""
|
||||
SELECT p.*, c.distance, c.discipline, c.specialite,
|
||||
c.nb_declares_partants, c.montant_prix, c.penetrometre_intitule,
|
||||
c.libelle as course_libelle, c.libelle_court as hippodrome,
|
||||
c.heure_depart_str, c.parcours
|
||||
FROM pmu_partants p
|
||||
LEFT JOIN pmu_courses c ON p.date_programme = c.date_programme
|
||||
AND p.num_reunion = c.num_reunion AND p.num_course = c.num_course
|
||||
WHERE {" AND ".join(where_clauses)}
|
||||
ORDER BY p.num_reunion, p.num_course, p.num_pmu
|
||||
"""
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return jsonify(
|
||||
{
|
||||
"date": date_used,
|
||||
"model_version": mod.get_model_version(),
|
||||
"predictions": [],
|
||||
"message": f"No partants found for date {date_used}",
|
||||
}
|
||||
)
|
||||
|
||||
# Convert to list of dicts
|
||||
partants = [dict(r) for r in rows]
|
||||
|
||||
# Run ensemble prediction
|
||||
preds = mod.predict_top3(partants, model=ensemble)
|
||||
|
||||
# Group by race
|
||||
races = {}
|
||||
for pred in preds:
|
||||
key = f"R{pred.get('num_reunion', 0)}C{pred.get('num_course', 0)}"
|
||||
if key not in races:
|
||||
# Find race metadata from partants
|
||||
for p in partants:
|
||||
if p.get("num_reunion") == pred.get("num_reunion") and p.get(
|
||||
"num_course"
|
||||
) == pred.get("num_course"):
|
||||
races[key] = {
|
||||
"reunion": pred.get("num_reunion"),
|
||||
"course": pred.get("num_course"),
|
||||
"label": key,
|
||||
"race_name": p.get("course_libelle", ""),
|
||||
"hippodrome": p.get("hippodrome", ""),
|
||||
"heure": p.get("heure_depart_str", ""),
|
||||
"discipline": p.get("discipline", ""),
|
||||
"distance": p.get("distance", 0),
|
||||
"horses": [],
|
||||
}
|
||||
break
|
||||
if key in races:
|
||||
races[key]["horses"].append(pred)
|
||||
|
||||
latency_ms = (_time.perf_counter() - t0) * 1000
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"date": date_used,
|
||||
"model_version": mod.get_model_version(),
|
||||
"latency_ms": round(latency_ms, 1),
|
||||
"total_horses": len(preds),
|
||||
"races": list(races.values()),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.route("/api/v1/model/invalidate-cache", methods=["POST"])
|
||||
@app.route("/turf/api/v1/model/invalidate-cache", methods=["POST"])
|
||||
def api_v1_invalidate_cache():
|
||||
"""Force reload of ensemble model on next prediction call."""
|
||||
mod = _load_predict_v2()
|
||||
if mod:
|
||||
mod.invalidate_model_cache()
|
||||
return jsonify({"status": "ok", "message": "Model cache invalidated"})
|
||||
return jsonify({"status": "error", "message": "predict_v2 module not loaded"}), 500
|
||||
|
||||
|
||||
@app.route("/api/v1/model/status", methods=["GET"])
|
||||
@app.route("/turf/api/v1/model/status", methods=["GET"])
|
||||
def api_v1_model_status():
|
||||
"""Return ensemble model status and version."""
|
||||
import os as _os
|
||||
from pathlib import Path as _Path
|
||||
|
||||
ensemble_path = _Path("/home/h3r7/turf_saas/models/ensemble_top3.pkl")
|
||||
benchmark_path = _Path("/home/h3r7/turf_saas/models/benchmark_report.json")
|
||||
|
||||
status = {
|
||||
"ensemble_available": ensemble_path.exists(),
|
||||
"ensemble_path": str(ensemble_path),
|
||||
}
|
||||
if ensemble_path.exists():
|
||||
mtime = _os.path.getmtime(str(ensemble_path))
|
||||
status["last_trained"] = datetime.fromtimestamp(mtime).isoformat()
|
||||
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
with open(benchmark_path) as f:
|
||||
import json as _json
|
||||
|
||||
report = _json.load(f)
|
||||
status["benchmark"] = {
|
||||
"baseline_precision_at3": report.get("baseline", {}).get(
|
||||
"precision_at3"
|
||||
),
|
||||
"ensemble_precision_at3": report.get("ensemble", {}).get(
|
||||
"precision_at3"
|
||||
),
|
||||
"delta": report.get("delta_precision_at3"),
|
||||
"deployed": report.get("deploy"),
|
||||
"run_date": report.get("run_date"),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mod = _load_predict_v2()
|
||||
if mod and ensemble_path.exists():
|
||||
status["model_version"] = mod.get_model_version()
|
||||
|
||||
return jsonify(status)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8790, debug=False)
|
||||
|
||||
90
middleware.py
Normal file
90
middleware.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Middleware — rate limiting, CORS, and access logging
|
||||
Sprint 2-3: HRT-28
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from threading import Lock
|
||||
|
||||
from flask import request, jsonify, g
|
||||
|
||||
logger = logging.getLogger("turf_saas.middleware")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# In-memory rate limiter (100 req/min per IP)
|
||||
# For production: replace with Redis-backed counter
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
_rate_store: dict = defaultdict(lambda: {"count": 0, "window_start": 0.0})
|
||||
_rate_lock = Lock()
|
||||
|
||||
RATE_LIMIT = 100 # max requests
|
||||
RATE_WINDOW = 60 # seconds
|
||||
|
||||
|
||||
def rate_limit_middleware(app):
|
||||
"""Register before_request rate limiting on the Flask app."""
|
||||
|
||||
@app.before_request
|
||||
def check_rate_limit():
|
||||
ip = request.remote_addr or "unknown"
|
||||
now = time.time()
|
||||
|
||||
with _rate_lock:
|
||||
bucket = _rate_store[ip]
|
||||
if now - bucket["window_start"] >= RATE_WINDOW:
|
||||
bucket["count"] = 0
|
||||
bucket["window_start"] = now
|
||||
bucket["count"] += 1
|
||||
count = bucket["count"]
|
||||
remaining = max(0, RATE_LIMIT - count)
|
||||
|
||||
if count > RATE_LIMIT:
|
||||
logger.warning("Rate limit exceeded for IP %s", ip)
|
||||
resp = jsonify({"error": "Trop de requêtes. Limite: 100/min par IP."})
|
||||
resp.status_code = 429
|
||||
resp.headers["X-RateLimit-Limit"] = str(RATE_LIMIT)
|
||||
resp.headers["X-RateLimit-Remaining"] = "0"
|
||||
resp.headers["Retry-After"] = str(RATE_WINDOW)
|
||||
return resp
|
||||
|
||||
# Attach headers on all responses via after_request
|
||||
g.rl_remaining = remaining
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Access logs (timestamped)
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
access_log = logging.getLogger("turf_saas.access")
|
||||
|
||||
|
||||
def access_log_middleware(app):
|
||||
"""Register after_request access logging on the Flask app."""
|
||||
|
||||
@app.after_request
|
||||
def log_access(response):
|
||||
ip = request.remote_addr or "unknown"
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
user_id = getattr(g, "current_user_id", "-")
|
||||
access_log.info(
|
||||
'%s %s %s "%s %s" %s %s',
|
||||
ts,
|
||||
ip,
|
||||
user_id,
|
||||
request.method,
|
||||
request.path,
|
||||
response.status_code,
|
||||
response.content_length or 0,
|
||||
)
|
||||
# Attach rate-limit headers
|
||||
remaining = getattr(g, "rl_remaining", None)
|
||||
if remaining is not None:
|
||||
response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
@@ -1,174 +0,0 @@
|
||||
{
|
||||
"run_date": "2026-04-25T19:09:46.629142",
|
||||
"dataset": {
|
||||
"db_path": "/home/h3r7/turf_saas/turf.db",
|
||||
"total_rows": 10899,
|
||||
"train_rows": 8719,
|
||||
"holdout_rows": 2180,
|
||||
"train_date_range": [
|
||||
"2026-03-31",
|
||||
"2026-04-19"
|
||||
],
|
||||
"holdout_date_range": [
|
||||
"2026-04-19",
|
||||
"2026-04-24"
|
||||
]
|
||||
},
|
||||
"baseline": {
|
||||
"model": "XGBoost (baseline)",
|
||||
"precision_at3": 0.5286821705426358,
|
||||
"auc": 0.7254057665061495
|
||||
},
|
||||
"individual_models": {
|
||||
"xgboost": {
|
||||
"model": "xgboost",
|
||||
"auc": 0.7856,
|
||||
"accuracy": 0.6917,
|
||||
"precision": 0.4865,
|
||||
"recall": 0.7229,
|
||||
"precision_at3": 0.5783,
|
||||
"latency_ms_per_row": 0.0112
|
||||
},
|
||||
"lightgbm": {
|
||||
"model": "lightgbm",
|
||||
"auc": 0.7833,
|
||||
"accuracy": 0.6995,
|
||||
"precision": 0.4951,
|
||||
"recall": 0.709,
|
||||
"precision_at3": 0.5736,
|
||||
"latency_ms_per_row": 0.0041
|
||||
},
|
||||
"mlp": {
|
||||
"model": "mlp",
|
||||
"auc": 0.7743,
|
||||
"accuracy": 0.7445,
|
||||
"precision": 0.5743,
|
||||
"recall": 0.5325,
|
||||
"precision_at3": 0.5643,
|
||||
"latency_ms_per_row": 0.0052
|
||||
}
|
||||
},
|
||||
"ensemble": {
|
||||
"model": "ensemble",
|
||||
"auc": 0.784,
|
||||
"accuracy": 0.7147,
|
||||
"precision": 0.5142,
|
||||
"recall": 0.6718,
|
||||
"precision_at3": 0.5814,
|
||||
"latency_ms_per_row": 0.0208
|
||||
},
|
||||
"delta_precision_at3": 0.0527,
|
||||
"deploy": true,
|
||||
"optuna": {
|
||||
"n_trials": 100,
|
||||
"xgboost_best_params": {
|
||||
"n_estimators": 141,
|
||||
"max_depth": 5,
|
||||
"learning_rate": 0.016298172447266404,
|
||||
"subsample": 0.7660470794373848,
|
||||
"colsample_bytree": 0.471124415020467,
|
||||
"min_child_weight": 14,
|
||||
"reg_alpha": 1.9364166463791586,
|
||||
"reg_lambda": 6.018030083488602,
|
||||
"gamma": 4.614943551368141
|
||||
},
|
||||
"lightgbm_best_params": {
|
||||
"n_estimators": 186,
|
||||
"max_depth": 4,
|
||||
"learning_rate": 0.012915117465216954,
|
||||
"num_leaves": 141,
|
||||
"subsample": 0.6193119116922561,
|
||||
"colsample_bytree": 0.539310022549326,
|
||||
"min_child_samples": 9,
|
||||
"reg_alpha": 0.6864583098112754,
|
||||
"reg_lambda": 0.0549259590914184
|
||||
}
|
||||
},
|
||||
"features": {
|
||||
"total": 43,
|
||||
"selected_by_shap": 31,
|
||||
"feature_list": [
|
||||
"age",
|
||||
"sexe_enc",
|
||||
"nombre_courses",
|
||||
"nombre_victoires",
|
||||
"nombre_places",
|
||||
"tx_victoire",
|
||||
"tx_place",
|
||||
"forme_recente",
|
||||
"tendance_num",
|
||||
"gains_annee_en_cours",
|
||||
"cote_direct",
|
||||
"cote_reference",
|
||||
"distance",
|
||||
"nb_partants",
|
||||
"discipline_enc",
|
||||
"specialite_enc",
|
||||
"oeilleres_enc",
|
||||
"tendance_cote_enc",
|
||||
"penetrometre_intitule_enc",
|
||||
"form_1",
|
||||
"form_2",
|
||||
"form_3",
|
||||
"form_4",
|
||||
"form_5",
|
||||
"form_weighted",
|
||||
"form_avg",
|
||||
"form_best",
|
||||
"form_worst",
|
||||
"win_ratio",
|
||||
"place_ratio",
|
||||
"implied_prob",
|
||||
"win_rate_adj",
|
||||
"place_rate_adj",
|
||||
"earnings_per_race",
|
||||
"cote_diff",
|
||||
"cote_ratio",
|
||||
"rang_cote",
|
||||
"ratio_cote_field",
|
||||
"distance_cat",
|
||||
"age_win_interact",
|
||||
"is_favorite",
|
||||
"poids",
|
||||
"prize_norm"
|
||||
],
|
||||
"shap_selected": [
|
||||
"rang_cote",
|
||||
"implied_prob",
|
||||
"cote_direct",
|
||||
"ratio_cote_field",
|
||||
"nb_partants",
|
||||
"cote_diff",
|
||||
"cote_ratio",
|
||||
"specialite_enc",
|
||||
"earnings_per_race",
|
||||
"nombre_courses",
|
||||
"cote_reference",
|
||||
"distance",
|
||||
"discipline_enc",
|
||||
"is_favorite",
|
||||
"prize_norm",
|
||||
"win_ratio",
|
||||
"place_rate_adj",
|
||||
"gains_annee_en_cours",
|
||||
"poids",
|
||||
"tx_place",
|
||||
"penetrometre_intitule_enc",
|
||||
"age_win_interact",
|
||||
"nombre_places",
|
||||
"tendance_num",
|
||||
"age",
|
||||
"form_avg",
|
||||
"form_weighted",
|
||||
"place_ratio",
|
||||
"form_3",
|
||||
"oeilleres_enc",
|
||||
"form_5"
|
||||
]
|
||||
},
|
||||
"ensemble_weights": {
|
||||
"xgboost": 0.23161801824035544,
|
||||
"lightgbm": 0.23415467282905,
|
||||
"mlp": 0.21290370528252356
|
||||
}
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
# Benchmark ML Ensemble — Turf Prédictions
|
||||
|
||||
**Date:** 2026-04-25
|
||||
**Dataset:** 10,899 partants
|
||||
**Holdout:** 2,180 lignes (2026-04-19 → 2026-04-24)
|
||||
|
||||
## Résultats
|
||||
|
||||
| Modèle | Precision@3 | AUC | Latence/prédiction |
|
||||
|--------|-------------|-----|-------------------|
|
||||
| XGBoost (baseline) | 0.5287 | 0.7254 | — |
|
||||
| xgboost | 0.5783 | 0.7856 | 0.01 ms |
|
||||
| lightgbm | 0.5736 | 0.7833 | 0.00 ms |
|
||||
| mlp | 0.5643 | 0.7743 | 0.01 ms |
|
||||
| **Ensemble** | **0.5814** | **0.7840** | **0.02 ms** |
|
||||
|
||||
## Décision de déploiement
|
||||
|
||||
- Delta Precision@3 : **+0.0527** (+5.3%)
|
||||
- Seuil requis : **+5%**
|
||||
- Résultat : **✅ DEPLOIEMENT RECOMMANDE**
|
||||
|
||||
## Optimisation Optuna
|
||||
|
||||
- Trials XGBoost : 100
|
||||
- Trials LightGBM : 100
|
||||
- Pruning : MedianPruner
|
||||
|
||||
### Meilleurs hyperparamètres XGBoost
|
||||
```json
|
||||
{
|
||||
"n_estimators": 141,
|
||||
"max_depth": 5,
|
||||
"learning_rate": 0.016298172447266404,
|
||||
"subsample": 0.7660470794373848,
|
||||
"colsample_bytree": 0.471124415020467,
|
||||
"min_child_weight": 14,
|
||||
"reg_alpha": 1.9364166463791586,
|
||||
"reg_lambda": 6.018030083488602,
|
||||
"gamma": 4.614943551368141
|
||||
}
|
||||
```
|
||||
|
||||
### Meilleurs hyperparamètres LightGBM
|
||||
```json
|
||||
{
|
||||
"n_estimators": 186,
|
||||
"max_depth": 4,
|
||||
"learning_rate": 0.012915117465216954,
|
||||
"num_leaves": 141,
|
||||
"subsample": 0.6193119116922561,
|
||||
"colsample_bytree": 0.539310022549326,
|
||||
"min_child_samples": 9,
|
||||
"reg_alpha": 0.6864583098112754,
|
||||
"reg_lambda": 0.0549259590914184
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- Total features : 43
|
||||
- Retenues par SHAP : 31
|
||||
|
||||
## Poids de l'ensemble
|
||||
|
||||
- xgboost : 0.2316
|
||||
- lightgbm : 0.2342
|
||||
- mlp : 0.2129
|
||||
387
predict_v2.py
387
predict_v2.py
@@ -1,387 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Ensemble prediction module for /api/v1/predictions.
|
||||
|
||||
Loads the trained ensemble model and provides a high-level predict_top3()
|
||||
function compatible with the existing combined_api.py interface.
|
||||
|
||||
Cache: model is loaded once at import time (or on first call).
|
||||
Invalidation: reload if models/ensemble_top3.pkl mtime changes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_DIR = Path("/home/h3r7/turf_saas/models")
|
||||
ENSEMBLE_PATH = MODELS_DIR / "ensemble_top3.pkl"
|
||||
|
||||
# ── Cache ─────────────────────────────────────────────────────────────────────
|
||||
_model_cache = {
|
||||
"ensemble": None,
|
||||
"mtime": None,
|
||||
"lock": threading.Lock(),
|
||||
}
|
||||
|
||||
# ── Feature list (must match train_ensemble.py FEATURE_COLS) ─────────────────
|
||||
FEATURE_COLS = [
|
||||
"age",
|
||||
"sexe_enc",
|
||||
"nombre_courses",
|
||||
"nombre_victoires",
|
||||
"nombre_places",
|
||||
"tx_victoire",
|
||||
"tx_place",
|
||||
"forme_recente",
|
||||
"tendance_num",
|
||||
"gains_annee_en_cours",
|
||||
"cote_direct",
|
||||
"cote_reference",
|
||||
"distance",
|
||||
"nb_partants",
|
||||
"discipline_enc",
|
||||
"specialite_enc",
|
||||
"oeilleres_enc",
|
||||
"tendance_cote_enc",
|
||||
"penetrometre_intitule_enc",
|
||||
"form_1",
|
||||
"form_2",
|
||||
"form_3",
|
||||
"form_4",
|
||||
"form_5",
|
||||
"form_weighted",
|
||||
"form_avg",
|
||||
"form_best",
|
||||
"form_worst",
|
||||
"win_ratio",
|
||||
"place_ratio",
|
||||
"implied_prob",
|
||||
"win_rate_adj",
|
||||
"place_rate_adj",
|
||||
"earnings_per_race",
|
||||
"cote_diff",
|
||||
"cote_ratio",
|
||||
"rang_cote",
|
||||
"ratio_cote_field",
|
||||
"distance_cat",
|
||||
"age_win_interact",
|
||||
"is_favorite",
|
||||
"poids",
|
||||
"prize_norm",
|
||||
]
|
||||
|
||||
|
||||
# ── Encoders (built per-prediction batch for live data) ──────────────────────
|
||||
def _fit_encoder(values, default):
|
||||
le = LabelEncoder()
|
||||
unique = list(set(str(v) if v else default for v in values)) + [default]
|
||||
le.fit(unique)
|
||||
return le
|
||||
|
||||
|
||||
def _safe_transform(le: LabelEncoder, value, default: str):
|
||||
v = str(value) if value else default
|
||||
if v not in le.classes_:
|
||||
v = default
|
||||
return int(le.transform([v])[0])
|
||||
|
||||
|
||||
# ── Model loading with auto-invalidation ─────────────────────────────────────
|
||||
def load_ensemble(force: bool = False) -> Optional[object]:
|
||||
"""Load ensemble model, reload if file changed."""
|
||||
with _model_cache["lock"]:
|
||||
if not ENSEMBLE_PATH.exists():
|
||||
return None
|
||||
mtime = ENSEMBLE_PATH.stat().st_mtime
|
||||
if force or _model_cache["ensemble"] is None or mtime != _model_cache["mtime"]:
|
||||
try:
|
||||
with open(ENSEMBLE_PATH, "rb") as f:
|
||||
_model_cache["ensemble"] = pickle.load(f)
|
||||
_model_cache["mtime"] = mtime
|
||||
logger.info(f"[predict_v2] Loaded ensemble model from {ENSEMBLE_PATH}")
|
||||
except Exception as e:
|
||||
logger.error(f"[predict_v2] Failed to load ensemble: {e}")
|
||||
return None
|
||||
return _model_cache["ensemble"]
|
||||
|
||||
|
||||
def invalidate_model_cache():
|
||||
"""Force reload on next prediction call."""
|
||||
with _model_cache["lock"]:
|
||||
_model_cache["mtime"] = None
|
||||
|
||||
|
||||
# ── Feature engineering for live pmu_partants rows ───────────────────────────
|
||||
def _parse_musique(musique) -> list:
|
||||
if not musique or pd.isna(str(musique)):
|
||||
return [0, 0, 0, 0, 0]
|
||||
try:
|
||||
clean = re.sub(r"\(\d+\)", "", str(musique))
|
||||
numbers = re.findall(r"\d+", clean)
|
||||
result = [int(n) for n in numbers[:5]]
|
||||
result += [0] * (5 - len(result))
|
||||
return result[:5]
|
||||
except Exception:
|
||||
return [0, 0, 0, 0, 0]
|
||||
|
||||
|
||||
def build_feature_df(partants: list) -> pd.DataFrame:
|
||||
"""
|
||||
Convert a list of pmu_partants dicts to a feature DataFrame.
|
||||
|
||||
Expected keys (same as pmu_partants columns):
|
||||
date_programme, num_reunion, num_course, num_pmu,
|
||||
age, sexe, musique, nombre_courses, nombre_victoires, nombre_places,
|
||||
gains_annee_en_cours, handicap_poids, oeilleres, cote_direct,
|
||||
cote_reference, tendance_cote, favoris, tx_victoire, tx_place,
|
||||
forme_recente, tendance_forme, indicateur_inedit,
|
||||
distance, discipline, specialite, nb_declares_partants,
|
||||
montant_prix, penetrometre_intitule
|
||||
"""
|
||||
if not partants:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(partants)
|
||||
|
||||
# ── Categorical encoders fitted on this batch ─────────────────────────────
|
||||
le_sexe = _fit_encoder(df.get("sexe", ["U"]), "U")
|
||||
le_oeilleres = _fit_encoder(df.get("oeilleres", ["SANS"]), "SANS")
|
||||
le_discipline = _fit_encoder(df.get("discipline", ["UNKNOWN"]), "UNKNOWN")
|
||||
le_specialite = _fit_encoder(df.get("specialite", ["UNKNOWN"]), "UNKNOWN")
|
||||
le_tendance = _fit_encoder(df.get("tendance_cote", ["STABLE"]), "STABLE")
|
||||
le_penet = _fit_encoder(df.get("penetrometre_intitule", ["BON"]), "BON")
|
||||
|
||||
df["sexe_enc"] = df["sexe"].apply(lambda v: _safe_transform(le_sexe, v, "U"))
|
||||
df["oeilleres_enc"] = df["oeilleres"].apply(
|
||||
lambda v: _safe_transform(le_oeilleres, v, "SANS")
|
||||
)
|
||||
df["discipline_enc"] = df.get("discipline", pd.Series(["UNKNOWN"] * len(df))).apply(
|
||||
lambda v: _safe_transform(le_discipline, v, "UNKNOWN")
|
||||
)
|
||||
df["specialite_enc"] = df.get("specialite", pd.Series(["UNKNOWN"] * len(df))).apply(
|
||||
lambda v: _safe_transform(le_specialite, v, "UNKNOWN")
|
||||
)
|
||||
df["tendance_cote_enc"] = df.get(
|
||||
"tendance_cote", pd.Series(["STABLE"] * len(df))
|
||||
).apply(lambda v: _safe_transform(le_tendance, v, "STABLE"))
|
||||
df["penetrometre_intitule_enc"] = df.get(
|
||||
"penetrometre_intitule", pd.Series(["BON"] * len(df))
|
||||
).apply(lambda v: _safe_transform(le_penet, v, "BON"))
|
||||
|
||||
# ── Musique ────────────────────────────────────────────────────────────────
|
||||
music_parsed = df["musique"].apply(_parse_musique)
|
||||
for i in range(5):
|
||||
df[f"form_{i + 1}"] = music_parsed.apply(lambda x: x[i])
|
||||
weights = np.array([0.4, 0.25, 0.15, 0.12, 0.08])
|
||||
df["form_weighted"] = music_parsed.apply(
|
||||
lambda x: sum(w * v for w, v in zip(weights, x))
|
||||
)
|
||||
df["form_avg"] = music_parsed.apply(np.mean)
|
||||
df["form_best"] = music_parsed.apply(min)
|
||||
df["form_worst"] = music_parsed.apply(max)
|
||||
|
||||
# ── Numeric features ───────────────────────────────────────────────────────
|
||||
for col in [
|
||||
"nombre_courses",
|
||||
"nombre_victoires",
|
||||
"nombre_places",
|
||||
"tx_victoire",
|
||||
"tx_place",
|
||||
"forme_recente",
|
||||
"tendance_forme",
|
||||
"gains_annee_en_cours",
|
||||
"cote_direct",
|
||||
"cote_reference",
|
||||
"distance",
|
||||
"handicap_poids",
|
||||
"age",
|
||||
"montant_prix",
|
||||
"nb_declares_partants",
|
||||
]:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)
|
||||
|
||||
df["tendance_num"] = df["tendance_forme"].fillna(0)
|
||||
df["win_ratio"] = df["nombre_victoires"] / df["nombre_courses"].replace(0, 1)
|
||||
df["place_ratio"] = df["nombre_places"] / df["nombre_courses"].replace(0, 1)
|
||||
df["implied_prob"] = 1.0 / df["cote_direct"].replace(0, np.nan)
|
||||
df["win_rate_adj"] = df["tx_victoire"] * np.log1p(df["nombre_courses"])
|
||||
df["place_rate_adj"] = df["tx_place"] * np.log1p(df["nombre_courses"])
|
||||
df["earnings_per_race"] = df["gains_annee_en_cours"] / df["nombre_courses"].replace(
|
||||
0, 1
|
||||
)
|
||||
df["cote_diff"] = (df["cote_direct"] - df["cote_reference"]).fillna(0)
|
||||
df["cote_ratio"] = (
|
||||
df["cote_direct"] / df["cote_reference"].replace(0, np.nan)
|
||||
).fillna(1)
|
||||
|
||||
# ── Per-race rank features ─────────────────────────────────────────────────
|
||||
if "num_reunion" in df.columns and "num_course" in df.columns:
|
||||
grp = ["date_programme", "num_reunion", "num_course"]
|
||||
# Some fields may be missing
|
||||
for g in grp:
|
||||
if g not in df.columns:
|
||||
df[g] = 0
|
||||
df["rang_cote"] = df.groupby(grp)["cote_direct"].rank(
|
||||
method="min", na_option="bottom"
|
||||
)
|
||||
race_mean = df.groupby(grp)["cote_direct"].transform("mean")
|
||||
df["ratio_cote_field"] = df["cote_direct"] / race_mean.replace(0, np.nan)
|
||||
df["nb_partants"] = df.groupby(grp)["cote_direct"].transform("count")
|
||||
else:
|
||||
df["rang_cote"] = 1.0
|
||||
df["ratio_cote_field"] = 1.0
|
||||
df["nb_partants"] = df.get("nb_declares_partants", pd.Series([10] * len(df)))
|
||||
|
||||
df["distance_cat"] = pd.cut(
|
||||
df["distance"].fillna(1600),
|
||||
bins=[0, 1400, 1800, 2200, 2600, 10000],
|
||||
labels=[1, 2, 3, 4, 5],
|
||||
).astype(float)
|
||||
df["age_win_interact"] = df["age"] * df["tx_victoire"]
|
||||
df["is_favorite"] = (
|
||||
df.get("favoris", pd.Series([0] * len(df))).fillna(0).astype(int)
|
||||
)
|
||||
df["poids"] = df["handicap_poids"].fillna(60)
|
||||
df["prize_norm"] = np.log1p(df["montant_prix"].fillna(0))
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# ── Main prediction function ───────────────────────────────────────────────────
|
||||
def predict_top3(partants: list, model=None) -> list:
|
||||
"""
|
||||
Given a list of partant dicts (from pmu_partants), return predictions.
|
||||
|
||||
Returns list of {horse_name, num_pmu, prob_top3, prob_top1_approx, ...}
|
||||
sorted by prob_top3 descending.
|
||||
|
||||
Falls back to empty list if model not available.
|
||||
"""
|
||||
t_start = time.perf_counter()
|
||||
|
||||
if model is None:
|
||||
model = load_ensemble()
|
||||
if model is None:
|
||||
logger.warning("[predict_v2] Ensemble model not available — no predictions")
|
||||
return []
|
||||
|
||||
df = build_feature_df(partants)
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
available = [c for c in FEATURE_COLS if c in df.columns]
|
||||
X = df[available].fillna(0)
|
||||
|
||||
try:
|
||||
proba = model.predict_proba(X)[:, 1]
|
||||
except Exception as e:
|
||||
logger.error(f"[predict_v2] predict_proba failed: {e}")
|
||||
return []
|
||||
|
||||
latency_ms = (time.perf_counter() - t_start) * 1000
|
||||
|
||||
results = []
|
||||
for i, (p, row) in enumerate(zip(proba, partants)):
|
||||
results.append(
|
||||
{
|
||||
"horse_name": row.get("nom", row.get("horse_name", f"H{i}")),
|
||||
"num_pmu": row.get("num_pmu", i + 1),
|
||||
"num_reunion": row.get("num_reunion"),
|
||||
"num_course": row.get("num_course"),
|
||||
"prob_top3": round(float(p) * 100, 1),
|
||||
# approx top1 from top3 score (divide by ~2.5 empirically)
|
||||
"prob_top1": round(float(p) / 2.5 * 100, 1),
|
||||
"ml_score": round(float(p) * 100, 1),
|
||||
"recommendation": "top3"
|
||||
if p >= 0.40
|
||||
else ("watch" if p >= 0.28 else "pass"),
|
||||
"is_value_bet": int(
|
||||
p >= 0.35 and float(row.get("cote_direct", 0) or 0) > 10
|
||||
),
|
||||
"model_version": getattr(model, "version", "ensemble_v1"),
|
||||
}
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x["prob_top3"], reverse=True)
|
||||
|
||||
# Mark top-3 predicted
|
||||
for i, r in enumerate(results[:3]):
|
||||
r["predicted_rank"] = i + 1
|
||||
|
||||
if results:
|
||||
logger.info(
|
||||
f"[predict_v2] {len(results)} horses predicted in {latency_ms:.1f} ms "
|
||||
f"({latency_ms / len(results):.2f} ms/horse)"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── API-compatible wrapper keeping model_version & structure ──────────────────
|
||||
def get_model_version() -> str:
|
||||
m = load_ensemble()
|
||||
if m is None:
|
||||
return "ensemble_v1_not_loaded"
|
||||
return getattr(m, "version", "ensemble_v1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick self-test
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect("/home/h3r7/turf_saas/turf.db")
|
||||
rows = conn.execute(
|
||||
"""SELECT p.*, c.distance, c.discipline, c.specialite,
|
||||
c.nb_declares_partants, c.montant_prix, c.penetrometre_intitule
|
||||
FROM pmu_partants p
|
||||
LEFT JOIN pmu_courses c ON p.date_programme=c.date_programme
|
||||
AND p.num_reunion=c.num_reunion AND p.num_course=c.num_course
|
||||
WHERE p.date_programme=(SELECT MAX(date_programme) FROM pmu_partants)
|
||||
AND p.num_reunion=1 AND p.num_course=1
|
||||
LIMIT 20"""
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
print("No data found for self-test")
|
||||
else:
|
||||
cols = [d[0] for d in conn.description] if hasattr(conn, "description") else []
|
||||
# Fallback column list
|
||||
import sqlite3 as sq3
|
||||
|
||||
conn2 = sq3.connect("/home/h3r7/turf_saas/turf.db")
|
||||
cur = conn2.execute(
|
||||
"""SELECT p.*, c.distance, c.discipline, c.specialite,
|
||||
c.nb_declares_partants, c.montant_prix, c.penetrometre_intitule
|
||||
FROM pmu_partants p
|
||||
LEFT JOIN pmu_courses c ON p.date_programme=c.date_programme
|
||||
AND p.num_reunion=c.num_reunion AND p.num_course=c.num_course
|
||||
WHERE p.date_programme=(SELECT MAX(date_programme) FROM pmu_partants)
|
||||
AND p.num_reunion=1 AND p.num_course=1
|
||||
LIMIT 20"""
|
||||
)
|
||||
cols = [d[0] for d in cur.description]
|
||||
rows2 = cur.fetchall()
|
||||
conn2.close()
|
||||
|
||||
partants = [dict(zip(cols, row)) for row in rows2]
|
||||
preds = predict_top3(partants)
|
||||
print(f"Self-test: {len(preds)} predictions")
|
||||
for p in preds[:5]:
|
||||
print(
|
||||
f" {p['horse_name']:20s} prob_top3={p['prob_top3']}% rec={p['recommendation']}"
|
||||
)
|
||||
12
pytest.ini
12
pytest.ini
@@ -1,12 +0,0 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = --tb=short -v
|
||||
markers =
|
||||
e2e: Tests End-to-End Playwright
|
||||
load: Tests de charge Locust
|
||||
security: Tests de sécurité
|
||||
smoke: Tests rapides de smoke
|
||||
@@ -1,182 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Rebuild ensemble using known best Optuna params (from completed study).
|
||||
Skips the 100-trial Optuna search and goes straight to training + pickling.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/home/h3r7/turf_saas')
|
||||
|
||||
from train_ensemble import (
|
||||
load_data, engineer_features, temporal_split, get_features_and_target,
|
||||
evaluate_baseline, train_xgboost, train_lightgbm, train_mlp,
|
||||
shap_feature_selection, compute_ensemble_weights,
|
||||
evaluate_model, compute_precision_at3, TurfEnsemble,
|
||||
MODELS_DIR, DEPLOY_THRESHOLD, _write_markdown_report
|
||||
)
|
||||
import json, pickle, numpy as np
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = '/home/h3r7/turf_saas/turf.db'
|
||||
|
||||
# Best params from the 100-trial Optuna run
|
||||
XGB_BEST = {
|
||||
'n_estimators': 141, 'max_depth': 5,
|
||||
'learning_rate': 0.016298172447266404,
|
||||
'subsample': 0.7660470794373848,
|
||||
'colsample_bytree': 0.471124415020467,
|
||||
'min_child_weight': 14,
|
||||
'reg_alpha': 1.9364166463791586,
|
||||
'reg_lambda': 6.018030083488602,
|
||||
'gamma': 4.614943551368141,
|
||||
}
|
||||
LGB_BEST = {
|
||||
'n_estimators': 186, 'max_depth': 4,
|
||||
'learning_rate': 0.012915117465216954,
|
||||
'num_leaves': 141,
|
||||
'subsample': 0.6193119116922561,
|
||||
'colsample_bytree': 0.539310022549326,
|
||||
'min_child_samples': 9,
|
||||
'reg_alpha': 0.6864583098112754,
|
||||
'reg_lambda': 0.0549259590914184,
|
||||
}
|
||||
|
||||
print("=" * 65)
|
||||
print("TURF ENSEMBLE REBUILD (using pre-computed Optuna params)")
|
||||
print("=" * 65)
|
||||
|
||||
print("\n[1/7] Loading data...")
|
||||
df = load_data(DB_PATH)
|
||||
df = engineer_features(df)
|
||||
|
||||
print("\n[2/7] Temporal split...")
|
||||
train_df, holdout_df = temporal_split(df)
|
||||
X_train, y_train, feat_cols = get_features_and_target(train_df)
|
||||
X_holdout, y_holdout, _ = get_features_and_target(holdout_df)
|
||||
|
||||
n = len(X_train); n_val = int(n * 0.15)
|
||||
X_tr = X_train.iloc[:n-n_val]; y_tr = y_train.iloc[:n-n_val]
|
||||
X_val = X_train.iloc[n-n_val:]; y_val = y_train.iloc[n-n_val:]
|
||||
|
||||
print("\n[3/7] Evaluating baseline XGBoost...")
|
||||
baseline = evaluate_baseline(holdout_df, '/home/h3r7/turf_saas/xgboost_models.pkl')
|
||||
print(f" Baseline P@3={baseline['precision_at3']:.4f} AUC={baseline['auc']:.4f}")
|
||||
|
||||
print("\n[4/7] Training models with best params...")
|
||||
print(" XGBoost...")
|
||||
xgb_model = train_xgboost(X_tr, y_tr, XGB_BEST)
|
||||
print(" LightGBM...")
|
||||
lgb_model = train_lightgbm(X_tr, y_tr, LGB_BEST)
|
||||
print(" MLP...")
|
||||
mlp_model = train_mlp(X_tr.values, y_tr)
|
||||
|
||||
print("\n[5/7] SHAP analysis...")
|
||||
selected_features, shap_df = shap_feature_selection(xgb_model, X_tr)
|
||||
|
||||
print("\n[6/7] Computing ensemble weights...")
|
||||
class WrappedMLP:
|
||||
def __init__(self, pipeline, cols):
|
||||
self.pipeline = pipeline
|
||||
self.feature_cols = cols
|
||||
def predict_proba(self, X):
|
||||
import pandas as pd
|
||||
available = [c for c in self.feature_cols if c in X.columns]
|
||||
return self.pipeline.predict_proba(X[available].values)
|
||||
|
||||
class WrappedTree:
|
||||
def __init__(self, model, cols):
|
||||
self.model = model
|
||||
self.feature_cols = cols
|
||||
def predict_proba(self, X):
|
||||
available = [c for c in self.feature_cols if c in X.columns]
|
||||
return self.model.predict_proba(X[available])
|
||||
|
||||
wrapped_xgb = WrappedTree(xgb_model, feat_cols)
|
||||
wrapped_lgb = WrappedTree(lgb_model, feat_cols)
|
||||
wrapped_mlp = WrappedMLP(mlp_model, feat_cols)
|
||||
model_dict = {'xgboost': wrapped_xgb, 'lightgbm': wrapped_lgb, 'mlp': wrapped_mlp}
|
||||
|
||||
weights = compute_ensemble_weights(model_dict, X_val, y_val, feat_cols)
|
||||
print(" Weights:", weights)
|
||||
|
||||
print("\n[7/7] Evaluating + saving ensemble...")
|
||||
ensemble = TurfEnsemble(xgb_model, lgb_model, mlp_model, weights, feat_cols)
|
||||
|
||||
results = {}
|
||||
for name, wrapped in model_dict.items():
|
||||
res = evaluate_model(wrapped, X_holdout, y_holdout, holdout_df, name)
|
||||
results[name] = res
|
||||
print(f" {name:12s} P@3={res['precision_at3']:.4f} AUC={res['auc']:.4f}")
|
||||
|
||||
ens_res = evaluate_model(ensemble, X_holdout, y_holdout, holdout_df, "ensemble")
|
||||
results["ensemble"] = ens_res
|
||||
print(f" {'ensemble':12s} P@3={ens_res['precision_at3']:.4f} AUC={ens_res['auc']:.4f}")
|
||||
|
||||
delta = ens_res['precision_at3'] - baseline['precision_at3']
|
||||
deploy = delta >= DEPLOY_THRESHOLD
|
||||
print(f"\n Delta: {delta:+.4f} ({delta*100:+.1f}%) Deploy={'YES' if deploy else 'NO'}")
|
||||
|
||||
# Save ensemble
|
||||
ensemble_path = MODELS_DIR / "ensemble_top3.pkl"
|
||||
with open(ensemble_path, "wb") as f:
|
||||
pickle.dump(ensemble, f)
|
||||
print(f"\n ✅ ensemble_top3.pkl saved ({ensemble_path.stat().st_size//1024} KB)")
|
||||
|
||||
# Save individual models
|
||||
for name, model in [("xgboost_optimized", xgb_model), ("lightgbm", lgb_model)]:
|
||||
path = MODELS_DIR / f"{name}_top3.pkl"
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({"model": model, "feature_cols": feat_cols}, f)
|
||||
print(f" ✅ {name}_top3.pkl saved")
|
||||
|
||||
mlp_path = MODELS_DIR / "mlp_top3.pkl"
|
||||
with open(mlp_path, "wb") as f:
|
||||
pickle.dump({"pipeline": mlp_model, "feature_cols": feat_cols}, f)
|
||||
print(f" ✅ mlp_top3.pkl saved")
|
||||
|
||||
# Benchmark report
|
||||
report = {
|
||||
"run_date": datetime.now().isoformat(),
|
||||
"dataset": {
|
||||
"db_path": DB_PATH,
|
||||
"total_rows": len(df),
|
||||
"train_rows": len(X_train),
|
||||
"holdout_rows": len(X_holdout),
|
||||
"train_date_range": [str(train_df["date_programme"].min()), str(train_df["date_programme"].max())],
|
||||
"holdout_date_range": [str(holdout_df["date_programme"].min()), str(holdout_df["date_programme"].max())],
|
||||
},
|
||||
"baseline": baseline,
|
||||
"individual_models": {k: v for k, v in results.items() if k != "ensemble"},
|
||||
"ensemble": ens_res,
|
||||
"delta_precision_at3": round(delta, 4),
|
||||
"deploy": deploy,
|
||||
"optuna": {
|
||||
"n_trials": 100,
|
||||
"xgboost_best_params": XGB_BEST,
|
||||
"lightgbm_best_params": LGB_BEST,
|
||||
},
|
||||
"features": {
|
||||
"total": len(feat_cols),
|
||||
"selected_by_shap": len(selected_features),
|
||||
"feature_list": feat_cols,
|
||||
"shap_selected": selected_features,
|
||||
},
|
||||
"ensemble_weights": weights,
|
||||
}
|
||||
|
||||
report_path = MODELS_DIR / "benchmark_report.json"
|
||||
with open(report_path, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f" ✅ benchmark_report.json saved")
|
||||
|
||||
md_path = MODELS_DIR / "benchmark_report.md"
|
||||
_write_markdown_report(report, md_path)
|
||||
print(f" ✅ benchmark_report.md saved")
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print("DONE")
|
||||
print(f" Baseline P@3: {baseline['precision_at3']:.4f}")
|
||||
print(f" Ensemble P@3: {ens_res['precision_at3']:.4f}")
|
||||
print(f" Delta: {delta:+.4f} ({delta*100:+.1f}%)")
|
||||
print(f" Deploy: {'✅ YES' if deploy else '❌ NO'}")
|
||||
print("=" * 65)
|
||||
247
saas_api.py
Normal file
247
saas_api.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Turf SaaS API v1 — Auth JWT + Multi-tenant
|
||||
Sprint 2-3: HRT-28
|
||||
|
||||
Run:
|
||||
FLASK_ENV=development ./venv/bin/python saas_api.py
|
||||
|
||||
Ports (isolated from production):
|
||||
Portal: 8793
|
||||
SaaS API: 8792 ← this file
|
||||
Dashboard: 8791
|
||||
Combined API: 8790
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
|
||||
from flask import Flask, jsonify, g, request
|
||||
from flask_cors import CORS
|
||||
from flask_jwt_extended import JWTManager, get_jwt
|
||||
|
||||
from auth_db import init_auth_tables
|
||||
from auth import (
|
||||
auth_bp,
|
||||
jwt_required_middleware,
|
||||
plan_required,
|
||||
free_daily_limit_check,
|
||||
_get_user_by_id,
|
||||
)
|
||||
from middleware import rate_limit_middleware, access_log_middleware
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Logging setup
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
LOG_DIR = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.handlers.RotatingFileHandler(
|
||||
os.path.join(LOG_DIR, "saas_api.log"),
|
||||
maxBytes=5 * 1024 * 1024,
|
||||
backupCount=3,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# App factory
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_app(test_config=None):
|
||||
app = Flask(__name__)
|
||||
|
||||
# JWT config
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||
"JWT_SECRET_KEY", "CHANGE_ME_IN_PRODUCTION_" + os.urandom(24).hex()
|
||||
)
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = 900 # 15 minutes
|
||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = 2592000 # 30 days
|
||||
|
||||
if test_config:
|
||||
app.config.update(test_config)
|
||||
|
||||
# CORS — SaaS domain + localhost for dev
|
||||
CORS(
|
||||
app,
|
||||
origins=os.environ.get(
|
||||
"CORS_ORIGINS",
|
||||
"http://localhost:8793,http://127.0.0.1:8793,https://turf-ia.h3r7.tech",
|
||||
).split(","),
|
||||
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
supports_credentials=True,
|
||||
)
|
||||
|
||||
# JWT
|
||||
jwt = JWTManager(app)
|
||||
|
||||
# ── JWT error handlers ────────────────────────────────────
|
||||
@jwt.expired_token_loader
|
||||
def expired_token(_jwt_header, _jwt_payload):
|
||||
return jsonify({"error": "Token expiré"}), 401
|
||||
|
||||
@jwt.invalid_token_loader
|
||||
def invalid_token(reason):
|
||||
return jsonify({"error": "Token invalide", "detail": reason}), 422
|
||||
|
||||
@jwt.unauthorized_loader
|
||||
def unauthorized(reason):
|
||||
return jsonify({"error": "Token manquant ou invalide", "detail": reason}), 401
|
||||
|
||||
# ── Register middleware ───────────────────────────────────
|
||||
rate_limit_middleware(app)
|
||||
access_log_middleware(app)
|
||||
|
||||
# ── Blueprints ────────────────────────────────────────────
|
||||
app.register_blueprint(auth_bp)
|
||||
|
||||
# ── Predictions routes (multi-tenant plan check) ──────────
|
||||
|
||||
@app.route("/api/v1/predictions", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@free_daily_limit_check
|
||||
def predictions():
|
||||
"""
|
||||
GET /api/v1/predictions
|
||||
- free: Top 3 uniquement (déjà filtrées par le moteur ML)
|
||||
- premium: toutes courses + alertes Telegram
|
||||
- pro: API complète + export CSV disponible
|
||||
"""
|
||||
user = g.current_user
|
||||
plan = user["plan"]
|
||||
|
||||
# Forward to combined_api for actual predictions
|
||||
import requests as req
|
||||
|
||||
try:
|
||||
params = dict(request.args)
|
||||
resp = req.get(
|
||||
"http://localhost:8790/api/predictions",
|
||||
params=params,
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
return jsonify(
|
||||
{"error": "Service prédictions indisponible", "detail": str(e)}
|
||||
), 503
|
||||
|
||||
# Plan filtering
|
||||
if plan == "free":
|
||||
# Top 3 only
|
||||
if isinstance(data, list):
|
||||
data = [
|
||||
{k: v for k, v in p.items() if k not in ("score_detaille",)}
|
||||
for p in data[:3]
|
||||
]
|
||||
return jsonify({"plan": plan, "predictions": data, "limit": "Top 3"}), 200
|
||||
|
||||
elif plan == "premium":
|
||||
# All courses, but no CSV export
|
||||
return jsonify(
|
||||
{"plan": plan, "predictions": data, "telegram_alerts": True}
|
||||
), 200
|
||||
|
||||
else: # pro
|
||||
return jsonify(
|
||||
{
|
||||
"plan": plan,
|
||||
"predictions": data,
|
||||
"telegram_alerts": True,
|
||||
"csv_export_url": "/api/v1/predictions/export",
|
||||
}
|
||||
), 200
|
||||
|
||||
@app.route("/api/v1/predictions/export", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
@plan_required("pro")
|
||||
def predictions_export():
|
||||
"""CSV export — pro plan only."""
|
||||
import requests as req
|
||||
import io
|
||||
|
||||
try:
|
||||
resp = req.get(
|
||||
"http://localhost:8790/api/predictions/export",
|
||||
params=dict(request.args),
|
||||
timeout=15,
|
||||
)
|
||||
from flask import Response
|
||||
|
||||
return Response(
|
||||
resp.content,
|
||||
mimetype="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=predictions.csv"},
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonify({"error": "Export indisponible", "detail": str(e)}), 503
|
||||
|
||||
@app.route("/api/v1/subscription/upgrade", methods=["GET"])
|
||||
@jwt_required_middleware
|
||||
def subscription_info():
|
||||
"""Return available plans and current user plan."""
|
||||
user = g.current_user
|
||||
return jsonify(
|
||||
{
|
||||
"current_plan": user["plan"],
|
||||
"plans": {
|
||||
"free": {
|
||||
"price": "0€/mois",
|
||||
"features": ["Top 3 prédictions", "1 course/jour"],
|
||||
},
|
||||
"premium": {
|
||||
"price": "9.99€/mois",
|
||||
"features": [
|
||||
"Toutes les courses",
|
||||
"Alertes Telegram",
|
||||
"Historique 30j",
|
||||
],
|
||||
},
|
||||
"pro": {
|
||||
"price": "29.99€/mois",
|
||||
"features": [
|
||||
"API complète",
|
||||
"Export CSV",
|
||||
"Alertes Telegram",
|
||||
"Historique illimité",
|
||||
"Support prioritaire",
|
||||
],
|
||||
},
|
||||
},
|
||||
"upgrade_contact": "contact@h3r7.tech",
|
||||
}
|
||||
), 200
|
||||
|
||||
# ── Health check ──────────────────────────────────────────
|
||||
|
||||
@app.route("/api/v1/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify(
|
||||
{"status": "ok", "service": "turf-saas-api", "version": "2.3.0"}
|
||||
), 200
|
||||
|
||||
# Init DB tables on startup
|
||||
with app.app_context():
|
||||
init_auth_tables()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Entrypoint
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
port = int(os.environ.get("SAAS_API_PORT", 8792))
|
||||
app.run(host="0.0.0.0", port=port, debug=False)
|
||||
@@ -1,448 +0,0 @@
|
||||
"""
|
||||
Beta Monitoring — SaaS Turf Prédictions IA
|
||||
Sprint 8 — QA, Beta Fermee, Go/No-Go
|
||||
Ticket: HRT-34
|
||||
|
||||
Ce module :
|
||||
- Collecte les feedbacks beta via l'API in-app
|
||||
- Envoie des alertes Telegram en cas d'erreur détectée pendant la beta
|
||||
- Génère le rapport beta final (bugs, UX, NPS)
|
||||
|
||||
Usage :
|
||||
# Démarrer le monitoring beta
|
||||
python tests/beta_monitor.py --watch --interval 60
|
||||
|
||||
# Générer le rapport beta final
|
||||
python tests/beta_monitor.py --report
|
||||
|
||||
# Test d'envoi Telegram
|
||||
python tests/beta_monitor.py --test-telegram
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import sqlite3
|
||||
import requests
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BASE_URL = os.environ.get("APP_URL", "http://localhost:8792")
|
||||
TELEGRAM_TOKEN = os.environ.get(
|
||||
"TELEGRAM_TOKEN", "8649773134:AAFqzZVtSHfPPFDadcte1B-1h23nZ8DmdYE"
|
||||
)
|
||||
TELEGRAM_CHAT_ID = os.environ.get("TELEGRAM_CHAT_ID", "") # À configurer
|
||||
|
||||
BETA_DB_PATH = os.environ.get("BETA_DB_PATH", "/home/h3r7/turf_saas/turf_saas.db")
|
||||
REPORTS_DIR = Path("tests/reports")
|
||||
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Seuils d'alerte
|
||||
ERROR_RATE_THRESHOLD = 0.01 # 1% d'erreurs → alerte
|
||||
LATENCY_P95_THRESHOLD_MS = 500 # p95 > 500ms → alerte
|
||||
BETA_MIN_USERS = 10 # Minimum d'utilisateurs beta requis
|
||||
NPS_TARGET = 7.0 # NPS cible (sur 10)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Alertes Telegram
|
||||
# ============================================================
|
||||
|
||||
|
||||
def send_telegram(message: str, parse_mode: str = "Markdown") -> bool:
|
||||
"""Envoie un message Telegram d'alerte."""
|
||||
if not TELEGRAM_TOKEN or not TELEGRAM_CHAT_ID:
|
||||
print(f"⚠️ Telegram non configuré. Message: {message[:100]}")
|
||||
return False
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage",
|
||||
json={
|
||||
"chat_id": TELEGRAM_CHAT_ID,
|
||||
"text": message,
|
||||
"parse_mode": parse_mode,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
print(f"✅ Alerte Telegram envoyée")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Telegram erreur: {resp.status_code} — {resp.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Telegram exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def alert_error(endpoint: str, status_code: int, message: str):
|
||||
"""Alerte Telegram sur erreur critique."""
|
||||
text = (
|
||||
f"🚨 *ALERTE BETA — SaaS Turf IA*\n\n"
|
||||
f"Erreur détectée sur `{endpoint}`\n"
|
||||
f"Status: `{status_code}`\n"
|
||||
f"Message: {message[:200]}\n"
|
||||
f"Heure: {datetime.now().strftime('%H:%M:%S')}\n\n"
|
||||
f"_Ticket: HRT-34_"
|
||||
)
|
||||
send_telegram(text)
|
||||
|
||||
|
||||
def alert_performance(p95_ms: float, error_rate: float):
|
||||
"""Alerte Telegram sur dégradation de performance."""
|
||||
text = (
|
||||
f"⚠️ *ALERTE PERFORMANCE — SaaS Turf IA*\n\n"
|
||||
f"p95 latence: `{p95_ms:.0f}ms` (seuil: {LATENCY_P95_THRESHOLD_MS}ms)\n"
|
||||
f"Error rate: `{error_rate * 100:.2f}%` (seuil: {ERROR_RATE_THRESHOLD * 100:.1f}%)\n"
|
||||
f"Heure: {datetime.now().strftime('%H:%M:%S')}\n\n"
|
||||
f"_Ticket: HRT-34_"
|
||||
)
|
||||
send_telegram(text)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Collecte de métriques
|
||||
# ============================================================
|
||||
|
||||
|
||||
class BetaMonitor:
|
||||
"""Moniteur actif pendant la beta fermée."""
|
||||
|
||||
ENDPOINTS_TO_CHECK = [
|
||||
"/api",
|
||||
"/api/races",
|
||||
"/api/scoring",
|
||||
"/dashboard",
|
||||
"/",
|
||||
]
|
||||
|
||||
def __init__(self, base_url: str = BASE_URL):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.errors: list[dict] = []
|
||||
self.latencies: list[float] = []
|
||||
self.check_count = 0
|
||||
|
||||
def check_endpoint(self, path: str) -> dict:
|
||||
"""Vérifie un endpoint et retourne le résultat."""
|
||||
start = time.time()
|
||||
try:
|
||||
resp = requests.get(f"{self.base_url}{path}", timeout=10)
|
||||
latency_ms = (time.time() - start) * 1000
|
||||
return {
|
||||
"path": path,
|
||||
"status": resp.status_code,
|
||||
"latency_ms": latency_ms,
|
||||
"ok": resp.status_code < 500,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
return {
|
||||
"path": path,
|
||||
"status": 0,
|
||||
"latency_ms": 0,
|
||||
"ok": False,
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"path": path,
|
||||
"status": 0,
|
||||
"latency_ms": 0,
|
||||
"ok": False,
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
def run_checks(self) -> dict:
|
||||
"""Exécute tous les checks et retourne un résumé."""
|
||||
results = [self.check_endpoint(p) for p in self.ENDPOINTS_TO_CHECK]
|
||||
self.check_count += 1
|
||||
|
||||
failures = [r for r in results if not r["ok"]]
|
||||
latencies = [r["latency_ms"] for r in results if r["latency_ms"] > 0]
|
||||
|
||||
p95 = (
|
||||
sorted(latencies)[int(len(latencies) * 0.95)]
|
||||
if len(latencies) >= 2
|
||||
else (latencies[0] if latencies else 0)
|
||||
)
|
||||
error_rate = len(failures) / len(results) if results else 0
|
||||
|
||||
# Stocker pour rapport
|
||||
self.latencies.extend(latencies)
|
||||
self.errors.extend(failures)
|
||||
|
||||
return {
|
||||
"check_number": self.check_count,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_checks": len(results),
|
||||
"failures": len(failures),
|
||||
"error_rate": error_rate,
|
||||
"p95_ms": p95,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
def watch(self, interval_seconds: int = 60):
|
||||
"""Surveillance continue avec alertes Telegram."""
|
||||
print(f"🔍 Beta monitoring démarré — {self.base_url}")
|
||||
print(f" Intervalle: {interval_seconds}s")
|
||||
print(f" Endpoints: {len(self.ENDPOINTS_TO_CHECK)}")
|
||||
print(f" Ctrl+C pour arrêter\n")
|
||||
|
||||
consecutive_errors = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
summary = self.run_checks()
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
status_icon = "✅" if summary["error_rate"] == 0 else "❌"
|
||||
print(
|
||||
f"[{timestamp}] {status_icon} "
|
||||
f"Check #{summary['check_number']} — "
|
||||
f"p95={summary['p95_ms']:.0f}ms, "
|
||||
f"errors={summary['failures']}/{summary['total_checks']}"
|
||||
)
|
||||
|
||||
# Alertes
|
||||
if summary["error_rate"] > ERROR_RATE_THRESHOLD:
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= 2: # 2 checks consécutifs en erreur
|
||||
for failure in summary["results"]:
|
||||
if not failure["ok"]:
|
||||
alert_error(
|
||||
failure["path"],
|
||||
failure.get("status", 0),
|
||||
failure.get("error", "Non-2xx response"),
|
||||
)
|
||||
else:
|
||||
consecutive_errors = 0
|
||||
|
||||
if summary["p95_ms"] > LATENCY_P95_THRESHOLD_MS:
|
||||
print(f"⚠️ Latence p95 élevée: {summary['p95_ms']:.0f}ms")
|
||||
if summary["p95_ms"] > LATENCY_P95_THRESHOLD_MS * 2:
|
||||
alert_performance(summary["p95_ms"], summary["error_rate"])
|
||||
|
||||
# Sauvegarder les résultats
|
||||
log_file = REPORTS_DIR / "beta_monitor_log.jsonl"
|
||||
with open(log_file, "a") as f:
|
||||
f.write(json.dumps(summary) + "\n")
|
||||
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n⏹️ Monitoring arrêté après {self.check_count} checks")
|
||||
self.generate_report()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Rapport beta final
|
||||
# ============================================================
|
||||
|
||||
|
||||
class BetaReport:
|
||||
"""Générateur de rapport beta fermée."""
|
||||
|
||||
def __init__(self, base_url: str = BASE_URL):
|
||||
self.base_url = base_url
|
||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
def collect_feedback_from_db(self) -> list[dict]:
|
||||
"""Collecte les feedbacks depuis la BDD (table beta_feedback si elle existe)."""
|
||||
try:
|
||||
conn = sqlite3.connect(BETA_DB_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='beta_feedback'"
|
||||
)
|
||||
if not c.fetchone():
|
||||
conn.close()
|
||||
return []
|
||||
c.execute("SELECT * FROM beta_feedback ORDER BY created_at DESC")
|
||||
rows = c.fetchall()
|
||||
conn.close()
|
||||
return [dict(zip([col[0] for col in c.description], row)) for row in rows]
|
||||
except Exception as e:
|
||||
print(f"⚠️ Impossible de lire beta_feedback: {e}")
|
||||
return []
|
||||
|
||||
def collect_monitor_logs(self) -> list[dict]:
|
||||
"""Lit les logs du monitoring beta."""
|
||||
log_file = REPORTS_DIR / "beta_monitor_log.jsonl"
|
||||
if not log_file.exists():
|
||||
return []
|
||||
entries = []
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except Exception:
|
||||
pass
|
||||
return entries
|
||||
|
||||
def generate(self) -> str:
|
||||
"""Génère le rapport complet et le sauvegarde."""
|
||||
feedbacks = self.collect_feedback_from_db()
|
||||
monitor_logs = self.collect_monitor_logs()
|
||||
|
||||
# Calculer NPS depuis les feedbacks
|
||||
nps_scores = [
|
||||
f.get("nps_score") for f in feedbacks if f.get("nps_score") is not None
|
||||
]
|
||||
avg_nps = sum(nps_scores) / len(nps_scores) if nps_scores else None
|
||||
|
||||
# Statistiques monitoring
|
||||
if monitor_logs:
|
||||
all_latencies = []
|
||||
total_errors = 0
|
||||
total_checks = 0
|
||||
for entry in monitor_logs:
|
||||
all_latencies.extend(
|
||||
[
|
||||
r["latency_ms"]
|
||||
for r in entry.get("results", [])
|
||||
if r.get("latency_ms", 0) > 0
|
||||
]
|
||||
)
|
||||
total_errors += entry.get("failures", 0)
|
||||
total_checks += entry.get("total_checks", 0)
|
||||
avg_latency = (
|
||||
sum(all_latencies) / len(all_latencies) if all_latencies else 0
|
||||
)
|
||||
overall_error_rate = total_errors / total_checks if total_checks > 0 else 0
|
||||
else:
|
||||
avg_latency = 0
|
||||
overall_error_rate = 0
|
||||
total_checks = 0
|
||||
|
||||
# Construire le rapport
|
||||
report = []
|
||||
report.append("=" * 60)
|
||||
report.append("RAPPORT BETA FERMÉE — SaaS Turf Prédictions IA")
|
||||
report.append(f"Généré le : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report.append(f"Ticket : HRT-34")
|
||||
report.append("=" * 60)
|
||||
report.append("")
|
||||
report.append("## 1. PARTICIPANTS BETA")
|
||||
report.append(f" Feedbacks reçus : {len(feedbacks)}")
|
||||
report.append(
|
||||
f" NPS moyen : {avg_nps:.1f}/10"
|
||||
if avg_nps
|
||||
else " NPS moyen : (en attente feedbacks)"
|
||||
)
|
||||
report.append(f" Cible NPS : ≥ {NPS_TARGET}/10")
|
||||
nps_ok = avg_nps is not None and avg_nps >= NPS_TARGET
|
||||
report.append(
|
||||
f" Statut NPS : {'✅ OBJECTIF ATTEINT' if nps_ok else '⏳ En attente' if avg_nps is None else '❌ OBJECTIF NON ATTEINT'}"
|
||||
)
|
||||
report.append("")
|
||||
report.append("## 2. BUGS SIGNALÉS")
|
||||
bugs = [f for f in feedbacks if f.get("type") == "bug"]
|
||||
critical_bugs = [b for b in bugs if b.get("severity") in ("critical", "high")]
|
||||
report.append(f" Total bugs : {len(bugs)}")
|
||||
report.append(f" Critiques/High : {len(critical_bugs)}")
|
||||
report.append(
|
||||
f" Statut : {'✅ 0 bug critique' if len(critical_bugs) == 0 else f'❌ {len(critical_bugs)} bug(s) critique(s)'}"
|
||||
)
|
||||
report.append("")
|
||||
report.append("## 3. PERFORMANCE RÉELLE (monitoring)")
|
||||
report.append(f" Checks effectués: {total_checks}")
|
||||
report.append(f" Latence moyenne : {avg_latency:.1f}ms")
|
||||
report.append(f" Error rate : {overall_error_rate * 100:.2f}%")
|
||||
report.append(f" Seuil latence : {LATENCY_P95_THRESHOLD_MS}ms")
|
||||
perf_ok = (
|
||||
avg_latency < LATENCY_P95_THRESHOLD_MS
|
||||
and overall_error_rate < ERROR_RATE_THRESHOLD
|
||||
)
|
||||
report.append(
|
||||
f" Statut : {'✅ OBJECTIF ATTEINT' if perf_ok else '⏳ Données insuffisantes' if total_checks == 0 else '❌ OBJECTIF NON ATTEINT'}"
|
||||
)
|
||||
report.append("")
|
||||
report.append("## 4. FEEDBACKS UX")
|
||||
ux_feedbacks = [f for f in feedbacks if f.get("type") == "ux"]
|
||||
report.append(f" Retours UX : {len(ux_feedbacks)}")
|
||||
if ux_feedbacks:
|
||||
for fb in ux_feedbacks[:5]: # Top 5
|
||||
report.append(f" - {fb.get('comment', '')[:100]}")
|
||||
report.append("")
|
||||
report.append("## 5. VERDICT BETA FERMÉE")
|
||||
users_ok = len(feedbacks) >= 5 # Au moins 5 feedbacks = 5 users satisfaits
|
||||
verdict = all([users_ok, nps_ok, len(critical_bugs) == 0])
|
||||
report.append(
|
||||
f" Participants suffisants (≥5) : {'✅' if users_ok else '❌'}"
|
||||
)
|
||||
report.append(f" NPS ≥ 7/10 : {'✅' if nps_ok else '❌'}")
|
||||
report.append(
|
||||
f" 0 bug critique : {'✅' if len(critical_bugs) == 0 else '❌'}"
|
||||
)
|
||||
report.append("")
|
||||
report.append(
|
||||
f" VERDICT GLOBAL : {'✅ GO — Beta réussie' if verdict else '❌ NO-GO — Conditions non remplies'}"
|
||||
)
|
||||
report.append("=" * 60)
|
||||
|
||||
report_text = "\n".join(report)
|
||||
|
||||
# Sauvegarder
|
||||
report_file = REPORTS_DIR / f"beta_report_{self.timestamp}.txt"
|
||||
with open(report_file, "w") as f:
|
||||
f.write(report_text)
|
||||
|
||||
print(report_text)
|
||||
print(f"\nRapport sauvegardé : {report_file}")
|
||||
|
||||
return report_text
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI
|
||||
# ============================================================
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Beta Monitor — SaaS Turf IA")
|
||||
parser.add_argument("--watch", action="store_true", help="Surveillance continue")
|
||||
parser.add_argument(
|
||||
"--interval", type=int, default=60, help="Intervalle en secondes (défaut: 60)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report", action="store_true", help="Générer le rapport beta final"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-telegram", action="store_true", help="Tester l'envoi Telegram"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url", default=BASE_URL, help=f"URL de l'app (défaut: {BASE_URL})"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.test_telegram:
|
||||
print("Test d'envoi Telegram...")
|
||||
ok = send_telegram(
|
||||
"✅ *Test alerte Beta* — SaaS Turf IA\n_Ceci est un test du système d'alertes QA_\nTicket: HRT-34"
|
||||
)
|
||||
sys.exit(0 if ok else 1)
|
||||
|
||||
if args.report:
|
||||
reporter = BetaReport(args.url)
|
||||
reporter.generate()
|
||||
sys.exit(0)
|
||||
|
||||
if args.watch:
|
||||
monitor = BetaMonitor(args.url)
|
||||
monitor.watch(interval_seconds=args.interval)
|
||||
sys.exit(0)
|
||||
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
conftest.py — Configuration pytest globale
|
||||
SaaS Turf Prédictions IA — Sprint 8 QA
|
||||
Ticket: HRT-34
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# ============================================================
|
||||
# Répertoires de sortie
|
||||
# ============================================================
|
||||
|
||||
REPORTS_DIR = Path("tests/reports")
|
||||
SCREENSHOTS_DIR = Path("tests/screenshots")
|
||||
|
||||
for d in [REPORTS_DIR, SCREENSHOTS_DIR]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Variables d'environnement
|
||||
# ============================================================
|
||||
|
||||
BASE_URL = os.environ.get("APP_URL", "http://localhost:8792")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Fixtures globales
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_url():
|
||||
return BASE_URL
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Event loop partagé pour les tests async de la session."""
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
loop = policy.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def reports_dir():
|
||||
return REPORTS_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def screenshots_dir():
|
||||
return SCREENSHOTS_DIR
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Hook : screenshot automatique sur échec
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
"""Capture screenshot automatiquement sur tout test E2E en échec."""
|
||||
outcome = yield
|
||||
report = outcome.get_result()
|
||||
|
||||
if report.when == "call" and report.failed:
|
||||
# Récupérer la page Playwright si disponible dans les fixtures
|
||||
page = None
|
||||
for fixture_name in ("page", "context_page"):
|
||||
if fixture_name in item.funcargs:
|
||||
val = item.funcargs[fixture_name]
|
||||
if isinstance(val, tuple):
|
||||
page = val[0] # (page, browser_name)
|
||||
else:
|
||||
page = val
|
||||
break
|
||||
|
||||
if page is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
test_name = item.name.replace("/", "_").replace(":", "_")
|
||||
screenshot_path = SCREENSHOTS_DIR / f"FAIL_{test_name}_{timestamp}.png"
|
||||
try:
|
||||
# Playwright page.screenshot est synchrone dans les fixtures sync
|
||||
# Pour les fixtures async, on force la capture
|
||||
import asyncio as _asyncio
|
||||
|
||||
if _asyncio.iscoroutinefunction(page.screenshot):
|
||||
loop = _asyncio.get_event_loop()
|
||||
loop.run_until_complete(page.screenshot(path=str(screenshot_path)))
|
||||
else:
|
||||
page.screenshot(path=str(screenshot_path))
|
||||
report.sections.append(
|
||||
("Screenshot", f"Sauvegardé : {screenshot_path}")
|
||||
)
|
||||
except Exception as e:
|
||||
report.sections.append(
|
||||
("Screenshot Error", f"Impossible de capturer : {e}")
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Marqueurs personnalisés
|
||||
# ============================================================
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "e2e: Tests End-to-End Playwright")
|
||||
config.addinivalue_line("markers", "load: Tests de charge Locust")
|
||||
config.addinivalue_line("markers", "security: Tests de sécurité")
|
||||
config.addinivalue_line(
|
||||
"markers", "smoke: Tests rapides de smoke (sans infra complète)"
|
||||
)
|
||||
config.addinivalue_line("markers", "beta: Tests spécifiques beta fermée")
|
||||
config.addinivalue_line(
|
||||
"markers", "requires_billing: Nécessite HRT-31 (Billing Stripe)"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "requires_infra: Nécessite HRT-33 (infra staging)"
|
||||
)
|
||||
473
tests/test_api_v1.py
Normal file
473
tests/test_api_v1.py
Normal file
@@ -0,0 +1,473 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for API v1 — HRT-29
|
||||
Sprint 3-4: Refacto API /v1/
|
||||
|
||||
Run with:
|
||||
cd /home/h3r7/turf_saas
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/test_api_v1.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
# Ensure local modules are importable
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Use a temp file DB for tests (in-memory fails with multiple connections)
|
||||
_tmp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_tmp_db.close()
|
||||
os.environ["TURF_SAAS_DB"] = _tmp_db.name
|
||||
os.environ["JWT_SECRET_KEY"] = "test-secret-key"
|
||||
|
||||
from app_v1 import create_app
|
||||
from auth_db import init_auth_tables
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Fixtures
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
application = create_app()
|
||||
application.config["TESTING"] = True
|
||||
application.config["JWT_SECRET_KEY"] = "test-secret-key"
|
||||
yield application
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def auth_tokens(client):
|
||||
"""Register a user and return tokens for each plan."""
|
||||
tokens = {}
|
||||
plans = {
|
||||
"free": ("free@test.com", "password123"),
|
||||
"premium": ("premium@test.com", "password123"),
|
||||
"pro": ("pro@test.com", "password123"),
|
||||
}
|
||||
|
||||
# Register users
|
||||
for plan, (email, pw) in plans.items():
|
||||
r = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": email, "password": pw},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert r.status_code in (201, 409), f"register failed for {plan}: {r.data}"
|
||||
|
||||
# Manually set plans in DB using direct sqlite (bypass app context issues)
|
||||
import sqlite3
|
||||
|
||||
db_path = os.environ.get("TURF_SAAS_DB", "/tmp/test_turf.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
for plan, (email, _) in plans.items():
|
||||
conn.execute("UPDATE users SET plan = ? WHERE email = ?", (plan, email))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Login and collect tokens
|
||||
for plan, (email, pw) in plans.items():
|
||||
r = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": pw},
|
||||
content_type="application/json",
|
||||
)
|
||||
assert r.status_code == 200, f"login failed for {plan}: {r.data}"
|
||||
data = r.get_json()
|
||||
tokens[plan] = data["access_token"]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def auth_header(token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Health
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHealth:
|
||||
def test_health_public(self, client):
|
||||
"""GET /api/v1/health — no auth required"""
|
||||
r = client.get("/api/v1/health")
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["version"] == "1.0"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_health_returns_json(self, client):
|
||||
r = client.get("/api/v1/health")
|
||||
assert r.content_type.startswith("application/json")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Auth
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAuth:
|
||||
def test_register_new_user(self, client):
|
||||
r = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "new_test@example.com", "password": "strongpass123"},
|
||||
)
|
||||
assert r.status_code in (201, 409)
|
||||
|
||||
def test_register_short_password(self, client):
|
||||
r = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "bad@example.com", "password": "123"},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_register_invalid_email(self, client):
|
||||
r = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "notemail", "password": "password123"},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_login_valid(self, client, auth_tokens):
|
||||
assert "free" in auth_tokens
|
||||
|
||||
def test_login_wrong_password(self, client):
|
||||
r = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "free@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_protected_without_token(self, client):
|
||||
r = client.get("/api/v1/courses/today")
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Courses
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCourses:
|
||||
def test_today_requires_auth(self, client):
|
||||
r = client.get("/api/v1/courses/today")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_today_with_auth(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/courses/today",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "courses" in data
|
||||
assert "pagination" in data
|
||||
assert "date" in data
|
||||
|
||||
def test_today_pagination(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/courses/today?limit=5&offset=0",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["pagination"]["limit"] == 5
|
||||
assert data["pagination"]["offset"] == 0
|
||||
|
||||
def test_today_filter_all(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/courses/today?filter=all",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_course_predictions_requires_auth(self, client):
|
||||
r = client.get("/api/v1/courses/1-1/predictions")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_course_predictions_invalid_id(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/courses/invalid/predictions",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_course_predictions_not_found(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/courses/99-99/predictions",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
# 404 expected since DB is empty; 429 if free daily limit already reached in this session
|
||||
assert r.status_code in (404, 200, 429) # 200 if gracefully returns empty
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Predictions
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPredictions:
|
||||
def test_top3_requires_auth(self, client):
|
||||
r = client.get("/api/v1/predictions/top3")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_top3_free_allowed(self, client, auth_tokens):
|
||||
# Reset daily usage for free user before testing rate-limited endpoint
|
||||
import sqlite3
|
||||
|
||||
db_path = os.environ.get("TURF_SAAS_DB", "/tmp/test_turf.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"UPDATE users SET daily_usage=0, last_usage_date=NULL WHERE email='free@test.com'"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
r = client.get(
|
||||
"/api/v1/predictions/top3",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "top3" in data
|
||||
|
||||
def test_all_requires_premium(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/predictions/all",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_all_premium_allowed(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/predictions/all",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "predictions" in data
|
||||
assert "pagination" in data
|
||||
|
||||
def test_all_pro_allowed(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/predictions/all",
|
||||
headers=auth_header(auth_tokens["pro"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Value Bets
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestValueBets:
|
||||
def test_requires_auth(self, client):
|
||||
r = client.get("/api/v1/valuebets")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_free_forbidden(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/valuebets",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_premium_allowed(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/valuebets",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "valuebets" in data
|
||||
assert "pagination" in data
|
||||
|
||||
def test_min_odds_filter(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/valuebets?min_odds=3.0",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["min_odds"] == 3.0
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Backtest
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBacktest:
|
||||
def test_requires_auth(self, client):
|
||||
r = client.get("/api/v1/backtest")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_premium_forbidden(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/backtest",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_pro_allowed(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/backtest",
|
||||
headers=auth_header(auth_tokens["pro"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "summary" in data
|
||||
assert "period" in data
|
||||
|
||||
def test_invalid_date_format(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/backtest?start=31-12-2025",
|
||||
headers=auth_header(auth_tokens["pro"]),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Export
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExport:
|
||||
def test_requires_auth(self, client):
|
||||
r = client.get("/api/v1/export/csv")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_free_forbidden(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/export/csv",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_premium_forbidden(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/export/csv",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_pro_allowed_predictions(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/export/csv?type=predictions",
|
||||
headers=auth_header(auth_tokens["pro"]),
|
||||
)
|
||||
# 200 (CSV) or 400 if table doesn't exist in test DB
|
||||
assert r.status_code in (200, 400)
|
||||
if r.status_code == 200:
|
||||
assert "text/csv" in r.content_type
|
||||
|
||||
def test_invalid_type(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/export/csv?type=invalid",
|
||||
headers=auth_header(auth_tokens["pro"]),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Metrics
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMetrics:
|
||||
def test_requires_auth(self, client):
|
||||
r = client.get("/api/v1/metrics")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_free_forbidden(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/metrics",
|
||||
headers=auth_header(auth_tokens["free"]),
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_premium_allowed(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/metrics",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert "bet_metrics" in data
|
||||
assert "ml_metrics" in data
|
||||
assert "period" in data
|
||||
|
||||
def test_days_parameter(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/metrics?days=7",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.get_json()
|
||||
assert data["period"]["days"] == 7
|
||||
|
||||
def test_invalid_days(self, client, auth_tokens):
|
||||
r = client.get(
|
||||
"/api/v1/metrics?days=abc",
|
||||
headers=auth_header(auth_tokens["premium"]),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Global error handlers
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestErrorHandlers:
|
||||
def test_404_returns_json(self, client):
|
||||
r = client.get("/api/v1/this-does-not-exist")
|
||||
assert r.status_code == 404
|
||||
data = r.get_json()
|
||||
assert data["code"] == 404
|
||||
|
||||
def test_uniform_error_shape(self, client):
|
||||
"""All error responses must have {status, message, code}."""
|
||||
r = client.get("/api/v1/this-does-not-exist")
|
||||
data = r.get_json()
|
||||
assert "status" in data
|
||||
assert "message" in data
|
||||
assert "code" in data
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Swagger docs
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDocs:
|
||||
def test_docs_accessible(self, client):
|
||||
r = client.get("/api/v1/docs")
|
||||
# flasgger returns a redirect or the UI page
|
||||
assert r.status_code in (200, 301, 302)
|
||||
|
||||
def test_apispec_json(self, client):
|
||||
r = client.get("/api/v1/apispec.json")
|
||||
assert r.status_code == 200
|
||||
spec = r.get_json()
|
||||
assert spec["swagger"] == "2.0"
|
||||
assert "paths" in spec
|
||||
404
tests/test_auth.py
Normal file
404
tests/test_auth.py
Normal file
@@ -0,0 +1,404 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pytest tests — Auth JWT + Multi-tenant
|
||||
Sprint 2-3: HRT-28
|
||||
Coverage target: >= 80%
|
||||
|
||||
Run:
|
||||
./venv/bin/pytest tests/test_auth.py -v --tb=short
|
||||
./venv/bin/pytest tests/test_auth.py -v --cov=auth --cov=auth_db --cov=middleware --cov=saas_api --cov-report=term-missing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import pytest
|
||||
|
||||
# Point to a temp SQLite DB for tests
|
||||
_tmp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_tmp_db.close()
|
||||
os.environ["TURF_SAAS_DB"] = _tmp_db.name
|
||||
os.environ["JWT_SECRET_KEY"] = "test-secret-key-for-pytest"
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from saas_api import create_app # noqa: E402
|
||||
|
||||
TEST_CONFIG = {
|
||||
"TESTING": True,
|
||||
"JWT_SECRET_KEY": "test-secret-key-for-pytest",
|
||||
"JWT_ACCESS_TOKEN_EXPIRES": 900,
|
||||
"JWT_REFRESH_TOKEN_EXPIRES": 2592000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
application = create_app(TEST_CONFIG)
|
||||
yield application
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Health
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHealth:
|
||||
def test_health_ok(self, client):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["service"] == "turf-saas-api"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Registration
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRegister:
|
||||
def test_register_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "user_test@example.com", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.get_json()
|
||||
assert "user_id" in data
|
||||
|
||||
def test_register_duplicate(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "dup@example.com", "password": "password123"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "dup@example.com", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_register_invalid_email(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "notanemail", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_register_short_password(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "shortpw@example.com", "password": "abc"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_register_missing_fields(self, client):
|
||||
resp = client.post("/api/v1/auth/register", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Login
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLogin:
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_user(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "login@example.com", "password": "loginpass1"},
|
||||
)
|
||||
|
||||
def test_login_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "login@example.com", "password": "loginpass1"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["plan"] == "free"
|
||||
|
||||
def test_login_wrong_password(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "login@example.com", "password": "wrongpass"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_login_unknown_email(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "ghost@example.com", "password": "anypass"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_login_missing_fields(self, client):
|
||||
resp = client.post("/api/v1/auth/login", json={"email": "login@example.com"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Token refresh
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRefresh:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "refresh@example.com", "password": "refreshpass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "refresh@example.com", "password": "refreshpass1"},
|
||||
)
|
||||
tokens = resp.get_json()
|
||||
self.refresh_token = tokens["refresh_token"]
|
||||
|
||||
def test_refresh_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
# New refresh token should differ from old
|
||||
assert data["refresh_token"] != self.refresh_token
|
||||
|
||||
def test_refresh_token_rotation(self, client):
|
||||
"""Old refresh token must be invalid after rotation."""
|
||||
client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
resp2 = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp2.status_code == 401
|
||||
|
||||
def test_refresh_invalid_token(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "completely.invalid.token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self, client):
|
||||
resp = client.post("/api/v1/auth/refresh", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Logout
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLogout:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "logout@example.com", "password": "logoutpass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "logout@example.com", "password": "logoutpass1"},
|
||||
)
|
||||
tokens = resp.get_json()
|
||||
self.refresh_token = tokens["refresh_token"]
|
||||
self.access_token = tokens["access_token"]
|
||||
|
||||
def test_logout_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_refresh_after_logout_fails(self, client):
|
||||
client.post("/api/v1/auth/logout", json={"refresh_token": self.refresh_token})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_logout_no_token(self, client):
|
||||
resp = client.post("/api/v1/auth/logout", json={})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# JWT middleware — protected routes
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestJWTMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "protected@example.com", "password": "protect123"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "protected@example.com", "password": "protect123"},
|
||||
)
|
||||
self.access_token = resp.get_json()["access_token"]
|
||||
|
||||
def test_subscription_info_requires_auth(self, client):
|
||||
resp = client.get("/api/v1/subscription/upgrade")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_subscription_info_with_token(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": f"Bearer {self.access_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "current_plan" in data
|
||||
assert data["current_plan"] == "free"
|
||||
|
||||
def test_invalid_token_rejected(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": "Bearer invalid.token.here"},
|
||||
)
|
||||
assert resp.status_code in (401, 422)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Plan checks
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client, app):
|
||||
# Register free user
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "free_plan@example.com", "password": "freepass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "free_plan@example.com", "password": "freepass1"},
|
||||
)
|
||||
self.free_token = resp.get_json()["access_token"]
|
||||
|
||||
# Upgrade user to pro directly in DB for testing
|
||||
import sqlite3
|
||||
|
||||
db_path = os.environ["TURF_SAAS_DB"]
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO users (email, password_hash, plan) VALUES (?,?,?)",
|
||||
("pro_plan@example.com", "$2b$12$placeholder", "pro"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Login pro user using JWT created manually via app context
|
||||
with app.app_context():
|
||||
from flask_jwt_extended import create_access_token
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
row = conn.execute(
|
||||
"SELECT id FROM users WHERE email='pro_plan@example.com'"
|
||||
).fetchone()
|
||||
conn.close()
|
||||
self.pro_token = create_access_token(
|
||||
identity=str(row[0]),
|
||||
additional_claims={"plan": "pro", "email": "pro_plan@example.com"},
|
||||
)
|
||||
|
||||
def test_export_blocked_for_free(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/predictions/export",
|
||||
headers={"Authorization": f"Bearer {self.free_token}"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
data = resp.get_json()
|
||||
assert "Plan insuffisant" in data["error"]
|
||||
|
||||
def test_export_allowed_for_pro(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/predictions/export",
|
||||
headers={"Authorization": f"Bearer {self.pro_token}"},
|
||||
)
|
||||
# 503 is expected because no backend is running; 403 would be wrong
|
||||
assert resp.status_code != 403
|
||||
|
||||
def test_upgrade_info_shows_plans(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": f"Bearer {self.free_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "free" in data["plans"]
|
||||
assert "premium" in data["plans"]
|
||||
assert "pro" in data["plans"]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Rate limiting
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
def test_rate_limit_headers_present(self, client):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert "X-RateLimit-Limit" in resp.headers
|
||||
assert resp.headers["X-RateLimit-Limit"] == "100"
|
||||
|
||||
def test_rate_limit_remaining_decreases(self, client):
|
||||
r1 = client.get("/api/v1/health")
|
||||
r2 = client.get("/api/v1/health")
|
||||
rem1 = int(r1.headers.get("X-RateLimit-Remaining", 100))
|
||||
rem2 = int(r2.headers.get("X-RateLimit-Remaining", 100))
|
||||
assert rem2 <= rem1
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# DB module
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAuthDB:
|
||||
def test_tables_exist(self):
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(os.environ["TURF_SAAS_DB"])
|
||||
tables = {
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
}
|
||||
assert "users" in tables
|
||||
assert "subscriptions" in tables
|
||||
assert "refresh_tokens" in tables
|
||||
conn.close()
|
||||
|
||||
def test_get_db_returns_connection(self):
|
||||
from auth_db import get_db
|
||||
|
||||
db = get_db()
|
||||
assert db is not None
|
||||
db.close()
|
||||
@@ -1,333 +0,0 @@
|
||||
"""
|
||||
Tests ML Ensemble — HRT-32 Sprint 6-7
|
||||
Tests de régression, benchmark et latence pour le nouveau modèle ensemble.
|
||||
|
||||
Usage:
|
||||
pytest tests/test_ml_ensemble.py -v
|
||||
pytest tests/test_ml_ensemble.py -v -m regression
|
||||
pytest tests/test_ml_ensemble.py -v -m latency
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
BASE_URL = os.environ.get("APP_URL", "http://localhost:8790")
|
||||
DB_PATH = os.environ.get("DB_PATH", "/home/h3r7/turf_saas/turf.db")
|
||||
MODELS_DIR = Path("/home/h3r7/turf_saas/models")
|
||||
ENSEMBLE_PATH = MODELS_DIR / "ensemble_top3.pkl"
|
||||
BENCHMARK_PATH = MODELS_DIR / "benchmark_report.json"
|
||||
|
||||
|
||||
# ─── Fixtures ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def ensemble_model():
|
||||
"""Load ensemble model (skip tests if not yet trained)."""
|
||||
if not ENSEMBLE_PATH.exists():
|
||||
pytest.skip(
|
||||
f"Ensemble model not found at {ENSEMBLE_PATH}. Run train_ensemble.py first."
|
||||
)
|
||||
with open(ENSEMBLE_PATH, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def benchmark_report():
|
||||
"""Load benchmark report (skip if not generated)."""
|
||||
if not BENCHMARK_PATH.exists():
|
||||
pytest.skip(f"Benchmark report not found at {BENCHMARK_PATH}.")
|
||||
with open(BENCHMARK_PATH) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def holdout_data():
|
||||
"""Load holdout slice (last 20% temporal) for regression tests."""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
df = pd.read_sql_query(
|
||||
"""
|
||||
SELECT p.*, c.distance, c.discipline, c.specialite,
|
||||
c.nb_declares_partants, c.montant_prix, c.penetrometre_intitule
|
||||
FROM pmu_partants p
|
||||
LEFT JOIN pmu_courses c ON p.date_programme=c.date_programme
|
||||
AND p.num_reunion=c.num_reunion AND p.num_course=c.num_course
|
||||
WHERE p.ordre_arrivee > 0
|
||||
ORDER BY p.date_programme, p.num_reunion, p.num_course, p.num_pmu
|
||||
""",
|
||||
conn,
|
||||
)
|
||||
conn.close()
|
||||
n = len(df)
|
||||
cutoff = int(n * 0.80)
|
||||
return df.iloc[cutoff:].copy()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def predict_v2():
|
||||
"""Import predict_v2 module."""
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"predict_v2", "/home/h3r7/turf_saas/predict_v2.py"
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ─── Model Existence Tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModelFiles:
|
||||
"""Verify all expected model files exist."""
|
||||
|
||||
def test_ensemble_model_exists(self):
|
||||
assert ENSEMBLE_PATH.exists(), f"Ensemble model missing: {ENSEMBLE_PATH}"
|
||||
|
||||
def test_benchmark_report_exists(self):
|
||||
assert BENCHMARK_PATH.exists(), f"Benchmark report missing: {BENCHMARK_PATH}"
|
||||
|
||||
def test_models_dir_contains_expected_files(self):
|
||||
expected = ["ensemble_top3.pkl", "benchmark_report.json", "benchmark_report.md"]
|
||||
for fname in expected:
|
||||
assert (MODELS_DIR / fname).exists(), f"Missing: {MODELS_DIR / fname}"
|
||||
|
||||
|
||||
# ─── Benchmark Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBenchmark:
|
||||
"""Validate benchmark metrics from the training report."""
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_ensemble_beats_baseline_or_meets_threshold(self, benchmark_report):
|
||||
"""Ensemble Precision@3 must be >= baseline XGBoost."""
|
||||
baseline = benchmark_report["baseline"]["precision_at3"]
|
||||
ensemble = benchmark_report["ensemble"]["precision_at3"]
|
||||
assert ensemble >= baseline, (
|
||||
f"Ensemble Precision@3 {ensemble:.4f} < baseline {baseline:.4f}"
|
||||
)
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_ensemble_auc_above_random(self, benchmark_report):
|
||||
"""Ensemble AUC must be > 0.60 (significantly above random 0.50)."""
|
||||
auc = benchmark_report["ensemble"]["auc"]
|
||||
assert auc > 0.60, f"Ensemble AUC {auc:.4f} <= 0.60"
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_optuna_ran_minimum_trials(self, benchmark_report):
|
||||
"""Optuna must have run at least 100 trials per model."""
|
||||
n_trials = benchmark_report["optuna"]["n_trials"]
|
||||
assert n_trials >= 100, f"Only {n_trials} Optuna trials (minimum 100 required)"
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_no_precision_regression(self, benchmark_report):
|
||||
"""Ensemble Precision@3 must not be below naive random baseline (~30%)."""
|
||||
ensemble_p3 = benchmark_report["ensemble"]["precision_at3"]
|
||||
assert ensemble_p3 >= 0.30, (
|
||||
f"Precision@3 {ensemble_p3:.4f} is below random baseline (~0.30)"
|
||||
)
|
||||
|
||||
def test_benchmark_has_all_required_models(self, benchmark_report):
|
||||
"""Benchmark must include results for all 3 models."""
|
||||
required = {"xgboost", "lightgbm", "mlp"}
|
||||
found = set(benchmark_report.get("individual_models", {}).keys())
|
||||
missing = required - found
|
||||
assert not missing, f"Missing model benchmarks: {missing}"
|
||||
|
||||
|
||||
# ─── Regression Tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPrecisionRegression:
|
||||
"""Holdout regression: ensure precision doesn't degrade."""
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_precision_at3_on_holdout(self, ensemble_model, holdout_data):
|
||||
"""Precision@3 on holdout must be above naive baseline."""
|
||||
from predict_v2 import build_feature_df, FEATURE_COLS
|
||||
|
||||
df = holdout_data.copy()
|
||||
df["top3"] = (df["ordre_arrivee"] <= 3).astype(int)
|
||||
|
||||
partants = df.to_dict("records")
|
||||
feature_df = build_feature_df(partants)
|
||||
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
||||
X = feature_df[available].fillna(0)
|
||||
|
||||
proba = ensemble_model.predict_proba(X)[:, 1]
|
||||
|
||||
# Per-race Precision@3
|
||||
tmp = df[["date_programme", "num_reunion", "num_course"]].copy()
|
||||
tmp["proba"] = proba
|
||||
tmp["actual"] = df["top3"].values
|
||||
|
||||
precisions = []
|
||||
for _, group in tmp.groupby(["date_programme", "num_reunion", "num_course"]):
|
||||
if len(group) >= 3:
|
||||
top3_pred = group.nlargest(3, "proba")
|
||||
precisions.append(top3_pred["actual"].sum() / 3.0)
|
||||
|
||||
p_at3 = float(np.mean(precisions)) if precisions else 0.0
|
||||
print(f"\n Holdout Precision@3: {p_at3:.4f} over {len(precisions)} races")
|
||||
|
||||
# Must beat random baseline (30%)
|
||||
assert p_at3 >= 0.30, f"Holdout Precision@3 {p_at3:.4f} < 0.30"
|
||||
|
||||
@pytest.mark.regression
|
||||
def test_no_all_zero_predictions(self, ensemble_model, holdout_data):
|
||||
"""Ensemble must not predict 0 probability for all horses."""
|
||||
from predict_v2 import build_feature_df, FEATURE_COLS
|
||||
|
||||
partants = holdout_data.head(50).to_dict("records")
|
||||
feature_df = build_feature_df(partants)
|
||||
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
||||
X = feature_df[available].fillna(0)
|
||||
|
||||
proba = ensemble_model.predict_proba(X)[:, 1]
|
||||
assert proba.max() > 0.01, "All predictions are near 0 — model appears broken"
|
||||
assert proba.std() > 0.01, (
|
||||
"All predictions have identical probability — no discrimination"
|
||||
)
|
||||
|
||||
|
||||
# ─── Latency Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPredictionLatency:
|
||||
"""Prediction latency must be < 200ms per race."""
|
||||
|
||||
@pytest.mark.latency
|
||||
def test_single_race_latency(self, ensemble_model, holdout_data):
|
||||
"""Prediction for a single race (<=20 horses) must be < 200ms."""
|
||||
from predict_v2 import build_feature_df, FEATURE_COLS
|
||||
|
||||
# Take one race
|
||||
first_race = (
|
||||
holdout_data.groupby(["date_programme", "num_reunion", "num_course"])
|
||||
.first()
|
||||
.reset_index()
|
||||
.iloc[0]
|
||||
)
|
||||
mask = (
|
||||
(holdout_data["date_programme"] == first_race["date_programme"])
|
||||
& (holdout_data["num_reunion"] == first_race["num_reunion"])
|
||||
& (holdout_data["num_course"] == first_race["num_course"])
|
||||
)
|
||||
race_df = holdout_data[mask]
|
||||
partants = race_df.to_dict("records")
|
||||
|
||||
# Warm-up
|
||||
feature_df = build_feature_df(partants)
|
||||
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
||||
X = feature_df[available].fillna(0)
|
||||
ensemble_model.predict_proba(X)
|
||||
|
||||
# Timed run
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(10):
|
||||
ensemble_model.predict_proba(X)
|
||||
elapsed_ms = (time.perf_counter() - t0) / 10 * 1000
|
||||
|
||||
print(f"\n Single-race latency: {elapsed_ms:.2f} ms ({len(partants)} horses)")
|
||||
assert elapsed_ms < 200, (
|
||||
f"Prediction latency {elapsed_ms:.1f} ms exceeds 200 ms limit"
|
||||
)
|
||||
|
||||
@pytest.mark.latency
|
||||
def test_full_day_latency(self, ensemble_model, holdout_data):
|
||||
"""Prediction for a full day (all races) must complete < 5 seconds."""
|
||||
from predict_v2 import build_feature_df, FEATURE_COLS
|
||||
|
||||
# Take one day
|
||||
day = holdout_data["date_programme"].iloc[0]
|
||||
day_df = holdout_data[holdout_data["date_programme"] == day]
|
||||
partants = day_df.to_dict("records")
|
||||
|
||||
feature_df = build_feature_df(partants)
|
||||
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
||||
X = feature_df[available].fillna(0)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
proba = ensemble_model.predict_proba(X)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
print(
|
||||
f"\n Full day latency: {elapsed_ms:.2f} ms ({len(partants)} horses, {day})"
|
||||
)
|
||||
assert elapsed_ms < 5000, (
|
||||
f"Full-day prediction {elapsed_ms:.0f} ms exceeds 5s limit"
|
||||
)
|
||||
|
||||
|
||||
# ─── API Endpoint Tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestV1PredictionsAPI:
|
||||
"""Tests for the new /api/v1/predictions endpoint."""
|
||||
|
||||
def _api_available(self):
|
||||
try:
|
||||
requests.get(f"{BASE_URL}/api/v1/model/status", timeout=3)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@pytest.mark.api
|
||||
def test_model_status_endpoint(self):
|
||||
"""GET /api/v1/model/status returns valid JSON."""
|
||||
if not self._api_available():
|
||||
pytest.skip("API server not running")
|
||||
resp = requests.get(f"{BASE_URL}/api/v1/model/status", timeout=10)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "ensemble_available" in data
|
||||
|
||||
@pytest.mark.api
|
||||
def test_v1_predictions_no_500(self):
|
||||
"""GET /api/v1/predictions must not return 5xx."""
|
||||
if not self._api_available():
|
||||
pytest.skip("API server not running")
|
||||
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
||||
assert resp.status_code < 500, (
|
||||
f"Server error: {resp.status_code}\n{resp.text[:200]}"
|
||||
)
|
||||
|
||||
@pytest.mark.api
|
||||
def test_v1_predictions_returns_json(self):
|
||||
"""GET /api/v1/predictions returns valid JSON with expected keys."""
|
||||
if not self._api_available():
|
||||
pytest.skip("API server not running")
|
||||
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
||||
if resp.status_code == 503:
|
||||
pytest.skip("Ensemble model not yet deployed")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "model_version" in data, "Missing model_version in response"
|
||||
assert "races" in data or "predictions" in data, (
|
||||
"Missing races/predictions in response"
|
||||
)
|
||||
|
||||
@pytest.mark.api
|
||||
def test_v1_predictions_latency(self):
|
||||
"""GET /api/v1/predictions must respond in < 3 seconds."""
|
||||
if not self._api_available():
|
||||
pytest.skip("API server not running")
|
||||
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
||||
if resp.status_code == 503:
|
||||
pytest.skip("Ensemble model not yet deployed")
|
||||
# Check API-reported latency
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
latency = data.get("latency_ms", 0)
|
||||
assert latency < 3000, f"API latency {latency:.0f} ms > 3000 ms"
|
||||
@@ -1,205 +0,0 @@
|
||||
"""
|
||||
Tests de smoke — SaaS Turf Prédictions IA
|
||||
Sprint 8 — QA, Beta Fermee, Go/No-Go
|
||||
Ticket: HRT-34
|
||||
|
||||
Vérifications rapides sur l'état de l'application :
|
||||
- Routes de base accessibles
|
||||
- API répond en JSON valide
|
||||
- Base de données accessible
|
||||
- Pas d'erreurs 5xx sur les routes principales
|
||||
|
||||
Ces tests peuvent tourner SANS infra complète (pas besoin de HRT-31/33).
|
||||
Exécuter sur l'app actuelle en staging ou localhost.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import os
|
||||
import json
|
||||
|
||||
BASE_URL = os.environ.get("APP_URL", "http://localhost:8792")
|
||||
|
||||
# Routes qui doivent retourner 200 (publiques)
|
||||
PUBLIC_ROUTES_200 = [
|
||||
"/",
|
||||
"/dashboard",
|
||||
]
|
||||
|
||||
# Routes API qui doivent retourner 200 ou 401 (jamais 500)
|
||||
API_ROUTES_NO_500 = [
|
||||
"/api",
|
||||
"/api/races",
|
||||
"/api/scoring",
|
||||
"/api/weather",
|
||||
"/api/odds_history",
|
||||
]
|
||||
|
||||
|
||||
class TestSmoke:
|
||||
"""Tests de smoke : l'app répond correctement aux requêtes de base."""
|
||||
|
||||
@pytest.mark.smoke
|
||||
@pytest.mark.parametrize("route", PUBLIC_ROUTES_200)
|
||||
def test_route_publique_accessible(self, route):
|
||||
"""Les routes publiques doivent retourner 200."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}{route}", timeout=10)
|
||||
assert resp.status_code in (200, 304), (
|
||||
f"Route publique inaccessible: {route} → {resp.status_code}"
|
||||
)
|
||||
assert len(resp.content) > 0, f"Réponse vide sur {route}"
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(
|
||||
f"App non accessible sur {BASE_URL} — vérifier que le serveur est démarré"
|
||||
)
|
||||
|
||||
@pytest.mark.smoke
|
||||
@pytest.mark.parametrize("route", API_ROUTES_NO_500)
|
||||
def test_api_pas_derreur_serveur(self, route):
|
||||
"""Les routes API ne doivent jamais retourner 5xx."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}{route}", timeout=10)
|
||||
assert resp.status_code < 500, (
|
||||
f"Erreur serveur sur {route}: {resp.status_code}\n{resp.text[:200]}"
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(f"App non accessible sur {BASE_URL}")
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_api_today_retourne_json(self):
|
||||
"""L'endpoint principal /api doit retourner du JSON valide."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api", timeout=10)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
assert data is not None, "Réponse JSON nulle"
|
||||
assert isinstance(data, (list, dict)), (
|
||||
f"Type de réponse inattendu: {type(data)}"
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(f"App non accessible sur {BASE_URL}")
|
||||
except json.JSONDecodeError as e:
|
||||
pytest.fail(f"/api ne retourne pas du JSON valide: {e}")
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_contenu_html_portail_valide(self):
|
||||
"""Le portail doit contenir un titre et du contenu significatif."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/", timeout=10)
|
||||
if resp.status_code == 200:
|
||||
content = resp.text
|
||||
assert "<html" in content.lower() or "<!doctype" in content.lower(), (
|
||||
"La page d'accueil ne retourne pas du HTML"
|
||||
)
|
||||
assert len(content) > 500, (
|
||||
f"Page d'accueil trop courte ({len(content)} chars)"
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(f"App non accessible sur {BASE_URL}")
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_headers_securite_presents(self):
|
||||
"""Les headers de sécurité de base doivent être présents."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/", timeout=10)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
|
||||
# En production (derrière Nginx), ces headers doivent être présents
|
||||
# En dev direct Flask, ils peuvent être absents — on note seulement
|
||||
security_headers = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": None, # SAMEORIGIN ou DENY
|
||||
"X-XSS-Protection": None,
|
||||
}
|
||||
|
||||
missing = []
|
||||
for header, expected_value in security_headers.items():
|
||||
if header not in resp.headers:
|
||||
missing.append(header)
|
||||
|
||||
if missing:
|
||||
# Warning seulement — bloquant uniquement en prod derrière Nginx
|
||||
pytest.warns(UserWarning, match=r".*") if False else None
|
||||
print(f"⚠️ Headers sécurité manquants (requis en prod): {missing}")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(f"App non accessible sur {BASE_URL}")
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_api_races_format_reponse(self):
|
||||
"""L'endpoint /api/races doit retourner une liste structurée."""
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/api/races", timeout=10)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
assert isinstance(data, (list, dict)), (
|
||||
f"Format inattendu pour /api/races: {type(data)}"
|
||||
)
|
||||
if isinstance(data, list) and len(data) > 0:
|
||||
first = data[0]
|
||||
# Vérifier la présence de champs clés
|
||||
expected_fields = ["date", "course", "hippodrome"]
|
||||
present = [
|
||||
f
|
||||
for f in expected_fields
|
||||
if f in first
|
||||
or any(k in first for k in [f, f.upper(), f.replace("_", "")])
|
||||
]
|
||||
assert len(present) > 0, (
|
||||
f"Champs attendus absents de /api/races. Champs présents: {list(first.keys())}"
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(f"App non accessible sur {BASE_URL}")
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail("/api/races ne retourne pas du JSON valide")
|
||||
|
||||
|
||||
class TestSmokeDatabase:
|
||||
"""Tests smoke sur la base de données."""
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_base_donnees_accessible(self):
|
||||
"""La base de données SQLite doit être accessible et contenir des données."""
|
||||
import sqlite3
|
||||
|
||||
db_path = "/home/h3r7/turf_saas/turf_saas.db"
|
||||
|
||||
if not __import__("os").path.exists(db_path):
|
||||
pytest.skip(f"Base de données non trouvée: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
# Vérifier que les tables essentielles existent
|
||||
c.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = {row[0] for row in c.fetchall()}
|
||||
conn.close()
|
||||
|
||||
expected_tables = ["predictions", "results"]
|
||||
for table in expected_tables:
|
||||
assert table in tables, (
|
||||
f"Table manquante dans la BDD: {table}. Tables présentes: {tables}"
|
||||
)
|
||||
|
||||
@pytest.mark.smoke
|
||||
def test_donnees_predictions_disponibles(self):
|
||||
"""Des prédictions doivent être présentes dans la BDD."""
|
||||
import sqlite3
|
||||
|
||||
db_path = "/home/h3r7/turf_saas/turf_saas.db"
|
||||
|
||||
if not __import__("os").path.exists(db_path):
|
||||
pytest.skip(f"Base de données non trouvée: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
c = conn.cursor()
|
||||
c.execute("SELECT COUNT(*) FROM predictions")
|
||||
count = c.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
# Au moins quelques données pour que le SaaS soit utile
|
||||
assert count >= 0, "Table predictions accessible"
|
||||
if count == 0:
|
||||
print("⚠️ Aucune prédiction en base — le scraper doit être lancé")
|
||||
1007
train_ensemble.py
1007
train_ensemble.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user