#!/usr/bin/env python3
"""
Calibrate Cyber Prairie scoring weights against user verdicts.

Uses coordinate descent to find weights that maximize concordance:
yes properties should score higher than no properties.

Usage:
    python3 calibrate_weights.py              # Run calibration, show report
    python3 calibrate_weights.py --apply      # Write optimized weights to config
    python3 calibrate_weights.py --simulate   # Show ranking with new weights
"""
import argparse
import json
import math
from pathlib import Path

from store import load

# Current weights (from cyber_prairie_score.py)
CURRENT_WEIGHTS = {
    'workshop':            3.0,
    'location_view':       3.0,
    'food_experience':     2.5,
    'guest_accommodation': 3.0,
    'livability':          3.5,
    'environmental_risk':  2.0,
    'design_story':        2.5,
    'market_garden':       2.0,
    'land_size':           1.5,
    'renovation_scope':    1.5,
    'local_market':        2.5,
}

DEFAULTS = {'renovation_scope': 3, 'local_market': 3}

VERDICT_TARGETS = {'yes': 5, 'maybe': 3, 'no': 1}

WEIGHT_CANDIDATES = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]


def score_land_size(land_m2):
    """Score land size on 1-5 scale."""
    if not land_m2:
        return None
    if land_m2 >= 50000: return 5
    if land_m2 >= 30000: return 4
    if land_m2 >= 15000: return 3
    if land_m2 >= 6000:  return 2
    return 1


def compute_livability_from_amenities(amenities):
    """Compute livability score from amenity distances (mirrors cyber_prairie_score.py)."""
    if not amenities:
        return None
    points = 0
    checks = 0

    bakery = amenities.get('bakery')
    if bakery:
        km = bakery.get('km', 999)
        checks += 2
        if km <= 10: points += 2
        elif km <= 20: points += 1

    supermarket = amenities.get('supermarket')
    if supermarket:
        km = supermarket.get('km', 999)
        checks += 2
        if km <= 10: points += 2
        elif km <= 20: points += 1

    hospital = amenities.get('hospital')
    if hospital:
        km = hospital.get('km', 999)
        checks += 2
        if km <= 30: points += 2
        elif km <= 50: points += 1

    train = amenities.get('train_station')
    if train:
        km = train.get('km', 999)
        checks += 1
        if km <= 20: points += 1
        elif km <= 50: points += 0.5

    town = amenities.get('town')
    if town:
        km = town.get('km', 999)
        checks += 1
        if km <= 15: points += 1
        elif km <= 30: points += 0.5

    if checks == 0:
        return None
    ratio = points / checks
    return max(1, min(5, round(ratio * 5)))


def get_criterion_scores(p):
    """Extract all criterion scores for a property (mirrors compute_cp_score)."""
    criteria = p.get('criteria', {})
    amenity_livability = compute_livability_from_amenities(p.get('amenities'))
    livability = amenity_livability if amenity_livability is not None else criteria.get('livability')

    # Environmental risk
    rs = p.get('risk_score')
    if rs is not None:
        env_risk = round(rs)
    else:
        profile = (p.get('risk_profile') or '').strip().lower()
        env_risk = {'laag': 5, 'gemiddeld': 3, 'hoog': 1}.get(profile)

    return {
        'workshop': criteria.get('workshop'),
        'location_view': criteria.get('location'),
        'food_experience': criteria.get('food_experience'),
        'guest_accommodation': criteria.get('guest_accommodation'),
        'livability': livability,
        'environmental_risk': env_risk,
        'design_story': criteria.get('design_story'),
        'market_garden': criteria.get('market_garden'),
        'land_size': score_land_size(p.get('land_size_m2')),
        'renovation_scope': None,
        'local_market': criteria.get('local_market'),
    }


def compute_score(scores, weights):
    """Compute weighted average score using given weights."""
    weighted_sum = 0.0
    weight_sum = 0.0
    for name, weight in weights.items():
        val = scores.get(name)
        if val is None:
            val = DEFAULTS.get(name)
        if val is not None:
            weighted_sum += val * weight
            weight_sum += weight
    return weighted_sum / weight_sum if weight_sum > 0 else 0


def load_verdicts():
    """Load all properties with user_verdict, regardless of status."""
    store = load()
    verdicts = []
    for url, p in store.items():
        v = p.get('user_verdict')
        if v and v in VERDICT_TARGETS:
            verdicts.append({
                'url': url,
                'verdict': v,
                'target': VERDICT_TARGETS[v],
                'scores': get_criterion_scores(p),
                'location': p.get('location') or p.get('city') or p.get('title', ''),
                'price': p.get('price'),
                'user_flags': p.get('user_flags', []),
            })
    return verdicts


def concordance(verdicts, weights):
    """Measure how well weights rank yes > maybe > no.

    Returns fraction of concordant pairs (0-1).
    A pair is concordant if the higher-target property scores higher.
    """
    scored = [(v['target'], compute_score(v['scores'], weights), v) for v in verdicts]
    concordant = 0
    total = 0
    for i, (t1, s1, _) in enumerate(scored):
        for j, (t2, s2, _) in enumerate(scored):
            if t1 > t2:
                total += 1
                if s1 > s2:
                    concordant += 1
                elif s1 == s2:
                    concordant += 0.5
    return concordant / total if total > 0 else 0


def optimize_weights(verdicts, current_weights, rounds=3):
    """Coordinate descent: optimize one weight at a time."""
    weights = dict(current_weights)
    history = []

    for round_num in range(rounds):
        for criterion in weights:
            best_w = weights[criterion]
            best_c = concordance(verdicts, weights)
            for w in WEIGHT_CANDIDATES:
                trial = dict(weights)
                trial[criterion] = w
                c = concordance(verdicts, trial)
                if c > best_c:
                    best_c = c
                    best_w = w
            weights[criterion] = best_w
        history.append((round_num + 1, concordance(verdicts, weights)))

    return weights, history


def criterion_analysis(verdicts):
    """Analyze which criteria distinguish yes from no properties."""
    yes_props = [v for v in verdicts if v['verdict'] == 'yes']
    no_props = [v for v in verdicts if v['verdict'] == 'no']

    print("\n  Criterion averages (yes vs no):")
    print(f"  {'Criterion':<25} {'Yes avg':>8} {'No avg':>8} {'Delta':>8} {'Signal':>8}")
    print(f"  {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    for criterion in CURRENT_WEIGHTS:
        yes_vals = [v['scores'].get(criterion) for v in yes_props if v['scores'].get(criterion) is not None]
        no_vals = [v['scores'].get(criterion) for v in no_props if v['scores'].get(criterion) is not None]

        if yes_vals and no_vals:
            yes_avg = sum(yes_vals) / len(yes_vals)
            no_avg = sum(no_vals) / len(no_vals)
            delta = yes_avg - no_avg
            signal = 'STRONG' if abs(delta) > 1.0 else 'weak' if abs(delta) > 0.3 else '-'
            print(f"  {criterion:<25} {yes_avg:>8.1f} {no_avg:>8.1f} {delta:>+8.1f} {signal:>8}")
        else:
            print(f"  {criterion:<25} {'n/a':>8} {'n/a':>8} {'':>8} {'':>8}")


def main():
    parser = argparse.ArgumentParser(description='Calibrate Paradisomatch scoring weights')
    parser.add_argument('--apply', action='store_true', help='Write optimized weights to config')
    parser.add_argument('--simulate', action='store_true', help='Show ranking with new weights')
    args = parser.parse_args()

    verdicts = load_verdicts()
    if not verdicts:
        print("No verdict data found. Use verdict.py to record user verdicts first.")
        return

    yes_count = sum(1 for v in verdicts if v['verdict'] == 'yes')
    maybe_count = sum(1 for v in verdicts if v['verdict'] == 'maybe')
    no_count = sum(1 for v in verdicts if v['verdict'] == 'no')

    print("=" * 70)
    print("  FARMMATCH WEIGHT CALIBRATION")
    print("=" * 70)
    print(f"  Verdicts: {len(verdicts)} total ({yes_count} yes, {maybe_count} maybe, {no_count} no)")
    print(f"  Pairs to rank: {yes_count * no_count} (yes vs no) + {yes_count * maybe_count + maybe_count * no_count} (with maybe)")

    # Current concordance
    current_conc = concordance(verdicts, CURRENT_WEIGHTS)
    print(f"\n  Current weights concordance: {current_conc:.2%}")

    # Criterion analysis
    criterion_analysis(verdicts)

    # Optimize
    print("\n  Optimizing weights (coordinate descent, 3 rounds)...")
    optimized, history = optimize_weights(verdicts, CURRENT_WEIGHTS)
    opt_conc = concordance(verdicts, optimized)

    print(f"  Optimized concordance: {opt_conc:.2%}")
    for rnd, c in history:
        print(f"    Round {rnd}: {c:.2%}")

    # Weight changes
    print(f"\n  {'Criterion':<25} {'Current':>8} {'Optimized':>10} {'Change':>8}")
    print(f"  {'-'*25} {'-'*8} {'-'*10} {'-'*8}")
    for criterion in CURRENT_WEIGHTS:
        curr = CURRENT_WEIGHTS[criterion]
        opt = optimized[criterion]
        delta = opt - curr
        marker = ' ***' if abs(delta) >= 1.0 else ''
        print(f"  {criterion:<25} {curr:>8.1f} {opt:>10.1f} {delta:>+8.1f}{marker}")

    # Ranking comparison
    print(f"\n  Property ranking (current → optimized):")
    print(f"  {'Verdict':<8} {'Location':<30} {'Current':>8} {'Optimized':>9} {'Move':>6}")
    print(f"  {'-'*8} {'-'*30} {'-'*8} {'-'*9} {'-'*6}")

    for v in sorted(verdicts, key=lambda x: -x['target']):
        curr_s = compute_score(v['scores'], CURRENT_WEIGHTS)
        opt_s = compute_score(v['scores'], optimized)
        loc = v['location'][:28]
        move = opt_s - curr_s
        marker = '++' if move > 0.3 else '--' if move < -0.3 else ''
        print(f"  {v['verdict']:<8} {loc:<30} {curr_s:>8.2f} {opt_s:>9.2f} {move:>+5.2f} {marker}")

    if args.simulate:
        print("\n  [Simulation mode — no changes written]")

    if args.apply:
        config_path = Path(__file__).parent / 'weights.json'
        with open(config_path, 'w') as f:
            json.dump(optimized, f, indent=2)
        print(f"\n  Optimized weights saved to {config_path}")
        print(f"  To use: update CRITERIA in cyber_prairie_score.py or load from weights.json")

    if not args.apply and not args.simulate:
        print(f"\n  Run with --apply to save optimized weights")
        print(f"  Run with --simulate to see detailed ranking")


if __name__ == '__main__':
    main()
