#!/usr/bin/env python3
"""
Génère un flux GTFS-realtime service_alerts.pb à partir de l'API Cityway MRN.

Convention GTFS utilisée :
- route_id = "<OPERATEUR>:<CODE_LIGNE>"  ex: "TCAR:20", "TAE:111", "TNI:301"
- stop_id  = "<OPERATEUR>:<CODE_ARRÊT>" ex: "TCAR:VERGE1"
- agency_id = "TCAR" / "TAE" / "TNI"
"""

import time
from datetime import datetime, timezone

import requests
from google.transit import gtfs_realtime_pb2


API_URL = "https://api.mrn.cityway.fr/disrupt/api/v1/fr/disruptions"


# ---------- Helpers ----------

from zoneinfo import ZoneInfo

def parse_iso_to_timestamp(value: str | None) -> int | None:
    if not value:
        return None
    try:
        dt = datetime.fromisoformat(value)
    except ValueError:
        return None

    # Interpréter les dates Cityway en heure locale France
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=ZoneInfo("Europe/Paris"))

    return int(dt.timestamp())


def map_cause(type_str: str | None) -> int:
    """Map du champ 'type' Cityway vers Alert.Cause GTFS-RT."""
    if not type_str:
        return gtfs_realtime_pb2.Alert.Cause.UNKNOWN_CAUSE

    s = type_str.lower()

    if "travaux" in s:
        return gtfs_realtime_pb2.Alert.Cause.CONSTRUCTION
    if "incident" in s or "technique" in s:
        return gtfs_realtime_pb2.Alert.Cause.TECHNICAL_PROBLEM
    if "manifestation" in s:
        return gtfs_realtime_pb2.Alert.Cause.DEMONSTRATION
    if "grève" in s or "social" in s:
        return gtfs_realtime_pb2.Alert.Cause.STRIKE

    return gtfs_realtime_pb2.Alert.Cause.UNKNOWN_CAUSE


def map_effect(severity: str | None) -> int:
    """Map du champ 'severity' Cityway vers Alert.Effect GTFS-RT."""
    if not severity:
        return gtfs_realtime_pb2.Alert.Effect.UNKNOWN_EFFECT

    s = severity.lower()

    if "interrompu" in s or "arrêt" in s:
        return gtfs_realtime_pb2.Alert.Effect.NO_SERVICE
    if "perturb" in s:
        return gtfs_realtime_pb2.Alert.Effect.MODIFIED_SERVICE
    if "retard" in s or "délai" in s:
        return gtfs_realtime_pb2.Alert.Effect.DELAY

    return gtfs_realtime_pb2.Alert.Effect.UNKNOWN_EFFECT


# ---------- Construction du flux GTFS-RT ----------

def build_feed() -> gtfs_realtime_pb2.FeedMessage:
    # 1. Récupération de l'API
    resp = requests.get(API_URL, timeout=10)
    resp.raise_for_status()
    payload = resp.json()

    disruptions = payload.get("data", [])
    print(f"Nombre de perturbations dans l'API : {len(disruptions)}")

    # 2. Création du FeedMessage
    feed = gtfs_realtime_pb2.FeedMessage()
    header = feed.header
    header.gtfs_realtime_version = "2.0"
    header.incrementality = gtfs_realtime_pb2.FeedHeader.FULL_DATASET
    header.timestamp = int(time.time())

    # 3. Parcours des perturbations
    for p in disruptions:
        entity = feed.entity.add()
        entity.id = str(p.get("id") or p.get("internalId") or "")
        alert = entity.alert

        # --- active_period ---
        start_ts = parse_iso_to_timestamp(p.get("effectiveStartDate"))
        end_ts = parse_iso_to_timestamp(p.get("effectiveEndDate"))
        if start_ts is not None or end_ts is not None:
            tr = alert.active_period.add()
            if start_ts is not None:
                tr.start = start_ts
            if end_ts is not None:
                tr.end = end_ts

        # --- header_text / description_text ---
        title = p.get("title") or ""
        desc = p.get("description") or ""

        if title:
            alert.header_text.translation.add(text=title, language="fr")
        if desc:
            # Si c'est du HTML, tu pourras plus tard faire un strip des balises
            alert.description_text.translation.add(text=desc, language="fr")

        # --- url : première pièce jointe si présente ---
        attachments = p.get("attachments") or []
        if attachments:
            first_att = attachments[0]
            url = first_att.get("url")
            if url:
                alert.url.translation.add(text=url, language="fr")

        # --- cause / effect ---
        alert.cause = map_cause(p.get("type"))
        alert.effect = map_effect(p.get("severity"))

        # === opérateur principal pour cette perturbation (TCAR / TAE / TNI...) ===
        primary_op_code = None

        # 1) On regarde d'abord les lignes
        for line in p.get("affectedLines") or []:
            op = (line.get("operator") or {}).get("code")
            if op:
                primary_op_code = op
                break

        # 2) Sinon, on regarde les arrêts
        if primary_op_code is None:
            for stop in p.get("affectedStops") or []:
                op = (stop.get("operator") or {}).get("code")
                if op:
                    primary_op_code = op
                    break

        # --- lignes impactées : route_id = OPERATEUR:CODE_LIGNE ---
        for line in p.get("affectedLines") or []:
            sel = alert.informed_entity.add()

            line_code = line.get("code")
            if line_code:
                sel.route_id = line_code

            operator = line.get("operator") or {}
            op_code = operator.get("code")
            if op_code:
                sel.agency_id = op_code

        # --- arrêts impactés : stop_id = OPERATEUR:CODE_ARRÊT ---
        for stop in p.get("affectedStops") or []:
            sel = alert.informed_entity.add()

            stop_code = stop.get("code")  # ex: VERGE1, RIBOT1, CHAMT2, ...
            if stop_code:
                sel.stop_id = stop_code  # ex: "TCAR:VERGE1"

            operator = stop.get("operator") or {}
            op_code = operator.get("code")
            if op_code:
                sel.agency_id = op_code

    return feed


def main() -> None:
    feed = build_feed()
    out_file = "service_alerts_alt.pb"
    with open(out_file, "wb") as f:
        f.write(feed.SerializeToString())
    print(f"Flux GTFS-RT écrit dans {out_file}")


if __name__ == "__main__":
    main()