"""
2019_H3_suri.py
都道府県別出生率の予測モデル構築：Ridge回帰・Lasso回帰による変数選択と正則化
統計数理賞（高校生部門） 2019年度

SSDSE-B-2026.csv の実データのみ使用。合成データ禁止。
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2019_H3_suri.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy import stats
from sklearn.linear_model import Ridge, Lasso, RidgeCV, LassoCV
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ============================================================
# データ読み込み
# ============================================================
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)

print("全カラム:", df_b.columns.tolist())
print("利用可能年度:", sorted(df_b['年度'].unique()))

# ============================================================
# 断面データ抽出（最新年度 2023年）
# ============================================================
YEAR = 2023
df = df_b[df_b['年度'] == YEAR].copy().reset_index(drop=True)
print(f"\n使用年度: {YEAR}年, 都道府県数: {len(df)}")

# ============================================================
# 変数定義
# ============================================================
TARGET = '合計特殊出生率'

# 人口ベースの比率変数を計算
df['高齢化率'] = df['65歳以上人口'] / df['総人口'] * 100
df['年少人口率'] = df['15歳未満人口'] / df['総人口'] * 100
df['婚姻率'] = df['婚姻件数'] / df['総人口'] * 1000
df['保育所定員率'] = df['保育所等定員数'] / df['総人口'] * 10000
df['保育待機児童率'] = df['保育所等利用待機児童数'] / df['保育所等在所児数'] * 100
df['人口密度代理'] = df['総人口'] / 10000  # 万人単位（都市化の代理）

# 説明変数リスト（8変数）
FEAT_COLS = [
    '婚姻率',           # 婚姻率（人口千人当たり）
    '保育所定員率',      # 保育所定員数（人口万人当たり）
    '保育待機児童率',    # 保育待機児童率（在所児対比%）
    '高齢化率',          # 65歳以上人口比率
    '年少人口率',        # 15歳未満人口比率
    '人口密度代理',      # 総人口（万人）：都市化・人口規模の代理
    '消費支出（二人以上の世帯）',   # 消費支出
    '保健医療費（二人以上の世帯）', # 保健医療費
]

FEAT_LABELS = {
    '婚姻率':           '婚姻率\n(‰)',
    '保育所定員率':      '保育所定員率\n(万人対)',
    '保育待機児童率':    '保育待機\n児童率(%)',
    '高齢化率':          '高齢化率\n(%)',
    '年少人口率':        '年少人口率\n(%)',
    '人口密度代理':      '総人口\n(万人)',
    '消費支出（二人以上の世帯）':   '消費支出\n(円)',
    '保健医療費（二人以上の世帯）': '保健医療費\n(円)',
}

# 欠損値処理
cols_needed = [TARGET] + FEAT_COLS + ['都道府県']
df_clean = df[cols_needed].dropna().copy()
print(f"\n欠損除外後: {len(df_clean)} 都道府県")

X = df_clean[FEAT_COLS].values
y = df_clean[TARGET].values.astype(float)
prefs = df_clean['都道府県'].values

print(f"\n目的変数（合計特殊出生率）の統計:")
print(f"  平均={y.mean():.3f}, 標準偏差={y.std():.3f}, 最小={y.min():.2f}, 最大={y.max():.2f}")

# ============================================================
# 標準化
# ============================================================
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# ============================================================
# OLS回帰（比較用）
# ============================================================
X_ols = sm.add_constant(X_scaled)
ols_model = sm.OLS(y, X_ols).fit()
print("\n=== OLS結果 ===")
print(ols_model.summary())

# ============================================================
# Ridge回帰（交差検証でλ最適化）
# ============================================================
ridge_cv = RidgeCV(alphas=np.logspace(-3, 3, 50), cv=5)
ridge_cv.fit(X_scaled, y)
best_alpha_ridge = ridge_cv.alpha_
print(f"\n=== Ridge最適λ = {best_alpha_ridge:.4f} ===")

ridge_best = Ridge(alpha=best_alpha_ridge)
ridge_best.fit(X_scaled, y)
ridge_coefs = ridge_best.coef_
ridge_cv_score = cross_val_score(ridge_best, X_scaled, y, cv=5, scoring='neg_mean_squared_error')
print(f"Ridge CV MSE: {-ridge_cv_score.mean():.6f} ± {ridge_cv_score.std():.6f}")

# ============================================================
# Lasso回帰（交差検証でλ最適化）
# ============================================================
lasso_cv = LassoCV(alphas=np.logspace(-3, 1, 50), cv=5, max_iter=10000, random_state=42)
lasso_cv.fit(X_scaled, y)
best_alpha_lasso = lasso_cv.alpha_
print(f"\n=== Lasso最適λ = {best_alpha_lasso:.4f} ===")

lasso_best = Lasso(alpha=best_alpha_lasso, max_iter=10000)
lasso_best.fit(X_scaled, y)
lasso_coefs = lasso_best.coef_
lasso_cv_score = cross_val_score(lasso_best, X_scaled, y, cv=5, scoring='neg_mean_squared_error')
print(f"Lasso CV MSE: {-lasso_cv_score.mean():.6f} ± {lasso_cv_score.std():.6f}")

# Lassoが選択した変数
print("\n=== Lassoが選択した変数（非ゼロ係数）===")
for i, (col, coef) in enumerate(zip(FEAT_COLS, lasso_coefs)):
    if abs(coef) > 1e-6:
        print(f"  {col}: {coef:.4f}")
    else:
        print(f"  {col}: 0 (除外)")

# ============================================================
# OLS係数（標準化後）
# ============================================================
ols_coefs = ols_model.params[1:]  # 定数項を除く

print("\n=== 係数比較（標準化後）===")
print(f"{'変数':<30} {'OLS':>8} {'Ridge':>8} {'Lasso':>8}")
for col, oc, rc, lc in zip(FEAT_COLS, ols_coefs, ridge_coefs, lasso_coefs):
    print(f"{col:<30} {oc:>8.4f} {rc:>8.4f} {lc:>8.4f}")

# ============================================================
# 地域区分の定義（Figure 1用）
# ============================================================
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北',
    '宮城県': '北海道・東北', '秋田県': '北海道・東北', '山形県': '北海道・東北',
    '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部',
    '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国',
    '徳島県': '中国・四国', '香川県': '中国・四国', '愛媛県': '中国・四国',
    '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄',
}
region_colors = {
    '北海道・東北': '#4472C4',
    '関東':         '#ED7D31',
    '中部':         '#A9D18E',
    '近畿':         '#FFC000',
    '中国・四国':   '#9DC3E6',
    '九州・沖縄':   '#FF7F7F',
}

df_clean['地域'] = df_clean['都道府県'].map(region_map)
df_clean_sorted = df_clean.sort_values(TARGET, ascending=True).reset_index(drop=True)
national_avg = y.mean()

# ============================================================
# Figure 1: 合計特殊出生率 都道府県ランキング（横棒グラフ）
# ============================================================
fig1, ax1 = plt.subplots(figsize=(10, 12))

bar_colors = [region_colors.get(r, '#999') for r in df_clean_sorted['地域']]
bars = ax1.barh(
    range(len(df_clean_sorted)),
    df_clean_sorted[TARGET],
    color=bar_colors,
    edgecolor='white',
    linewidth=0.5,
    height=0.75
)

ax1.axvline(national_avg, color='darkred', linestyle='--', linewidth=1.5,
            label=f'全国平均: {national_avg:.2f}')

ax1.set_yticks(range(len(df_clean_sorted)))
ax1.set_yticklabels(df_clean_sorted['都道府県'], fontsize=9)
ax1.set_xlabel('合計特殊出生率', fontsize=12)
ax1.set_title(f'合計特殊出生率 都道府県ランキング（{YEAR}年）', fontsize=14, fontweight='bold')
ax1.set_xlim(0.8, 2.2)

# 凡例
from matplotlib.patches import Patch
legend_patches = [Patch(color=c, label=r) for r, c in region_colors.items()]
legend_patches.append(plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=1.5,
                                  label=f'全国平均: {national_avg:.2f}'))
ax1.legend(handles=legend_patches, loc='lower right', fontsize=9, framealpha=0.9)

ax1.grid(axis='x', alpha=0.3, linestyle=':')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

plt.tight_layout()
fig1_path = os.path.join(FIG_DIR, '2019_H3_fig1.png')
plt.savefig(fig1_path, bbox_inches='tight')
plt.close()
print(f"\n[保存] {fig1_path}")

# ============================================================
# Figure 2: Lasso 係数パス図
# ============================================================
alphas = np.logspace(-3, 1, 100)
coefs = []
for a in alphas:
    lasso = Lasso(alpha=a, max_iter=10000)
    lasso.fit(X_scaled, y)
    coefs.append(lasso.coef_)
coefs = np.array(coefs)

fig2, ax2 = plt.subplots(figsize=(10, 6))

colors_path = plt.cm.tab10(np.linspace(0, 1, len(FEAT_COLS)))
for i, (col, color) in enumerate(zip(FEAT_COLS, colors_path)):
    ax2.semilogx(alphas, coefs[:, i], color=color,
                 label=FEAT_LABELS.get(col, col), linewidth=2)

ax2.axvline(best_alpha_lasso, color='red', linestyle='--', linewidth=2,
            label=f'最適λ = {best_alpha_lasso:.4f}')
ax2.axhline(0, color='black', linewidth=0.5, linestyle='-')

ax2.set_xlabel('正則化パラメータ λ（対数スケール）', fontsize=12)
ax2.set_ylabel('標準化係数', fontsize=12)
ax2.set_title('Lasso 係数パス図：λ増大に伴う係数の収縮', fontsize=14, fontweight='bold')
ax2.legend(loc='upper right', fontsize=9, ncol=2, framealpha=0.9)
ax2.grid(True, alpha=0.3, linestyle=':')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()
fig2_path = os.path.join(FIG_DIR, '2019_H3_fig2.png')
plt.savefig(fig2_path, bbox_inches='tight')
plt.close()
print(f"[保存] {fig2_path}")

# ============================================================
# Figure 3: Ridge vs Lasso vs OLS 係数比較（グループ棒グラフ）
# ============================================================
short_labels = [FEAT_LABELS.get(c, c) for c in FEAT_COLS]
n_vars = len(FEAT_COLS)
x = np.arange(n_vars)
width = 0.25

fig3, ax3 = plt.subplots(figsize=(13, 6))

bars_ols   = ax3.bar(x - width, ols_coefs, width, label='OLS',   color='#4472C4', alpha=0.85)
bars_ridge = ax3.bar(x,         ridge_coefs, width, label='Ridge', color='#ED7D31', alpha=0.85)
bars_lasso = ax3.bar(x + width, lasso_coefs, width, label='Lasso', color='#70AD47', alpha=0.85)

ax3.axhline(0, color='black', linewidth=0.8)
ax3.set_xticks(x)
ax3.set_xticklabels(short_labels, fontsize=9, rotation=0)
ax3.set_ylabel('標準化係数', fontsize=12)
ax3.set_title('OLS・Ridge・Lasso の係数比較（標準化後）', fontsize=14, fontweight='bold')
ax3.legend(fontsize=11, framealpha=0.9)
ax3.grid(axis='y', alpha=0.3, linestyle=':')
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)

# 最適λをテキストで注釈
ax3.text(0.02, 0.98, f'Ridge 最適λ={best_alpha_ridge:.3f}\nLasso 最適λ={best_alpha_lasso:.4f}',
         transform=ax3.transAxes, fontsize=10, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

plt.tight_layout()
fig3_path = os.path.join(FIG_DIR, '2019_H3_fig3.png')
plt.savefig(fig3_path, bbox_inches='tight')
plt.close()
print(f"[保存] {fig3_path}")

# ============================================================
# Figure 4: 交差検証スコア（MSE）vs λ（両対数スケール）
# ============================================================
alphas_ridge_grid = np.logspace(-3, 3, 60)
alphas_lasso_grid = np.logspace(-3, 1, 60)

ridge_cv_mses = []
for a in alphas_ridge_grid:
    r = Ridge(alpha=a)
    scores = cross_val_score(r, X_scaled, y, cv=5, scoring='neg_mean_squared_error')
    ridge_cv_mses.append(-scores.mean())

lasso_cv_mses = []
for a in alphas_lasso_grid:
    l = Lasso(alpha=a, max_iter=10000)
    scores = cross_val_score(l, X_scaled, y, cv=5, scoring='neg_mean_squared_error')
    lasso_cv_mses.append(-scores.mean())

fig4, ax4 = plt.subplots(figsize=(10, 6))

ax4.loglog(alphas_ridge_grid, ridge_cv_mses, color='#ED7D31', linewidth=2.5,
           label='Ridge CV-MSE')
ax4.loglog(alphas_lasso_grid, lasso_cv_mses, color='#70AD47', linewidth=2.5,
           label='Lasso CV-MSE')

# 最適λの縦線
ax4.axvline(best_alpha_ridge, color='#ED7D31', linestyle='--', linewidth=1.8,
            label=f'Ridge 最適λ = {best_alpha_ridge:.3f}')
ax4.axvline(best_alpha_lasso, color='#70AD47', linestyle='--', linewidth=1.8,
            label=f'Lasso 最適λ = {best_alpha_lasso:.4f}')

ax4.set_xlabel('正則化パラメータ λ（対数スケール）', fontsize=12)
ax4.set_ylabel('交差検証 MSE（対数スケール）', fontsize=12)
ax4.set_title('5分割交差検証による MSE vs λ（Ridge・Lasso）', fontsize=14, fontweight='bold')
ax4.legend(fontsize=10, framealpha=0.9)
ax4.grid(True, which='both', alpha=0.3, linestyle=':')
ax4.spines['top'].set_visible(False)
ax4.spines['right'].set_visible(False)

plt.tight_layout()
fig4_path = os.path.join(FIG_DIR, '2019_H3_fig4.png')
plt.savefig(fig4_path, bbox_inches='tight')
plt.close()
print(f"[保存] {fig4_path}")

# ============================================================
# 最終サマリ
# ============================================================
print("\n" + "="*60)
print("最終結果サマリ")
print("="*60)
print(f"使用データ: SSDSE-B-2026.csv, {YEAR}年断面, {len(df_clean)}都道府県")
print(f"目的変数: 合計特殊出生率（平均={y.mean():.3f}）")
print(f"\nRidge 最適λ: {best_alpha_ridge:.4f}")
print(f"Ridge CV-MSE: {-ridge_cv_score.mean():.6f}")
print(f"\nLasso 最適λ: {best_alpha_lasso:.4f}")
print(f"Lasso CV-MSE: {-lasso_cv_score.mean():.6f}")
print(f"\nLasso 変数選択（非ゼロ = 選択）:")
for col, coef in zip(FEAT_COLS, lasso_coefs):
    status = "★ 選択" if abs(coef) > 1e-6 else "  除外"
    print(f"  {status}: {col} ({coef:.4f})")

print("\nDONE: 2019_H3_suri")
