"""
教育用再現コード: 2022年 統計データ分析コンペティション 審査員奨励賞 [大学生・一般の部]
=================================================================
論文タイトル：大学進学率の地域格差：地理的異質性と集積効果のパネル分析
受賞：審査員奨励賞（大学生・一般の部）

【分析概要】
  データ：SSDSE-B-2026.csv（都道府県別パネルデータ, 2012〜2023年度）
  対象：全47都道府県 × 12年（2012〜2023）
  目的変数：大学進学率 = 高等学校卒業者のうち進学者数 / 高等学校卒業者数 × 100 (%)

  分析の流れ
  1. 時系列：都道府県別大学進学率の推移（地域グループ別）
  2. 都道府県別進学率の格差（箱ひげ図 + 2022年ランキング）
  3. OLS回帰分析：進学率の決定要因
  4. Ward法クラスタリング：都道府県グループ分類

【被説明変数】
  大学進学率 (%) = 高校卒業者のうち進学者数 / 高校卒業者数 × 100

【説明変数】
  大学数（都道府県内大学数）
  高齢化率 = 65歳以上人口 / 総人口 × 100
  消費支出_log = log(消費支出（二人以上の世帯）)
  教育費（二人以上の世帯）(千円)
  人口密度代理 = 総人口（千人）

【データ出典】
  SSDSE-B-2026.csv: 社会・人口統計体系（都道府県データ）

【データサイエンス学習ポイント】
  1. 教育指標の比率変数設計（進学率の定義と解釈）
  2. 地域格差の可視化（箱ひげ図 + ランキング）
  3. OLS回帰の実施と多重共線性チェック（VIF）
  4. Ward法による都道府県クラスタリング
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2022_U5_10_shorei.py
# ============================================================


import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist
from sklearn.preprocessing import StandardScaler
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
import warnings
warnings.filterwarnings('ignore')

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

import os
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ================================================================
# ■ 実データ読み込み
# ================================================================
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}$', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)
df_b = df_b.sort_values(['都道府県', '年度']).reset_index(drop=True)

# ================================================================
# ■ 変数生成
# ================================================================
df_b['大学進学率'] = df_b['高等学校卒業者のうち進学者数'] / df_b['高等学校卒業者数'].replace(0, np.nan) * 100
df_b['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口'] * 100
df_b['消費支出_log'] = np.log(df_b['消費支出（二人以上の世帯）'].clip(lower=1))
df_b['教育費_千円'] = df_b['教育費（二人以上の世帯）'] / 1000
df_b['総人口_千人'] = df_b['総人口'] / 1000

# 地域区分
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北',
    '宮城県': '北海道・東北', '秋田県': '北海道・東北', '山形県': '北海道・東北',
    '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国', '徳島県': '中国・四国',
    '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄',
}
df_b['地域'] = df_b['都道府県'].map(region_map)
region_colors = {
    '北海道・東北': '#4e9af1', '関東': '#e05c5c', '中部': '#f0a500',
    '近畿': '#5cb85c', '中国・四国': '#9b59b6', '九州・沖縄': '#f39c12',
}

# ================================================================
# ■ Fig1: 大学進学率の時系列推移（地域別平均）
# ================================================================
fig, ax = plt.subplots(figsize=(10, 5))
yearly = df_b.groupby(['年度', '地域'])['大学進学率'].mean().reset_index()
for reg, grp in yearly.groupby('地域'):
    ax.plot(grp['年度'], grp['大学進学率'], marker='o', markersize=4,
            label=reg, color=region_colors.get(reg, 'gray'))
ax.set_xlabel('年度', fontsize=12)
ax.set_ylabel('大学進学率（%）', fontsize=12)
ax.set_title('地域別 大学進学率の推移（2012〜2023年）', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_10_fig1_ts.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig1 saved")

# ================================================================
# ■ Fig2: 都道府県別進学率ランキング（2022年）
# ================================================================
df_2022 = df_b[df_b['年度'] == 2022].copy()
df_sorted = df_2022.sort_values('大学進学率', ascending=True).dropna(subset=['大学進学率'])
fig, ax = plt.subplots(figsize=(8, 12))
colors = [region_colors.get(df_sorted[df_sorted['都道府県'] == pref]['地域'].values[0], 'gray')
          if pref in df_sorted['都道府県'].values else 'gray'
          for pref in df_sorted['都道府県']]
bars = ax.barh(df_sorted['都道府県'], df_sorted['大学進学率'], color=colors, alpha=0.8)
ax.set_xlabel('大学進学率（%）', fontsize=12)
ax.set_title('都道府県別 大学進学率ランキング（2022年）', fontsize=13, fontweight='bold')
ax.axvline(df_2022['大学進学率'].mean(), color='red', linestyle='--', linewidth=1.5,
           label=f'全国平均: {df_2022["大学進学率"].mean():.1f}%')
ax.legend(fontsize=10)
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_10_fig2_rank.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig2 saved")

# ================================================================
# ■ Fig3: OLS回帰分析 — 係数プロット + VIF
# ================================================================
xvars = ['大学数', '高齢化率', '消費支出_log', '教育費_千円']
df_reg = df_2022[['大学進学率'] + xvars].dropna()
X = sm.add_constant(df_reg[xvars])
res = sm.OLS(df_reg['大学進学率'], X).fit()

# VIF計算
vif_df = pd.DataFrame({
    '変数': xvars,
    'VIF': [variance_inflation_factor(df_reg[xvars].values, i) for i in range(len(xvars))]
})
print(vif_df)
print(res.summary())

coefs = res.params.drop('const')
ses = res.bse.drop('const')
pvals = res.pvalues.drop('const')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 係数プロット
colors_c = ['#e05c5c' if p < 0.05 else '#888888' for p in pvals]
ax1.barh(range(len(coefs)), coefs, xerr=1.96 * ses, color=colors_c, alpha=0.8,
         error_kw={'elinewidth': 1.5, 'capsize': 4})
ax1.set_yticks(range(len(coefs)))
ax1.set_yticklabels(coefs.index, fontsize=10)
ax1.axvline(0, color='black', linewidth=0.8)
ax1.set_xlabel('回帰係数', fontsize=12)
ax1.set_title(f'OLS回帰係数（R²={res.rsquared:.3f}）\n（赤=p<0.05）', fontsize=12, fontweight='bold')
ax1.grid(axis='x', alpha=0.3)
# VIFテーブル
ax2.axis('off')
tdata = [[v, f'{vf:.2f}'] for v, vf in zip(vif_df['変数'], vif_df['VIF'])]
table = ax2.table(cellText=tdata, colLabels=['変数', 'VIF'],
                  cellLoc='center', loc='center', bbox=[0.1, 0.2, 0.8, 0.6])
table.auto_set_font_size(False)
table.set_fontsize(10)
ax2.set_title('分散拡大要因（VIF）\nVIF > 10 は多重共線性の懸念', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_10_fig3_ols.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig3 saved")

# ================================================================
# ■ Fig4: Ward法クラスタリング（デンドログラム）
# ================================================================
cluster_vars = ['大学進学率', '高齢化率', '消費支出_log', '教育費_千円']
df_clust = df_2022[['都道府県'] + cluster_vars].dropna().set_index('都道府県')
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df_clust)
Z = linkage(X_scaled, method='ward')

fig, ax = plt.subplots(figsize=(12, 6))
dendrogram(Z, labels=df_clust.index.tolist(), ax=ax,
           color_threshold=Z[-3, 2], leaf_rotation=90, leaf_font_size=8)
ax.set_title('Ward法クラスタリング：大学進学率・高齢化・消費水準（2022年）', fontsize=13, fontweight='bold')
ax.set_xlabel('都道府県', fontsize=11)
ax.set_ylabel('距離（Ward法）', fontsize=11)
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2022_U5_10_fig4_cluster.png'), dpi=150, bbox_inches='tight')
plt.close()
print("Fig4 saved")
print("All figures saved!")
