"""
教育用再現コード: 2020年度 統計数理賞（高校生の部）
=================================================================
論文タイトル：地域の所得格差・教育環境と学力・進学率の関係
受賞：統計数理賞（高校生の部）

【分析概要】
  データ：SSDSE-B-2026（47都道府県・2022年度断面）
  目的変数：大学進学率（高等学校卒業者のうち進学者数 / 高等学校卒業者数 × 100）
  説明変数：
    - 教育費（消費支出における教育費：家庭の教育投資）
    - 学校数/万人（小中高合計 / 総人口 × 10000：教育施設密度）
    - 消費支出総額（所得水準の代理）
    - 総人口の対数（都市化度の代理）
    - 高齢化率（65歳以上人口 / 総人口 × 100）
    - 合計特殊出生率

【分析手法】
  1. 散布図（教育費 vs 大学進学率）：地域色分け・回帰直線
  2. 都道府県別大学進学率ランキング（上位・下位各10県）
  3. 相関ヒートマップ（Pearson相関）
  4. 重回帰分析（OLS）・標準化偏回帰係数プロット

【学習ポイント 4項目】
  1. 散布図と回帰直線：2変数の線形関係の可視化
  2. 重回帰分析（OLS）：statsmodelsによる推定
  3. 標準化偏回帰係数：各説明変数の相対的な影響力の比較
  4. 相関ヒートマップ：多変数間の関係を俯瞰する

【使用データ】
  SSDSE-B-2026.csv（SSDSE: 社会・人口統計体系データセット）
  ※ 合成データ・乱数生成（np.random.seed等）は一切使用しない。
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2020_H3_suri.py
# ============================================================


import os
import numpy as np
import pandas as pd

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import statsmodels.api as sm
from scipy import stats

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

# ─── パス設定 ────────────────────────────────────────────────────────────────
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ─── データ読み込み（47都道府県のみ） ────────────────────────────────────────
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)

# ─── 地域マップ ──────────────────────────────────────────────────────────────
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北',
    '宮城県': '北海道・東北', '秋田県': '北海道・東北', '山形県': '北海道・東北',
    '福島県': '北海道・東北', '茨城県': '関東', '栃木県': '関東', '群馬県': '関東',
    '埼玉県': '関東', '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国', '徳島県': '中国・四国',
    '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄'
}
region_colors = {
    '北海道・東北': '#4e9af1',
    '関東':         '#e05c5c',
    '中部':         '#f0a500',
    '近畿':         '#5cb85c',
    '中国・四国':   '#9b59b6',
    '九州・沖縄':   '#f39c12'
}
region_order = ['北海道・東北', '関東', '中部', '近畿', '中国・四国', '九州・沖縄']

# ─── 2022年度断面データ抽出と変数作成 ────────────────────────────────────────
df = df_b[df_b['年度'] == 2022].copy()

# 目的変数
df['大学進学率'] = df['高等学校卒業者のうち進学者数'] / df['高等学校卒業者数'] * 100

# 説明変数
df['教育費']         = df['教育費（二人以上の世帯）']           # 円/月
df['消費支出総額']   = df['消費支出（二人以上の世帯）']          # 円/月
df['高齢化率']       = df['65歳以上人口'] / df['総人口'] * 100   # %
df['合計特殊出生率'] = df['合計特殊出生率']
df['学校数合計']     = df['小学校数'] + df['中学校数'] + df['高等学校数']
df['学校数_万人']    = df['学校数合計'] / df['総人口'] * 10000   # 校/万人
df['人口対数']       = np.log(df['総人口'])                       # 都市化の代理

# 地域ブロック付与
df['地域'] = df['都道府県'].map(region_map)

# 分析用列の選択（欠損なし確認）
analysis_cols = ['都道府県', '地域', '大学進学率', '教育費', '学校数_万人',
                 '消費支出総額', '人口対数', '高齢化率', '合計特殊出生率']
df_ana = df[analysis_cols].dropna().reset_index(drop=True)
print(f"分析サンプル数: {len(df_ana)} 都道府県 (2022年度)")

# ─── 重回帰分析（OLS） ───────────────────────────────────────────────────────
y_col = '大学進学率'
x_cols = ['教育費', '学校数_万人', '消費支出総額', '人口対数', '高齢化率', '合計特殊出生率']

Y = df_ana[y_col]
X = df_ana[x_cols]
X_sm = sm.add_constant(X)

model = sm.OLS(Y, X_sm).fit()
print("\n=== OLS 重回帰分析結果 ===")
print(model.summary())

# 標準化偏回帰係数
X_std = (X - X.mean()) / X.std()
Y_std = (Y - Y.mean()) / Y.std()
X_std_sm = sm.add_constant(X_std)
model_std = sm.OLS(Y_std, X_std_sm).fit()
beta_std = model_std.params.drop('const')
pval_std = model_std.pvalues.drop('const')

print("\n=== 標準化偏回帰係数 ===")
for v in x_cols:
    sig = '***' if pval_std[v] < 0.001 else ('**' if pval_std[v] < 0.01 else ('*' if pval_std[v] < 0.05 else ''))
    print(f"  {v:12s}: β = {beta_std[v]:+.4f}  p = {pval_std[v]:.4f} {sig}")

print(f"\nR² = {model.rsquared:.4f},  自由度調整済みR² = {model.rsquared_adj:.4f}")

# 地域別平均進学率
region_mean = df_ana.groupby('地域')['大学進学率'].agg(['mean', 'std', 'count']).reset_index()
region_mean.columns = ['地域', '平均', '標準偏差', 'n']
print("\n=== 地域別大学進学率 ===")
print(region_mean.to_string(index=False))

# ─── 相関係数行列 ─────────────────────────────────────────────────────────────
corr_cols = ['大学進学率', '教育費', '学校数_万人', '消費支出総額', '人口対数', '高齢化率', '合計特殊出生率']
corr_labels = ['大学\n進学率', '教育費', '学校数\n/万人', '消費\n支出', '人口\n対数', '高齢化\n率', '合計\n出生率']
corr_mat = df_ana[corr_cols].corr()
print("\n=== 相関行列 ===")
print(corr_mat.round(3).to_string())

# ─── 図1: 散布図（教育費 vs 大学進学率） ─────────────────────────────────────
fig1, ax1 = plt.subplots(figsize=(10, 7))

for region in region_order:
    sub = df_ana[df_ana['地域'] == region]
    ax1.scatter(sub['教育費'], sub['大学進学率'],
                color=region_colors[region], label=region,
                s=70, alpha=0.85, zorder=3)
    for _, row in sub.iterrows():
        pref_short = row['都道府県'].replace('県', '').replace('府', '').replace('都', '').replace('道', '')
        ax1.annotate(pref_short,
                     xy=(row['教育費'], row['大学進学率']),
                     xytext=(3, 3), textcoords='offset points',
                     fontsize=7.5, color='#333333', alpha=0.9)

# 回帰直線
slope, intercept, r_val, p_val, se = stats.linregress(df_ana['教育費'], df_ana['大学進学率'])
x_line = np.linspace(df_ana['教育費'].min(), df_ana['教育費'].max(), 200)
ax1.plot(x_line, intercept + slope * x_line, color='#333333', linewidth=1.8,
         linestyle='--', alpha=0.8, label=f'回帰直線 (r={r_val:.2f}, p={p_val:.3f})')

ax1.set_xlabel('教育費（消費支出内, 円/月）', fontsize=13)
ax1.set_ylabel('大学進学率（%）', fontsize=13)
ax1.set_title('図1  教育費 vs 大学進学率（47都道府県, 2022年度）', fontsize=14, fontweight='bold')
ax1.legend(loc='lower right', fontsize=9, framealpha=0.8)
ax1.grid(True, alpha=0.3)
ax1.tick_params(labelsize=10)
fig1.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2020_H3_fig1.png'), bbox_inches='tight')
plt.close(fig1)
print("fig1 saved.")

# ─── 図2: 都道府県別大学進学率ランキング（上位10 + 下位10） ─────────────────
df_rank = df_ana.sort_values('大学進学率', ascending=False).reset_index(drop=True)
top10    = df_rank.head(10)
bottom10 = df_rank.tail(10).sort_values('大学進学率', ascending=True)

fig2, axes = plt.subplots(1, 2, figsize=(13, 5.5))

for ax, sub_df, title_str in [
    (axes[0], top10,    '上位10都道府県'),
    (axes[1], bottom10, '下位10都道府県')
]:
    colors = [region_colors[r] for r in sub_df['地域']]
    bars = ax.barh(sub_df['都道府県'], sub_df['大学進学率'],
                   color=colors, edgecolor='white', height=0.65)
    for bar, val in zip(bars, sub_df['大学進学率']):
        ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2,
                f'{val:.1f}%', va='center', fontsize=9)
    ax.set_xlabel('大学進学率（%）', fontsize=11)
    ax.set_title(title_str, fontsize=13, fontweight='bold')
    ax.set_xlim(0, df_ana['大学進学率'].max() * 1.12)
    ax.grid(axis='x', alpha=0.3)
    ax.tick_params(labelsize=10)

# 凡例
patches = [mpatches.Patch(color=region_colors[r], label=r) for r in region_order]
fig2.legend(handles=patches, loc='lower center', ncol=3, fontsize=9,
            bbox_to_anchor=(0.5, -0.04), framealpha=0.85)
fig2.suptitle('図2  都道府県別大学進学率ランキング（2022年度）',
              fontsize=14, fontweight='bold', y=1.01)
fig2.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2020_H3_fig2.png'), bbox_inches='tight')
plt.close(fig2)
print("fig2 saved.")

# ─── 図3: 相関ヒートマップ ───────────────────────────────────────────────────
fig3, ax3 = plt.subplots(figsize=(8.5, 7))
n_vars = len(corr_cols)
im = ax3.imshow(corr_mat.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
ax3.set_xticks(range(n_vars))
ax3.set_yticks(range(n_vars))
ax3.set_xticklabels(corr_labels, fontsize=10)
ax3.set_yticklabels(corr_labels, fontsize=10)

for i in range(n_vars):
    for j in range(n_vars):
        val = corr_mat.values[i, j]
        text_color = 'white' if abs(val) > 0.6 else 'black'
        ax3.text(j, i, f'{val:.2f}', ha='center', va='center',
                 fontsize=9.5, color=text_color, fontweight='bold' if abs(val) >= 0.5 else 'normal')

cbar = fig3.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
cbar.set_label('Pearson相関係数', fontsize=10)
ax3.set_title('図3  教育・所得関連変数のPearson相関ヒートマップ（2022年度）',
              fontsize=13, fontweight='bold', pad=12)
fig3.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2020_H3_fig3.png'), bbox_inches='tight')
plt.close(fig3)
print("fig3 saved.")

# ─── 図4: 標準化偏回帰係数プロット ───────────────────────────────────────────
var_labels = {
    '教育費':       '教育費\n（消費支出内）',
    '学校数_万人':  '学校数\n/万人',
    '消費支出総額': '消費支出\n総額',
    '人口対数':     '人口\n（対数）',
    '高齢化率':     '高齢化率',
    '合計特殊出生率':'合計\n特殊出生率',
}
# 係数の絶対値で降順ソート
beta_sorted = beta_std.reindex(x_cols).sort_values(key=lambda x: x.abs())
labels_sorted = [var_labels[v] for v in beta_sorted.index]
pvals_sorted  = pval_std.reindex(beta_sorted.index)

bar_colors = ['#e05c5c' if b > 0 else '#4e9af1' for b in beta_sorted.values]
edge_colors = []
for p in pvals_sorted.values:
    if p < 0.05:
        edge_colors.append('#cc0000' if True else '#00008b')
    else:
        edge_colors.append('gray')

fig4, ax4 = plt.subplots(figsize=(9, 5.5))
bars4 = ax4.barh(labels_sorted, beta_sorted.values,
                 color=bar_colors, edgecolor='white', height=0.6, alpha=0.88)

# 有意性マーク
for i, (bar, p, b) in enumerate(zip(bars4, pvals_sorted.values, beta_sorted.values)):
    sig_mark = '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else ''))
    offset = 0.01 if b >= 0 else -0.01
    ha_val  = 'left' if b >= 0 else 'right'
    ax4.text(b + offset, bar.get_y() + bar.get_height()/2,
             f'{b:+.3f} {sig_mark}', va='center', ha=ha_val, fontsize=9)

ax4.axvline(0, color='black', linewidth=0.8)
ax4.set_xlabel('標準化偏回帰係数 (β)', fontsize=12)
ax4.set_title('図4  大学進学率に対する標準化偏回帰係数\n（重回帰OLS, 2022年度, n=47）',
              fontsize=13, fontweight='bold')
ax4.set_xlim(beta_sorted.min() * 1.35, beta_sorted.max() * 1.35)
ax4.grid(axis='x', alpha=0.3)
ax4.tick_params(axis='y', labelsize=10)

# 凡例
pos_patch = mpatches.Patch(color='#e05c5c', label='正の効果')
neg_patch = mpatches.Patch(color='#4e9af1', label='負の効果')
ax4.legend(handles=[pos_patch, neg_patch], fontsize=9, loc='lower right')

# R²テキスト
ax4.text(0.98, 0.04, f'R²={model.rsquared:.3f}, Adj.R²={model.rsquared_adj:.3f}',
         transform=ax4.transAxes, ha='right', va='bottom',
         fontsize=9, color='#555555')
ax4.text(0.98, 0.10, '* p<.05  ** p<.01  *** p<.001',
         transform=ax4.transAxes, ha='right', va='bottom',
         fontsize=8.5, color='#555555')

fig4.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2020_H3_fig4.png'), bbox_inches='tight')
plt.close(fig4)
print("fig4 saved.")

# ─── 統計サマリー出力（HTML作成用） ──────────────────────────────────────────
print("\n=== HTML用統計値サマリー ===")
print(f"全国平均大学進学率: {df_ana['大学進学率'].mean():.1f}%")
print(f"最高: {df_rank.iloc[0]['都道府県']} {df_rank.iloc[0]['大学進学率']:.1f}%")
print(f"最低: {df_rank.iloc[-1]['都道府県']} {df_rank.iloc[-1]['大学進学率']:.1f}%")
print(f"標準偏差: {df_ana['大学進学率'].std():.2f}%")
print(f"教育費と進学率の相関: r = {corr_mat.loc['大学進学率','教育費']:.3f}")
print(f"消費支出と進学率の相関: r = {corr_mat.loc['大学進学率','消費支出総額']:.3f}")
print(f"高齢化率と進学率の相関: r = {corr_mat.loc['大学進学率','高齢化率']:.3f}")
print(f"人口対数と進学率の相関: r = {corr_mat.loc['大学進学率','人口対数']:.3f}")
print(f"重回帰: R² = {model.rsquared:.4f}, Adj.R² = {model.rsquared_adj:.4f}")
print(f"F統計量 p値: {model.f_pvalue:.4f}")
print("\n地域別平均:")
for _, row in region_mean.iterrows():
    print(f"  {row['地域']}: {row['平均']:.1f}% (SD={row['標準偏差']:.1f}, n={int(row['n'])})")

# 標準化係数（上位3）
beta_abs_sorted = beta_std.abs().sort_values(ascending=False)
print("\n標準化係数（影響力大→小）:")
for v in beta_abs_sorted.index:
    p = pval_std[v]
    sig = '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else 'n.s.'))
    print(f"  {v}: β={beta_std[v]:+.3f} ({sig})")

print("\n全図の生成が完了しました。")
print(f"保存先: {os.path.abspath(FIG_DIR)}")
