#!/usr/bin/env python3 """ Billing Blueprint — Stripe integration Sprint 5-6: HRT-31 Endpoints: POST /api/v1/billing/checkout — create Stripe Checkout session (auth required) POST /api/v1/billing/portal — create Stripe Customer Portal session (auth required) POST /api/v1/billing/webhook — Stripe webhook handler (public, signature-verified) GET /api/v1/billing/status — current subscription status (auth required) Environment variables required: STRIPE_SECRET_KEY — Stripe secret key (sk_live_... or sk_test_...) STRIPE_PUBLISHABLE_KEY — Stripe publishable key (pk_...) STRIPE_WEBHOOK_SECRET — webhook signing secret (whsec_...) STRIPE_PRICE_PREMIUM — Stripe Price ID for Premium plan (price_...) STRIPE_PRICE_PRO — Stripe Price ID for Pro plan (price_...) APP_BASE_URL — e.g. https://turf-ia.h3r7.tech (default http://localhost:8793) """ import json import logging import os from datetime import datetime, timedelta, timezone import stripe from flask import Blueprint, jsonify, request from saas_auth import require_auth as jwt_required_middleware from billing_db import get_db, migrate_billing_tables logger = logging.getLogger("turf_saas.billing") billing_bp = Blueprint("billing", __name__, url_prefix="/api/v1/billing") # ────────────────────────────────────────────────────────────── # Stripe configuration # ────────────────────────────────────────────────────────────── stripe.api_key = os.environ.get("STRIPE_SECRET_KEY", "") STRIPE_WEBHOOK_SECRET = os.environ.get("STRIPE_WEBHOOK_SECRET", "") STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY", "") APP_BASE_URL = os.environ.get("APP_BASE_URL", "http://localhost:8793") # Plan → Stripe Price ID mapping PLAN_PRICE_IDS = { "premium": os.environ.get("STRIPE_PRICE_PREMIUM", ""), "pro": os.environ.get("STRIPE_PRICE_PRO", ""), } # Plan display names PLAN_NAMES = { "free": "Free", "premium": "Premium", "pro": "Pro", } # ────────────────────────────────────────────────────────────── # DB helpers # ────────────────────────────────────────────────────────────── def _sget(obj, key, default=None): """Safely get a value from a dict OR a Stripe StripeObject. Stripe v7+ uses attribute-style access; plain dicts use [] / .get(). """ try: # StripeObject supports [] but not .get(); dict supports both val = obj[key] return val if val is not None else default except (KeyError, TypeError): return default def _get_active_subscription(db, user_id): """Return the most recent active subscription row for a user.""" return db.execute( """SELECT * FROM saas_subscriptions WHERE user_id = ? ORDER BY start_date DESC LIMIT 1""", (str(user_id),), ).fetchone() def _upsert_subscription(db, user_id, **fields): """ Update existing subscription or insert a new one. fields: plan, stripe_customer_id, stripe_subscription_id, status, current_period_end, grace_period_end, end_date """ existing = _get_active_subscription(db, user_id) if existing: # Build SET clause dynamically from provided fields set_parts = ", ".join(f"{k} = ?" for k in fields) values = list(fields.values()) + [existing["id"]] db.execute(f"UPDATE saas_subscriptions SET {set_parts} WHERE id = ?", values) else: cols = ", ".join(["user_id"] + list(fields.keys())) placeholders = ", ".join(["?"] * (1 + len(fields))) values = [str(user_id)] + list(fields.values()) db.execute( f"INSERT INTO saas_subscriptions ({cols}) VALUES ({placeholders})", values ) def _update_user_plan(db, user_id, plan: str): """Sync saas_users.plan field to match active subscription.""" db.execute("UPDATE saas_users SET plan = ? WHERE id = ?", (plan, str(user_id))) def _get_or_create_stripe_customer(user, db) -> str: """Return existing stripe_customer_id or create a new Stripe Customer.""" sub = _get_active_subscription(db, user["id"]) if sub and sub["stripe_customer_id"]: return sub["stripe_customer_id"] # Create new customer in Stripe customer = stripe.Customer.create( email=user["email"], metadata={"user_id": str(user["id"])}, ) return customer["id"] def _record_billing_event( db, stripe_event_id: str, event_type: str, user_id=None, payload=None ): """Insert a billing_events audit row (idempotent on stripe_event_id).""" try: db.execute( """INSERT OR IGNORE INTO billing_events (stripe_event_id, event_type, user_id, payload) VALUES (?, ?, ?, ?)""", ( stripe_event_id, event_type, user_id, json.dumps(payload) if payload else None, ), ) except Exception as e: logger.warning("Could not record billing event %s: %s", stripe_event_id, e) # ────────────────────────────────────────────────────────────── # POST /api/v1/billing/checkout # ────────────────────────────────────────────────────────────── @billing_bp.route("/checkout", methods=["POST"]) @jwt_required_middleware def create_checkout(): """ Create a Stripe Checkout session for upgrading to Premium or Pro. --- tags: - Billing security: - Bearer: [] requestBody: required: true content: application/json: schema: type: object required: [plan] properties: plan: type: string enum: [premium, pro] responses: 200: description: Checkout session URL schema: type: object properties: checkout_url: type: string session_id: type: string 400: description: Invalid plan or Stripe not configured 503: description: Stripe API error """ if not stripe.api_key: return jsonify({"error": "Stripe non configuré"}), 503 body = request.get_json(silent=True) or {} plan = body.get("plan", "").lower() if plan not in ("premium", "pro"): return jsonify({"error": "Plan invalide. Choisir 'premium' ou 'pro'"}), 400 price_id = PLAN_PRICE_IDS.get(plan) if not price_id: return jsonify({"error": f"Prix Stripe non configuré pour le plan {plan}"}), 503 user = request.current_user if user["plan"] == plan: return jsonify({"error": f"Vous êtes déjà sur le plan {plan}"}), 400 db = get_db() try: customer_id = _get_or_create_stripe_customer(user, db) # Persist customer_id early to prevent duplicates _upsert_subscription( db, user["id"], stripe_customer_id=customer_id, plan=user["plan"] ) db.commit() session = stripe.checkout.Session.create( customer=customer_id, payment_method_types=["card"], line_items=[{"price": price_id, "quantity": 1}], mode="subscription", success_url=f"{APP_BASE_URL}/billing/success?session_id={{CHECKOUT_SESSION_ID}}", cancel_url=f"{APP_BASE_URL}/billing/cancel", metadata={"user_id": str(user["id"]), "plan": plan}, subscription_data={"metadata": {"user_id": str(user["id"]), "plan": plan}}, ) except stripe.StripeError as e: logger.error("Stripe checkout error for user %s: %s", user["id"], e) return jsonify({"error": "Erreur Stripe", "detail": str(e)}), 503 finally: db.close() return jsonify( { "checkout_url": session.url, "session_id": session.id, "plan": plan, "publishable_key": STRIPE_PUBLISHABLE_KEY, } ), 200 # ────────────────────────────────────────────────────────────── # POST /api/v1/billing/portal # ────────────────────────────────────────────────────────────── @billing_bp.route("/portal", methods=["POST"]) @jwt_required_middleware def create_portal(): """ Create a Stripe Customer Portal session for managing subscription. --- tags: - Billing security: - Bearer: [] responses: 200: description: Portal session URL 400: description: No Stripe customer found 503: description: Stripe not configured or API error """ if not stripe.api_key: return jsonify({"error": "Stripe non configuré"}), 503 user = request.current_user db = get_db() try: sub = _get_active_subscription(db, user["id"]) customer_id = sub["stripe_customer_id"] if sub else None if not customer_id: return jsonify( { "error": "Aucun abonnement Stripe trouvé. " "Souscrivez d'abord à un plan payant." } ), 400 session = stripe.billing_portal.Session.create( customer=customer_id, return_url=f"{APP_BASE_URL}/account", ) except stripe.StripeError as e: logger.error("Stripe portal error for user %s: %s", user["id"], e) return jsonify({"error": "Erreur Stripe", "detail": str(e)}), 503 finally: db.close() return jsonify({"portal_url": session.url}), 200 # ────────────────────────────────────────────────────────────── # GET /api/v1/billing/status # ────────────────────────────────────────────────────────────── @billing_bp.route("/status", methods=["GET"]) @jwt_required_middleware def billing_status(): """ Return current subscription status for the authenticated user. --- tags: - Billing security: - Bearer: [] responses: 200: description: Subscription status """ user = request.current_user db = get_db() try: sub = _get_active_subscription(db, user["id"]) finally: db.close() if not sub: return jsonify( { "plan": "free", "status": "active", "stripe_customer_id": None, "stripe_subscription_id": None, "current_period_end": None, "grace_period_end": None, } ), 200 return jsonify( { "plan": sub["plan"], "status": sub["status"] or "active", "stripe_customer_id": sub["stripe_customer_id"], "stripe_subscription_id": sub["stripe_subscription_id"], "start_date": sub["start_date"], "end_date": sub["end_date"], "current_period_end": sub["current_period_end"], "grace_period_end": sub["grace_period_end"], } ), 200 # ────────────────────────────────────────────────────────────── # POST /api/v1/billing/webhook # ────────────────────────────────────────────────────────────── @billing_bp.route("/webhook", methods=["POST"]) def stripe_webhook(): """ Stripe webhook handler — no auth, signature-verified. Handled events: checkout.session.completed → activate subscription customer.subscription.updated → sync plan/status customer.subscription.deleted → downgrade to free invoice.payment_failed → set past_due + 3-day grace period invoice.payment_succeeded → clear grace period """ payload = request.get_data() sig_header = request.headers.get("Stripe-Signature", "") # Verify webhook signature (required in production) if STRIPE_WEBHOOK_SECRET: try: event = stripe.Webhook.construct_event( payload, sig_header, STRIPE_WEBHOOK_SECRET ) except stripe.SignatureVerificationError as e: logger.warning("Stripe webhook signature invalid: %s", e) return jsonify({"error": "Signature invalide"}), 400 except ValueError as e: logger.warning("Stripe webhook payload invalid: %s", e) return jsonify({"error": "Payload invalide"}), 400 else: # Dev/test: accept without verification (log a warning) logger.warning("STRIPE_WEBHOOK_SECRET not set — skipping signature check!") try: event = stripe.Event.construct_from(json.loads(payload), stripe.api_key) except Exception as e: return jsonify({"error": "Payload invalide", "detail": str(e)}), 400 event_type = event["type"] event_id = event["id"] logger.info("Stripe webhook received: %s (%s)", event_type, event_id) db = get_db() try: if event_type == "checkout.session.completed": _handle_checkout_completed(db, event) elif event_type in ( "customer.subscription.updated", "customer.subscription.created", ): _handle_subscription_updated(db, event) elif event_type == "customer.subscription.deleted": _handle_subscription_deleted(db, event) elif event_type == "invoice.payment_failed": _handle_payment_failed(db, event) elif event_type == "invoice.payment_succeeded": _handle_payment_succeeded(db, event) else: logger.debug("Unhandled Stripe event type: %s", event_type) db.commit() except Exception as e: db.rollback() logger.error("Error processing Stripe webhook %s: %s", event_id, e) return jsonify({"error": "Erreur interne"}), 500 finally: db.close() return jsonify({"status": "ok"}), 200 # ────────────────────────────────────────────────────────────── # Webhook handlers # ────────────────────────────────────────────────────────────── def _resolve_user_from_customer(db, customer_id: str): """Look up user_id via subscriptions.stripe_customer_id.""" row = db.execute( "SELECT user_id FROM saas_subscriptions WHERE stripe_customer_id = ? LIMIT 1", (customer_id,), ).fetchone() if row: return row["user_id"] # Fallback: query Stripe for user_id metadata try: customer = stripe.Customer.retrieve(customer_id) meta = _sget(customer, "metadata") or {} uid = _sget(meta, "user_id") if uid: return int(uid) except Exception: pass return None def _resolve_plan_from_price(price_id: str) -> str: """Map Stripe price ID to internal plan name.""" for plan, pid in PLAN_PRICE_IDS.items(): if pid and pid == price_id: return plan # Unknown price — default to premium (safer than pro) return "premium" def _handle_checkout_completed(db, event): """checkout.session.completed → activate subscription for the user.""" session = event["data"]["object"] customer_id = _sget(session, "customer") subscription_id = _sget(session, "subscription") metadata = _sget(session, "metadata") or {} plan = _sget(metadata, "plan") or "premium" user_id = _sget(metadata, "user_id") if user_id: user_id = str(user_id) else: user_id = _resolve_user_from_customer(db, customer_id) if not user_id: logger.error( "checkout.session.completed: cannot resolve user for customer %s", customer_id, ) return # Fetch subscription details from Stripe current_period_end = None if subscription_id: try: sub = stripe.Subscription.retrieve(subscription_id) current_period_end = datetime.fromtimestamp( sub["current_period_end"], tz=timezone.utc ).isoformat() # Sync plan from price if metadata plan is missing if sub["items"]["data"]: price_id = sub["items"]["data"][0]["price"]["id"] plan = _resolve_plan_from_price(price_id) except Exception as e: logger.warning("Could not fetch subscription %s: %s", subscription_id, e) _upsert_subscription( db, user_id, plan=plan, stripe_customer_id=customer_id, stripe_subscription_id=subscription_id, status="active", current_period_end=current_period_end, grace_period_end=None, ) _update_user_plan(db, user_id, plan) _record_billing_event(db, event["id"], event["type"], user_id=user_id) logger.info("checkout.session.completed: user %s upgraded to %s", user_id, plan) def _handle_subscription_updated(db, event): """customer.subscription.updated → sync status and plan.""" sub_obj = event["data"]["object"] customer_id = _sget(sub_obj, "customer") subscription_id = _sget(sub_obj, "id") stripe_status = _sget(sub_obj, "status") or "active" current_period_end = None cpe = _sget(sub_obj, "current_period_end") if cpe: current_period_end = datetime.fromtimestamp(cpe, tz=timezone.utc).isoformat() # Resolve plan from price plan = "premium" items_data = _sget(_sget(sub_obj, "items") or {}, "data") if items_data: price_id = items_data[0]["price"]["id"] plan = _resolve_plan_from_price(price_id) user_id = _resolve_user_from_customer(db, customer_id) if not user_id: # Try metadata meta = _sget(sub_obj, "metadata") or {} meta_uid = _sget(meta, "user_id") if meta_uid: user_id = str(meta_uid) if not user_id: logger.error( "subscription.updated: cannot resolve user for customer %s", customer_id ) return _upsert_subscription( db, user_id, plan=plan, stripe_customer_id=customer_id, stripe_subscription_id=subscription_id, status=stripe_status, current_period_end=current_period_end, ) _update_user_plan(db, user_id, plan) _record_billing_event(db, event["id"], event["type"], user_id=user_id) logger.info( "subscription.updated: user %s plan=%s status=%s", user_id, plan, stripe_status ) def _handle_subscription_deleted(db, event): """customer.subscription.deleted → downgrade to free.""" sub_obj = event["data"]["object"] customer_id = _sget(sub_obj, "customer") user_id = _resolve_user_from_customer(db, customer_id) if not user_id: meta = _sget(sub_obj, "metadata") or {} meta_uid = _sget(meta, "user_id") if meta_uid: user_id = str(meta_uid) if not user_id: logger.error( "subscription.deleted: cannot resolve user for customer %s", customer_id ) return _upsert_subscription( db, user_id, plan="free", stripe_subscription_id=None, status="canceled", end_date=datetime.now(timezone.utc).isoformat(), current_period_end=None, grace_period_end=None, ) _update_user_plan(db, user_id, "free") _record_billing_event(db, event["id"], event["type"], user_id=user_id) logger.info("subscription.deleted: user %s downgraded to free", user_id) def _handle_payment_failed(db, event): """invoice.payment_failed → mark past_due + 3-day grace period.""" invoice = event["data"]["object"] customer_id = _sget(invoice, "customer") subscription_id = _sget(invoice, "subscription") user_id = _resolve_user_from_customer(db, customer_id) if not user_id: logger.error( "invoice.payment_failed: cannot resolve user for customer %s", customer_id ) return grace_end = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat() _upsert_subscription(db, user_id, status="past_due", grace_period_end=grace_end) _record_billing_event( db, event["id"], event["type"], user_id=user_id, payload={"subscription_id": subscription_id}, ) # TODO: send notification email via /api/notifications logger.warning( "invoice.payment_failed: user %s past_due, grace period until %s", user_id, grace_end, ) def _handle_payment_succeeded(db, event): """invoice.payment_succeeded → clear past_due / grace period.""" invoice = event["data"]["object"] customer_id = _sget(invoice, "customer") user_id = _resolve_user_from_customer(db, customer_id) if not user_id: return # Refresh subscription period end current_period_end = None lines = _sget(invoice, "lines") or {} lines_data = _sget(lines, "data") or [] if lines_data: period = lines_data[0].get("period") or {} period_end = ( period.get("end") if isinstance(period, dict) else _sget(period, "end") ) if period_end: current_period_end = datetime.fromtimestamp( period_end, tz=timezone.utc ).isoformat() _upsert_subscription( db, user_id, status="active", grace_period_end=None, current_period_end=current_period_end, ) _record_billing_event(db, event["id"], event["type"], user_id=user_id) logger.info("invoice.payment_succeeded: user %s payment cleared", user_id) # ────────────────────────────────────────────────────────────── # On-import: ensure DB migration ran # ────────────────────────────────────────────────────────────── try: migrate_billing_tables() except Exception as _e: logger.warning("billing_db migration skipped (test env?): %s", _e)