"""
教育用再現コード: 2025年 統計データ分析コンペティション 優秀賞（大学生）
=================================================================
論文タイトル：女性就業率と少子化の動態
手法：VAR（ベクトル自己回帰）モデル + Granger因果性検定

【分析概要】
  SSDSE-B（都道府県別パネル、2012-2023年）を用いて
  合計特殊出生率（TFR）と婚姻率の間のGranger因果性を検証する。
  婚姻率 = 婚姻件数 / 15～64歳人口 × 1000

  分析1：全国年度平均の時系列でVAR(2)推定 → Granger因果検定 + IRF
  分析2：都道府県ごとの散布図（複数年）

【使用データ（実データ）】
  data/raw/SSDSE-B-2026.csv — 47都道府県 パネル 2012-2023
  出典: 政府統計の総合窓口（e-Stat）、SSDSE（教育用標準データセット）

【出力図】
  html/figures/2025_U2_fig1_trend.png  — TFR・婚姻率の時系列推移
  html/figures/2025_U2_fig2_irf.png   — インパルス応答関数（IRF）
  html/figures/2025_U2_fig3_granger.png — Granger因果検定の結果
  html/figures/2025_U2_fig4_scatter.png — 都道府県別散布図（複数年）
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2025_U2_yushu.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from statsmodels.tsa.vector_ar.var_model import VAR
import statsmodels.api as sm

warnings.filterwarnings('ignore')

# ──────────────────────────────────────────────────────────────
# 共通設定
# ──────────────────────────────────────────────────────────────
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIGURE_DIR = os.path.normpath('html/figures')
DATA_DIR   = os.path.normpath('data/raw')
os.makedirs(FIGURE_DIR, exist_ok=True)

COLORS = {
    'primary':   '#1565C0',
    'secondary': '#E65100',
    'success':   '#2E7D32',
    'danger':    '#C62828',
    'purple':    '#6A1B9A',
    'teal':      '#00695C',
    'gray':      '#616161',
}

# ──────────────────────────────────────────────────────────────
# データ読み込み（SSDSE-B-2026）
# ──────────────────────────────────────────────────────────────
print("=" * 65)
print("■ データ読み込み: SSDSE-B-2026.csv（47都道府県、2012-2023年）")
print("=" * 65)

df_raw = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'),
    encoding='cp932',
    header=1
)

# 47都道府県のみ抽出（R + 5桁数字）
df_pref = df_raw[df_raw['地域コード'].str.match(r'^R\d{5}$', na=False)].copy()

# 数値変換
for col in ['合計特殊出生率', '婚姻件数', '15～64歳人口', '総人口',
            '65歳以上人口', '15歳未満人口']:
    df_pref[col] = pd.to_numeric(df_pref[col], errors='coerce')

# 婚姻率 = 婚姻件数 / 15～64歳人口 × 1000（労働年齢人口千人あたり婚姻件数）
df_pref['婚姻率'] = df_pref['婚姻件数'] / df_pref['15～64歳人口'] * 1000

print(f"データ件数: {len(df_pref)}（{df_pref['年度'].nunique()}年 × {df_pref['都道府県'].nunique()}都道府県）")
print(f"年度: {sorted(df_pref['年度'].unique())}")

# ──────────────────────────────────────────────────────────────
# 分析1: VARモデル（全国平均の時系列）
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("■ 分析1：VARモデル + Granger因果性検定（全国年度平均）")
print("=" * 65)

# 全国年度平均を計算
annual = df_pref.groupby('年度')[['合計特殊出生率', '婚姻率']].mean().sort_index()
annual.columns = ['TFR', '婚姻率']
annual = annual.dropna()

print(f"\n【全国年度平均（{annual.index[0]}-{annual.index[-1]}年）】")
print(annual.round(4))

# VAR モデル推定（ラグ次数を AIC で選択、最大2）
var_model = VAR(annual)
var_result = var_model.fit(maxlags=2, ic='aic')
lag_order = var_result.k_ar
print(f"\n【VAR モデル情報】")
print(f"  選択ラグ次数 : {lag_order}")
print(f"  標本数       : T={len(annual)}")
print(f"  AIC          : {var_result.aic:.4f}")
print(f"  BIC          : {var_result.bic:.4f}")

# Granger 因果性検定
print(f"\n【Granger因果性検定（Wald検定）】")
print(f"  帰無仮説：原因変数の係数がすべてゼロ（Granger因果なし）\n")
granger_results = {}

test_pairs = [
    ('TFR',    '婚姻率',  '婚姻率 → TFR（婚姻率が TFR を予測するか）'),
    ('婚姻率', 'TFR',     'TFR → 婚姻率（TFR が婚姻率を予測するか）'),
]

for response, causing, desc in test_pairs:
    try:
        gc = var_result.test_causality(response, causing=causing, kind='wald')
        stat = float(gc.test_statistic)
        pval = float(gc.pvalue)
        df_gc = int(gc.df)
    except Exception:
        gc = var_result.test_causality(response, causing=causing, kind='f')
        stat = float(gc.test_statistic)
        pval = float(gc.pvalue)
        df_gc = lag_order
    sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else 'n.s.'
    granger_results[desc] = {'stat': stat, 'pvalue': pval, 'df': df_gc, 'sig': sig,
                              'response': response, 'causing': causing}
    print(f"  {desc}")
    print(f"    Wald統計量={stat:.3f}  df={df_gc}  p値={pval:.4f}  {sig}")

# インパルス応答関数（IRF）
n_periods = 8
irf_obj = var_result.irf(n_periods)

# TFR の変数インデックス
var_names = list(annual.columns)
idx_tfr    = var_names.index('TFR')
idx_marr   = var_names.index('婚姻率')

# 婚姻率 → TFR の IRF
irf_marr_to_tfr = irf_obj.irfs[:, idx_tfr, idx_marr]
# TFR → 婚姻率 の IRF
irf_tfr_to_marr = irf_obj.irfs[:, idx_marr, idx_tfr]

# 漸近デルタ法による IRF 信頼区間（statsmodels 組み込み、ランダム数不使用）
print("\nIRF 信頼区間計算中（漸近デルタ法）...")
irf_se = irf_obj.stderr()  # shape: (n_periods+1, n_vars, n_vars) — Lütkepohl delta method
irf_mt_lo = irf_marr_to_tfr - 1.96 * irf_se[:, idx_tfr, idx_marr]
irf_mt_hi = irf_marr_to_tfr + 1.96 * irf_se[:, idx_tfr, idx_marr]
irf_tm_lo = irf_tfr_to_marr - 1.96 * irf_se[:, idx_marr, idx_tfr]
irf_tm_hi = irf_tfr_to_marr + 1.96 * irf_se[:, idx_marr, idx_tfr]
print("完了（漸近95%CI）")


# ──────────────────────────────────────────────────────────────
# 分析2: 都道府県別パネル（TFR vs 婚姻率、複数年）
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("■ 分析2：都道府県別 TFR × 婚姻率（地域間比較）")
print("=" * 65)

# 3つの代表年度を使用
years_scatter = [2014, 2018, 2022]
df_scatter = df_pref[df_pref['年度'].isin(years_scatter)].dropna(
    subset=['合計特殊出生率', '婚姻率']
).copy()
print(f"\n散布図データ: {len(df_scatter)}件（{years_scatter}年 × 47都道府県）")

# 都道府県別回帰係数（2022年の例）
df22 = df_scatter[df_scatter['年度'] == 2022].copy()
X22 = sm.add_constant(df22['婚姻率'].values)
fit22 = sm.OLS(df22['合計特殊出生率'].values, X22).fit()
print(f"\n2022年 都道府県別 OLS: 婚姻率 → TFR")
print(f"  係数={fit22.params[1]:.4f}, p={fit22.pvalues[1]:.4f}, R²={fit22.rsquared:.3f}")


# ================================================================
# 図1：時系列プロット（TFR と婚姻率の推移）
# ================================================================
print("\n図1: 時系列推移プロットを作成中...")

fig1, ax1a = plt.subplots(figsize=(10, 5))
ax1b = ax1a.twinx()

years_ts = annual.index.values

l1 = ax1a.plot(years_ts, annual['TFR'], color=COLORS['primary'],
               linewidth=2.5, marker='o', markersize=7,
               markerfacecolor='white', markeredgewidth=2,
               label='合計特殊出生率（TFR）')
l2 = ax1b.plot(years_ts, annual['婚姻率'], color=COLORS['secondary'],
               linewidth=2.5, marker='s', markersize=7,
               markerfacecolor='white', markeredgewidth=2,
               linestyle='--', label='婚姻率（千人あたり）')

ax1a.set_xlabel('年度', fontsize=12)
ax1a.set_ylabel('合計特殊出生率（TFR）', fontsize=12, color=COLORS['primary'])
ax1b.set_ylabel('婚姻率（15-64歳人口千人あたり）', fontsize=12, color=COLORS['secondary'])
ax1a.tick_params(axis='y', labelcolor=COLORS['primary'])
ax1b.tick_params(axis='y', labelcolor=COLORS['secondary'])

ax1a.set_title('全国平均 合計特殊出生率（TFR）と婚姻率の推移（2012-2023年）\n'
               'データ出典: SSDSE-B-2026（e-Stat）', fontsize=12, fontweight='bold')
ax1a.set_xlim(2011.5, 2023.5)
ax1a.grid(True, alpha=0.3)

lines = l1 + l2
labels = [l.get_label() for l in lines]
ax1a.legend(lines, labels, loc='upper right', fontsize=10)

# 注釈
ax1a.annotate('COVID-19\n（2020年）', xy=(2020, annual.loc[2020, 'TFR']),
              xytext=(2020.5, annual.loc[2020, 'TFR'] + 0.04),
              arrowprops=dict(arrowstyle='->', color='gray', lw=1.2),
              fontsize=8, color='gray')

plt.tight_layout()
fig1.savefig(os.path.join(FIGURE_DIR, '2025_U2_fig1_trend.png'), bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  → 2025_U2_fig1_trend.png 保存完了")


# ================================================================
# 図2：インパルス応答関数（IRF）
# ================================================================
print("図2: インパルス応答関数（IRF）を作成中...")

fig2, axes2 = plt.subplots(1, 2, figsize=(13, 5))
h_vals = np.arange(n_periods + 1)

# 左パネル: 婚姻率ショック → TFR
ax2a = axes2[0]
ax2a.plot(h_vals, irf_marr_to_tfr, color=COLORS['primary'],
          linewidth=2.5, marker='o', markersize=5, label='IRF 点推定値')
ax2a.fill_between(h_vals, irf_mt_lo, irf_mt_hi,
                  color=COLORS['primary'], alpha=0.2, label='漸近95%CI（デルタ法）')
ax2a.axhline(0, color='black', linewidth=0.8, linestyle='--')
ax2a.set_xlabel('ホライゾン（年）', fontsize=11)
ax2a.set_ylabel('TFR の応答', fontsize=11)
ax2a.set_title('IRF：婚姻率への1単位ショック → TFR の応答\n（漸近デルタ法 95%CI）',
               fontsize=11, fontweight='bold')
ax2a.legend(fontsize=9)
ax2a.grid(True, alpha=0.3)
ax2a.set_xlim(-0.3, n_periods + 0.3)

# 右パネル: TFR ショック → 婚姻率
ax2b = axes2[1]
ax2b.plot(h_vals, irf_tfr_to_marr, color=COLORS['secondary'],
          linewidth=2.5, marker='s', markersize=5, label='IRF 点推定値')
ax2b.fill_between(h_vals, irf_tm_lo, irf_tm_hi,
                  color=COLORS['secondary'], alpha=0.2, label='漸近95%CI（デルタ法）')
ax2b.axhline(0, color='black', linewidth=0.8, linestyle='--')
ax2b.set_xlabel('ホライゾン（年）', fontsize=11)
ax2b.set_ylabel('婚姻率の応答', fontsize=11)
ax2b.set_title('IRF：TFR への1単位ショック → 婚姻率の応答\n（漸近デルタ法 95%CI）',
               fontsize=11, fontweight='bold')
ax2b.legend(fontsize=9)
ax2b.grid(True, alpha=0.3)
ax2b.set_xlim(-0.3, n_periods + 0.3)

# 学習ポイントの注釈
fig2.text(0.5, -0.04,
          '【学習ポイント】IRFは「1変数へのショックが他変数に与える動学的影響」を可視化する。\n'
          '95%CI（薄色帯）がゼロを含む場合、その応答は統計的に有意でない。',
          ha='center', fontsize=9, style='italic', color='gray',
          bbox=dict(boxstyle='round', facecolor='#FFF9C4', alpha=0.7))

plt.tight_layout()
fig2.savefig(os.path.join(FIGURE_DIR, '2025_U2_fig2_irf.png'), bbox_inches='tight', dpi=150)
plt.close(fig2)
print("  → 2025_U2_fig2_irf.png 保存完了")


# ================================================================
# 図3：Granger 因果性検定の結果
# ================================================================
print("図3: Granger因果性検定結果を作成中...")

fig3, axes3 = plt.subplots(1, 2, figsize=(13, 5))

# 左パネル: Wald 統計量の棒グラフ
ax3a = axes3[0]
labels_gc = list(granger_results.keys())
wald_vals  = [granger_results[k]['stat']   for k in labels_gc]
pvals_gc   = [granger_results[k]['pvalue'] for k in labels_gc]
sigs_gc    = [granger_results[k]['sig']    for k in labels_gc]
df_gc_vals = [granger_results[k]['df']     for k in labels_gc]

bar_colors_gc = []
for p in pvals_gc:
    if p < 0.01:   bar_colors_gc.append(COLORS['danger'])
    elif p < 0.05: bar_colors_gc.append(COLORS['secondary'])
    elif p < 0.1:  bar_colors_gc.append(COLORS['success'])
    else:          bar_colors_gc.append(COLORS['gray'])

y_pos = np.arange(len(labels_gc))
ax3a.barh(y_pos, wald_vals, color=bar_colors_gc, alpha=0.85,
          edgecolor='white', linewidth=0.8, height=0.5)

# 臨界値
chi2_05 = 3.841 * lag_order
chi2_01 = 6.635 * lag_order
ax3a.axvline(chi2_05, color='orange', linestyle='--', linewidth=1.5,
             label=f'χ²₅%(df={lag_order})={chi2_05:.2f}')
ax3a.axvline(chi2_01, color='red',    linestyle=':',  linewidth=1.5,
             label=f'χ²₁%(df={lag_order})={chi2_01:.2f}')

ax3a.set_yticks(y_pos)
# 短縮ラベル
short_labels = ['婚姻率 → TFR', 'TFR → 婚姻率']
ax3a.set_yticklabels(short_labels, fontsize=11)
ax3a.set_xlabel('Wald統計量（χ²）', fontsize=11)
ax3a.set_title(f'Granger因果性検定（VAR({lag_order})、Wald検定）',
               fontsize=12, fontweight='bold')

for i, (v, sig, pv) in enumerate(zip(wald_vals, sigs_gc, pvals_gc)):
    ax3a.text(v + 0.05, i, f'{v:.2f} {sig}\n(p={pv:.3f})',
              va='center', fontsize=9, fontweight='bold')

legend_patches = [
    plt.Rectangle((0, 0), 1, 1, color=COLORS['danger'],    label='p < 0.01 ***'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['secondary'], label='p < 0.05 **'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['success'],   label='p < 0.10 *'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['gray'],       label='n.s.'),
]
ax3a.legend(handles=legend_patches + [
    plt.Line2D([0], [0], color='orange', linestyle='--', label=f'χ²₅%={chi2_05:.2f}'),
    plt.Line2D([0], [0], color='red',    linestyle=':',  label=f'χ²₁%={chi2_01:.2f}'),
], loc='lower right', fontsize=8)
ax3a.set_xlim(0, max(wald_vals) * 1.6 + chi2_01 * 0.2)
ax3a.grid(axis='x', alpha=0.3)
ax3a.invert_yaxis()

# 右パネル: p値の比較とVAR係数サマリー
ax3b = axes3[1]
# VAR 係数行列を棒グラフで可視化（ラグ1のみ）
coef_mat = var_result.coefs[0]   # shape: (n_vars, n_vars)  [i, j]: i が被説明変数, j が説明変数
bse_mat  = np.sqrt(np.diag(var_result.cov_params()[:var_result.neqs**2, :var_result.neqs**2]).reshape(var_result.neqs, var_result.neqs))

# 4つのクロス係数
coef_labels = [
    f'TFR(t-1)→TFR',
    f'婚姻率(t-1)→TFR',
    f'TFR(t-1)→婚姻率',
    f'婚姻率(t-1)→婚姻率',
]
coef_vals = [coef_mat[idx_tfr, idx_tfr],
             coef_mat[idx_tfr, idx_marr],
             coef_mat[idx_marr, idx_tfr],
             coef_mat[idx_marr, idx_marr]]

if lag_order >= 1:
    coef_vals_disp = coef_vals
else:
    coef_vals_disp = coef_vals

x_pos_b = np.arange(len(coef_labels))
bar_col_b = [COLORS['primary'] if c >= 0 else COLORS['danger'] for c in coef_vals_disp]
ax3b.bar(x_pos_b, coef_vals_disp, color=bar_col_b, alpha=0.85,
         edgecolor='white', linewidth=0.8)
ax3b.axhline(0, color='black', linewidth=0.8, linestyle='--')
ax3b.set_xticks(x_pos_b)
ax3b.set_xticklabels(coef_labels, fontsize=9, rotation=15, ha='right')
ax3b.set_ylabel('VAR係数値', fontsize=11)
ax3b.set_title(f'VAR({lag_order}) ラグ1係数（ラグ1の係数行列）',
               fontsize=11, fontweight='bold')
ax3b.grid(axis='y', alpha=0.3)
for i, v in enumerate(coef_vals_disp):
    ax3b.text(i, v + (0.01 if v >= 0 else -0.02), f'{v:.3f}',
              ha='center', fontsize=9, fontweight='bold')

plt.tight_layout()
fig3.savefig(os.path.join(FIGURE_DIR, '2025_U2_fig3_granger.png'), bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  → 2025_U2_fig3_granger.png 保存完了")


# ================================================================
# 図4：都道府県別散布図（TFR vs 婚姻率、複数年）
# ================================================================
print("図4: 都道府県別散布図（複数年）を作成中...")

fig4, axes4 = plt.subplots(1, 3, figsize=(15, 5))
fig4.suptitle('都道府県別 合計特殊出生率（TFR）× 婚姻率（複数年比較）\n'
              'データ出典: SSDSE-B-2026（e-Stat）',
              fontsize=12, fontweight='bold')

scatter_colors = [COLORS['primary'], COLORS['secondary'], COLORS['success']]

for ax, yr, col in zip(axes4, years_scatter, scatter_colors):
    df_yr = df_scatter[df_scatter['年度'] == yr].dropna(
        subset=['合計特殊出生率', '婚姻率']
    )
    x = df_yr['婚姻率'].values
    y = df_yr['合計特殊出生率'].values

    ax.scatter(x, y, color=col, alpha=0.7, s=55,
               edgecolors='white', linewidth=0.7, zorder=3)

    # 回帰直線
    X_fit = sm.add_constant(x)
    fit_yr = sm.OLS(y, X_fit).fit()
    x_line = np.linspace(x.min(), x.max(), 100)
    y_line = fit_yr.params[0] + fit_yr.params[1] * x_line
    ax.plot(x_line, y_line, color=col, linewidth=2.0, linestyle='-', alpha=0.9, zorder=2)

    r2 = fit_yr.rsquared
    pv = fit_yr.pvalues[1]
    sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else 'n.s.'

    ax.set_xlabel('婚姻率（15-64歳人口千人あたり）', fontsize=10)
    ax.set_ylabel('合計特殊出生率（TFR）', fontsize=10)
    ax.set_title(f'{yr}年度\nβ={fit_yr.params[1]:.3f}{sig}  R²={r2:.3f}',
                 fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3)

    # 上位・下位都道府県にラベル
    df_yr_sorted = df_yr.sort_values('合計特殊出生率')
    for _, row in df_yr_sorted.tail(3).iterrows():
        ax.annotate(row['都道府県'],
                    (row['婚姻率'], row['合計特殊出生率']),
                    fontsize=7, alpha=0.8,
                    xytext=(3, 2), textcoords='offset points')
    for _, row in df_yr_sorted.head(2).iterrows():
        ax.annotate(row['都道府県'],
                    (row['婚姻率'], row['合計特殊出生率']),
                    fontsize=7, alpha=0.8,
                    xytext=(3, -8), textcoords='offset points')

plt.tight_layout()
fig4.savefig(os.path.join(FIGURE_DIR, '2025_U2_fig4_scatter.png'), bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  → 2025_U2_fig4_scatter.png 保存完了")


print("\n" + "=" * 65)
print("✓ 全図の生成完了")
print("=" * 65)
print("\n【主要結果サマリー】")
for desc, res in granger_results.items():
    print(f"  {desc}: Wald={res['stat']:.3f}, p={res['pvalue']:.4f} {res['sig']}")
print(f"\n  IRF（婚姻率→TFR）: h=1期後の応答={irf_marr_to_tfr[1]:.4f}")
print(f"  IRF（TFR→婚姻率）: h=1期後の応答={irf_tfr_to_marr[1]:.4f}")
print(f"\n  2022年 都道府県別 OLS: 婚姻率β={fit22.params[1]:.4f}, R²={fit22.rsquared:.3f}")
