"""
2024_U3_suri.py
日本における人口集中と経済成長の関係性―閾値回帰モデルを用いた都道府県別分析―
統計数理賞 [大学生・一般の部]
北岡和真ほか（南山大学経済学部）

実データ（SSDSE-B-2026, SSDSE-E-2026）による教育用再現コード
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_U3_suri.py
# ============================================================


import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from scipy import stats
from numpy.linalg import lstsq

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False

import os
FIGDIR = os.path.normpath('html/figures') + os.sep
DATA_B = 'data/raw/SSDSE-B-2026.csv'
DATA_E = 'data/raw/SSDSE-E-2026.csv'
os.makedirs(FIGDIR, exist_ok=True)

# ----------------------------------------------------------------
# データ読み込み
# ----------------------------------------------------------------
# SSDSE-E: 人口, 面積, 1人当たり県民所得（断面データ）
df_e_raw = pd.read_csv(DATA_E, encoding='cp932', header=0)
df_e = df_e_raw.iloc[2:].copy()
df_e.columns = df_e_raw.iloc[1].values
df_e = df_e[df_e['都道府県'] != '全国'].reset_index(drop=True)

# 数値変換
for col in ['総人口', '総面積（北方地域及び竹島を除く）', '1人当たり県民所得（平成27年基準）']:
    df_e[col] = pd.to_numeric(df_e[col], errors='coerce')

# 人口密度 (人/km²) : 総面積 [ha] / 100 → km²
df_e['面積_km2'] = df_e['総面積（北方地域及び竹島を除く）'] / 100
df_e['人口密度'] = df_e['総人口'] / df_e['面積_km2']

# SSDSE-B: 転入率・転出率 → 経済活力代理指標
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
mask = (df_b['地域コード'].str.match(r'^R\d{5}$', na=False) &
        (df_b['地域コード'] != 'R00000'))
df_b = df_b[mask].copy()

# 複数年平均（2018-2022）で安定した指標を構築
df_b_avg = (df_b[df_b['年度'].between(2018, 2022)]
            .groupby('都道府県')[['総人口', '転入者数（日本人移動者）', '転出者数（日本人移動者）']]
            .mean())
df_b_avg['転入率'] = df_b_avg['転入者数（日本人移動者）'] / df_b_avg['総人口'] * 1000
df_b_avg['転出率'] = df_b_avg['転出者数（日本人移動者）'] / df_b_avg['総人口'] * 1000
df_b_avg['転入超過率'] = df_b_avg['転入率'] - df_b_avg['転出率']

# ----------------------------------------------------------------
# 都道府県名の統一（SSDSE-E は「北海道」,「青森県」形式）
# ----------------------------------------------------------------
def normalize_pref(s):
    return str(s).rstrip('県府都道').strip()

df_e['pref_short'] = df_e['都道府県'].apply(normalize_pref)
df_b_avg['pref_short'] = [normalize_pref(p) for p in df_b_avg.index]

df_merged = df_e.merge(
    df_b_avg.reset_index()[['都道府県', 'pref_short', '転入超過率']],
    on='pref_short', how='inner', suffixes=('_e', '_b'))

# 最終データ: 人口密度(threshold変数), 1人当たり所得(経済指標)
df_merged['income'] = pd.to_numeric(df_merged['1人当たり県民所得（平成27年基準）'], errors='coerce')
# 都道府県列: SSDSE-E側
pref_col = '都道府県_e' if '都道府県_e' in df_merged.columns else '都道府県'
df_merged['都道府県_label'] = df_merged[pref_col]
df_clean = df_merged[['都道府県_label', '人口密度', 'income', '転入超過率']].dropna()
df_clean = df_clean.rename(columns={'都道府県_label': '都道府県'})
df_clean = df_clean.reset_index(drop=True)

prefectures = df_clean['都道府県'].tolist()
x_all = df_clean['人口密度'].values.astype(float)
y_all = df_clean['income'].values.astype(float)

print(f"サンプル数: {len(df_clean)}")
print(f"人口密度: 最小={x_all.min():.0f}, 最大={x_all.max():.0f}, 中央値={np.median(x_all):.0f}")
print(f"1人当たり所得: 最小={y_all.min():.0f}, 最大={y_all.max():.0f}")

# ================================================================
# 閾値回帰: グリッドサーチで最小RSSのτを探索
# ================================================================
# 人口密度の5～95パーセンタイルをsearch rangeに
pct5, pct95 = np.percentile(x_all, 5), np.percentile(x_all, 95)
tau_candidates = np.linspace(pct5, pct95, 200)
rss_list = []

def fit_ols(x, y):
    X = np.column_stack([np.ones(len(x)), x])
    b, _, _, _ = lstsq(X, y, rcond=None)
    return b

for tau in tau_candidates:
    mask_low = x_all < tau
    mask_high = ~mask_low
    rss = 0
    for mask in [mask_low, mask_high]:
        if mask.sum() < 4:
            rss += 1e15
            continue
        xi = x_all[mask]
        yi = y_all[mask]
        xi_c = np.column_stack([np.ones(len(xi)), xi])
        b, _, _, _ = lstsq(xi_c, yi, rcond=None)
        rss += np.sum((yi - xi_c @ b) ** 2)
    rss_list.append(rss)

rss_arr = np.array(rss_list)
tau_hat = tau_candidates[np.argmin(rss_arr)]
print(f"推定閾値 τ̂ = {tau_hat:.1f} 人/km²")

mask_low = x_all < tau_hat
mask_high = ~mask_low

b_low = fit_ols(x_all[mask_low], y_all[mask_low])
b_high = fit_ols(x_all[mask_high], y_all[mask_high])
b_linear = fit_ols(x_all, y_all)

# ================================================================
# 図1: 散布図 + 閾値回帰の2本の回帰直線
# ================================================================
fig, ax = plt.subplots(figsize=(9, 6))

ax.scatter(x_all[mask_low], y_all[mask_low], alpha=0.5, color='#1565C0', s=50,
           label=f'人口密度 < τ={tau_hat:.0f}（低密度域）')
ax.scatter(x_all[mask_high], y_all[mask_high], alpha=0.5, color='#C62828', s=50,
           label=f'人口密度 >= τ={tau_hat:.0f}（高密度域）')

x_low_line = np.linspace(x_all[mask_low].min(), tau_hat, 100)
x_high_line = np.linspace(tau_hat, x_all[mask_high].max(), 100)
x_all_line = np.linspace(x_all.min(), x_all.max(), 200)

ax.plot(x_low_line, b_low[0] + b_low[1] * x_low_line, '-', color='#1565C0', lw=2.5,
        label=f'低密度域回帰直線 (β={b_low[1]:+.2f})')
ax.plot(x_high_line, b_high[0] + b_high[1] * x_high_line, '-', color='#C62828', lw=2.5,
        label=f'高密度域回帰直線 (β={b_high[1]:+.2f})')
ax.plot(x_all_line, b_linear[0] + b_linear[1] * x_all_line, '--', color='gray', lw=1.5,
        label='線形OLS（比較）')

ax.axvline(tau_hat, color='orange', lw=2, linestyle='--', label=f'閾値 τ̂={tau_hat:.0f}人/km²')

# 主要都道府県にラベル
for i, pref in enumerate(prefectures):
    if any(k in pref for k in ['東京', '大阪', '沖縄', '北海道', '鳥取', '高知']):
        ax.annotate(pref.replace('都', '').replace('道', '').replace('府', '').replace('県', ''),
                    (x_all[i], y_all[i]), textcoords='offset points',
                    xytext=(5, 3), fontsize=8)

ax.set_xlabel("人口密度（人/km²）")
ax.set_ylabel("1人当たり県民所得（万円）")
ax.set_title("図1: 閾値回帰モデル — 散布図と2本の回帰直線\n（閾値変数: 人口密度, 目的変数: 1人当たり県民所得）", fontsize=13)
ax.legend(fontsize=9, loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U3_fig1_scatter.png", dpi=150)
plt.close()
print("fig1 saved")

# ================================================================
# 図2: グリッドサーチのRSS曲線
# ================================================================
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(tau_candidates, rss_arr / 1e6, color='#1565C0', lw=2)
ax.axvline(tau_hat, color='red', lw=2, linestyle='--',
           label=f'RSS最小点: τ̂ = {tau_hat:.0f}人/km²')
ax.scatter([tau_hat], [rss_arr.min() / 1e6], color='red', s=100, zorder=5)
ax.annotate(f'τ̂ = {tau_hat:.0f}\nRSS = {rss_arr.min()/1e6:.1f}',
            xy=(tau_hat, rss_arr.min() / 1e6),
            xytext=(tau_hat + 200, rss_arr.min() / 1e6 + np.ptp(rss_arr / 1e6) * 0.1),
            arrowprops=dict(arrowstyle='->', color='red'),
            fontsize=11, color='red')

ax.set_xlabel("候補閾値 τ（人口密度, 人/km²）")
ax.set_ylabel("RSS（残差平方和, ×10⁶）")
ax.set_title("図2: グリッドサーチによるRSS曲線（最小点が推定閾値τ̂）", fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U3_fig2_rss.png", dpi=150)
plt.close()
print("fig2 saved")

# ================================================================
# 図3: 線形OLS vs 閾値回帰 モデル比較
# ================================================================
rss_linear = np.sum((y_all - (b_linear[0] + b_linear[1] * x_all)) ** 2)
rss_threshold = rss_arr.min()
n_total = len(x_all)
df1_f, df2_f = 2, n_total - 4
F_test = ((rss_linear - rss_threshold) / df1_f) / (rss_threshold / df2_f)
p_test = 1 - stats.f.cdf(F_test, df1_f, df2_f)

ss_tot = np.sum((y_all - y_all.mean()) ** 2)
r2_linear = 1 - rss_linear / ss_tot
r2_threshold = 1 - rss_threshold / ss_tot

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
model_names = ['線形OLS', '閾値回帰']
r2_vals = [r2_linear, r2_threshold]
colors_m = ['#90A4AE', '#2E7D32']
bars = ax.bar(model_names, r2_vals, color=colors_m, alpha=0.85, width=0.5)
for bar, val in zip(bars, r2_vals):
    ax.text(bar.get_x() + bar.get_width() / 2, val + 0.005,
            f'R²={val:.3f}', ha='center', fontsize=12, fontweight='bold')
ax.set_ylim(0, max(r2_vals) * 1.25)
ax.set_ylabel("R²")
ax.set_title("モデル適合度比較（R²）")
ax.grid(True, axis='y', alpha=0.3)

sig_str = '***' if p_test < 0.001 else ('**' if p_test < 0.01 else '*')
info_text = (
    f"閾値回帰 vs 線形OLS\n\n"
    f"F検定統計量: {F_test:.2f}\n"
    f"p値: {p_test:.4f} {sig_str}\n\n"
    f"線形OLS  R² = {r2_linear:.3f}\n"
    f"閾値回帰 R² = {r2_threshold:.3f}\n\n"
    f"RSS改善: {(rss_linear - rss_threshold)/1e6:.2f}×10⁶\n"
    f"τ̂ = {tau_hat:.0f} 人/km²"
)
axes[1].text(0.5, 0.5, info_text, transform=axes[1].transAxes,
             ha='center', va='center', fontsize=12,
             bbox=dict(boxstyle='round', facecolor='#E8F5E9', alpha=0.8))
axes[1].axis('off')
axes[1].set_title("F検定結果（閾値効果の有意性）")

fig.suptitle("図3: 線形OLS vs 閾値回帰モデルの比較", fontsize=13)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U3_fig3_panel.png", dpi=150)
plt.close()
print("fig3 saved")

# ================================================================
# 図4: 都道府県別 — 高域/低域の効果分布
# ================================================================
fig, ax = plt.subplots(figsize=(12, 7))
for i, pref in enumerate(prefectures):
    xv = x_all[i]
    yv = y_all[i]
    pref_short = pref.rstrip('県府都道')
    if xv < tau_hat:
        ax.scatter(xv, yv, color='#1565C0', s=60, alpha=0.7, zorder=3)
        ax.text(xv + 10, yv, pref_short, fontsize=7, color='#1565C0', alpha=0.85)
    else:
        ax.scatter(xv, yv, color='#C62828', s=60, alpha=0.7, zorder=3)
        ax.text(xv + 10, yv, pref_short, fontsize=7, color='#C62828', alpha=0.85)

ax.axvline(tau_hat, color='orange', lw=2, linestyle='--',
           label=f'閾値 τ̂={tau_hat:.0f}人/km²')

x_lo = np.linspace(x_all[mask_low].min(), tau_hat, 100)
x_hi = np.linspace(tau_hat, x_all[mask_high].max(), 100)
ax.plot(x_lo, b_low[0] + b_low[1] * x_lo, '-', color='#1565C0', lw=2.5, alpha=0.6)
ax.plot(x_hi, b_high[0] + b_high[1] * x_hi, '-', color='#C62828', lw=2.5, alpha=0.6)

low_patch = mpatches.Patch(color='#1565C0', alpha=0.7, label='低密度域（密度 < τ）')
high_patch = mpatches.Patch(color='#C62828', alpha=0.7, label='高密度域（密度 >= τ）')
ax.legend(handles=[low_patch, high_patch,
                   mpatches.Patch(color='orange', label=f'τ={tau_hat:.0f}人/km²')],
          fontsize=10)

ax.set_xlabel("人口密度（人/km²）")
ax.set_ylabel("1人当たり県民所得（万円）")
ax.set_title("図4: 都道府県別 人口密度と1人当たり県民所得の分布", fontsize=13)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U3_fig4_map.png", dpi=150)
plt.close()
print("fig4 saved")
print("All figures saved.")
