"""
2023年度 統計データ分析コンペティション 総務大臣賞（大学生の部）
論文タイトル: 小中学生の不登校率における環境要因分析

目的変数の代理指標：
  高校→大学進学率 = 高校卒業者のうち進学者数 / 高校卒業者数 (E4602/E4601)
  (高い進学率 ≒ 教育成果への積極的関与、低い = 不関与の代理)

説明変数 (SSDSE-B 2022年度 都道府県データ):
  poverty    : 消費支出（L3221）  ← 高いほど豊か
  tfr        : 合計特殊出生率（A4103）← 低い=従来型家族少
  female_work: 15-64歳女性比率（A130202/A110102）
  pop_density: 総人口/住宅地価格（A1101/C5401）← 都市化代理
  school_res : 小学校教員数/児童数（E2401/E2501）← 高い=資源豊富
  land_price : 住宅地標準価格（C5401）
  temperature: 年平均気温（B4101）
  aging      : 高齢化率（A1303/A1101）

分析手法: 重回帰分析(OLS), VIF(多重共線性), 標準化係数
データ: SSDSE-B-2026.csv (実データ, 合成データ一切使用しない)
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2023_U1_daijin.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from scipy import stats

warnings.filterwarnings('ignore')

# ── パス設定 ─────────────────────────────────────────────
_dir    = os.path.dirname(os.path.abspath(__file__)) if '__file__' in dir() else os.getcwd()
FIG_DIR = os.path.join(_dir, '..', 'html', 'figures')
DATA_B  = os.path.join(_dir, '..', 'data', 'raw', 'SSDSE-B-2026.csv')

os.makedirs(FIG_DIR, exist_ok=True)

# ── フォント設定 ─────────────────────────────────────────
plt.rcParams.update({
    'font.family'     : ['Hiragino Sans', 'Hiragino Kaku Gothic ProN',
                          'AppleGothic', 'sans-serif'],
    'axes.unicode_minus': False,
    'figure.dpi'      : 150,
})
DPI = 150

# ── データ読み込み ──────────────────────────────────────
print("データを読み込み中...")
df_raw = pd.read_csv(DATA_B, header=1, encoding='cp932')

# 都道府県レベル (R\d{5}$) に絞り、2022年度を使用
mask = df_raw['地域コード'].str.match(r'^R\d{5}$', na=False)
df_raw = df_raw[mask].copy()
df_2022 = df_raw[df_raw['年度'] == 2022].copy()
print(f"2022年度 都道府県データ: {len(df_2022)}件")

# ── 変数の構築 ─────────────────────────────────────────
df = df_2022[['都道府県']].copy()

# 目的変数: 高校→大学等進学率
df['univ_rate'] = (
    df_2022['高等学校卒業者のうち進学者数'] / df_2022['高等学校卒業者数'] * 100
)

# 説明変数
df['poverty']    = df_2022['消費支出（二人以上の世帯）']               # L3221: 高いほど豊か
df['tfr']        = df_2022['合計特殊出生率']                           # A4103
df['female_work'] = (
    df_2022['15～64歳人口（女）'] / df_2022['15～64歳人口（女）'].replace(0, np.nan) * 100
    # 実際は 女性就業率ではなくデータに就業率がないので 15-64歳女性/全女性 比を使用
)
df['female_work'] = (
    df_2022['15～64歳人口（女）'] / df_2022['総人口（女）'] * 100       # A130202 / A110102 相当
)
df['pop_density_proxy'] = (
    df_2022['総人口'] / df_2022['標準価格（平均価格）（住宅地）'].replace(0, np.nan)
)                                                                        # A1101 / C5401
df['school_res'] = (
    df_2022['小学校教員数'] / df_2022['小学校児童数'].replace(0, np.nan) * 100
)                                                                        # E2401 / E2501
df['land_price'] = df_2022['標準価格（平均価格）（住宅地）']             # C5401
df['temperature'] = df_2022['年平均気温']                               # B4101
df['aging']      = (
    df_2022['65歳以上人口'] / df_2022['総人口'] * 100                   # A1303 / A1101
)

df = df.dropna()
print(f"欠損除外後: {len(df)}件 (都道府県)")

# 変数リスト
PRED_COLS = [
    'poverty', 'tfr', 'female_work', 'pop_density_proxy',
    'school_res', 'land_price', 'temperature', 'aging'
]
PRED_LABELS = {
    'poverty'          : '消費支出\n(千円/世帯)',
    'tfr'              : '合計特殊\n出生率',
    'female_work'      : '15-64歳\n女性比率(%)',
    'pop_density_proxy': '人口/地価\n(都市化代理)',
    'school_res'       : '小学校教員\n/児童比(%)',
    'land_price'       : '住宅地価格\n(円/㎡)',
    'temperature'      : '年平均\n気温(℃)',
    'aging'            : '高齢化率\n(%)',
}
Y_LABEL = '高校→大学進学率(%)'

X = df[PRED_COLS]
y = df['univ_rate']

print(f"\n目的変数 '{Y_LABEL}' の基本統計:")
print(y.describe().round(2))

# ── VIF 計算 ───────────────────────────────────────────
X_vif = sm.add_constant(X)
vif_vals = []
for i, col in enumerate(X_vif.columns):
    if col == 'const':
        continue
    vif_vals.append({
        'variable': col,
        'VIF'     : variance_inflation_factor(X_vif.values, i)
    })
vif_df = pd.DataFrame(vif_vals)
print("\nVIF:\n", vif_df.round(2).to_string(index=False))

# ── OLS 回帰 ───────────────────────────────────────────
X_ols = sm.add_constant(X)
model = sm.OLS(y, X_ols).fit()
print("\n" + "="*60)
print("OLS 回帰サマリー")
print("="*60)
print(model.summary())

# 標準化係数
X_std = (X - X.mean()) / X.std()
y_std = (y - y.mean()) / y.std()
X_std_ols = sm.add_constant(X_std)
model_std = sm.OLS(y_std, X_std_ols).fit()
coef_std = model_std.params.drop('const')
pvals    = model.pvalues.drop('const')

# ════════════════════════════════════════════════════════════
# 図1: 相関ヒートマップ
# ════════════════════════════════════════════════════════════
fig1_path = os.path.join(FIG_DIR, '2023_U1_fig1_heatmap.png')

plot_df = df[PRED_COLS + ['univ_rate']].copy()
corr = plot_df.corr()

short_labels = [PRED_LABELS.get(c, c) for c in PRED_COLS] + [Y_LABEL]

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(corr.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

n = len(corr)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(short_labels, fontsize=8, rotation=45, ha='right')
ax.set_yticklabels(short_labels, fontsize=8)

for i in range(n):
    for j in range(n):
        val = corr.values[i, j]
        color = 'white' if abs(val) > 0.6 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=8, color=color)

ax.set_title('図1: 変数間の相関ヒートマップ\n（目的変数: 高校→大学進学率）', fontsize=13, pad=14)
fig.tight_layout()
fig.savefig(fig1_path, dpi=DPI, bbox_inches='tight')
plt.close(fig)
print(f"\n保存: {fig1_path}")

# ════════════════════════════════════════════════════════════
# 図2: VIF 棒グラフ
# ════════════════════════════════════════════════════════════
fig2_path = os.path.join(FIG_DIR, '2023_U1_fig2_vif.png')

fig, ax = plt.subplots(figsize=(9, 5))
vif_labels = [PRED_LABELS.get(r['variable'], r['variable']) for _, r in vif_df.iterrows()]
vif_colors = ['#C62828' if v >= 10 else '#E65100' if v >= 5 else '#1565C0'
              for v in vif_df['VIF']]
bars = ax.bar(range(len(vif_df)), vif_df['VIF'], color=vif_colors, edgecolor='white', width=0.6)

ax.axhline(5,  color='#E65100', linestyle='--', lw=1.5, label='VIF = 5 (要注意)')
ax.axhline(10, color='#C62828', linestyle='--', lw=1.5, label='VIF = 10 (問題あり)')

for bar, v in zip(bars, vif_df['VIF']):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.15,
            f'{v:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

ax.set_xticks(range(len(vif_df)))
ax.set_xticklabels(vif_labels, fontsize=9)
ax.set_ylabel('VIF 値', fontsize=11)
ax.set_title('図2: 分散拡大係数（VIF）— 多重共線性の診断', fontsize=13, pad=10)
ax.legend(fontsize=10)
ax.set_ylim(0, max(vif_df['VIF'].max() * 1.2, 12))
ax.yaxis.set_major_locator(mticker.MultipleLocator(2))
fig.tight_layout()
fig.savefig(fig2_path, dpi=DPI, bbox_inches='tight')
plt.close(fig)
print(f"保存: {fig2_path}")

# ════════════════════════════════════════════════════════════
# 図3: 標準化回帰係数
# ════════════════════════════════════════════════════════════
fig3_path = os.path.join(FIG_DIR, '2023_U1_fig3_coef.png')

coef_plot = pd.DataFrame({
    'var'  : PRED_COLS,
    'coef' : coef_std.values,
    'pval' : pvals.values,
})
coef_plot = coef_plot.sort_values('coef', ascending=True).reset_index(drop=True)

fig, ax = plt.subplots(figsize=(9, 6))
colors = []
for _, row in coef_plot.iterrows():
    if row['pval'] < 0.05:
        colors.append('#1565C0' if row['coef'] > 0 else '#C62828')
    else:
        colors.append('#BDBDBD')

bars = ax.barh(range(len(coef_plot)), coef_plot['coef'], color=colors, edgecolor='white', height=0.6)
ax.axvline(0, color='black', lw=0.8)

labels = [PRED_LABELS.get(v, v) for v in coef_plot['var']]
ax.set_yticks(range(len(coef_plot)))
ax.set_yticklabels(labels, fontsize=9)

for i, (_, row) in enumerate(coef_plot.iterrows()):
    sig = '**' if row['pval'] < 0.01 else '*' if row['pval'] < 0.05 else ''
    x_pos = row['coef'] + (0.015 if row['coef'] >= 0 else -0.015)
    ha = 'left' if row['coef'] >= 0 else 'right'
    ax.text(x_pos, i, f"{row['coef']:.3f}{sig}", va='center', ha=ha, fontsize=9)

ax.set_xlabel('標準化回帰係数 (β)', fontsize=11)
ax.set_title('図3: 標準化OLS回帰係数\n（青: 正・有意, 赤: 負・有意, 灰: 非有意  *p<0.05  **p<0.01）',
             fontsize=12, pad=10)
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#1565C0', label='正の効果（p<0.05）'),
    Patch(facecolor='#C62828', label='負の効果（p<0.05）'),
    Patch(facecolor='#BDBDBD', label='非有意（p≥0.05）'),
]
ax.legend(handles=legend_elements, fontsize=9, loc='lower right')
fig.tight_layout()
fig.savefig(fig3_path, dpi=DPI, bbox_inches='tight')
plt.close(fig)
print(f"保存: {fig3_path}")

# ════════════════════════════════════════════════════════════
# 図4: 散布図（消費支出 vs 大学進学率）+ 回帰直線
# ════════════════════════════════════════════════════════════
fig4_path = os.path.join(FIG_DIR, '2023_U1_fig4_scatter.png')

x_sc = df['poverty'].values
y_sc = df['univ_rate'].values
pref = df['都道府県'].values

slope, intercept, r_val, p_val, _ = stats.linregress(x_sc, y_sc)
x_line = np.linspace(x_sc.min(), x_sc.max(), 200)
y_line = intercept + slope * x_line

fig, ax = plt.subplots(figsize=(9, 6))
sc = ax.scatter(x_sc, y_sc, c='#1565C0', alpha=0.75, s=70, edgecolors='white', lw=0.5, zorder=3)
ax.plot(x_line, y_line, color='#E65100', lw=2, label=f'回帰直線  r={r_val:.3f}, p={p_val:.3f}')

# 注目点にラベル
highlight = ['東京都', '沖縄県', '大阪府', '愛知県', '北海道', '京都府', '鹿児島県']
for i, pref_name in enumerate(pref):
    if pref_name in highlight:
        ax.annotate(pref_name, (x_sc[i], y_sc[i]),
                    textcoords='offset points', xytext=(5, 4),
                    fontsize=8, color='#333')

ax.set_xlabel('消費支出（円/世帯）', fontsize=11)
ax.set_ylabel('高校→大学等進学率（%）', fontsize=11)
ax.set_title('図4: 消費支出と大学等進学率の関係\n（47都道府県 2022年度）', fontsize=13, pad=10)
ax.legend(fontsize=10)
ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda v, _: f'{v/10000:.0f}万'))
fig.tight_layout()
fig.savefig(fig4_path, dpi=DPI, bbox_inches='tight')
plt.close(fig)
print(f"保存: {fig4_path}")

print("\n全図表の生成が完了しました。")
print(f"  fig1 (ヒートマップ)    : {os.path.basename(fig1_path)}")
print(f"  fig2 (VIF棒グラフ)     : {os.path.basename(fig2_path)}")
print(f"  fig3 (標準化係数プロット): {os.path.basename(fig3_path)}")
print(f"  fig4 (散布図)          : {os.path.basename(fig4_path)}")
