"""
教育用再現コード: 2024年 統計データ分析コンペティション 審査員奨励賞（高校生）
=================================================================
論文タイトル：少子化進行抑止のための家庭・社会要因の探究

【分析概要】
  データ：SSDSE-B-2026（2022年度）, SSDSE-E-2026
  目的変数：合計特殊出生率（TFR）
  説明変数：婚姻率, 保育所充実度, 高齢化率, 大学進学率, 1人当たり県民所得, 年平均気温

  Step1. 相関ヒートマップ
  Step2. 主要変数 vs TFR の散布図
  Step3. 47都道府県 重回帰係数
  Step4. 東京都あり/なし比較（R²・係数）

【データサイエンス学習ポイント】
  1. 外れ値（東京都）が回帰結果に与える影響の検討
  2. 相関行列による多重共線性の事前確認
  3. 標準化係数による変数重要度の比較
  4. TFR（合計特殊出生率）という政策的目的変数の理解

【データ】実公的データ（SSDSE-B-2026, SSDSE-E-2026）を使用
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_H5_1_shorei.py
# ============================================================


import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy import stats
import os
import warnings
warnings.filterwarnings('ignore')

# ──────────────────────────────────────────────────────────────
# 共通設定
# ──────────────────────────────────────────────────────────────
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIG_DIR = 'html/figures'
os.makedirs(FIG_DIR, exist_ok=True)

# ================================================================
# ■ Step 0. データ読み込み（実公的データ）
# ================================================================
print("=" * 65)
print("■ データ読み込み（SSDSE-B-2026 / SSDSE-E-2026）")
print("=" * 65)

DATA_DIR = 'data/raw'

# SSDSE-B 2022年度 都道府県データ
df_b_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'), encoding='cp932', header=1)
mask_b = df_b_raw['地域コード'].str.match(r'^R\d{5}$', na=False) & (df_b_raw['年度'] == 2022)
df_b = df_b_raw[mask_b].copy().reset_index(drop=True)
print(f"SSDSE-B 2022: {len(df_b)}都道府県")

# SSDSE-E 都道府県断面データ
df_e_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-E-2026.csv'), encoding='cp932', header=0)
df_e = df_e_raw.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[0].values
df_e2 = df_e.iloc[1:].copy()
df_e2.columns = df_e.iloc[0].values
df_e2 = df_e2[df_e2['都道府県'] != '全国'].reset_index(drop=True)
print(f"SSDSE-E: {len(df_e2)}都道府県")

# SSDSE-E から 1人当たり県民所得 を数値変換
df_e2['1人当たり県民所得'] = pd.to_numeric(df_e2['1人当たり県民所得（平成27年基準）'], errors='coerce')

# SSDSE-B から数値変換
for col in ['合計特殊出生率', '婚姻件数', '保育所等数', '65歳以上人口',
            '15～64歳人口', '総人口', '年平均気温',
            '高等学校卒業者数', '高等学校卒業者のうち進学者数']:
    df_b[col] = pd.to_numeric(df_b[col], errors='coerce')

# ── 変数構築 ──
df_b['婚姻率'] = df_b['婚姻件数'] / df_b['15～64歳人口'] * 1000
df_b['保育所充実度'] = df_b['保育所等数'] / df_b['総人口'] * 1000
df_b['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口'] * 100
df_b['大学進学率'] = df_b['高等学校卒業者のうち進学者数'] / df_b['高等学校卒業者数'] * 100

# SSDSE-B + SSDSE-E をマージ
df = df_b[['都道府県', '合計特殊出生率', '婚姻率', '保育所充実度',
           '高齢化率', '大学進学率', '年平均気温']].merge(
    df_e2[['都道府県', '1人当たり県民所得']], on='都道府県', how='inner')

# 全変数を数値型に
for col in df.columns[1:]:
    df[col] = pd.to_numeric(df[col], errors='coerce')

df = df.dropna().reset_index(drop=True)
print(f"分析対象: {len(df)}都道府県（欠損除外後）")

EXPLAIN_VARS = ['婚姻率', '保育所充実度', '高齢化率', '大学進学率', '1人当たり県民所得', '年平均気温']
TARGET = '合計特殊出生率'

print(f"\n基本統計:")
print(df[[TARGET] + EXPLAIN_VARS].describe().round(3))

# 標準化
df_std = df[EXPLAIN_VARS].copy()
for v in EXPLAIN_VARS:
    mu, sg = df_std[v].mean(), df_std[v].std()
    df_std[v + '_z'] = (df_std[v] - mu) / sg
Z_VARS = [v + '_z' for v in EXPLAIN_VARS]

# 重回帰：47都道府県
X_all = sm.add_constant(df_std[Z_VARS])
y = df[TARGET]
reg_all = sm.OLS(y, X_all).fit(cov_type='HC1')
print(f"\n重回帰（全{len(df)}都道府県）: R²={reg_all.rsquared:.3f}, adj.R²={reg_all.rsquared_adj:.3f}")

# 重回帰：東京都除く
df_notokyo = df[df['都道府県'] != '東京都'].copy().reset_index(drop=True)
df_std_nt = df_notokyo[EXPLAIN_VARS].copy()
for v in EXPLAIN_VARS:
    mu, sg = df_std_nt[v].mean(), df_std_nt[v].std()
    df_std_nt[v + '_z'] = (df_std_nt[v] - mu) / sg
X_nt = sm.add_constant(df_std_nt[Z_VARS])
y_nt = df_notokyo[TARGET]
reg_nt = sm.OLS(y_nt, X_nt).fit(cov_type='HC1')
print(f"重回帰（東京除く{len(df_notokyo)}都道府県）: R²={reg_nt.rsquared:.3f}, adj.R²={reg_nt.rsquared_adj:.3f}")

# ================================================================
# ■ 図1: 相関ヒートマップ
# ================================================================
print("\n図1: 相関ヒートマップを作成中...")

fig1, ax1 = plt.subplots(figsize=(9, 7))
fig1.suptitle('合計特殊出生率と説明変数の相関行列\n（2022年度 都道府県データ）',
              fontsize=12, fontweight='bold')

vars_for_corr = [TARGET] + EXPLAIN_VARS
labels_corr = ['TFR', '婚姻率', '保育所\n充実度', '高齢化率', '大学\n進学率', '県民\n所得', '年平均\n気温']
corr_mat = df[vars_for_corr].corr()

im = ax1.imshow(corr_mat.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
plt.colorbar(im, ax=ax1, label='Pearson 相関係数', shrink=0.8)
ax1.set_xticks(range(len(vars_for_corr)))
ax1.set_yticks(range(len(vars_for_corr)))
ax1.set_xticklabels(labels_corr, fontsize=9)
ax1.set_yticklabels(labels_corr, fontsize=9)
for i in range(len(vars_for_corr)):
    for j in range(len(vars_for_corr)):
        val = corr_mat.values[i, j]
        col = 'white' if abs(val) > 0.6 else 'black'
        ax1.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=9,
                 color=col, fontweight='bold' if abs(val) > 0.5 else 'normal')

ax1.set_title('相関係数ヒートマップ\n（赤：正の相関, 青：負の相関）', fontsize=10)
plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2024_H5_1_fig1_corr.png'), bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  → 2024_H5_1_fig1_corr.png 保存完了")

# ================================================================
# ■ 図2: 主要変数 vs TFR 散布図
# ================================================================
print("図2: 主要変数 vs TFR 散布図を作成中...")

fig2, axes2 = plt.subplots(2, 3, figsize=(15, 9))
fig2.suptitle('合計特殊出生率（TFR）と各説明変数の散布図\n（2022年度 47都道府県）',
              fontsize=13, fontweight='bold')

plot_vars = EXPLAIN_VARS
plot_xlabels = ['婚姻率 (婚姻件数/15-64歳人口×1000)', '保育所充実度 (保育所等数/人口×1000)',
                '高齢化率 (%)', '大学進学率 (%)', '1人当たり県民所得 (万円)',
                '年平均気温 (℃)']
axes2_flat = axes2.flatten()

for i, (var, xlabel) in enumerate(zip(plot_vars, plot_xlabels)):
    ax = axes2_flat[i]
    x_vals = df[var].values
    y_vals = df[TARGET].values

    # 都道府県ラベル色
    colors = []
    for pref in df['都道府県']:
        if pref == '東京都':
            colors.append('#E53935')
        elif pref == '沖縄県':
            colors.append('#43A047')
        else:
            colors.append('#1565C0')

    ax.scatter(x_vals, y_vals, c=colors, s=60, alpha=0.8, edgecolors='white', linewidth=0.5)

    # 回帰直線
    slope, intercept, r_val, p_val, _ = stats.linregress(x_vals, y_vals)
    x_line = np.linspace(x_vals.min(), x_vals.max(), 100)
    ax.plot(x_line, intercept + slope * x_line, '--', color='#FF6F00', linewidth=1.8)

    # 注目県ラベル
    for _, row in df[df['都道府県'].isin(['東京都', '沖縄県', '秋田県', '鳥取県'])].iterrows():
        short = row['都道府県'].replace('県', '').replace('府', '').replace('都', '').replace('道', '')
        ax.annotate(short, (row[var], row[TARGET]), fontsize=7.5, fontweight='bold',
                    xytext=(3, 3), textcoords='offset points')

    sig_str = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'n.s.'
    ax.set_xlabel(xlabel, fontsize=8)
    ax.set_ylabel('TFR', fontsize=9)
    ax.set_title(f'r = {r_val:.3f} ({sig_str})', fontsize=9, fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2024_H5_1_fig2_scatter.png'), bbox_inches='tight', dpi=150)
plt.close(fig2)
print("  → 2024_H5_1_fig2_scatter.png 保存完了")

# ================================================================
# ■ 図3: 47都道府県 重回帰係数
# ================================================================
print("図3: 重回帰係数を作成中...")

fig3, axes3 = plt.subplots(1, 2, figsize=(14, 6))
fig3.suptitle(f'TFR重回帰分析：{len(df)}都道府県\n(R²={reg_all.rsquared:.3f}, adj.R²={reg_all.rsquared_adj:.3f})',
              fontsize=12, fontweight='bold')

ax3a = axes3[0]
coefs = [reg_all.params.get(zv, 0) for zv in Z_VARS]
ses = [reg_all.bse.get(zv, 0) for zv in Z_VARS]
pvals = [reg_all.pvalues.get(zv, 1) for zv in Z_VARS]
bar_colors = ['#E53935' if c < 0 else '#1565C0' for c in coefs]
sorted_idx = np.argsort(coefs)

ax3a.barh(range(len(Z_VARS)),
          [coefs[i] for i in sorted_idx],
          xerr=[1.96 * ses[i] for i in sorted_idx],
          color=[bar_colors[i] for i in sorted_idx],
          alpha=0.85, edgecolor='white', capsize=4,
          error_kw={'elinewidth': 1.5, 'ecolor': '#555'})
ax3a.set_yticks(range(len(Z_VARS)))
ax3a.set_yticklabels([EXPLAIN_VARS[i] for i in sorted_idx], fontsize=10)
ax3a.axvline(0, color='black', linewidth=1.0)
ax3a.set_xlabel('標準化回帰係数 (±95%CI)', fontsize=10)
ax3a.set_title(f'TFR の重回帰係数\n({len(df)}都道府県, HC1標準誤差)', fontsize=10, fontweight='bold')
ax3a.grid(axis='x', alpha=0.3)

for i, idx in enumerate(sorted_idx):
    p = pvals[idx]
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else ''
    if sig:
        c = coefs[idx]
        se = ses[idx]
        ax3a.text(c + np.sign(c) * 1.96 * se + np.sign(c) * 0.01, i,
                  sig, va='center', ha='left' if c > 0 else 'right',
                  fontsize=10, fontweight='bold')

ax3b = axes3[1]
y_pred = reg_all.fittedvalues
residuals = y - y_pred
ax3b.scatter(y_pred, y, alpha=0.75, c='#1565C0', s=65,
             edgecolors='white', linewidth=0.5, zorder=3)
ax3b.plot([y.min(), y.max()], [y.min(), y.max()], '--', color='#E53935', linewidth=1.8, label='完全一致線')

# 東京都ハイライト
tokyo_row = df[df['都道府県'] == '東京都']
if len(tokyo_row) > 0:
    ti = tokyo_row.index[0]
    ax3b.scatter([y_pred.iloc[ti]], [y.iloc[ti]], s=180, c='#E53935', marker='*',
                 zorder=5, label='東京都')
    ax3b.annotate('東京都', (y_pred.iloc[ti], y.iloc[ti]),
                  fontsize=9, fontweight='bold', color='#E53935',
                  xytext=(6, 4), textcoords='offset points')

ax3b.set_xlabel('TFR 予測値', fontsize=11)
ax3b.set_ylabel('TFR 実測値', fontsize=11)
ax3b.set_title('予測値 vs 実測値\n（赤星：東京都）', fontsize=10, fontweight='bold')
ax3b.legend(fontsize=9)
ax3b.grid(True, alpha=0.3)

plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2024_H5_1_fig3_coef.png'), bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  → 2024_H5_1_fig3_coef.png 保存完了")

# ================================================================
# ■ 図4: 東京都あり/なし比較
# ================================================================
print("図4: 東京都あり/なし比較を作成中...")

fig4, axes4 = plt.subplots(1, 2, figsize=(14, 6))
fig4.suptitle('東京都 包含/除外による重回帰係数の変化（TFR モデル）',
              fontsize=12, fontweight='bold')

ax4a = axes4[0]
coefs_all = [reg_all.params.get(zv, 0) for zv in Z_VARS]
coefs_nt = [reg_nt.params.get(zv, 0) for zv in Z_VARS]
x_pos = np.arange(len(Z_VARS))
width = 0.35

ax4a.bar(x_pos - width / 2, coefs_all, width,
         label=f'47都道府県（東京含む, R²={reg_all.rsquared:.3f}）',
         color='#1565C0', alpha=0.82, edgecolor='white')
ax4a.bar(x_pos + width / 2, coefs_nt, width,
         label=f'46都道府県（東京除く, R²={reg_nt.rsquared:.3f}）',
         color='#43A047', alpha=0.82, edgecolor='white')
ax4a.set_xticks(x_pos)
ax4a.set_xticklabels(EXPLAIN_VARS, fontsize=8.5, rotation=30, ha='right')
ax4a.axhline(0, color='black', linewidth=1.0)
ax4a.set_ylabel('標準化回帰係数', fontsize=11)
ax4a.set_title('東京都の包含/除外による係数の変化', fontsize=10, fontweight='bold')
ax4a.legend(fontsize=8.5, loc='upper left')
ax4a.grid(axis='y', alpha=0.3)

ax4b = axes4[1]
# 婚姻率 vs TFR の散布図（東京ハイライト）
x_marr = df['婚姻率'].values
y_tfr = df[TARGET].values
is_tokyo = df['都道府県'] == '東京都'
is_okinawa = df['都道府県'] == '沖縄県'

ax4b.scatter(x_marr[~is_tokyo & ~is_okinawa], y_tfr[~is_tokyo & ~is_okinawa],
             c='#1565C0', s=65, alpha=0.8, edgecolors='white', linewidth=0.5, label='その他都道府県')
ax4b.scatter(x_marr[is_okinawa], y_tfr[is_okinawa],
             c='#43A047', s=120, alpha=0.9, edgecolors='white', linewidth=0.5, label='沖縄県', zorder=4)
ax4b.scatter(x_marr[is_tokyo], y_tfr[is_tokyo],
             c='#E53935', s=200, marker='*', alpha=0.95, label='東京都（外れ値）', zorder=5)

# 回帰直線（東京除く）
x_nt_m = x_marr[~is_tokyo]
y_nt_t = y_tfr[~is_tokyo]
slope_nt, intercept_nt, r_nt, p_nt, _ = stats.linregress(x_nt_m, y_nt_t)
x_line = np.linspace(x_marr.min(), x_marr.max(), 100)
ax4b.plot(x_line, intercept_nt + slope_nt * x_line, '--', color='#E65100',
          linewidth=2, label=f'回帰直線（東京除く, r={r_nt:.3f}）')

for _, row in df[df['都道府県'].isin(['東京都', '沖縄県', '秋田県'])].iterrows():
    short = row['都道府県'].replace('県', '').replace('都', '').replace('府', '').replace('道', '')
    ax4b.annotate(short, (row['婚姻率'], row[TARGET]),
                  fontsize=9, fontweight='bold', xytext=(5, 4), textcoords='offset points')

ax4b.set_xlabel('婚姻率 (婚姻件数/15-64歳人口×1000)', fontsize=10)
ax4b.set_ylabel('合計特殊出生率 (TFR)', fontsize=11)
ax4b.set_title('婚姻率 vs TFR（東京都の位置）\n（2022年度 実データ）', fontsize=10, fontweight='bold')
ax4b.legend(fontsize=8.5)
ax4b.grid(True, alpha=0.3)

plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2024_H5_1_fig4_tokyo.png'), bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  → 2024_H5_1_fig4_tokyo.png 保存完了")

print("\n" + "=" * 65)
print("✓ 全図の生成完了（4枚）")
print("=" * 65)
print(f"\n保存先: {FIG_DIR}")
print("  2024_H5_1_fig1_corr.png      - 相関ヒートマップ")
print("  2024_H5_1_fig2_scatter.png   - 主要変数 vs TFR 散布図")
print("  2024_H5_1_fig3_coef.png      - 重回帰係数（全都道府県）")
print("  2024_H5_1_fig4_tokyo.png     - 東京あり/なし比較")
