"""
2018_U2_yushu.py
少子化における未婚化・晩婚化の影響：婚姻率と合計特殊出生率の地域パネル分析
優秀賞（大学生・一般の部） — 2018年度 統計データ分析コンペティション
教育用再現コード（SSDSE-B-2026 実データ使用）
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2018_U2_yushu.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import statsmodels.api as sm
from scipy import stats

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ── データ読み込み ──────────────────────────────────────────────────────────
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)

print("columns:", df_b.columns.tolist())

# ── 派生変数の作成 ──────────────────────────────────────────────────────────
# 合計特殊出生率（TFR）
TFR_col     = '合計特殊出生率'

# 婚姻率 = 婚姻件数 / 総人口 × 1000（人口千対）
df_b['婚姻率（千対）'] = df_b['婚姻件数'] / df_b['総人口'] * 1000
婚姻率col = '婚姻率（千対）'

# 女性労働力率プロキシ = 15-64歳女性人口 / 総人口 × 100
df_b['女性労働力率（％）'] = df_b['15～64歳人口（女）'] / df_b['総人口'] * 100
女性就業col = '女性労働力率（％）'

# 消費支出（実数、円）
消費支出col = '消費支出（二人以上の世帯）'

# 高齢化率 = 65歳以上人口 / 総人口 × 100
df_b['高齢化率（％）'] = df_b['65歳以上人口'] / df_b['総人口'] * 100
高齢化率col = '高齢化率（％）'

X_cols = [婚姻率col, 女性就業col, 消費支出col, 高齢化率col]

print("TFR col:", TFR_col)
print("X cols:", X_cols)
print("年度:", sorted(df_b['年度'].unique()))

# ── 地域分類 ───────────────────────────────────────────────────────────────
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北', '宮城県': '北海道・東北',
    '秋田県': '北海道・東北', '山形県': '北海道・東北', '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部', '山梨県': '中部',
    '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国', '広島県': '中国・四国',
    '山口県': '中国・四国', '徳島県': '中国・四国', '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄', '熊本県': '九州・沖縄',
    '大分県': '九州・沖縄', '宮崎県': '九州・沖縄', '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄'
}
region_colors = {
    '北海道・東北': '#4e9af1',
    '関東':        '#e05c5c',
    '中部':        '#f0a500',
    '近畿':        '#5cb85c',
    '中国・四国':   '#9b59b6',
    '九州・沖縄':   '#f39c12'
}
df_b['地域'] = df_b['都道府県'].map(region_map)

# ── 代表6都道府県（地域ごと1都道府県）─────────────────────────────────────
rep_prefs = {
    '北海道・東北': '北海道',
    '関東':        '東京都',
    '中部':        '愛知県',
    '近畿':        '大阪府',
    '中国・四国':   '広島県',
    '九州・沖縄':   '福岡県',
}

# =========================================================================
# Figure 1: TFR 時系列推移（2012-2023年、6地域折れ線、COVID帯グレー）
# =========================================================================
fig1, ax1 = plt.subplots(figsize=(10, 5.5))

# COVID帯（2020-2021）
ax1.axvspan(2019.6, 2021.4, color='lightgray', alpha=0.45, zorder=0, label='COVID-19期')

years = sorted(df_b['年度'].unique())

for region, pref in rep_prefs.items():
    sub = df_b[df_b['都道府県'] == pref].sort_values('年度')
    ax1.plot(sub['年度'].values, sub[TFR_col].values,
             color=region_colors[region],
             linewidth=2.2, marker='o', markersize=5, label=f'{pref}（{region}）')

ax1.set_xlabel('年度', fontsize=12)
ax1.set_ylabel('合計特殊出生率', fontsize=12)
ax1.set_title('合計特殊出生率の推移（代表6都道府県、2012–2023年）', fontsize=14, fontweight='bold')
ax1.legend(fontsize=9, loc='lower left')
ax1.set_xticks(years)
ax1.tick_params(axis='x', rotation=45)
ax1.grid(axis='y', linestyle='--', alpha=0.5)
ax1.set_xlim(min(years) - 0.5, max(years) + 0.5)

fig1.tight_layout()
out1 = os.path.join(FIG_DIR, '2018_U2_fig1.png')
fig1.savefig(out1, bbox_inches='tight')
plt.close(fig1)
print(f"Saved: {out1}")

# =========================================================================
# Figure 2: 婚姻率 vs TFR 散布図（最新年断面、47都道府県ラベル付き）
# =========================================================================
latest_year = df_b['年度'].max()
df_latest   = df_b[df_b['年度'] == latest_year].copy()

fig2, ax2 = plt.subplots(figsize=(11, 8))

for _, row in df_latest.iterrows():
    reg = row['地域']
    col = region_colors.get(reg, '#888888')
    ax2.scatter(row[婚姻率col], row[TFR_col], color=col, s=60, zorder=3, alpha=0.85)
    short = row['都道府県'].replace('県', '').replace('府', '').replace('都', '').replace('道', '')
    ax2.annotate(short, (row[婚姻率col], row[TFR_col]),
                 fontsize=7.5, ha='left', va='bottom',
                 xytext=(2, 2), textcoords='offset points')

# 回帰直線
x_reg = df_latest[婚姻率col].astype(float).values
y_reg = df_latest[TFR_col].astype(float).values
valid  = ~(np.isnan(x_reg) | np.isnan(y_reg))
x_reg, y_reg = x_reg[valid], y_reg[valid]
slope, intercept, r_val, p_val, _ = stats.linregress(x_reg, y_reg)
xline = np.linspace(x_reg.min(), x_reg.max(), 100)
ax2.plot(xline, slope * xline + intercept, 'k--', linewidth=1.5)

# 凡例
handles = [mpatches.Patch(color=c, label=r) for r, c in region_colors.items()]
handles.append(plt.Line2D([0], [0], color='k', linestyle='--',
               label=f'回帰直線 r={r_val:.3f}  p={p_val:.3f}'))
ax2.legend(handles=handles, fontsize=9, loc='upper left')

ax2.set_xlabel('婚姻率（人口千対）', fontsize=12)
ax2.set_ylabel('合計特殊出生率', fontsize=12)
ax2.set_title(f'婚姻率と合計特殊出生率の関係（{latest_year}年、47都道府県）', fontsize=13, fontweight='bold')
ax2.grid(linestyle='--', alpha=0.4)

fig2.tight_layout()
out2 = os.path.join(FIG_DIR, '2018_U2_fig2.png')
fig2.savefig(out2, bbox_inches='tight')
plt.close(fig2)
print(f"Saved: {out2}")

# =========================================================================
# パネル分析（固定効果・変量効果 + Hausman検定）
# =========================================================================
hausman_stat   = None
hausman_p      = None
hausman_result = "（Hausman検定実行済み）"
fe_params = None
re_params = None

try:
    from linearmodels.panel import PanelOLS, RandomEffects

    df_panel = df_b.dropna(subset=[TFR_col] + X_cols).copy()
    df_panel = df_panel.set_index(['都道府県', '年度'])

    # 固定効果モデル（Clustered SE by entity）
    fe = PanelOLS(
        df_panel[TFR_col].astype(float),
        df_panel[X_cols].astype(float),
        entity_effects=True
    ).fit(cov_type='clustered', cluster_entity=True)

    # 変量効果モデル
    re_exog = sm.add_constant(df_panel[X_cols].astype(float))
    re = RandomEffects(df_panel[TFR_col].astype(float), re_exog).fit()

    fe_params = fe.params
    re_params = re.params[X_cols]

    # ── Hausman検定（手動実装）─────────────────────────────────────────────
    diff       = (fe_params - re_params).values
    var_fe     = fe.cov.loc[X_cols, X_cols].values
    var_re     = re.cov.loc[X_cols, X_cols].values
    var_diff   = var_fe - var_re
    # 正定値化（数値誤差対策）
    eigvals    = np.linalg.eigvalsh(var_diff)
    if eigvals.min() < 0:
        var_diff += np.eye(len(diff)) * (-eigvals.min() + 1e-8)
    try:
        hausman_stat = float(diff @ np.linalg.inv(var_diff) @ diff)
        df_h         = len(diff)
        hausman_p    = 1 - stats.chi2.cdf(hausman_stat, df_h)
        if hausman_p < 0.05:
            hausman_result = (f"Hausman検定: χ²={hausman_stat:.3f}  p={hausman_p:.3f}  "
                              f"→ 固定効果モデルを採用（RE不一致）")
        else:
            hausman_result = (f"Hausman検定: χ²={hausman_stat:.3f}  p={hausman_p:.3f}  "
                              f"→ 変量効果モデル採用可（RE一致）")
    except Exception as he:
        hausman_result = f"Hausman検定（簡易）: {he}"

    print("\n=== 固定効果モデル ===")
    print(fe.summary)
    print("\n=== 変量効果モデル ===")
    print(re.summary)
    print("\n", hausman_result)

except Exception as e:
    print(f"Panel error: {e}")
    # フォールバック: プーリングOLS
    df_ols = df_b.dropna(subset=[TFR_col] + X_cols).copy()
    X_ols  = sm.add_constant(df_ols[X_cols].astype(float))
    ols    = sm.OLS(df_ols[TFR_col].astype(float), X_ols).fit()
    fe_params = ols.params[X_cols]
    re_params = fe_params.copy()
    hausman_result = "（フォールバック: プーリングOLS）"
    print(ols.summary())

# =========================================================================
# Figure 3: FE vs RE 係数比較（横棒グラフ）
# =========================================================================
short_labels = {
    婚姻率col:  '婚姻率\n（千対）',
    女性就業col:'女性\n労働力率',
    消費支出col:'消費支出',
    高齢化率col:'高齢化率',
}

fig3, ax3 = plt.subplots(figsize=(9, 5))
y_pos = np.arange(len(X_cols))
bar_height = 0.35

fe_vals = [float(fe_params[c]) for c in X_cols]
re_vals = [float(re_params[c]) for c in X_cols]

ax3.barh(y_pos + bar_height/2, fe_vals, bar_height,
         color='#1565C0', alpha=0.85, label='固定効果モデル（FE）')
ax3.barh(y_pos - bar_height/2, re_vals, bar_height,
         color='#E65100', alpha=0.75, label='変量効果モデル（RE）')

ax3.set_yticks(y_pos)
ax3.set_yticklabels([short_labels[c] for c in X_cols], fontsize=11)
ax3.axvline(0, color='black', linewidth=0.8)
ax3.set_xlabel('偏回帰係数', fontsize=11)
title_str = f'固定効果 vs 変量効果モデルの係数比較\n{hausman_result}'
ax3.set_title(title_str, fontsize=11, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(axis='x', linestyle='--', alpha=0.4)

fig3.tight_layout()
out3 = os.path.join(FIG_DIR, '2018_U2_fig3.png')
fig3.savefig(out3, bbox_inches='tight')
plt.close(fig3)
print(f"Saved: {out3}")

# =========================================================================
# Figure 4: TFRランキング棒グラフ（47都道府県、地域色分け、全国平均線）
# =========================================================================
df_rank = df_latest[['都道府県', '地域', TFR_col]].dropna().sort_values(TFR_col, ascending=True).copy()
national_avg = df_rank[TFR_col].mean()

colors_bar = [region_colors.get(r, '#888') for r in df_rank['地域']]

fig4, ax4 = plt.subplots(figsize=(10, 12))
ax4.barh(range(len(df_rank)), df_rank[TFR_col].values,
         color=colors_bar, alpha=0.85, edgecolor='white', linewidth=0.5)

short_names = [p.replace('県', '').replace('府', '').replace('都', '').replace('道', '')
               for p in df_rank['都道府県']]
ax4.set_yticks(range(len(df_rank)))
ax4.set_yticklabels(short_names, fontsize=9)
ax4.axvline(national_avg, color='red', linewidth=1.8, linestyle='--',
            label=f'全国平均 {national_avg:.3f}')
ax4.set_xlabel('合計特殊出生率', fontsize=11)
ax4.set_title(f'都道府県別 合計特殊出生率ランキング（{latest_year}年）', fontsize=13, fontweight='bold')

# 地域色凡例
handles_r = [mpatches.Patch(color=c, label=r) for r, c in region_colors.items()]
handles_r.append(plt.Line2D([0], [0], color='red', linestyle='--',
                 label=f'全国平均 {national_avg:.3f}'))
ax4.legend(handles=handles_r, fontsize=8, loc='lower right')

ax4.grid(axis='x', linestyle='--', alpha=0.4)
fig4.tight_layout()
out4 = os.path.join(FIG_DIR, '2018_U2_fig4.png')
fig4.savefig(out4, bbox_inches='tight')
plt.close(fig4)
print(f"Saved: {out4}")

print("\nDONE: 2018_U2_yushu")
