"""
教育用再現コード: 2023年 統計データ分析コンペティション 優秀賞（大学生）
=================================================================
論文タイトル：市町村費負担教員任用の規定要因
            ―ハードルモデルを用いた多変量解析から―

手法：ハードルモデル（2部構成）
  Part 1: ロジスティック回帰 — 都道府県に「高学校密度」があるか（2値）
  Part 2: OLS 回帰         — 高学校密度都道府県の教員/児童比を予測
  補助：VIF（分散拡大係数）による多重共線性チェック

【分析概要】
  ハードルモデルは、ゼロ過剰データや「2段階の意思決定」を持つデータに
  適用される2部構成モデル。
  第1部で「ハードルを超えるか（binary）」を確率モデルで推定し、
  第2部で「超えた場合の大きさ」を連続モデルで推定する。

  本分析では:
    - 第1部: 学校密度（E2101/A1101×10000）が中央値を超えるか → ロジスティック回帰
    - 第2部: 高密度都道府県の教員/児童比（E2401/E2501）を予測 → OLS回帰

【説明変数（都道府県レベル）】
  x1  高齢化率    A1303/A1101 × 100
  x2  消費支出    L3221
  x3  住宅地標準価格 C5401
  x4  合計特殊出生率 A4103
  x5  転出率      A5102/A1101 × 100
  x6  有効求人倍率  F3103/F3102
  x7  高校→大学進学率 E4602/E4601 × 100

【使用データ（実データ）】
  data/raw/SSDSE-B-2026.csv — 47都道府県、2022年度
  出典: 政府統計の総合窓口（e-Stat）、SSDSE（教育用標準データセット）

【出力図】
  html/figures/2023_U2_fig1_corr.png      — 相関ヒートマップ
  html/figures/2023_U2_fig2_logit_or.png  — ロジスティック回帰オッズ比棒グラフ
  html/figures/2023_U2_fig3_ols_coef.png  — OLS係数棒グラフ（第2部）
  html/figures/2023_U2_fig4_actual_pred.png — 実測値 vs 予測値散布図
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2023_U2_yushu.py
# ============================================================


import os
import warnings
import re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import statsmodels.formula.api as smf
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor

warnings.filterwarnings('ignore')

# ──────────────────────────────────────────────────────────────
# パス設定
# ──────────────────────────────────────────────────────────────
_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in dir() else os.getcwd()
FIG_DIR = os.path.join(_dir, '..', 'html', 'figures')
DATA_B  = os.path.join(_dir, '..', 'data', 'raw', 'SSDSE-B-2026.csv')
os.makedirs(FIG_DIR, exist_ok=True)

# ──────────────────────────────────────────────────────────────
# 共通設定
# ──────────────────────────────────────────────────────────────
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

COLORS = {
    'primary':   '#1565C0',
    'secondary': '#E65100',
    'success':   '#2E7D32',
    'danger':    '#C62828',
    'purple':    '#6A1B9A',
    'teal':      '#00695C',
    'gray':      '#616161',
}

# ──────────────────────────────────────────────────────────────
# データ読み込み（SSDSE-B-2026、header=0 で変数コードを列名に）
# ──────────────────────────────────────────────────────────────
print("=" * 65)
print("■ データ読み込み: SSDSE-B-2026.csv（47都道府県、2022年度）")
print("=" * 65)

# header=0 → 1行目の英字コード（A1101, E2101, …）が列名になる
# 2行目（日本語ラベル）をスキップして実数行のみ使用
df_raw = pd.read_csv(DATA_B, encoding='cp932', header=0, low_memory=False)

# 列名を標準化（最初の3列を固定名に）
df_raw.columns = ['年度', '地域コード', '都道府県'] + list(df_raw.columns[3:])

# 1行目は日本語ラベル行なので除外
df_raw = df_raw.iloc[1:].reset_index(drop=True)

# 47都道府県のみ抽出（地域コード = R + 5桁数字）
df_pref = df_raw[df_raw['地域コード'].astype(str).str.match(r'^R\d{5}$', na=False)].copy()

# 2022年度のみ
df_pref = df_pref[df_pref['年度'].astype(str) == '2022'].copy()
print(f"抽出後: {len(df_pref)}都道府県（2022年度）")

# ──────────────────────────────────────────────────────────────
# 変数構築
# ──────────────────────────────────────────────────────────────
# 必要列を数値変換
num_cols = ['E2101', 'A1101', 'A1303', 'L3221', 'C5401', 'A4103',
            'A5102', 'F3102', 'F3103', 'E4601', 'E4602', 'E2401', 'E2501']
for c in num_cols:
    df_pref[c] = pd.to_numeric(df_pref[c], errors='coerce')

# 目的変数・予測変数の計算
df_pref['school_per_pop']  = df_pref['E2101'] / df_pref['A1101'] * 10000  # 学校密度（1万人対）
df_pref['teacher_ratio']   = df_pref['E2401'] / df_pref['E2501']           # 教員/児童比

# 説明変数
df_pref['x1_aging_rate']   = df_pref['A1303'] / df_pref['A1101'] * 100    # 高齢化率（%）
df_pref['x2_consumption']  = df_pref['L3221']                              # 消費支出（円）
df_pref['x3_land_price']   = df_pref['C5401']                              # 住宅地標準価格（円/m²）
df_pref['x4_tfr']          = df_pref['A4103']                              # 合計特殊出生率
df_pref['x5_outmigrate']   = df_pref['A5102'] / df_pref['A1101'] * 100    # 転出率（%）
df_pref['x6_job_ratio']    = df_pref['F3103'] / df_pref['F3102']           # 有効求人倍率
df_pref['x7_college_rate'] = df_pref['E4602'] / df_pref['E4601'] * 100    # 高校→大学進学率（%）

# 第1部目的変数: 学校密度が中央値を超えるか（binary）
median_density = df_pref['school_per_pop'].median()
df_pref['school_dense'] = (df_pref['school_per_pop'] > median_density).astype(int)

print(f"\n学校密度（1万人対）中央値: {median_density:.4f}")
print(f"school_dense=1（高密度）: {df_pref['school_dense'].sum()}都道府県")
print(f"school_dense=0（低密度）: {(df_pref['school_dense']==0).sum()}都道府県")

# 欠損値確認
analysis_cols = ['school_dense', 'teacher_ratio',
                 'x1_aging_rate', 'x2_consumption', 'x3_land_price',
                 'x4_tfr', 'x5_outmigrate', 'x6_job_ratio', 'x7_college_rate']
df_ana = df_pref[['都道府県', 'school_per_pop'] + analysis_cols].dropna().copy()
print(f"\n欠損値除去後: {len(df_ana)}都道府県")

# ──────────────────────────────────────────────────────────────
# 変数名マッピング
# ──────────────────────────────────────────────────────────────
PRED_COLS = ['x1_aging_rate', 'x2_consumption', 'x3_land_price',
             'x4_tfr', 'x5_outmigrate', 'x6_job_ratio', 'x7_college_rate']
PRED_LABELS = {
    'x1_aging_rate':   '高齢化率（%）',
    'x2_consumption':  '消費支出（円）',
    'x3_land_price':   '住宅地標準価格（円/m²）',
    'x4_tfr':          '合計特殊出生率',
    'x5_outmigrate':   '転出率（%）',
    'x6_job_ratio':    '有効求人倍率',
    'x7_college_rate': '高校→大学進学率（%）',
}

# ──────────────────────────────────────────────────────────────
# VIF チェック（多重共線性診断）
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("■ VIF チェック（全47都道府県、説明変数7個）")
print("=" * 65)

X_vif = sm.add_constant(df_ana[PRED_COLS].astype(float))
vif_df = pd.DataFrame({
    '変数': PRED_COLS,
    'VIF': [variance_inflation_factor(X_vif.values, i + 1)
            for i in range(len(PRED_COLS))]
})
vif_df['ラベル'] = vif_df['変数'].map(PRED_LABELS)
print(vif_df[['ラベル', 'VIF']].to_string(index=False))
print("（VIF < 5: 問題なし  5-10: 中程度  >10: 深刻な多重共線性）")

# ──────────────────────────────────────────────────────────────
# Part 1: ロジスティック回帰（全47都道府県）
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("■ Part 1: ロジスティック回帰（ハードルモデル第1部）")
print("  目的変数: school_dense（学校密度 > 中央値 → 1）")
print("=" * 65)

formula_logit = (
    'school_dense ~ x1_aging_rate + x2_consumption + x3_land_price '
    '+ x4_tfr + x5_outmigrate + x6_job_ratio + x7_college_rate'
)

logit_model = smf.logit(formula_logit, data=df_ana).fit(maxiter=200, disp=False)
print(logit_model.summary())

# オッズ比と95%CI
or_params  = np.exp(logit_model.params)
or_ci      = np.exp(logit_model.conf_int())
or_pvals   = logit_model.pvalues

print("\n【オッズ比と95%CI】")
or_df = pd.DataFrame({
    'OR':     or_params,
    'CI_lo':  or_ci[0],
    'CI_hi':  or_ci[1],
    'p値':    or_pvals,
})
print(or_df.round(4))

# ──────────────────────────────────────────────────────────────
# Part 2: OLS 回帰（school_dense==1 の都道府県のみ）
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("■ Part 2: OLS 回帰（ハードルモデル第2部）")
print("  対象: school_dense==1 の都道府県")
print("  目的変数: teacher_ratio（教員/児童比）")
print("=" * 65)

df_part2 = df_ana[df_ana['school_dense'] == 1].copy()
print(f"Part 2 サンプルサイズ: {len(df_part2)}都道府県")

Y2 = df_part2['teacher_ratio'].astype(float)
X2 = sm.add_constant(df_part2[PRED_COLS].astype(float))
ols_model = sm.OLS(Y2, X2).fit()
print(ols_model.summary())

# 予測値
pred_vals = ols_model.fittedvalues

print("\n【OLS係数サマリー】")
coef_df = pd.DataFrame({
    '係数':   ols_model.params,
    '標準誤差': ols_model.bse,
    'p値':    ols_model.pvalues,
})
print(coef_df.round(6))

# ──────────────────────────────────────────────────────────────
# 図1: 相関ヒートマップ
# ──────────────────────────────────────────────────────────────
print("\n図1: 相関ヒートマップを作成中...")

corr_cols  = ['school_per_pop', 'teacher_ratio'] + PRED_COLS
corr_label = {
    'school_per_pop':  '学校密度\n(万人対)',
    'teacher_ratio':   '教員/\n児童比',
    'x1_aging_rate':   '高齢化率',
    'x2_consumption':  '消費支出',
    'x3_land_price':   '住宅地\n標準価格',
    'x4_tfr':          'TFR',
    'x5_outmigrate':   '転出率',
    'x6_job_ratio':    '有効\n求人倍率',
    'x7_college_rate': '進学率',
}

df_corr = df_ana[corr_cols].astype(float).rename(columns=corr_label)
corr_mat = df_corr.corr()

fig1, ax1 = plt.subplots(figsize=(10, 8))
cmap = sns.diverging_palette(220, 20, as_cmap=True)
sns.heatmap(
    corr_mat, ax=ax1, cmap=cmap, center=0,
    vmin=-1, vmax=1,
    annot=True, fmt='.2f', annot_kws={'size': 9},
    linewidths=0.5, linecolor='white',
    square=True, cbar_kws={'shrink': 0.8}
)
ax1.set_title(
    '相関ヒートマップ（都道府県別、2022年度、N=47）\n'
    'データ出典: SSDSE-B-2026（e-Stat）',
    fontsize=12, fontweight='bold', pad=14
)
ax1.tick_params(axis='x', labelsize=9, rotation=0)
ax1.tick_params(axis='y', labelsize=9, rotation=0)
plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2023_U2_fig1_corr.png'), bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  → 2023_U2_fig1_corr.png 保存完了")

# ──────────────────────────────────────────────────────────────
# 図2: ロジスティック回帰 オッズ比棒グラフ（第1部）
# ──────────────────────────────────────────────────────────────
print("図2: ロジスティック回帰 オッズ比棒グラフを作成中...")

# 定数項を除く
or_plot = or_df.drop('Intercept', errors='ignore').copy()
or_plot['label'] = [PRED_LABELS.get(v, v) for v in or_plot.index]
or_plot = or_plot.reset_index(drop=False)

fig2, ax2 = plt.subplots(figsize=(10, 6))
y_pos = np.arange(len(or_plot))

bar_colors2 = []
for p in or_plot['p値']:
    if p < 0.01:   bar_colors2.append(COLORS['danger'])
    elif p < 0.05: bar_colors2.append(COLORS['secondary'])
    elif p < 0.1:  bar_colors2.append(COLORS['success'])
    else:          bar_colors2.append(COLORS['gray'])

bars2 = ax2.barh(y_pos, or_plot['OR'], color=bar_colors2, alpha=0.85,
                 edgecolor='white', linewidth=0.8, height=0.55)
ax2.errorbar(
    or_plot['OR'], y_pos,
    xerr=[or_plot['OR'] - or_plot['CI_lo'], or_plot['CI_hi'] - or_plot['OR']],
    fmt='none', color='black', capsize=4, linewidth=1.5, capthick=1.5
)
ax2.axvline(1.0, color='black', linewidth=1.2, linestyle='--', alpha=0.7, label='OR = 1（基準線）')

ax2.set_yticks(y_pos)
ax2.set_yticklabels(or_plot['label'], fontsize=10)
ax2.set_xlabel('オッズ比（OR）± 95%CI', fontsize=11)
ax2.set_title(
    'Part 1: ロジスティック回帰 — オッズ比（OR）\n'
    '目的変数: 学校密度（1万人対）> 中央値（school_dense = 1）\n'
    'データ出典: SSDSE-B-2026（e-Stat）',
    fontsize=11, fontweight='bold'
)

legend_patches = [
    plt.Rectangle((0, 0), 1, 1, color=COLORS['danger'],    label='p < 0.01 ***'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['secondary'], label='p < 0.05 **'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['success'],   label='p < 0.10 *'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['gray'],       label='n.s.'),
]
ax2.legend(handles=legend_patches + [
    plt.Line2D([0], [0], color='black', linestyle='--', label='OR = 1（基準線）')
], loc='lower right', fontsize=9)

for i, (_, row) in enumerate(or_plot.iterrows()):
    sig = '***' if row['p値'] < 0.01 else '**' if row['p値'] < 0.05 else '*' if row['p値'] < 0.1 else ''
    ax2.text(
        max(row['CI_hi'], row['OR']) + 0.02, i,
        f"{row['OR']:.3f}{sig}",
        va='center', fontsize=9
    )

ax2.grid(axis='x', alpha=0.3)
ax2.invert_yaxis()
plt.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2023_U2_fig2_logit_or.png'), bbox_inches='tight', dpi=150)
plt.close(fig2)
print("  → 2023_U2_fig2_logit_or.png 保存完了")

# ──────────────────────────────────────────────────────────────
# 図3: OLS 係数棒グラフ（第2部）
# ──────────────────────────────────────────────────────────────
print("図3: OLS係数棒グラフを作成中...")

coef_plot = coef_df.drop('const', errors='ignore').copy()
coef_plot['label'] = [PRED_LABELS.get(v, v) for v in coef_plot.index]
coef_plot = coef_plot.reset_index(drop=False)

# 95%CI
ci_ols = ols_model.conf_int()
ci_ols.columns = ['lo', 'hi']
ci_ols = ci_ols.drop('const', errors='ignore')
coef_plot['ci_lo'] = ci_ols['lo'].values
coef_plot['ci_hi'] = ci_ols['hi'].values

fig3, ax3 = plt.subplots(figsize=(10, 6))
y_pos3 = np.arange(len(coef_plot))

bar_colors3 = []
for p in coef_plot['p値']:
    if p < 0.01:   bar_colors3.append(COLORS['danger'])
    elif p < 0.05: bar_colors3.append(COLORS['secondary'])
    elif p < 0.1:  bar_colors3.append(COLORS['success'])
    else:          bar_colors3.append(COLORS['gray'])

ax3.barh(y_pos3, coef_plot['係数'], color=bar_colors3, alpha=0.85,
         edgecolor='white', linewidth=0.8, height=0.55)
ax3.errorbar(
    coef_plot['係数'], y_pos3,
    xerr=[coef_plot['係数'] - coef_plot['ci_lo'], coef_plot['ci_hi'] - coef_plot['係数']],
    fmt='none', color='black', capsize=4, linewidth=1.5, capthick=1.5
)
ax3.axvline(0, color='black', linewidth=1.2, linestyle='--', alpha=0.7)

ax3.set_yticks(y_pos3)
ax3.set_yticklabels(coef_plot['label'], fontsize=10)
ax3.set_xlabel('OLS 回帰係数 ± 95%CI', fontsize=11)
ax3.set_title(
    f'Part 2: OLS 回帰係数（school_dense=1 の都道府県、N={len(df_part2)}）\n'
    '目的変数: 教員/児童比（E2401/E2501）\n'
    'データ出典: SSDSE-B-2026（e-Stat）',
    fontsize=11, fontweight='bold'
)

legend_patches3 = [
    plt.Rectangle((0, 0), 1, 1, color=COLORS['danger'],    label='p < 0.01 ***'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['secondary'], label='p < 0.05 **'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['success'],   label='p < 0.10 *'),
    plt.Rectangle((0, 0), 1, 1, color=COLORS['gray'],       label='n.s.'),
]
ax3.legend(handles=legend_patches3, loc='lower right', fontsize=9)

for i, (_, row) in enumerate(coef_plot.iterrows()):
    sig = '***' if row['p値'] < 0.01 else '**' if row['p値'] < 0.05 else '*' if row['p値'] < 0.1 else ''
    offset = row['ci_hi'] if row['係数'] >= 0 else row['ci_lo']
    ax3.text(
        offset, i,
        f"  {row['係数']:.5f}{sig}",
        va='center', fontsize=8
    )

ax3.grid(axis='x', alpha=0.3)
ax3.invert_yaxis()
plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2023_U2_fig3_ols_coef.png'), bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  → 2023_U2_fig3_ols_coef.png 保存完了")

# ──────────────────────────────────────────────────────────────
# 図4: 実測値 vs 予測値 散布図（Part 2 OLS）
# ──────────────────────────────────────────────────────────────
print("図4: 実測値 vs 予測値散布図を作成中...")

actual_vals = Y2.values
prefs_part2 = df_part2['都道府県'].values

fig4, ax4 = plt.subplots(figsize=(8, 7))

ax4.scatter(actual_vals, pred_vals.values,
            color=COLORS['primary'], alpha=0.75, s=70,
            edgecolors='white', linewidth=0.8, zorder=3)

# 対角線（完全一致線）
all_vals = np.concatenate([actual_vals, pred_vals.values])
v_min, v_max = all_vals.min(), all_vals.max()
margin = (v_max - v_min) * 0.05
ax4.plot([v_min - margin, v_max + margin], [v_min - margin, v_max + margin],
         color='black', linewidth=1.2, linestyle='--', alpha=0.7, zorder=2, label='完全一致線（y=x）')

# 都道府県名ラベル
for pref, act, pred in zip(prefs_part2, actual_vals, pred_vals.values):
    ax4.annotate(
        pref, (act, pred),
        fontsize=6.5, alpha=0.8,
        xytext=(3, 3), textcoords='offset points'
    )

r2_val   = ols_model.rsquared
rmse_val = np.sqrt(np.mean((actual_vals - pred_vals.values) ** 2))
ax4.set_xlabel('実測値（教員/児童比）', fontsize=11)
ax4.set_ylabel('予測値（OLS、教員/児童比）', fontsize=11)
ax4.set_title(
    f'Part 2 OLS: 実測値 vs 予測値（school_dense=1、N={len(df_part2)}）\n'
    f'R² = {r2_val:.3f}  RMSE = {rmse_val:.6f}\n'
    'データ出典: SSDSE-B-2026（e-Stat）',
    fontsize=11, fontweight='bold'
)
ax4.set_xlim(v_min - margin, v_max + margin)
ax4.set_ylim(v_min - margin, v_max + margin)
ax4.legend(fontsize=9)
ax4.grid(True, alpha=0.3)
plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2023_U2_fig4_actual_pred.png'), bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  → 2023_U2_fig4_actual_pred.png 保存完了")

# ──────────────────────────────────────────────────────────────
# 最終サマリー
# ──────────────────────────────────────────────────────────────
print("\n" + "=" * 65)
print("✓ 全図の生成完了")
print("=" * 65)

print("\n【ハードルモデル 結果サマリー】")
print(f"\n[VIF]")
for _, row in vif_df.iterrows():
    flag = " ← 注意（多重共線性）" if row['VIF'] > 5 else ""
    print(f"  {row['ラベル']:20s}: VIF = {row['VIF']:.2f}{flag}")

print(f"\n[Part 1: ロジスティック回帰]  N=47")
print(f"  疑似R²（McFadden）: {logit_model.prsquared:.3f}")
print(f"  AIC: {logit_model.aic:.2f}")
print("  オッズ比（有意な変数のみ）:")
for var in or_plot['index']:
    row_or = or_plot[or_plot['index'] == var].iloc[0]
    if row_or['p値'] < 0.1:
        sig = '***' if row_or['p値'] < 0.01 else '**' if row_or['p値'] < 0.05 else '*'
        print(f"    {PRED_LABELS.get(var, var):20s}: OR={row_or['OR']:.3f} {sig}  (p={row_or['p値']:.4f})")

print(f"\n[Part 2: OLS 回帰]  N={len(df_part2)}（school_dense=1 の都道府県）")
print(f"  R²  : {ols_model.rsquared:.3f}")
print(f"  調整済みR²: {ols_model.rsquared_adj:.3f}")
print(f"  RMSE: {rmse_val:.6f}")
print(f"  AIC : {ols_model.aic:.2f}")
print("  係数（有意な変数のみ）:")
for var in coef_plot['index']:
    row_c = coef_plot[coef_plot['index'] == var].iloc[0]
    if row_c['p値'] < 0.1:
        sig = '***' if row_c['p値'] < 0.01 else '**' if row_c['p値'] < 0.05 else '*'
        print(f"    {PRED_LABELS.get(var, var):20s}: β={row_c['係数']:.6f} {sig}  (p={row_c['p値']:.4f})")
