- train_ensemble.py: full training pipeline with 100-trial Optuna studies for XGBoost and LightGBM, MLP (256-128-64), SHAP feature selection, weighted soft-voting ensemble, benchmark report generation - predict_v2.py: production prediction module with model cache invalidation - combined_api.py: add /api/v1/predictions, /api/v1/model/status, /api/v1/model/invalidate-cache endpoints using ensemble model - tests/test_ml_ensemble.py: regression, latency and API tests Baseline XGBoost Precision@3: 0.5287 (holdout 20% temporal) Deploy threshold: +5% = 0.5551 Co-Authored-By: Paperclip <noreply@paperclip.ing>
334 lines
13 KiB
Python
334 lines
13 KiB
Python
"""
|
|
Tests ML Ensemble — HRT-32 Sprint 6-7
|
|
Tests de régression, benchmark et latence pour le nouveau modèle ensemble.
|
|
|
|
Usage:
|
|
pytest tests/test_ml_ensemble.py -v
|
|
pytest tests/test_ml_ensemble.py -v -m regression
|
|
pytest tests/test_ml_ensemble.py -v -m latency
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import pickle
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
import requests
|
|
|
|
BASE_URL = os.environ.get("APP_URL", "http://localhost:8790")
|
|
DB_PATH = os.environ.get("DB_PATH", "/home/h3r7/turf_saas/turf.db")
|
|
MODELS_DIR = Path("/home/h3r7/turf_saas/models")
|
|
ENSEMBLE_PATH = MODELS_DIR / "ensemble_top3.pkl"
|
|
BENCHMARK_PATH = MODELS_DIR / "benchmark_report.json"
|
|
|
|
|
|
# ─── Fixtures ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def ensemble_model():
|
|
"""Load ensemble model (skip tests if not yet trained)."""
|
|
if not ENSEMBLE_PATH.exists():
|
|
pytest.skip(
|
|
f"Ensemble model not found at {ENSEMBLE_PATH}. Run train_ensemble.py first."
|
|
)
|
|
with open(ENSEMBLE_PATH, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def benchmark_report():
|
|
"""Load benchmark report (skip if not generated)."""
|
|
if not BENCHMARK_PATH.exists():
|
|
pytest.skip(f"Benchmark report not found at {BENCHMARK_PATH}.")
|
|
with open(BENCHMARK_PATH) as f:
|
|
return json.load(f)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def holdout_data():
|
|
"""Load holdout slice (last 20% temporal) for regression tests."""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
df = pd.read_sql_query(
|
|
"""
|
|
SELECT p.*, c.distance, c.discipline, c.specialite,
|
|
c.nb_declares_partants, c.montant_prix, c.penetrometre_intitule
|
|
FROM pmu_partants p
|
|
LEFT JOIN pmu_courses c ON p.date_programme=c.date_programme
|
|
AND p.num_reunion=c.num_reunion AND p.num_course=c.num_course
|
|
WHERE p.ordre_arrivee > 0
|
|
ORDER BY p.date_programme, p.num_reunion, p.num_course, p.num_pmu
|
|
""",
|
|
conn,
|
|
)
|
|
conn.close()
|
|
n = len(df)
|
|
cutoff = int(n * 0.80)
|
|
return df.iloc[cutoff:].copy()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def predict_v2():
|
|
"""Import predict_v2 module."""
|
|
import importlib.util
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
"predict_v2", "/home/h3r7/turf_saas/predict_v2.py"
|
|
)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
# ─── Model Existence Tests ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestModelFiles:
|
|
"""Verify all expected model files exist."""
|
|
|
|
def test_ensemble_model_exists(self):
|
|
assert ENSEMBLE_PATH.exists(), f"Ensemble model missing: {ENSEMBLE_PATH}"
|
|
|
|
def test_benchmark_report_exists(self):
|
|
assert BENCHMARK_PATH.exists(), f"Benchmark report missing: {BENCHMARK_PATH}"
|
|
|
|
def test_models_dir_contains_expected_files(self):
|
|
expected = ["ensemble_top3.pkl", "benchmark_report.json", "benchmark_report.md"]
|
|
for fname in expected:
|
|
assert (MODELS_DIR / fname).exists(), f"Missing: {MODELS_DIR / fname}"
|
|
|
|
|
|
# ─── Benchmark Tests ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestBenchmark:
|
|
"""Validate benchmark metrics from the training report."""
|
|
|
|
@pytest.mark.regression
|
|
def test_ensemble_beats_baseline_or_meets_threshold(self, benchmark_report):
|
|
"""Ensemble Precision@3 must be >= baseline XGBoost."""
|
|
baseline = benchmark_report["baseline"]["precision_at3"]
|
|
ensemble = benchmark_report["ensemble"]["precision_at3"]
|
|
assert ensemble >= baseline, (
|
|
f"Ensemble Precision@3 {ensemble:.4f} < baseline {baseline:.4f}"
|
|
)
|
|
|
|
@pytest.mark.regression
|
|
def test_ensemble_auc_above_random(self, benchmark_report):
|
|
"""Ensemble AUC must be > 0.60 (significantly above random 0.50)."""
|
|
auc = benchmark_report["ensemble"]["auc"]
|
|
assert auc > 0.60, f"Ensemble AUC {auc:.4f} <= 0.60"
|
|
|
|
@pytest.mark.regression
|
|
def test_optuna_ran_minimum_trials(self, benchmark_report):
|
|
"""Optuna must have run at least 100 trials per model."""
|
|
n_trials = benchmark_report["optuna"]["n_trials"]
|
|
assert n_trials >= 100, f"Only {n_trials} Optuna trials (minimum 100 required)"
|
|
|
|
@pytest.mark.regression
|
|
def test_no_precision_regression(self, benchmark_report):
|
|
"""Ensemble Precision@3 must not be below naive random baseline (~30%)."""
|
|
ensemble_p3 = benchmark_report["ensemble"]["precision_at3"]
|
|
assert ensemble_p3 >= 0.30, (
|
|
f"Precision@3 {ensemble_p3:.4f} is below random baseline (~0.30)"
|
|
)
|
|
|
|
def test_benchmark_has_all_required_models(self, benchmark_report):
|
|
"""Benchmark must include results for all 3 models."""
|
|
required = {"xgboost", "lightgbm", "mlp"}
|
|
found = set(benchmark_report.get("individual_models", {}).keys())
|
|
missing = required - found
|
|
assert not missing, f"Missing model benchmarks: {missing}"
|
|
|
|
|
|
# ─── Regression Tests ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestPrecisionRegression:
|
|
"""Holdout regression: ensure precision doesn't degrade."""
|
|
|
|
@pytest.mark.regression
|
|
def test_precision_at3_on_holdout(self, ensemble_model, holdout_data):
|
|
"""Precision@3 on holdout must be above naive baseline."""
|
|
from predict_v2 import build_feature_df, FEATURE_COLS
|
|
|
|
df = holdout_data.copy()
|
|
df["top3"] = (df["ordre_arrivee"] <= 3).astype(int)
|
|
|
|
partants = df.to_dict("records")
|
|
feature_df = build_feature_df(partants)
|
|
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
|
X = feature_df[available].fillna(0)
|
|
|
|
proba = ensemble_model.predict_proba(X)[:, 1]
|
|
|
|
# Per-race Precision@3
|
|
tmp = df[["date_programme", "num_reunion", "num_course"]].copy()
|
|
tmp["proba"] = proba
|
|
tmp["actual"] = df["top3"].values
|
|
|
|
precisions = []
|
|
for _, group in tmp.groupby(["date_programme", "num_reunion", "num_course"]):
|
|
if len(group) >= 3:
|
|
top3_pred = group.nlargest(3, "proba")
|
|
precisions.append(top3_pred["actual"].sum() / 3.0)
|
|
|
|
p_at3 = float(np.mean(precisions)) if precisions else 0.0
|
|
print(f"\n Holdout Precision@3: {p_at3:.4f} over {len(precisions)} races")
|
|
|
|
# Must beat random baseline (30%)
|
|
assert p_at3 >= 0.30, f"Holdout Precision@3 {p_at3:.4f} < 0.30"
|
|
|
|
@pytest.mark.regression
|
|
def test_no_all_zero_predictions(self, ensemble_model, holdout_data):
|
|
"""Ensemble must not predict 0 probability for all horses."""
|
|
from predict_v2 import build_feature_df, FEATURE_COLS
|
|
|
|
partants = holdout_data.head(50).to_dict("records")
|
|
feature_df = build_feature_df(partants)
|
|
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
|
X = feature_df[available].fillna(0)
|
|
|
|
proba = ensemble_model.predict_proba(X)[:, 1]
|
|
assert proba.max() > 0.01, "All predictions are near 0 — model appears broken"
|
|
assert proba.std() > 0.01, (
|
|
"All predictions have identical probability — no discrimination"
|
|
)
|
|
|
|
|
|
# ─── Latency Tests ────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestPredictionLatency:
|
|
"""Prediction latency must be < 200ms per race."""
|
|
|
|
@pytest.mark.latency
|
|
def test_single_race_latency(self, ensemble_model, holdout_data):
|
|
"""Prediction for a single race (<=20 horses) must be < 200ms."""
|
|
from predict_v2 import build_feature_df, FEATURE_COLS
|
|
|
|
# Take one race
|
|
first_race = (
|
|
holdout_data.groupby(["date_programme", "num_reunion", "num_course"])
|
|
.first()
|
|
.reset_index()
|
|
.iloc[0]
|
|
)
|
|
mask = (
|
|
(holdout_data["date_programme"] == first_race["date_programme"])
|
|
& (holdout_data["num_reunion"] == first_race["num_reunion"])
|
|
& (holdout_data["num_course"] == first_race["num_course"])
|
|
)
|
|
race_df = holdout_data[mask]
|
|
partants = race_df.to_dict("records")
|
|
|
|
# Warm-up
|
|
feature_df = build_feature_df(partants)
|
|
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
|
X = feature_df[available].fillna(0)
|
|
ensemble_model.predict_proba(X)
|
|
|
|
# Timed run
|
|
t0 = time.perf_counter()
|
|
for _ in range(10):
|
|
ensemble_model.predict_proba(X)
|
|
elapsed_ms = (time.perf_counter() - t0) / 10 * 1000
|
|
|
|
print(f"\n Single-race latency: {elapsed_ms:.2f} ms ({len(partants)} horses)")
|
|
assert elapsed_ms < 200, (
|
|
f"Prediction latency {elapsed_ms:.1f} ms exceeds 200 ms limit"
|
|
)
|
|
|
|
@pytest.mark.latency
|
|
def test_full_day_latency(self, ensemble_model, holdout_data):
|
|
"""Prediction for a full day (all races) must complete < 5 seconds."""
|
|
from predict_v2 import build_feature_df, FEATURE_COLS
|
|
|
|
# Take one day
|
|
day = holdout_data["date_programme"].iloc[0]
|
|
day_df = holdout_data[holdout_data["date_programme"] == day]
|
|
partants = day_df.to_dict("records")
|
|
|
|
feature_df = build_feature_df(partants)
|
|
available = [c for c in FEATURE_COLS if c in feature_df.columns]
|
|
X = feature_df[available].fillna(0)
|
|
|
|
t0 = time.perf_counter()
|
|
proba = ensemble_model.predict_proba(X)
|
|
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
|
|
print(
|
|
f"\n Full day latency: {elapsed_ms:.2f} ms ({len(partants)} horses, {day})"
|
|
)
|
|
assert elapsed_ms < 5000, (
|
|
f"Full-day prediction {elapsed_ms:.0f} ms exceeds 5s limit"
|
|
)
|
|
|
|
|
|
# ─── API Endpoint Tests ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestV1PredictionsAPI:
|
|
"""Tests for the new /api/v1/predictions endpoint."""
|
|
|
|
def _api_available(self):
|
|
try:
|
|
requests.get(f"{BASE_URL}/api/v1/model/status", timeout=3)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
@pytest.mark.api
|
|
def test_model_status_endpoint(self):
|
|
"""GET /api/v1/model/status returns valid JSON."""
|
|
if not self._api_available():
|
|
pytest.skip("API server not running")
|
|
resp = requests.get(f"{BASE_URL}/api/v1/model/status", timeout=10)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "ensemble_available" in data
|
|
|
|
@pytest.mark.api
|
|
def test_v1_predictions_no_500(self):
|
|
"""GET /api/v1/predictions must not return 5xx."""
|
|
if not self._api_available():
|
|
pytest.skip("API server not running")
|
|
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
|
assert resp.status_code < 500, (
|
|
f"Server error: {resp.status_code}\n{resp.text[:200]}"
|
|
)
|
|
|
|
@pytest.mark.api
|
|
def test_v1_predictions_returns_json(self):
|
|
"""GET /api/v1/predictions returns valid JSON with expected keys."""
|
|
if not self._api_available():
|
|
pytest.skip("API server not running")
|
|
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
|
if resp.status_code == 503:
|
|
pytest.skip("Ensemble model not yet deployed")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "model_version" in data, "Missing model_version in response"
|
|
assert "races" in data or "predictions" in data, (
|
|
"Missing races/predictions in response"
|
|
)
|
|
|
|
@pytest.mark.api
|
|
def test_v1_predictions_latency(self):
|
|
"""GET /api/v1/predictions must respond in < 3 seconds."""
|
|
if not self._api_available():
|
|
pytest.skip("API server not running")
|
|
resp = requests.get(f"{BASE_URL}/api/v1/predictions", timeout=30)
|
|
if resp.status_code == 503:
|
|
pytest.skip("Ensemble model not yet deployed")
|
|
# Check API-reported latency
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
latency = data.get("latency_ms", 0)
|
|
assert latency < 3000, f"API latency {latency:.0f} ms > 3000 ms"
|