"""
2024_U2_yushu.py
合計特殊出生率の決定要因の影響はコロナ禍で変化したのか
優秀賞 [大学生・一般の部]
天野葵、伊藤愛、神谷珠里（南山大学総合政策学部）

実データ（SSDSE-B-2026, SSDSE-E-2026）による教育用再現コード
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_U2_yushu.py
# ============================================================


import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from scipy import stats
from numpy.linalg import lstsq

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False

import os
FIGDIR = os.path.normpath('html/figures') + os.sep
DATA_B = 'data/raw/SSDSE-B-2026.csv'
DATA_E = 'data/raw/SSDSE-E-2026.csv'
os.makedirs(FIGDIR, exist_ok=True)

# ----------------------------------------------------------------
# データ読み込み: SSDSE-B 2015-2022
# ----------------------------------------------------------------
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
mask = (df_b['地域コード'].str.match(r'^R\d{5}$', na=False) &
        (df_b['地域コード'] != 'R00000') &
        df_b['年度'].between(2015, 2022))
df_b = df_b[mask].copy()

# SSDSE-E: 1人当たり県民所得（都道府県別、単年値）
df_e_raw = pd.read_csv(DATA_E, encoding='cp932', header=0)
df_e = df_e_raw.iloc[2:].copy()
df_e.columns = df_e_raw.iloc[1].values
df_e = df_e[df_e['都道府県'] != '全国'].reset_index(drop=True)
income_map = pd.to_numeric(
    df_e.set_index('都道府県')['1人当たり県民所得（平成27年基準）'], errors='coerce')

# ----------------------------------------------------------------
# パネルデータ構築
# ----------------------------------------------------------------
years = list(range(2015, 2023))
n_pref = 47

records = []
for _, row in df_b.iterrows():
    yr = int(row['年度'])
    pref = row['都道府県']
    pop = float(row['総人口']) if pd.notna(row['総人口']) else np.nan
    tfr = float(row['合計特殊出生率']) if pd.notna(row['合計特殊出生率']) else np.nan
    marriages = float(row['婚姻件数']) if pd.notna(row['婚姻件数']) else np.nan
    nursery = float(row['保育所等数']) if pd.notna(row['保育所等数']) else np.nan

    # 婚姻率 (人口千人あたり)
    mar_rate = (marriages / pop * 1000) if (pop and pop > 0) else np.nan
    # 保育所数 (人口万人あたり)
    nur_rate = (nursery / pop * 10000) if (pop and pop > 0) else np.nan

    # 所得: SSDSE-E の都道府県名を合わせる
    inc = np.nan
    for key in income_map.index:
        if pref.startswith(key.rstrip('県府都道')) or key.startswith(pref.rstrip('県府都道')):
            inc = float(income_map[key]) if pd.notna(income_map[key]) else np.nan
            break
    if np.isnan(inc) and pref in income_map.index:
        inc = float(income_map[pref])

    records.append({
        'pref': pref,
        'year': yr,
        'TFR': tfr,
        '婚姻率': mar_rate,
        '保育所数': nur_rate,
        '1人あたり所得': inc,
        'コロナダミー': 1 if yr >= 2020 else 0,
    })

df = pd.DataFrame(records)

# 所得の欠損を都道府県平均で補完（SSDSE-Eは単年のため）
inc_pref_mean = df.groupby('pref')['1人あたり所得'].transform(lambda x: x.fillna(x.mean()))
df['1人あたり所得'] = inc_pref_mean

# 数値型に変換・欠損除去
numeric_cols = ['TFR', '婚姻率', '保育所数', '1人あたり所得']
df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric, errors='coerce')
df = df.dropna(subset=numeric_cols)

# ================================================================
# 図1: TFRと主要変数の時系列（全国平均）
# ================================================================
fig, axes = plt.subplots(2, 2, figsize=(11, 8))
yr_mean = df.groupby('year').mean(numeric_only=True)

ax = axes[0, 0]
ax.plot(yr_mean.index, yr_mean['TFR'], 'o-', color='#C62828', lw=2, ms=6)
ax.axvspan(2020, 2022.5, alpha=0.12, color='gray')
ax.set_title("合計特殊出生率（TFR）", fontsize=11)
ax.set_ylabel("TFR")
ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.plot(yr_mean.index, yr_mean['婚姻率'], 's-', color='#1565C0', lw=2, ms=6)
ax.axvspan(2020, 2022.5, alpha=0.12, color='gray')
ax.set_title("婚姻率（人口千人あたり）", fontsize=11)
ax.set_ylabel("婚姻率")
ax.grid(True, alpha=0.3)

ax = axes[1, 0]
ax.plot(yr_mean.index, yr_mean['1人あたり所得'], '^-', color='#2E7D32', lw=2, ms=6)
ax.axvspan(2020, 2022.5, alpha=0.12, color='gray')
ax.set_title("1人あたり県民所得（万円）", fontsize=11)
ax.set_ylabel("所得")
ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(yr_mean.index, yr_mean['保育所数'], 'D-', color='#6A1B9A', lw=2, ms=6)
ax.axvspan(2020, 2022.5, alpha=0.12, color='gray')
ax.set_title("保育所数（人口万人あたり）", fontsize=11)
ax.set_ylabel("保育所数")
ax.grid(True, alpha=0.3)

for ax in axes.flat:
    ax.axvline(2020, color='gray', lw=1.2, linestyle='--')
    ax.set_xlabel("年")

fig.suptitle("図1: TFRと主要説明変数の時系列（47都道府県平均）\n灰色帯=コロナ禍（2020-2022）", fontsize=12)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U2_fig1_ts.png", dpi=150)
plt.close()
print("fig1 saved")

# ================================================================
# 図2: コロナ前後の回帰係数比較
# ================================================================
def ols_coef(df_sub, target, predictors):
    X = df_sub[predictors].values.astype(float)
    y_v = df_sub[target].values.astype(float)
    # standardize
    X_m, X_s = X.mean(axis=0), X.std(axis=0)
    X_s[X_s == 0] = 1
    X_std = (X - X_m) / X_s
    X_c = np.column_stack([np.ones(len(y_v)), X_std])
    coef_v, _, _, _ = lstsq(X_c, y_v, rcond=None)
    res = y_v - X_c @ coef_v
    n_r, k = X_c.shape
    sigma2 = res @ res / max(n_r - k, 1)
    try:
        cov = sigma2 * np.linalg.inv(X_c.T @ X_c)
        se = np.sqrt(np.diag(cov))[1:]
    except np.linalg.LinAlgError:
        se = np.ones(len(predictors)) * np.nan
    return coef_v[1:], se

predictors = ['婚姻率', '保育所数', '1人あたり所得']
df_pre = df[df['コロナダミー'] == 0]
df_covid = df[df['コロナダミー'] == 1]

coef_pre, se_pre = ols_coef(df_pre, 'TFR', predictors)
coef_covid, se_covid = ols_coef(df_covid, 'TFR', predictors)

x = np.arange(len(predictors))
w = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(x - w/2, coef_pre, w, label='コロナ前（2015-2019）',
       color='#1565C0', alpha=0.8, yerr=1.96 * se_pre, capsize=5)
ax.bar(x + w/2, coef_covid, w, label='コロナ禍（2020-2022）',
       color='#C62828', alpha=0.8, yerr=1.96 * se_covid, capsize=5)

ax.axhline(0, color='black', lw=1)
ax.set_xticks(x)
ax.set_xticklabels(predictors, fontsize=11)
ax.set_ylabel("標準化回帰係数（95% CI）")
ax.set_title("図2: コロナ前後の回帰係数比較（目的変数: TFR）", fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, axis='y', alpha=0.3)

for i, var in enumerate(predictors):
    change = coef_covid[i] - coef_pre[i]
    if abs(change) > 0.02:
        ax.annotate(f"Δ={change:+.3f}", xy=(i + w/2, coef_covid[i]),
                    xytext=(i + w/2 + 0.15, coef_covid[i] + 0.005),
                    fontsize=9, color='darkred')

plt.tight_layout()
plt.savefig(FIGDIR + "2024_U2_fig2_coef_compare.png", dpi=150)
plt.close()
print("fig2 saved")

# ================================================================
# 図3: 交互作用効果の可視化（婚姻率×コロナダミー, 保育所数×コロナダミー）
# ================================================================
fig, axes = plt.subplots(1, 2, figsize=(11, 5))

for ax, xvar, xlabel in [
    (axes[0], '婚姻率', '婚姻率（人口千人あたり）'),
    (axes[1], '保育所数', '保育所数（人口万人あたり）'),
]:
    for lbl, color, mask in [('コロナ前', '#1565C0', df['コロナダミー'] == 0),
                               ('コロナ禍', '#C62828', df['コロナダミー'] == 1)]:
        sub = df[mask].dropna(subset=[xvar, 'TFR'])
        ax.scatter(sub[xvar], sub['TFR'], alpha=0.15, color=color, s=15)
        x_line = np.linspace(sub[xvar].min(), sub[xvar].max(), 100)
        slope, intercept, _, _, _ = stats.linregress(sub[xvar], sub['TFR'])
        ax.plot(x_line, intercept + slope * x_line, color=color, lw=2.5,
                label=f'{lbl} (β={slope:.3f})')
    ax.set_xlabel(xlabel)
    ax.set_ylabel("TFR")
    ax.set_title(f"{xvar} × コロナダミー の交互作用", fontsize=11)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

fig.suptitle("図3: コロナ禍前後の交互作用効果（回帰直線の傾きの変化）", fontsize=13)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U2_fig3_interaction.png", dpi=150)
plt.close()
print("fig3 saved")

# ================================================================
# 図4: グレンジャー因果検定
# ================================================================
def granger_f(y_series, x_series, lag=1):
    n = len(y_series)
    y_v = y_series[lag:]
    y_lag = y_series[:-lag]
    x_lag = x_series[:-lag]
    X0 = np.column_stack([np.ones(n - lag), y_lag])
    b0, _, _, _ = lstsq(X0, y_v, rcond=None)
    rss0 = np.sum((y_v - X0 @ b0) ** 2)
    X1 = np.column_stack([np.ones(n - lag), y_lag, x_lag])
    b1, _, _, _ = lstsq(X1, y_v, rcond=None)
    rss1 = np.sum((y_v - X1 @ b1) ** 2)
    df1_v, df2_v = 1, n - lag - 3
    if df2_v <= 0 or rss1 == 0:
        return 0.0, 1.0
    F = ((rss0 - rss1) / df1_v) / (rss1 / df2_v)
    p = 1 - stats.f.cdf(F, df1_v, df2_v)
    return F, p

variables = ['婚姻率', '保育所数', '1人あたり所得']
f_stats = []
p_stats = []

for var in variables:
    F_list, P_list = [], []
    for pref, sub_df in df.groupby('pref'):
        sub = sub_df.sort_values('year')
        sub_clean = sub.dropna(subset=['TFR', var])
        if len(sub_clean) < 5:
            continue
        F, p = granger_f(sub_clean['TFR'].values, sub_clean[var].values, lag=1)
        F_list.append(F)
        P_list.append(p)
    f_stats.append(np.mean(F_list) if F_list else 0.0)
    p_stats.append(np.mean(P_list) if P_list else 1.0)

fig, ax = plt.subplots(figsize=(8, 5))
colors_g = ['#C62828' if p < 0.05 else '#90A4AE' for p in p_stats]
ax.barh(variables, f_stats, color=colors_g, alpha=0.8)
ax.axvline(stats.f.ppf(0.95, 1, 5), color='red', lw=1.5, linestyle='--',
           label='有意水準5%閾値（F臨界値）')

for i, (F, p) in enumerate(zip(f_stats, p_stats)):
    sig = "**" if p < 0.01 else ("*" if p < 0.05 else "n.s.")
    ax.text(F + 0.1, i, f"F={F:.2f} ({sig})", va='center', fontsize=10)

ax.set_title("図4: グレンジャー因果検定 F統計量\n（各変数→TFR, lag=1年, 47都道府県平均）", fontsize=12)
ax.set_xlabel("F統計量（都道府県平均）")
ax.legend(fontsize=10)
ax.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U2_fig4_granger.png", dpi=150)
plt.close()
print("fig4 saved")
print("All figures saved.")
