"""
教育用再現コード: 2022年 統計データ分析コンペティション 審査員奨励賞 [大学生・一般の部]
=================================================================
論文タイトル：出生率の地域差：男性育児参加と女性社会進出のパネル分析
受賞：審査員奨励賞（大学生・一般の部）

【分析概要】
  データ：SSDSE-B-2026.csv（都道府県別パネルデータ, 2012〜2023年度）
  対象：全47都道府県 × 12年（2012〜2023）
  目的変数：合計特殊出生率

  分析の流れ
  ・時系列：都道府県別の合計特殊出生率推移（地域グループ別）
  ・相関ヒートマップ：説明変数間の相関構造
  ・パネルOLS（固定効果モデル）：出生率の決定要因推定
  ・散布図：女性就業率代理 vs 合計特殊出生率

【被説明変数】
  合計特殊出生率

【説明変数】
  女性比率（15〜64歳）= 15〜64歳人口（女）/ 15〜64歳人口 × 100
  保育所密度 = 保育所等数 / 小学校児童数 × 1000
  保健医療費（二人以上の世帯）
  教育費（二人以上の世帯）
  消費支出（二人以上の世帯）（log）
  高齢化率 = 65歳以上人口 / 総人口 × 100

【推定手法】
  linearmodels PanelOLS: entity_effects=True, cov_type='clustered'

【データ出典】
  SSDSE-B-2026.csv: 社会・人口統計体系（都道府県別データ）

【データサイエンス学習ポイント】
  1. 合計特殊出生率の地域差と時系列変動の把握
  2. 女性社会進出の代理変数（就業参加率）の設計
  3. 固定効果パネルモデルによる地域固定要因の除去
  4. 保育所密度が出生率に与える影響の推定
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2022_U5_9_shorei.py
# ============================================================


import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import warnings
warnings.filterwarnings('ignore')

from linearmodels.panel import PanelOLS, RandomEffects
import statsmodels.api as sm
from scipy import stats

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

import os
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ================================================================
# ■ 実データ読み込み
# ================================================================
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}$', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)
df_b = df_b.sort_values(['都道府県', '年度']).reset_index(drop=True)

# ================================================================
# ■ 変数生成
# ================================================================
df_b['TFR'] = df_b['合計特殊出生率']
df_b['女性就業率_代理'] = df_b['15～64歳人口（女）'] / df_b['15～64歳人口'] * 100
df_b['保育所密度'] = df_b['保育所等数'] / df_b['小学校児童数'].replace(0, np.nan) * 1000
df_b['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口'] * 100
df_b['消費支出_log'] = np.log(df_b['消費支出（二人以上の世帯）'].clip(lower=1))
df_b['保健医療費'] = df_b['保健医療費（二人以上の世帯）'] / 1000  # 千円
df_b['教育費'] = df_b['教育費（二人以上の世帯）'] / 1000

# 地域区分
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北',
    '宮城県': '北海道・東北', '秋田県': '北海道・東北', '山形県': '北海道・東北',
    '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国', '徳島県': '中国・四国',
    '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄',
}
df_b['地域'] = df_b['都道府県'].map(region_map)

region_colors = {
    '北海道・東北': '#4e9af1', '関東': '#e05c5c', '中部': '#f0a500',
    '近畿': '#5cb85c', '中国・四国': '#9b59b6', '九州・沖縄': '#f39c12',
}

# ================================================================
# ■ Fig1: TFR時系列（地域別平均）
# ================================================================
fig, ax = plt.subplots(figsize=(10, 5))
yearly = df_b.groupby(['年度', '地域'])['TFR'].mean().reset_index()
for reg, grp in yearly.groupby('地域'):
    ax.plot(grp['年度'], grp['TFR'], marker='o', markersize=4,
            label=reg, color=region_colors.get(reg, 'gray'))
ax.set_xlabel('年度', fontsize=12)
ax.set_ylabel('合計特殊出生率', fontsize=12)
ax.set_title('地域別 合計特殊出生率の推移（2012〜2023年）', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='upper right')
ax.grid(alpha=0.3)
ax.axhline(2.07, color='red', linestyle='--', alpha=0.5, label='人口置換水準 2.07')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_9_fig1_tfr_ts.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig1 saved")

# ================================================================
# ■ Fig2: 相関ヒートマップ
# ================================================================
analysis_vars = ['TFR', '女性就業率_代理', '保育所密度', '高齢化率', '消費支出_log', '保健医療費', '教育費']
df_2022 = df_b[df_b['年度'] == 2022][analysis_vars].dropna()
corr = df_2022.corr()

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_xticks(range(len(analysis_vars)))
ax.set_yticks(range(len(analysis_vars)))
ax.set_xticklabels(analysis_vars, rotation=45, ha='right', fontsize=9)
ax.set_yticklabels(analysis_vars, fontsize=9)
for i in range(len(analysis_vars)):
    for j in range(len(analysis_vars)):
        ax.text(j, i, f'{corr.iloc[i, j]:.2f}', ha='center', va='center',
                fontsize=8, color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black')
plt.colorbar(im, ax=ax, label='相関係数')
ax.set_title('分析変数間の相関ヒートマップ（2022年, N=47）', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_9_fig2_corr.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig2 saved")

# ================================================================
# ■ Fig3: パネルOLS固定効果推定 — 係数プロット
# ================================================================
panel_vars = ['女性就業率_代理', '保育所密度', '高齢化率', '消費支出_log', '保健医療費']
df_panel = df_b[['年度', '都道府県', 'TFR'] + panel_vars].dropna()
df_panel = df_panel.set_index(['都道府県', '年度'])

y = df_panel['TFR']
X = sm.add_constant(df_panel[panel_vars])

try:
    mod = PanelOLS(y, X, entity_effects=True, drop_absorbed=True)
    res = mod.fit(cov_type='clustered', cluster_entity=True)
    coefs = res.params.drop('const', errors='ignore')
    ses = res.std_errors.drop('const', errors='ignore')
    pvals = res.pvalues.drop('const', errors='ignore')

    fig, ax = plt.subplots(figsize=(8, 5))
    y_pos = range(len(coefs))
    colors = ['#e05c5c' if p < 0.05 else '#888888' for p in pvals]
    ax.barh(y_pos, coefs, xerr=1.96 * ses, color=colors, alpha=0.8,
            error_kw={'elinewidth': 1.5, 'capsize': 4})
    ax.set_yticks(y_pos)
    ax.set_yticklabels(coefs.index, fontsize=10)
    ax.axvline(0, color='black', linewidth=0.8)
    ax.set_xlabel('係数（固定効果推定）', fontsize=12)
    ax.set_title('合計特殊出生率の決定要因 — FEパネルOLS係数\n（赤=p<0.05, 灰=非有意）', fontsize=12, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, '2022_U5_9_fig3_fe_coef.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print("Fig3 saved")
    print(res.summary.tables[1])
except Exception as e:
    print(f"FE model error: {e}")
    # Fallback: OLS
    df_2022_f = df_b[df_b['年度'] == 2022][['TFR'] + panel_vars].dropna()
    X_ols = sm.add_constant(df_2022_f[panel_vars])
    res_ols = sm.OLS(df_2022_f['TFR'], X_ols).fit()
    coefs = res_ols.params.drop('const')
    ses = res_ols.bse.drop('const')
    pvals = res_ols.pvalues.drop('const')
    fig, ax = plt.subplots(figsize=(8, 5))
    y_pos = range(len(coefs))
    colors = ['#e05c5c' if p < 0.05 else '#888888' for p in pvals]
    ax.barh(y_pos, coefs, xerr=1.96 * ses, color=colors, alpha=0.8,
            error_kw={'elinewidth': 1.5, 'capsize': 4})
    ax.set_yticks(y_pos)
    ax.set_yticklabels(coefs.index, fontsize=10)
    ax.axvline(0, color='black', linewidth=0.8)
    ax.set_xlabel('係数（OLS推定）', fontsize=12)
    ax.set_title('合計特殊出生率の決定要因 — OLS係数\n（赤=p<0.05）', fontsize=12, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, '2022_U5_9_fig3_fe_coef.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print("Fig3 (OLS fallback) saved")

# ================================================================
# ■ Fig4: 散布図 女性就業率代理 vs TFR（2022年）
# ================================================================
df_2022_sc = df_b[df_b['年度'] == 2022][['都道府県', 'TFR', '女性就業率_代理', '地域', '保育所密度']].dropna()

fig, ax = plt.subplots(figsize=(9, 6))
for reg, grp in df_2022_sc.groupby('地域'):
    ax.scatter(grp['女性就業率_代理'], grp['TFR'],
               color=region_colors.get(reg, 'gray'), label=reg, s=60, alpha=0.8)
for _, row in df_2022_sc.iterrows():
    ax.annotate(row['都道府県'][:2], (row['女性就業率_代理'], row['TFR']),
                fontsize=6, alpha=0.6,
                xytext=(2, 2), textcoords='offset points')
# 回帰直線
x_vals = df_2022_sc['女性就業率_代理']
y_vals = df_2022_sc['TFR']
slope, intercept, r, p, _ = stats.linregress(x_vals, y_vals)
xr = np.linspace(x_vals.min(), x_vals.max(), 100)
ax.plot(xr, slope * xr + intercept, 'k--', linewidth=1.5,
        label=f'回帰直線 (r={r:.2f}, p={p:.3f})')
ax.set_xlabel('女性就業率代理（15〜64歳女性比率, %）', fontsize=12)
ax.set_ylabel('合計特殊出生率', fontsize=12)
ax.set_title('女性就業率代理 vs 合計特殊出生率（2022年, N=47）', fontsize=13, fontweight='bold')
ax.legend(fontsize=8, loc='upper left', ncol=2)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_9_fig4_scatter.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig4 saved")
print("All figures saved successfully!")
