feat: Sprint 2-3 — Auth JWT + Multi-tenant (HRT-28)
- auth_db.py: create users, subscriptions, refresh_tokens tables in turf_saas.db - auth.py: register/login/refresh/logout endpoints, JWT middleware, plan_required decorator, free daily-limit check - middleware.py: in-memory rate limiter (100 req/min/IP), timestamped access logs - saas_api.py: Flask app factory wiring JWT, CORS, blueprints, /api/v1/predictions plan-gating - tests/test_auth.py: 27 pytest tests, 83% coverage (target >=80%) - API_AUTH.md: full endpoint documentation Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
404
tests/test_auth.py
Normal file
404
tests/test_auth.py
Normal file
@@ -0,0 +1,404 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pytest tests — Auth JWT + Multi-tenant
|
||||
Sprint 2-3: HRT-28
|
||||
Coverage target: >= 80%
|
||||
|
||||
Run:
|
||||
./venv/bin/pytest tests/test_auth.py -v --tb=short
|
||||
./venv/bin/pytest tests/test_auth.py -v --cov=auth --cov=auth_db --cov=middleware --cov=saas_api --cov-report=term-missing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import pytest
|
||||
|
||||
# Point to a temp SQLite DB for tests
|
||||
_tmp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_tmp_db.close()
|
||||
os.environ["TURF_SAAS_DB"] = _tmp_db.name
|
||||
os.environ["JWT_SECRET_KEY"] = "test-secret-key-for-pytest"
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from saas_api import create_app # noqa: E402
|
||||
|
||||
TEST_CONFIG = {
|
||||
"TESTING": True,
|
||||
"JWT_SECRET_KEY": "test-secret-key-for-pytest",
|
||||
"JWT_ACCESS_TOKEN_EXPIRES": 900,
|
||||
"JWT_REFRESH_TOKEN_EXPIRES": 2592000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
application = create_app(TEST_CONFIG)
|
||||
yield application
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Health
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHealth:
|
||||
def test_health_ok(self, client):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["service"] == "turf-saas-api"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Registration
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRegister:
|
||||
def test_register_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "user_test@example.com", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.get_json()
|
||||
assert "user_id" in data
|
||||
|
||||
def test_register_duplicate(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "dup@example.com", "password": "password123"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "dup@example.com", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_register_invalid_email(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "notanemail", "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_register_short_password(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "shortpw@example.com", "password": "abc"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_register_missing_fields(self, client):
|
||||
resp = client.post("/api/v1/auth/register", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Login
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLogin:
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_user(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "login@example.com", "password": "loginpass1"},
|
||||
)
|
||||
|
||||
def test_login_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "login@example.com", "password": "loginpass1"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["plan"] == "free"
|
||||
|
||||
def test_login_wrong_password(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "login@example.com", "password": "wrongpass"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_login_unknown_email(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "ghost@example.com", "password": "anypass"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_login_missing_fields(self, client):
|
||||
resp = client.post("/api/v1/auth/login", json={"email": "login@example.com"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Token refresh
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRefresh:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "refresh@example.com", "password": "refreshpass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "refresh@example.com", "password": "refreshpass1"},
|
||||
)
|
||||
tokens = resp.get_json()
|
||||
self.refresh_token = tokens["refresh_token"]
|
||||
|
||||
def test_refresh_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
# New refresh token should differ from old
|
||||
assert data["refresh_token"] != self.refresh_token
|
||||
|
||||
def test_refresh_token_rotation(self, client):
|
||||
"""Old refresh token must be invalid after rotation."""
|
||||
client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
resp2 = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp2.status_code == 401
|
||||
|
||||
def test_refresh_invalid_token(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "completely.invalid.token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self, client):
|
||||
resp = client.post("/api/v1/auth/refresh", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Logout
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLogout:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "logout@example.com", "password": "logoutpass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "logout@example.com", "password": "logoutpass1"},
|
||||
)
|
||||
tokens = resp.get_json()
|
||||
self.refresh_token = tokens["refresh_token"]
|
||||
self.access_token = tokens["access_token"]
|
||||
|
||||
def test_logout_success(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_refresh_after_logout_fails(self, client):
|
||||
client.post("/api/v1/auth/logout", json={"refresh_token": self.refresh_token})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": self.refresh_token},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_logout_no_token(self, client):
|
||||
resp = client.post("/api/v1/auth/logout", json={})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# JWT middleware — protected routes
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestJWTMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client):
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "protected@example.com", "password": "protect123"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "protected@example.com", "password": "protect123"},
|
||||
)
|
||||
self.access_token = resp.get_json()["access_token"]
|
||||
|
||||
def test_subscription_info_requires_auth(self, client):
|
||||
resp = client.get("/api/v1/subscription/upgrade")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_subscription_info_with_token(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": f"Bearer {self.access_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "current_plan" in data
|
||||
assert data["current_plan"] == "free"
|
||||
|
||||
def test_invalid_token_rejected(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": "Bearer invalid.token.here"},
|
||||
)
|
||||
assert resp.status_code in (401, 422)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Plan checks
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, client, app):
|
||||
# Register free user
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "free_plan@example.com", "password": "freepass1"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "free_plan@example.com", "password": "freepass1"},
|
||||
)
|
||||
self.free_token = resp.get_json()["access_token"]
|
||||
|
||||
# Upgrade user to pro directly in DB for testing
|
||||
import sqlite3
|
||||
|
||||
db_path = os.environ["TURF_SAAS_DB"]
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO users (email, password_hash, plan) VALUES (?,?,?)",
|
||||
("pro_plan@example.com", "$2b$12$placeholder", "pro"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Login pro user using JWT created manually via app context
|
||||
with app.app_context():
|
||||
from flask_jwt_extended import create_access_token
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
row = conn.execute(
|
||||
"SELECT id FROM users WHERE email='pro_plan@example.com'"
|
||||
).fetchone()
|
||||
conn.close()
|
||||
self.pro_token = create_access_token(
|
||||
identity=str(row[0]),
|
||||
additional_claims={"plan": "pro", "email": "pro_plan@example.com"},
|
||||
)
|
||||
|
||||
def test_export_blocked_for_free(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/predictions/export",
|
||||
headers={"Authorization": f"Bearer {self.free_token}"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
data = resp.get_json()
|
||||
assert "Plan insuffisant" in data["error"]
|
||||
|
||||
def test_export_allowed_for_pro(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/predictions/export",
|
||||
headers={"Authorization": f"Bearer {self.pro_token}"},
|
||||
)
|
||||
# 503 is expected because no backend is running; 403 would be wrong
|
||||
assert resp.status_code != 403
|
||||
|
||||
def test_upgrade_info_shows_plans(self, client):
|
||||
resp = client.get(
|
||||
"/api/v1/subscription/upgrade",
|
||||
headers={"Authorization": f"Bearer {self.free_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert "free" in data["plans"]
|
||||
assert "premium" in data["plans"]
|
||||
assert "pro" in data["plans"]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Rate limiting
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
def test_rate_limit_headers_present(self, client):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert "X-RateLimit-Limit" in resp.headers
|
||||
assert resp.headers["X-RateLimit-Limit"] == "100"
|
||||
|
||||
def test_rate_limit_remaining_decreases(self, client):
|
||||
r1 = client.get("/api/v1/health")
|
||||
r2 = client.get("/api/v1/health")
|
||||
rem1 = int(r1.headers.get("X-RateLimit-Remaining", 100))
|
||||
rem2 = int(r2.headers.get("X-RateLimit-Remaining", 100))
|
||||
assert rem2 <= rem1
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# DB module
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAuthDB:
|
||||
def test_tables_exist(self):
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(os.environ["TURF_SAAS_DB"])
|
||||
tables = {
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
}
|
||||
assert "users" in tables
|
||||
assert "subscriptions" in tables
|
||||
assert "refresh_tokens" in tables
|
||||
conn.close()
|
||||
|
||||
def test_get_db_returns_connection(self):
|
||||
from auth_db import get_db
|
||||
|
||||
db = get_db()
|
||||
assert db is not None
|
||||
db.close()
|
||||
Reference in New Issue
Block a user