"""
2019_U5_2_shorei.py
都道府県別健康寿命の決定要因：社会経済・医療アクセス要因の重回帰分析
2019年度 統計データ分析コンペティション 審査員奨励賞（大学生部門）
教育用再現コード ― 実データ(SSDSE-B-2026)のみ使用
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2019_U5_2_shorei.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy import stats
from scipy.stats import kruskal

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ─────────────────────────────────────────
# データ読み込み
# ─────────────────────────────────────────
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)

print("=== df_b.columns ===")
print(df_b.columns.tolist())
print()
print("=== df_b.head() ===")
print(df_b.head())

# ─────────────────────────────────────────
# 2019年度の横断面データを作成
# ─────────────────────────────────────────
YEAR = 2019
df_cross = df_b[df_b['年度'] == YEAR].copy().reset_index(drop=True)
print(f"\n=== {YEAR}年度 都道府県数: {len(df_cross)} ===")

# ---- 分析変数の作成 ----
# 目的変数：死亡率（千人当たり）― 高いほど健康状態が悪い
df_cross['死亡率'] = df_cross['死亡数'] / df_cross['総人口'] * 1000

# 説明変数
# (1) 高齢化率（%）― 高齢者が多いほど死亡率は上がる可能性
df_cross['高齢化率'] = df_cross['65歳以上人口'] / df_cross['総人口'] * 100

# (2) 消費支出（万円/世帯月額）― 社会経済的地位の代理変数
df_cross['消費支出_万円'] = df_cross['消費支出（二人以上の世帯）'] / 10000

# (3) 保健医療費（万円/世帯月額）― 医療へのアクセス・支出
df_cross['保健医療費_万円'] = df_cross['保健医療費（二人以上の世帯）'] / 10000

# (4) 病院数（千人当たり）― 医療インフラアクセス
df_cross['病院数_千人'] = df_cross['一般病院数'] / df_cross['総人口'] * 1000

# (5) 食料費（万円/世帯月額）― 食生活・栄養の代理変数
df_cross['食料費_万円'] = df_cross['食料費（二人以上の世帯）'] / 10000

print("\n=== 主要変数の記述統計 ===")
key_vars = ['死亡率', '高齢化率', '消費支出_万円', '保健医療費_万円', '病院数_千人', '食料費_万円']
print(df_cross[key_vars].describe().round(3))

# 地域区分（6地域）
region_map = {
    '北海道': '北海道・東北', '青森県': '北海道・東北', '岩手県': '北海道・東北',
    '宮城県': '北海道・東北', '秋田県': '北海道・東北', '山形県': '北海道・東北',
    '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国', '徳島県': '中国・四国',
    '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄'
}
df_cross['地域'] = df_cross['都道府県'].map(region_map)

region_order = ['北海道・東北', '関東', '中部', '近畿', '中国・四国', '九州・沖縄']
region_colors = {
    '北海道・東北': '#2196F3',
    '関東':       '#4CAF50',
    '中部':       '#FF9800',
    '近畿':       '#9C27B0',
    '中国・四国':  '#F44336',
    '九州・沖縄':  '#009688'
}

# ─────────────────────────────────────────
# OLS 重回帰分析
# ─────────────────────────────────────────
y_col = '死亡率'
X_cols = ['高齢化率', '消費支出_万円', '保健医療費_万円', '病院数_千人']

df_reg = df_cross[['都道府県', '地域', y_col] + X_cols + ['食料費_万円']].dropna().copy()
print(f"\n=== 回帰分析対象サンプル数: {len(df_reg)} ===")

X_raw = sm.add_constant(df_reg[X_cols].astype(float))
y = df_reg[y_col].astype(float)
model = sm.OLS(y, X_raw).fit()

print("\n=== OLS 回帰結果 ===")
print(model.summary())

# 標準化偏回帰係数
X_std_vals = (df_reg[X_cols] - df_reg[X_cols].mean()) / df_reg[X_cols].std()
X_std = sm.add_constant(X_std_vals.astype(float))
model_std = sm.OLS(y, X_std).fit()

print("\n=== 標準化偏回帰係数 ===")
for name, coef, pval in zip(['(定数項)'] + X_cols,
                             model_std.params, model_std.pvalues):
    stars = '***' if pval < 0.001 else '**' if pval < 0.01 else '*' if pval < 0.05 else ''
    print(f"  {name:18s}: β={coef:+.4f}  p={pval:.4f} {stars}")

print(f"\nR²     = {model.rsquared:.4f}")
print(f"Adj R² = {model.rsquared_adj:.4f}")
print(f"F-stat = {model.fvalue:.4f}")
print(f"F-prob = {model.f_pvalue:.6f}")

# Cook's Distance
influence = model.get_influence()
cooks_d = influence.cooks_distance[0]
cooks_threshold = 4 / len(df_reg)
high_influence = df_reg['都道府県'].values[cooks_d > cooks_threshold]
print(f"\n=== Cook's距離 > {cooks_threshold:.3f} の都道府県 ===")
for pref, cd in zip(df_reg['都道府県'].values, cooks_d):
    if cd > cooks_threshold:
        print(f"  {pref}: D={cd:.4f}")

# ─────────────────────────────────────────
# 図1：相関ヒートマップ
# ─────────────────────────────────────────
heat_vars = ['死亡率', '高齢化率', '消費支出_万円', '保健医療費_万円', '病院数_千人', '食料費_万円']
heat_labels = ['死亡率\n(千人当)', '高齢化率\n(%)', '消費支出\n(万円)', '保健医療費\n(万円)', '病院数\n(千人当)', '食料費\n(万円)']
n_vars_heat = len(heat_vars)

corr_matrix = df_cross[heat_vars].corr()
pval_matrix = pd.DataFrame(np.ones((n_vars_heat, n_vars_heat)),
                            index=heat_vars, columns=heat_vars)
for i, v1 in enumerate(heat_vars):
    for j, v2 in enumerate(heat_vars):
        if i != j:
            r, p = stats.pearsonr(df_cross[v1].dropna(), df_cross[v2].dropna())
            pval_matrix.loc[v1, v2] = p

fig1, ax1 = plt.subplots(figsize=(8, 6.5))
im = ax1.imshow(corr_matrix.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
plt.colorbar(im, ax=ax1, shrink=0.8, label='Pearson相関係数 r')

ax1.set_xticks(range(n_vars_heat))
ax1.set_yticks(range(n_vars_heat))
ax1.set_xticklabels(heat_labels, fontsize=10)
ax1.set_yticklabels(heat_labels, fontsize=10)

for i in range(n_vars_heat):
    for j in range(n_vars_heat):
        r_val = corr_matrix.values[i, j]
        p_val = pval_matrix.values[i, j]
        stars = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else ''
        txt = f"{r_val:.2f}{stars}"
        color = 'white' if abs(r_val) > 0.5 else 'black'
        ax1.text(j, i, txt, ha='center', va='center', fontsize=10,
                 color=color, fontweight='bold' if stars else 'normal')

ax1.set_title('図1：主要変数間のPearson相関係数ヒートマップ\n（2019年度 都道府県別，N=47）\n*p<0.05, **p<0.01, ***p<0.001',
              fontsize=11, pad=12)
plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2019_U5_2_fig1.png'), bbox_inches='tight')
plt.close(fig1)
print("\n図1 保存完了")

# ─────────────────────────────────────────
# 図2：消費支出 vs 死亡率 散布図（地域別色分け）
# ─────────────────────────────────────────
fig2, ax2 = plt.subplots(figsize=(9, 7))

for region in region_order:
    mask = df_cross['地域'] == region
    ax2.scatter(df_cross.loc[mask, '消費支出_万円'],
                df_cross.loc[mask, '死亡率'],
                c=region_colors[region], label=region, s=60, alpha=0.85, zorder=3)

# 都道府県ラベル
for _, row in df_cross.iterrows():
    ax2.annotate(row['都道府県'].replace('県', '').replace('都', '').replace('府', '').replace('道', ''),
                 (row['消費支出_万円'], row['死亡率']),
                 fontsize=6.5, ha='left', va='bottom',
                 xytext=(2, 2), textcoords='offset points', color='#333333')

# 回帰直線
x_plot = df_cross['消費支出_万円'].dropna()
y_plot = df_cross['死亡率'].dropna()
idx_common = df_cross[['消費支出_万円', '死亡率']].dropna().index
x_c = df_cross.loc[idx_common, '消費支出_万円']
y_c = df_cross.loc[idx_common, '死亡率']
r_val, p_val = stats.pearsonr(x_c, y_c)
slope, intercept, _, _, _ = stats.linregress(x_c, y_c)
x_line = np.linspace(x_c.min(), x_c.max(), 100)
ax2.plot(x_line, slope * x_line + intercept, 'k--', linewidth=1.5, alpha=0.7)

stars = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else ''
ax2.text(0.05, 0.95, f'r = {r_val:.3f}{stars}\np = {p_val:.4f}\nN = 47',
         transform=ax2.transAxes, fontsize=11,
         verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

ax2.set_xlabel('消費支出（万円/月，二人以上の世帯）', fontsize=12)
ax2.set_ylabel('死亡率（千人当たり）', fontsize=12)
ax2.set_title('図2：消費支出と死亡率の散布図\n（2019年度 都道府県別，地域別色分け）', fontsize=12)
ax2.legend(loc='upper right', fontsize=9, title='地域', title_fontsize=9)
ax2.grid(True, alpha=0.3)
fig2.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2019_U5_2_fig2.png'), bbox_inches='tight')
plt.close(fig2)
print("図2 保存完了")

# ─────────────────────────────────────────
# 図3：地域別死亡率の箱ひげ図（Kruskal-Wallis検定）
# ─────────────────────────────────────────
fig3, ax3 = plt.subplots(figsize=(9, 6))

groups = [df_cross.loc[df_cross['地域'] == reg, '死亡率'].dropna().values
          for reg in region_order]
stat_kw, p_kw = kruskal(*groups)

bp = ax3.boxplot(groups, patch_artist=True, notch=False,
                 medianprops=dict(color='black', linewidth=2),
                 whiskerprops=dict(linewidth=1.5),
                 capprops=dict(linewidth=1.5))

for patch, region in zip(bp['boxes'], region_order):
    patch.set_facecolor(region_colors[region])
    patch.set_alpha(0.7)

# 個別データ点をオーバーレイ
for i, (region, grp) in enumerate(zip(region_order, groups), start=1):
    jitter = np.random.default_rng(seed=42 + i).uniform(-0.15, 0.15, len(grp))
    ax3.scatter([i + j for j in jitter], grp,
                color=region_colors[region], s=30, alpha=0.8, zorder=3)

ax3.set_xticks(range(1, len(region_order) + 1))
ax3.set_xticklabels(region_order, fontsize=9, rotation=15)
ax3.set_ylabel('死亡率（千人当たり）', fontsize=12)
ax3.set_title('図3：地域別死亡率の比較（箱ひげ図）\n（2019年度，N=47）', fontsize=12)
ax3.text(0.98, 0.97,
         f'Kruskal-Wallis: H={stat_kw:.2f}, p={p_kw:.4f}',
         transform=ax3.transAxes, ha='right', va='top', fontsize=10,
         bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.9))
ax3.grid(True, axis='y', alpha=0.3)
fig3.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2019_U5_2_fig3.png'), bbox_inches='tight')
plt.close(fig3)
print("図3 保存完了")

# ─────────────────────────────────────────
# 図4：標準化偏回帰係数プロット（95%CI）
# ─────────────────────────────────────────
std_coefs = model_std.params[1:]   # 定数項を除く
std_cis   = model_std.conf_int().iloc[1:]
std_pvals = model_std.pvalues[1:]

display_names = {
    '高齢化率':    '高齢化率（%）',
    '消費支出_万円': '消費支出（万円）',
    '保健医療費_万円': '保健医療費（万円）',
    '病院数_千人':  '病院数（千人当）',
}

fig4, ax4 = plt.subplots(figsize=(7.5, 5))

bar_colors = []
for col in X_cols:
    p = std_pvals[col]
    c = std_coefs[col]
    if p < 0.05 and c > 0:
        bar_colors.append('#E53935')  # 赤：正の有意
    elif p < 0.05 and c < 0:
        bar_colors.append('#1E88E5')  # 青：負の有意
    else:
        bar_colors.append('#9E9E9E')  # グレー：非有意

y_pos = range(len(X_cols))
ax4.barh(list(y_pos), [std_coefs[c] for c in X_cols],
         color=bar_colors, alpha=0.85, height=0.55)

# 95%CI エラーバー
for i, col in enumerate(X_cols):
    lo = std_cis.loc[col, 0]
    hi = std_cis.loc[col, 1]
    ax4.plot([lo, hi], [i, i], 'k-', linewidth=2)
    ax4.plot([lo, lo], [i - 0.12, i + 0.12], 'k-', linewidth=1.5)
    ax4.plot([hi, hi], [i - 0.12, i + 0.12], 'k-', linewidth=1.5)

    # p値ラベル
    p = std_pvals[col]
    stars = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'n.s.'
    ax4.text(std_coefs[col] + (0.02 if std_coefs[col] >= 0 else -0.02),
             i, f' {stars}', va='center', fontsize=11,
             ha='left' if std_coefs[col] >= 0 else 'right')

ax4.set_yticks(list(y_pos))
ax4.set_yticklabels([display_names[c] for c in X_cols], fontsize=11)
ax4.axvline(0, color='black', linewidth=0.8, linestyle='-')
ax4.set_xlabel('標準化偏回帰係数（β）', fontsize=12)
ax4.set_title('図4：重回帰分析の標準化偏回帰係数\n（目的変数：死亡率，エラーバー=95%CI）', fontsize=12)

# 凡例
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#E53935', alpha=0.85, label='正の有意（p<0.05）'),
                   Patch(facecolor='#1E88E5', alpha=0.85, label='負の有意（p<0.05）'),
                   Patch(facecolor='#9E9E9E', alpha=0.85, label='非有意')]
ax4.legend(handles=legend_elements, loc='lower right', fontsize=9)
ax4.grid(True, axis='x', alpha=0.3)
fig4.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2019_U5_2_fig4.png'), bbox_inches='tight')
plt.close(fig4)
print("図4 保存完了")

print("\nDONE: 2019_U5_2_shorei")
