#!/usr/bin/env python3
"""
Objective risk profiler for FarmMatch.

What it does:
- Reads analysis_output.csv and enriched_data.json
- For each property with lat/lon, queries Overpass for nearest hospital
- Computes a simple risk band from hospital distance (Low/Med/High)
- Fills missing risk_profile in CSV/JSON with the objective band
- Caches Overpass lookups to reduce API calls

Usage:
  ../venv/bin/python3.14 risk_features.py
"""
import json
import math
import time
from pathlib import Path
from typing import Dict, Optional, Tuple

import pandas as pd
import requests

ANALYSIS_CSV = Path("analysis_output.csv")
ENRICHED_JSON = Path("enriched_data.json")
CACHE_FILE = Path("risk_features_cache.json")

# Distance thresholds (meters)
LOW_THRESHOLD = 15_000     # Low risk if hospital within 15 km
MEDIUM_THRESHOLD = 40_000  # Medium if within 40 km, else High

OVERPASS_URLS = [
    "https://overpass-api.de/api/interpreter",
    "https://lz4.overpass-api.de/api/interpreter",
    "https://overpass.kumi.systems/api/interpreter",
]

def pick_overpass(attempt: int) -> str:
    return OVERPASS_URLS[attempt % len(OVERPASS_URLS)]


def haversine(lat1, lon1, lat2, lon2) -> float:
    """Distance in meters between two WGS84 points."""
    R = 6371000
    phi1, phi2 = math.radians(lat1), math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlambda = math.radians(lon2 - lon1)
    a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2) ** 2
    return 2 * R * math.atan2(math.sqrt(a), math.sqrt(1 - a))


def load_cache() -> Dict:
    if CACHE_FILE.exists():
        try:
            return json.loads(CACHE_FILE.read_text())
        except Exception:
            return {}
    return {}


def save_cache(cache: Dict):
    CACHE_FILE.write_text(json.dumps(cache, indent=2))


def nearest_hospital(lat: float, lon: float, cache: Dict) -> Optional[float]:
    """Return distance in meters to nearest hospital using Overpass (with rotation and backoff)."""
    key = f"{round(lat, 3)}_{round(lon, 3)}"
    if key in cache:
        return cache[key]

    # Query up to 50 km radius for hospitals
    query = f"""
    [out:json][timeout:25];
    (
      node["amenity"="hospital"](around:50000,{lat},{lon});
      way["amenity"="hospital"](around:50000,{lat},{lon});
      relation["amenity"="hospital"](around:50000,{lat},{lon});
    );
    out center 1;
    """
    for attempt in range(len(OVERPASS_URLS) * 2):
        url = pick_overpass(attempt)
        try:
            resp = requests.post(url, data=query.encode("utf-8"), timeout=45)
            resp.raise_for_status()
            data = resp.json()
            distances = []
            for el in data.get("elements", []):
                if "lat" in el and "lon" in el:
                    d = haversine(lat, lon, el["lat"], el["lon"])
                elif "center" in el:
                    d = haversine(lat, lon, el["center"]["lat"], el["center"]["lon"])
                else:
                    continue
                distances.append(d)
            if distances:
                best = min(distances)
                cache[key] = best
                return best
        except Exception as e:
            # Rotate endpoints; brief backoff
            time.sleep(2)
            continue
    return None


def risk_band_from_distance(dist_m: Optional[float]) -> Optional[str]:
    if dist_m is None:
        return None
    if dist_m <= LOW_THRESHOLD:
        return "Laag"
    if dist_m <= MEDIUM_THRESHOLD:
        return "Gemiddeld"
    return "Hoog"


def update_csv(cache: Dict) -> Tuple[int, int]:
    if not ANALYSIS_CSV.exists():
        print(f"❌ {ANALYSIS_CSV} not found.")
        return 0, 0

    df = pd.read_csv(ANALYSIS_CSV)
    updated = 0
    total = 0

    # Ensure columns exist
    for col in ["Latitude", "Longitude", "hospital_distance_m", "risk_profile_objective", "risk_profile"]:
        if col not in df.columns:
            df[col] = None

    for idx, row in df.iterrows():
        lat = row.get("Latitude")
        lon = row.get("Longitude")
        if pd.isna(lat) or pd.isna(lon):
            continue
        try:
            latf = float(lat)
            lonf = float(lon)
        except Exception:
            continue

        total += 1
        dist = nearest_hospital(latf, lonf, cache)
        band = risk_band_from_distance(dist)
        if dist is not None:
            df.at[idx, "hospital_distance_m"] = round(dist, 1)
        if pd.isna(row.get("risk_profile")) and band:
            df.at[idx, "risk_profile"] = band
        if band:
            df.at[idx, "risk_profile_objective"] = band
            updated += 1

    df.to_csv(ANALYSIS_CSV, index=False, encoding="utf-8")
    return updated, total


def update_json(cache: Dict) -> Tuple[int, int]:
    if not ENRICHED_JSON.exists():
        print(f"❌ {ENRICHED_JSON} not found.")
        return 0, 0

    data = json.loads(ENRICHED_JSON.read_text(encoding="utf-8"))
    updated = 0
    total = 0

    for prop in data:
        lat = prop.get("lat")
        lon = prop.get("lon")
        if lat is None or lon is None:
            continue
        try:
            latf = float(lat)
            lonf = float(lon)
        except Exception:
            continue

        total += 1
        dist = nearest_hospital(latf, lonf, cache)
        band = risk_band_from_distance(dist)
        if dist is not None:
            prop["hospital_distance_m"] = round(dist, 1)
        if not prop.get("risk_profile") and band:
            prop["risk_profile"] = band
        if band:
            prop["risk_profile_objective"] = band
            updated += 1

    ENRICHED_JSON.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
    return updated, total


def main():
    cache = load_cache()
    print("🔍 Running objective risk profiling (hospital proximity)...")
    csv_updated, csv_total = update_csv(cache)
    json_updated, json_total = update_json(cache)
    save_cache(cache)

    print("\n📊 Risk profiling summary:")
    print(f"  CSV updated: {csv_updated} / {csv_total} with coords")
    print(f"  JSON updated: {json_updated} / {json_total} with coords")
    print("✅ Done. Re-run quality_gate.py to see coverage.")


if __name__ == "__main__":
    main()
