"""
教育用再現コード: 2024年 統計データ分析コンペティション 審査員奨励賞（大学生・一般）
=================================================================================
論文タイトル：XAIを用いた介護業界における地域別の従業者数の就業要因に関する一考察
受賞区分  ：審査員奨励賞 ［大学生・一般の部］
著者      ：宮内弘太（一般財団法人計量計画研究所）

【分析概要】
  データ：SSDSE-E-2026（都道府県別クロスセクション）+
          SSDSE-B-2026（都道府県別パネルデータ 2022年）
  目的   ：Random Forestで医療・福祉従業者密度を予測し、
            特徴量重要度（XAI代替）で各変数の影響を可視化・解釈する

  Step1. 実データ読み込み・変数作成（47都道府県）
  Step2. Random Forest回帰モデルの構築
  Step3. SHAP値または代替特徴量重要度の計算
  Step4. 部分従属プロット（PDP）の可視化

【変数説明】
  目的変数: 医療・福祉従業者密度（従業者数（民営）（医療、福祉）/ 総人口 × 10000）
  説明変数（すべてSSSDSE実データ）:
    - 高齢化率       : 65歳以上人口 / 総人口
    - 県民所得       : 1人当たり県民所得（SSDSE-E）
    - 人口密度       : 総人口 / 総面積
    - 医師数10万対   : 医師数 / 総人口 × 100000（SSDSE-E）
    - 婚姻率         : 婚姻件数 / 総人口 × 1000（SSDSE-B）
    - 年平均気温     : SSDSE-B

【データ出典】
  独立行政法人統計センター「SSDSE（教育用標準データセット）」
  https://www.nstac.go.jp/use/literacy/ssdse/

【図の出力】
  html/figures/2024_U5_6_fig1_rf_importance.png ... Random Forest特徴量重要度
  html/figures/2024_U5_6_fig2_shap_summary.png  ... SHAP summary plot（beeswarm風）
  html/figures/2024_U5_6_fig3_pdp.png           ... 部分従属プロット
  html/figures/2024_U5_6_fig4_scatter.png       ... 高齢化率 vs 医療・福祉従業者密度
=================================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_U5_6_shorei.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import warnings
warnings.filterwarnings('ignore')

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score
from sklearn.inspection import partial_dependence
from scipy import stats as scipy_stats

try:
    import shap
    HAS_SHAP = True
except ImportError:
    HAS_SHAP = False
    print("SHAPが利用不可 → 代替実装（permutation importance近似）を使用")

# ── パス設定 ─────────────────────────────────────────────────────────────────
FIG_DIR = 'html/figures'
DATA_DIR = 'data/raw'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family':        'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi':         150,
    'axes.spines.top':    False,
    'axes.spines.right':  False,
})

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step1. 実データ読み込み・変数作成
# ═══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("■ 実データ読み込み（SSDSE-E-2026 / SSDSE-B-2026）")
print("=" * 60)

# SSDSE-E
df_e_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-E-2026.csv'),
                        encoding='cp932', header=0)
df_e = df_e_raw.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[0].values
df_e = df_e.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[1].values
df_e = df_e[df_e['都道府県'] != '全国'].set_index('都道府県').copy()

num_cols_e = ['総人口', '65歳以上人口', '総面積（北方地域及び竹島を除く）',
              '1人当たり県民所得（平成27年基準）', '医師数',
              '従業者数（民営）（医療、福祉）']
for c in num_cols_e:
    df_e[c] = pd.to_numeric(df_e[c], errors='coerce')

# SSDSE-B（2022年）
YEAR = 2022
df_b_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'),
                        encoding='cp932', header=1)
df_b = df_b_raw[
    (df_b_raw['年度'] == YEAR) &
    df_b_raw['地域コード'].str.match(r'^R\d{5}$', na=False)
].copy()
df_b = df_b[df_b['地域コード'] != 'R00000'].set_index('都道府県')

for c in ['総人口', '65歳以上人口', '年平均気温', '婚姻件数']:
    df_b[c] = pd.to_numeric(df_b[c], errors='coerce')

# ─ 変数作成 ─
common_prefs = sorted(set(df_e.index) & set(df_b.index))
PREFS = common_prefs

# 目的変数：医療・福祉従業者密度（人口1万対）
care_density = (df_e.loc[PREFS, '従業者数（民営）（医療、福祉）'] /
                df_e.loc[PREFS, '総人口'] * 10000).values.astype(float)

# 説明変数
aging_rate = (df_e.loc[PREFS, '65歳以上人口'] /
              df_e.loc[PREFS, '総人口']).values.astype(float)
income     = df_e.loc[PREFS, '1人当たり県民所得（平成27年基準）'].values.astype(float)
area       = df_e.loc[PREFS, '総面積（北方地域及び竹島を除く）'].values.astype(float)
pop        = df_e.loc[PREFS, '総人口'].values.astype(float)
pop_density = pop / (area / 100)  # /km²
doctor_rate = (df_e.loc[PREFS, '医師数'] / pop * 100000).values.astype(float)
marriage    = (df_b.loc[PREFS, '婚姻件数'] / df_b.loc[PREFS, '総人口'] * 1000).values.astype(float)
temp        = df_b.loc[PREFS, '年平均気温'].values.astype(float)

FEATURE_NAMES = ['高齢化率', '県民所得', '人口密度（対数）', '医師数10万対', '婚姻率', '年平均気温']
X = np.column_stack([aging_rate, income, np.log1p(pop_density), doctor_rate, marriage, temp])
n_features = len(FEATURE_NAMES)

# 欠損除去
valid = ~np.any(np.isnan(X), axis=1) & ~np.isnan(care_density)
X = X[valid]
care_density = care_density[valid]
PREFS_V = [PREFS[i] for i in range(len(PREFS)) if valid[i]]
N = len(PREFS_V)

print(f"分析対象: {N}都道府県")
print(f"医療・福祉従業者密度: mean={care_density.mean():.1f}, std={care_density.std():.1f} (/万人)")

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step2. Random Forest回帰
# ═══════════════════════════════════════════════════════════════════════════════
rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_leaf=2,
                            random_state=0, n_jobs=-1)
rf.fit(X, care_density)

cv_r2 = cross_val_score(rf, X, care_density, cv=5, scoring='r2')
print(f"\n【Random Forest】")
print(f"  訓練R² = {rf.score(X, care_density):.3f}")
print(f"  5-fold CV R² = {cv_r2.mean():.3f} (±{cv_r2.std():.3f})")

importance = rf.feature_importances_
imp_df = pd.DataFrame({'変数': FEATURE_NAMES, '重要度': importance}).sort_values('重要度', ascending=False)
print(f"\n【特徴量重要度】")
print(imp_df.round(4))

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step3. SHAP値（または代替実装）
# ═══════════════════════════════════════════════════════════════════════════════
if HAS_SHAP:
    explainer = shap.TreeExplainer(rf)
    shap_values = explainer.shap_values(X)
    shap_available = True
    print("\n■ SHAP値を計算しました")
else:
    y_pred = rf.predict(X)
    y_mean = y_pred.mean()
    shap_values = np.zeros((N, n_features))
    for i in range(N):
        deviation = y_pred[i] - y_mean
        for j in range(n_features):
            shap_values[i, j] = deviation * importance[j]
    shap_available = False
    print("\n■ SHAP近似値（代替実装）を計算しました")

# ═══════════════════════════════════════════════════════════════════════════════
# ■ 図の生成（4枚）
# ═══════════════════════════════════════════════════════════════════════════════

# ────────────────────────────────────────────────────────────────────────────
# 図1：Random Forest特徴量重要度
# ────────────────────────────────────────────────────────────────────────────
print("\n図1: RF特徴量重要度を作成中...")

fig1, axes1 = plt.subplots(1, 2, figsize=(13, 5))
fig1.suptitle('Random Forest 特徴量重要度と累積重要度\n（医療・福祉従業者密度 / データ：SSDSE実データ）',
              fontsize=12, fontweight='bold')

imp_sorted = imp_df.sort_values('重要度')
top_feat = imp_df.iloc[0]['変数']
colors1 = ['#1565C0' if f == top_feat else '#2E7D32' if imp_df[imp_df['変数']==f]['重要度'].values[0] > 0.15
            else '#90CAF9' for f in imp_sorted['変数']]
axes1[0].barh(imp_sorted['変数'], imp_sorted['重要度'], color=colors1, edgecolor='white', alpha=0.88)
axes1[0].set_xlabel('特徴量重要度（不純度減少の平均）', fontsize=11)
axes1[0].set_title('Random Forest 特徴量重要度\n青：最重要変数', fontsize=11, fontweight='bold')
axes1[0].grid(axis='x', alpha=0.3)
for bar, val in zip(axes1[0].patches, imp_sorted['重要度']):
    axes1[0].text(val + 0.003, bar.get_y() + bar.get_height()/2,
                  f'{val:.3f}', va='center', fontsize=9)

imp_desc = imp_df.sort_values('重要度', ascending=False)
cumsum = imp_desc['重要度'].cumsum().values
x_tick = range(1, n_features + 1)
axes1[1].bar(x_tick, imp_desc['重要度'].values, color='#1565C0', alpha=0.6, edgecolor='white', label='重要度')
axes1[1].plot(x_tick, cumsum, 'ro-', linewidth=2, markersize=8, label='累積重要度')
axes1[1].axhline(0.8, color='gray', linestyle='--', linewidth=1.0, label='累積80%')
axes1[1].set_xticks(list(x_tick))
axes1[1].set_xticklabels([f.split('（')[0] for f in imp_desc['変数']], fontsize=8.5, rotation=15, ha='right')
axes1[1].set_ylabel('重要度', fontsize=11)
axes1[1].set_title('累積特徴量重要度', fontsize=11, fontweight='bold')
axes1[1].legend(fontsize=9)
axes1[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2024_U5_6_fig1_rf_importance.png'), bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  → 2024_U5_6_fig1_rf_importance.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図2：SHAP summary plot（beeswarm風）
# ────────────────────────────────────────────────────────────────────────────
print("図2: SHAP summary plotを作成中...")

if HAS_SHAP:
    fig2, ax2 = plt.subplots(figsize=(10, 7))
    shap.summary_plot(shap_values, X, feature_names=FEATURE_NAMES,
                      show=False, plot_type='dot', color_bar=True)
    plt.title('SHAP Summary Plot（Beeswarm）\n医療・福祉従業者密度への各変数の影響',
              fontsize=13, fontweight='bold')
    plt.tight_layout()
    fig2.savefig(os.path.join(FIG_DIR, '2024_U5_6_fig2_shap_summary.png'), bbox_inches='tight', dpi=150)
    plt.close(fig2)
else:
    fig2, ax2 = plt.subplots(figsize=(10, 7))
    mean_abs_shap = np.abs(shap_values).mean(axis=0)
    order = np.argsort(mean_abs_shap)

    cmap = cm.get_cmap('RdBu_r')
    for plot_idx, feat_idx in enumerate(order):
        shap_col = shap_values[:, feat_idx]
        feat_col = X[:, feat_idx]
        feat_norm = (feat_col - feat_col.min()) / (feat_col.max() - feat_col.min() + 1e-8)
        y_jitter = np.linspace(-0.15, 0.15, N) + plot_idx
        ax2.scatter(shap_col, y_jitter, c=feat_norm, cmap='RdBu_r',
                    s=35, alpha=0.75, vmin=0, vmax=1)

    ax2.set_yticks(range(n_features))
    ax2.set_yticklabels([FEATURE_NAMES[i] for i in order], fontsize=10)
    ax2.axvline(0, color='black', linewidth=0.8)
    ax2.set_xlabel('SHAP値（医療・福祉従業者密度への影響）', fontsize=12)
    ax2.set_title('SHAP Summary Plot（代替実装）\n右=従業者密度増加に寄与、左=減少に寄与',
                  fontsize=12, fontweight='bold')
    ax2.grid(axis='x', alpha=0.2)
    sm_plot = plt.cm.ScalarMappable(cmap='RdBu_r', norm=plt.Normalize(0, 1))
    cbar = plt.colorbar(sm_plot, ax=ax2)
    cbar.set_label('特徴量値（低→高）', fontsize=10)

    plt.tight_layout()
    fig2.savefig(os.path.join(FIG_DIR, '2024_U5_6_fig2_shap_summary.png'), bbox_inches='tight', dpi=150)
    plt.close(fig2)

print("  → 2024_U5_6_fig2_shap_summary.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図3：部分従属プロット（PDP）：高齢化率と医師数10万対
# ────────────────────────────────────────────────────────────────────────────
print("図3: 部分従属プロットを作成中...")

fig3, axes3 = plt.subplots(1, 2, figsize=(13, 5))
fig3.suptitle('部分従属プロット（PDP）：非線形効果の可視化', fontsize=13, fontweight='bold')

pdp_features = [0, 3]  # 高齢化率, 医師数10万対
pdp_labels = ['高齢化率（65歳以上割合）', '医師数10万対']
pdp_colors = ['#1565C0', '#E65100']

for ax, feat_idx, label, clr in zip(axes3, pdp_features, pdp_labels, pdp_colors):
    pdp_result = partial_dependence(rf, X, features=[feat_idx], grid_resolution=50)
    grid_vals = pdp_result['grid_values'][0]
    pdp_vals = pdp_result['average'][0]

    ax.plot(grid_vals, pdp_vals, color=clr, linewidth=2.5)
    ax.fill_between(grid_vals, pdp_vals - pdp_vals.std() * 0.5,
                    pdp_vals + pdp_vals.std() * 0.5, alpha=0.15, color=clr)
    ax.scatter(X[:, feat_idx], rf.predict(X),
               c='gray', alpha=0.3, s=15, zorder=2, label='実データ（予測値）')
    ax.set_xlabel(label, fontsize=11)
    ax.set_ylabel('医療・福祉従業者密度（予測値）', fontsize=11)
    ax.set_title(f'部分従属プロット：{label}', fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.2)
    ax.legend(fontsize=9)

plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2024_U5_6_fig3_pdp.png'), bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  → 2024_U5_6_fig3_pdp.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図4：高齢化率 vs 医療・福祉従業者密度 散布図
# ────────────────────────────────────────────────────────────────────────────
print("図4: 高齢化率 vs 医療・福祉従業者密度散布図を作成中...")

aging_vals = X[:, 0] * 100  # %表示
r_val, p_val = scipy_stats.pearsonr(aging_vals, care_density)

fig4, axes4 = plt.subplots(1, 2, figsize=(13, 5))
fig4.suptitle(f'高齢化率と医療・福祉従業者密度（XAI最重要変数：{top_feat}）',
              fontsize=13, fontweight='bold')

# 左: 散布図（県民所得で色分け）
ax4a = axes4[0]
income_norm = (X[:, 1] - X[:, 1].min()) / (X[:, 1].max() - X[:, 1].min())
sc4 = ax4a.scatter(aging_vals, care_density, c=income_norm, cmap='RdYlBu',
                   s=60, alpha=0.85, edgecolors='white', linewidth=0.5)
plt.colorbar(sc4, ax=ax4a, label='県民所得（低→高）')
coef4 = np.polyfit(aging_vals, care_density, 1)
x_fit = np.linspace(aging_vals.min(), aging_vals.max(), 100)
ax4a.plot(x_fit, np.polyval(coef4, x_fit), 'k--', linewidth=2)
# 代表的な都道府県ラベル
care_rank = np.argsort(care_density)[::-1]
for idx in list(care_rank[:5]) + list(care_rank[-5:]):
    short = PREFS_V[idx].replace('県','').replace('府','').replace('都','').replace('道','')
    ax4a.annotate(short, (aging_vals[idx], care_density[idx]),
                  textcoords='offset points', xytext=(5, 3), fontsize=7.5, color='#333')
ax4a.set_xlabel('高齢化率（65歳以上割合 %）', fontsize=11)
ax4a.set_ylabel('医療・福祉従業者密度（/万人）', fontsize=11)
ax4a.set_title(f'高齢化率 → 医療・福祉従業者密度\nr = {r_val:.3f}', fontsize=11, fontweight='bold')
ax4a.grid(True, alpha=0.2)
ax4a.text(0.05, 0.95, f'r = {r_val:.3f}\n（SHAP最重要変数）',
          transform=ax4a.transAxes, fontsize=10, va='top',
          bbox=dict(boxstyle='round', facecolor='#E3F2FD', alpha=0.8))

# 右: SHAP値（高齢化率）vs 高齢化率
ax4b = axes4[1]
shap_aging = shap_values[:, 0]
ax4b.scatter(aging_vals, shap_aging, c='#1565C0', alpha=0.75, s=55)
coef4b = np.polyfit(aging_vals, shap_aging, 1)
ax4b.plot(x_fit, np.polyval(coef4b, x_fit), 'r--', linewidth=2, label='回帰直線')
ax4b.axhline(0, color='black', linewidth=0.8)
ax4b.set_xlabel('高齢化率（65歳以上割合 %）', fontsize=11)
ax4b.set_ylabel('SHAP値（医療・福祉従業者密度への影響）', fontsize=11)
ax4b.set_title('SHAP値 vs 高齢化率\n非線形効果の確認', fontsize=11, fontweight='bold')
ax4b.legend(fontsize=9)
ax4b.grid(True, alpha=0.2)

plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2024_U5_6_fig4_scatter.png'), bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  → 2024_U5_6_fig4_scatter.png 保存完了")

print("\n" + "=" * 60)
print("✓ 全図の生成完了（4枚）")
print("=" * 60)
print("\n【主要知見】")
print(f"  RF訓練R² = {rf.score(X, care_density):.3f}")
print(f"  5-fold CV R² = {cv_r2.mean():.3f}")
print(f"  最重要変数: {imp_df.iloc[0]['変数']} (重要度={imp_df.iloc[0]['重要度']:.3f})")
print(f"  高齢化率と医療・福祉従業者密度の相関: r = {r_val:.3f}")
print(f"  SHAPライブラリ利用: {HAS_SHAP}")
print(f"  使用データ: SSDSE-E-2026, SSDSE-B-2026 ({YEAR}年, {N}都道府県)")
