"""
2021_U2_yushu.py
================
教育用ハンズオン教材：合計特殊出生率の地域格差 - パネルデータによる決定要因分析
（2021年度 統計データ分析コンペティション 優秀賞 大学生・一般の部）

手法:
  - PanelOLS 固定効果モデル (linearmodels)
  - Hausman 検定 (FE vs RE 比較)
  - 相関分析・時系列分析

データ: SSDSE-B-2026.csv（都道府県別パネル, 2012-2023）

使用変数:
  TFR                 = 合計特殊出生率
  保育所密度           = 保育所等数 / 総人口 × 10000
  女性就業率_代理       = 15〜64歳人口（女）/ 15〜64歳人口 × 100
  婚姻率              = 婚姻件数 / 総人口 × 1000
  高齢化率             = 65歳以上人口 / 総人口 × 100
  消費支出_log        = log(消費支出（二人以上の世帯）)

出力 (html/figures/):
  2021_U2_fig1_timeseries.png  - TFRの時系列推移（地域別平均）
  2021_U2_fig2_scatter.png     - 婚姻率 vs TFR 散布図（2021年）
  2021_U2_fig3_fe_coef.png     - 固定効果モデル 係数プロット
  2021_U2_fig4_hausman.png     - Hausman検定: FE vs RE 係数比較
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2021_U2_yushu.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rcParams

# ── フォント設定（Hiragino Sans / fallback） ──────────────────────────
rcParams['font.family'] = ['Hiragino Sans', 'Hiragino Kaku Gothic ProN',
                            'AppleGothic', 'Noto Sans CJK JP', 'sans-serif']
rcParams['axes.unicode_minus'] = False

# ── パス設定 ──────────────────────────────────────────────────────────
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ── 地域区分 ──────────────────────────────────────────────────────────
REGION_MAP = {
    '北海道': '北海道', '青森県': '東北', '岩手県': '東北', '宮城県': '東北',
    '秋田県': '東北', '山形県': '東北', '福島県': '東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国', '島根県': '中国', '岡山県': '中国', '広島県': '中国', '山口県': '中国',
    '徳島県': '四国', '香川県': '四国', '愛媛県': '四国', '高知県': '四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄',
}

REGION_ORDER = ['北海道', '東北', '関東', '中部', '近畿', '中国', '四国', '九州・沖縄']
REGION_COLORS = {
    '北海道': '#1f77b4', '東北': '#ff7f0e', '関東': '#2ca02c',
    '中部': '#d62728', '近畿': '#9467bd', '中国': '#8c564b',
    '四国': '#e377c2', '九州・沖縄': '#17becf',
}

# ── データ読込・変数作成 ──────────────────────────────────────────────
print("データ読込中...")
df_raw = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_raw = df_raw.rename(columns={
    '年度': 'year',
    '都道府県': 'pref',
    '総人口': 'pop_total',
    '15～64歳人口': 'pop_1564',
    '15～64歳人口（女）': 'pop_1564_f',
    '65歳以上人口': 'pop_65plus',
    '合計特殊出生率': 'TFR',
    '婚姻件数': 'marriages',
    '保育所等数': 'nurseries',
    '消費支出（二人以上の世帯）': 'consumption',
})

# 都道府県のみ抽出（市区町村除外済: SSDSE-Bは都道府県レベル）
df = df_raw[['year', 'pref', 'pop_total', 'pop_1564', 'pop_1564_f',
             'pop_65plus', 'TFR', 'marriages', 'nurseries', 'consumption']].copy()

# 派生変数
df['保育所密度']     = df['nurseries'] / df['pop_total'] * 10000
df['女性就業率_代理'] = df['pop_1564_f'] / df['pop_1564'] * 100
df['婚姻率']         = df['marriages'] / df['pop_total'] * 1000
df['高齢化率']       = df['pop_65plus'] / df['pop_total'] * 100
df['消費支出_log']   = np.log(df['consumption'])
df['地域']           = df['pref'].map(REGION_MAP)

# 欠損除外
explainers = ['保育所密度', '女性就業率_代理', '婚姻率', '高齢化率', '消費支出_log']
df = df.dropna(subset=['TFR'] + explainers)
print(f"  有効サンプル: {len(df)}行, {df['pref'].nunique()}都道府県, {df['year'].nunique()}年次")

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fig 1: TFRの時系列推移（地域別平均）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
print("Fig1: TFRの時系列推移（地域別平均）作成中...")

df_region_yr = (df.groupby(['year', '地域'])['TFR']
                  .mean().reset_index())
df_nat_yr    = df.groupby('year')['TFR'].mean()

fig1, ax1 = plt.subplots(figsize=(10, 6), dpi=150)

for region in REGION_ORDER:
    sub = df_region_yr[df_region_yr['地域'] == region].sort_values('year')
    if sub.empty:
        continue
    ax1.plot(sub['year'], sub['TFR'],
             color=REGION_COLORS[region], marker='o', markersize=4,
             linewidth=1.8, label=region)

ax1.plot(df_nat_yr.index, df_nat_yr.values,
         color='black', linestyle='--', linewidth=2.5,
         marker='D', markersize=5, label='全国平均', zorder=5)

ax1.axhline(y=2.07, color='red', linestyle=':', linewidth=1.2, alpha=0.7)
ax1.text(2023.1, 2.09, '人口置換水準\n(2.07)', fontsize=9, color='red', va='bottom')

ax1.set_xlabel('年度', fontsize=12)
ax1.set_ylabel('合計特殊出生率 (TFR)', fontsize=12)
ax1.set_title('合計特殊出生率の時系列推移（地域別平均, 2012–2023）', fontsize=14, fontweight='bold')
ax1.set_xticks(sorted(df['year'].unique()))
ax1.set_xticklabels([str(y) for y in sorted(df['year'].unique())], rotation=45, ha='right')
ax1.set_ylim(0.8, 2.2)
ax1.legend(loc='upper right', fontsize=9, ncol=2, framealpha=0.8)
ax1.grid(axis='y', linestyle='--', alpha=0.4)
ax1.spines[['top', 'right']].set_visible(False)

fig1.tight_layout()
out1 = os.path.join(FIG_DIR, '2021_U2_fig1_timeseries.png')
fig1.savefig(out1, dpi=150, bbox_inches='tight')
plt.close(fig1)
print(f"  -> {out1}")

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fig 2: 婚姻率 vs TFR 散布図（2021年, 都道府県ラベル, 地域色）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
print("Fig2: 婚姻率 vs TFR 散布図（2021年）作成中...")

df2021 = df[df['year'] == 2021].copy()

fig2, ax2 = plt.subplots(figsize=(11, 8), dpi=150)

for region in REGION_ORDER:
    sub = df2021[df2021['地域'] == region]
    ax2.scatter(sub['婚姻率'], sub['TFR'],
                color=REGION_COLORS[region], s=80, alpha=0.85,
                edgecolors='white', linewidths=0.5, label=region, zorder=3)
    for _, row in sub.iterrows():
        name = row['pref'].replace('県', '').replace('都', '').replace('道', '').replace('府', '')
        ax2.annotate(name, (row['婚姻率'], row['TFR']),
                     textcoords='offset points', xytext=(4, 3),
                     fontsize=7.5, color=REGION_COLORS.get(row['地域'], 'gray'), alpha=0.9)

# 回帰直線
from scipy import stats as scipy_stats
slope, intercept, r_val, p_val, _ = scipy_stats.linregress(df2021['婚姻率'], df2021['TFR'])
x_range = np.linspace(df2021['婚姻率'].min(), df2021['婚姻率'].max(), 100)
ax2.plot(x_range, slope * x_range + intercept,
         color='black', linestyle='--', linewidth=1.5, alpha=0.7,
         label=f'回帰直線 (r={r_val:.3f}, p={p_val:.3f})')

ax2.set_xlabel('婚姻率（件/千人）', fontsize=12)
ax2.set_ylabel('合計特殊出生率 (TFR)', fontsize=12)
ax2.set_title('婚姻率と合計特殊出生率の関係（2021年, 都道府県別）', fontsize=14, fontweight='bold')
ax2.legend(loc='upper left', fontsize=9, framealpha=0.85)
ax2.grid(linestyle='--', alpha=0.3)
ax2.spines[['top', 'right']].set_visible(False)

fig2.tight_layout()
out2 = os.path.join(FIG_DIR, '2021_U2_fig2_scatter.png')
fig2.savefig(out2, dpi=150, bbox_inches='tight')
plt.close(fig2)
print(f"  -> {out2}")

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# パネルデータ準備 & PanelOLS FE / RE 推定
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
print("PanelOLS 固定効果・変量効果モデル推定中...")

from linearmodels.panel import PanelOLS, RandomEffects

df_panel = df.set_index(['pref', 'year'])
y  = df_panel['TFR']
X  = df_panel[explainers]

import statsmodels.api as sm
X_const = sm.add_constant(X)

# 固定効果モデル (time_effects=True で年次FEも含む)
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    fe_model = PanelOLS(y, X_const, entity_effects=True, time_effects=True)
    fe_res   = fe_model.fit(cov_type='clustered', cluster_entity=True)

    re_model = RandomEffects(y, X_const)
    re_res   = re_model.fit()

print(fe_res.summary)

# FE係数（const除く）
fe_params = fe_res.params.drop('const', errors='ignore')
fe_se     = fe_res.std_errors.drop('const', errors='ignore')
fe_pvals  = fe_res.pvalues.drop('const', errors='ignore')

# RE係数（const除く）
re_params = re_res.params.drop('const', errors='ignore')
re_se     = re_res.std_errors.drop('const', errors='ignore')

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fig 3: FE 係数プロット
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
print("Fig3: FE係数プロット作成中...")

var_labels = {
    '保育所密度':     '保育所密度\n(施設数/万人)',
    '女性就業率_代理': '女性就業率\n代理変数(%)',
    '婚姻率':         '婚姻率\n(件/千人)',
    '高齢化率':       '高齢化率\n(%)',
    '消費支出_log':   '消費支出\n(対数)',
}

n_vars  = len(fe_params)
y_pos   = np.arange(n_vars)
colors3 = []
for v, p in zip(fe_params.index, fe_pvals.values):
    if p < 0.01:
        colors3.append('#C62828')
    elif p < 0.05:
        colors3.append('#E65100')
    elif p < 0.10:
        colors3.append('#F9A825')
    else:
        colors3.append('#9E9E9E')

fig3, ax3 = plt.subplots(figsize=(9, 5), dpi=150)

bars = ax3.barh(y_pos, fe_params.values, xerr=1.96 * fe_se.values,
                color=colors3, edgecolor='white', height=0.55,
                error_kw={'elinewidth': 1.5, 'ecolor': '#555', 'capsize': 4})
ax3.axvline(0, color='black', linewidth=1.0, linestyle='-')

labels = [var_labels.get(v, v) for v in fe_params.index]
ax3.set_yticks(y_pos)
ax3.set_yticklabels(labels, fontsize=11)
ax3.set_xlabel('推定係数（95%信頼区間）', fontsize=12)
ax3.set_title('TFR決定要因の固定効果推定値\n（都道府県・年次FE, クラスター標準誤差）',
              fontsize=13, fontweight='bold')

legend_handles = [
    mpatches.Patch(color='#C62828', label='p < 0.01'),
    mpatches.Patch(color='#E65100', label='p < 0.05'),
    mpatches.Patch(color='#F9A825', label='p < 0.10'),
    mpatches.Patch(color='#9E9E9E', label='有意でない'),
]
ax3.legend(handles=legend_handles, loc='lower right', fontsize=9, framealpha=0.85)

# p値注記
for i, (coef, p) in enumerate(zip(fe_params.values, fe_pvals.values)):
    star = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.10 else 'n.s.'
    offset = 1.96 * fe_se.values[i] + abs(coef) * 0.02 + 0.002
    ax3.text(coef + (offset if coef >= 0 else -offset), i,
             star, va='center', ha='left' if coef >= 0 else 'right',
             fontsize=10, color='#333')

ax3.grid(axis='x', linestyle='--', alpha=0.3)
ax3.spines[['top', 'right']].set_visible(False)

fig3.tight_layout()
out3 = os.path.join(FIG_DIR, '2021_U2_fig3_fe_coef.png')
fig3.savefig(out3, dpi=150, bbox_inches='tight')
plt.close(fig3)
print(f"  -> {out3}")

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Hausman 検定 & Fig 4: FE vs RE 係数比較
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
print("Hausman検定・Fig4作成中...")

# Hausman統計量の手動計算
try:
    # linearmodels のHausman検定
    from linearmodels.panel.results import compare
    # coefficient difference test
    b_diff = fe_params.values - re_params[fe_params.index].values
    fe_cov = fe_res.cov.loc[fe_params.index, fe_params.index].values
    re_cov = re_res.cov.loc[fe_params.index, fe_params.index].values
    cov_diff = fe_cov - re_cov
    # Hausman statistic
    try:
        hausman_stat = float(b_diff @ np.linalg.pinv(cov_diff) @ b_diff)
        from scipy.stats import chi2 as chi2_dist
        hausman_pval = 1 - chi2_dist.cdf(hausman_stat, df=len(b_diff))
        hausman_ok = True
    except Exception:
        hausman_ok = False
except Exception:
    hausman_ok = False

if hausman_ok:
    print(f"  Hausman統計量: {hausman_stat:.4f}, p値: {hausman_pval:.4f}")
    # Negative Hausman stat (numerical issue): treat as non-rejection
    if hausman_stat < 0:
        hausman_ok = False
        print("  ※ 統計量が負 (数値的問題) → FE/RE差の視覚的比較を表示")
    else:
        verdict = "固定効果モデルを採択" if hausman_pval < 0.05 else "変量効果モデルを採択"
        print(f"  判定: {verdict}")

# Fig4: FE vs RE 係数比較パネル
fig4, ax4 = plt.subplots(figsize=(10, 5.5), dpi=150)

x_pos = np.arange(len(fe_params))
width = 0.35

re_vals = re_params[fe_params.index].values
re_sev  = re_se[fe_params.index].values

bars_fe = ax4.bar(x_pos - width/2, fe_params.values, width,
                   yerr=1.96 * fe_se.values, capsize=4,
                   color='#1565C0', alpha=0.85, label='固定効果 (FE)',
                   error_kw={'elinewidth': 1.5, 'ecolor': '#0D47A1'})
bars_re = ax4.bar(x_pos + width/2, re_vals, width,
                   yerr=1.96 * re_sev, capsize=4,
                   color='#E65100', alpha=0.75, label='変量効果 (RE)',
                   error_kw={'elinewidth': 1.5, 'ecolor': '#BF360C'})

ax4.axhline(0, color='black', linewidth=0.8)
tick_labels = [var_labels.get(v, v) for v in fe_params.index]
ax4.set_xticks(x_pos)
ax4.set_xticklabels(tick_labels, fontsize=10)
ax4.set_ylabel('推定係数', fontsize=12)

if hausman_ok and hausman_stat > 0:
    title_str = (f'Hausman検定: FE vs RE 係数比較\n'
                 f'χ²({len(b_diff)}) = {hausman_stat:.3f}, p = {hausman_pval:.4f} → {verdict}')
else:
    # Summarize coefficient differences as informal Hausman-style comparison
    max_diff_var = fe_params.index[np.argmax(np.abs(fe_params.values - re_vals))]
    title_str = ('Hausman検定: FE vs RE 係数比較\n'
                 f'（最大乖離変数: {var_labels.get(max_diff_var, max_diff_var).replace(chr(10), " ")}）'
                 '  → 両モデルの差から内生性を診断')

ax4.set_title(title_str, fontsize=12, fontweight='bold')
ax4.legend(fontsize=11, framealpha=0.85)
ax4.grid(axis='y', linestyle='--', alpha=0.3)
ax4.spines[['top', 'right']].set_visible(False)

# 差の大きい箇所に矢印注記
for i, (fe_v, re_v, var) in enumerate(zip(fe_params.values, re_vals, fe_params.index)):
    diff = abs(fe_v - re_v)
    if diff > 0.03:
        ax4.annotate('', xy=(x_pos[i] + width/2, re_v),
                     xytext=(x_pos[i] - width/2, fe_v),
                     arrowprops=dict(arrowstyle='<->', color='darkgreen',
                                     lw=1.2, connectionstyle='arc3,rad=0.15'))

fig4.tight_layout()
out4 = os.path.join(FIG_DIR, '2021_U2_fig4_hausman.png')
fig4.savefig(out4, dpi=150, bbox_inches='tight')
plt.close(fig4)
print(f"  -> {out4}")

print("\n全図作成完了。")
print(f"  Fig1: {out1}")
print(f"  Fig2: {out2}")
print(f"  Fig3: {out3}")
print(f"  Fig4: {out4}")
