"""
教育用再現コード: 2022年度 統計データ分析コンペティション 審査員奨励賞（高校生の部）
=================================================================
論文タイトル：気候と健康：気温・降水量が死亡率に与える影響の分析

【分析概要】
  データ：SSDSE-B-2026.csv（都道府県データ、2012〜2023年）
  目的変数：粗死亡率 = 死亡数 / 総人口 × 1000
  主要説明変数：
    - 年平均気温      （B4101）
    - 最高気温        （B4102：日最高気温の月平均の最高値）
    - 降水日数        （B4106：年間）
    - 高齢化率        = 65歳以上人口 / 総人口 × 100
    - 保健医療費割合  = 保健医療費（二人以上の世帯）/ 消費支出（二人以上の世帯）× 100

【分析の要】
  粗死亡率は高齢化率と強く相関するため、高齢化率を制御した上で
  気象要因（気温・降水日数）の純効果を重回帰で分析する。
  高齢化率なしの単回帰と高齢化率ありの重回帰を比較することで、
  「交絡変数の制御」という重要な概念を学ぶことができる。

  Step1. 時系列分析（死亡率の地域別推移、2012〜2023年）
  Step2. 散布図（高齢化率 vs 死亡率、2022年、都道府県ラベル付き）
  Step3. OLS重回帰係数プロット（死亡率の決定要因）
  Step4. 気温帯別 死亡率箱ひげ図（年平均気温を3グループに分けた比較）

【変数定義（SSDSE-B-2026列名）】
  SSDSE-B-2026 : 年度
  Prefecture   : 都道府県
  A1101 : 総人口
  A1303 : 65歳以上人口
  A4200 : 死亡数
  B4101 : 年平均気温（℃）
  B4102 : 最高気温（日最高気温の月平均の最高値）（℃）
  B4106 : 降水日数（年間）
  L3221 : 消費支出（二人以上の世帯）（円/月）
  L322106 : 保健医療費（二人以上の世帯）（円/月）

【派生変数の計算式】
  粗死亡率       = A4200 / A1101 × 1000  （人口千人あたり死亡数）
  高齢化率       = A1303 / A1101 × 100   （65歳以上の割合 %）
  保健医療費割合  = L322106 / L3221 × 100 （消費支出に占める保健医療費 %）

【データ出典】
  SSDSE-B-2026.csv: 社会・人口統計体系（都道府県データ）
  統計数理研究所・統計センターより公表の実データ

【データサイエンス学習ポイント】
  1. 粗死亡率と年齢調整死亡率の違い（高齢化率を制御する必要性）
  2. 交絡変数の制御（高齢化率を加えた重回帰の意義）
  3. 連続変数のグループ化（pd.qcut でのカテゴリ分け）
  4. 気候変動と健康政策（高温・熱中症対策の統計的根拠）
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2022_H5_12_shorei.py
# ============================================================


import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# ──────────────────────────────────────────────────────────────
# 共通設定
# ──────────────────────────────────────────────────────────────
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

import os
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ================================================================
# ■ 実データ読み込み（SSDSE-B-2026: 都道府県データ）
# ================================================================
df_raw = pd.read_csv(DATA_B, encoding='shift-jis', header=0)
# 行0がラベル行、行1以降が実データ
df_all = df_raw.iloc[1:].copy()
df_all.columns = df_raw.columns

# 数値列に変換
NUM_COLS = ['A1101', 'A1303', 'A4200', 'B4101', 'B4102', 'B4106', 'L3221', 'L322106']
for c in NUM_COLS:
    df_all[c] = pd.to_numeric(df_all[c], errors='coerce')

df_all['年度']    = df_all['SSDSE-B-2026'].astype(int)
df_all['都道府県'] = df_all['Prefecture'].astype(str)

print("=" * 65)
print("■ SSDSE-B-2026 読み込み完了")
print(f"  年度範囲: {df_all['年度'].min()}〜{df_all['年度'].max()}")
print(f"  都道府県数（各年）: {df_all.groupby('年度')['都道府県'].count().iloc[0]}件")
print("=" * 65)

# ================================================================
# ■ 特徴量エンジニアリング（全年度）
# ================================================================
df_all['粗死亡率']      = df_all['A4200'] / df_all['A1101'] * 1000
df_all['高齢化率']      = df_all['A1303'] / df_all['A1101'] * 100
df_all['保健医療費割合'] = df_all['L322106'] / df_all['L3221'] * 100

# 2022年データ抽出
df_2022 = df_all[df_all['年度'] == 2022].copy().reset_index(drop=True)

print(f"\n■ 2022年データ: {len(df_2022)}都道府県")
print(f"  粗死亡率（平均）: {df_2022['粗死亡率'].mean():.2f} ‰")
print(f"  高齢化率（平均）: {df_2022['高齢化率'].mean():.1f}%")
print(f"  年平均気温（平均）: {df_2022['B4101'].mean():.1f}℃")
print(f"  降水日数（平均）: {df_2022['B4106'].mean():.1f}日")

# ================================================================
# ■ OLS重回帰分析（2022年、粗死亡率の決定要因）
# ================================================================
print("\n" + "=" * 65)
print("■ OLS重回帰分析（2022年）")
print("=" * 65)

REG_VARS   = ['高齢化率', 'B4101', 'B4102', 'B4106', '保健医療費割合']
REG_LABELS = ['高齢化率(%)', '年平均気温(℃)', '最高気温(℃)', '降水日数(日)', '保健医療費割合(%)']

df_reg = df_2022[['粗死亡率'] + REG_VARS].dropna()
y_ols  = df_reg['粗死亡率'].values
X_ols  = df_reg[REG_VARS].values

# 標準化（係数の大きさを比較するため）
X_std = (X_ols - X_ols.mean(axis=0)) / X_ols.std(axis=0)
X_with_const = sm.add_constant(X_std)
ols_model = sm.OLS(y_ols, X_with_const).fit()

print(ols_model.summary2())
print(f"\n  R²        = {ols_model.rsquared:.4f}")
print(f"  自由度修正済R² = {ols_model.rsquared_adj:.4f}")

# 相関分析（高齢化率 vs 粗死亡率）
r_aging, p_aging = stats.pearsonr(df_2022['高齢化率'].dropna(),
                                   df_2022['粗死亡率'].dropna())
print(f"\n■ 相関分析（高齢化率 vs 粗死亡率）: r={r_aging:.4f}, p={p_aging:.4f}")

# 相関分析（年平均気温 vs 粗死亡率）
r_temp, p_temp = stats.pearsonr(df_2022['B4101'].dropna(),
                                 df_2022['粗死亡率'].dropna())
print(f"■ 相関分析（年平均気温 vs 粗死亡率）: r={r_temp:.4f}, p={p_temp:.4f}")

# ================================================================
# ■ 地方区分マップ
# ================================================================
REGION_MAP = {
    '北海道': '北海道・東北',
    '青森県': '北海道・東北', '岩手県': '北海道・東北', '宮城県': '北海道・東北',
    '秋田県': '北海道・東北', '山形県': '北海道・東北', '福島県': '北海道・東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国・四国', '島根県': '中国・四国', '岡山県': '中国・四国',
    '広島県': '中国・四国', '山口県': '中国・四国',
    '徳島県': '中国・四国', '香川県': '中国・四国', '愛媛県': '中国・四国', '高知県': '中国・四国',
    '福岡県': '九州・沖縄', '佐賀県': '九州・沖縄', '長崎県': '九州・沖縄',
    '熊本県': '九州・沖縄', '大分県': '九州・沖縄', '宮崎県': '九州・沖縄',
    '鹿児島県': '九州・沖縄', '沖縄県': '九州・沖縄',
}
REGION_ORDER = ['北海道・東北', '関東', '中部', '近畿', '中国・四国', '九州・沖縄']
REGION_COLORS = {
    '北海道・東北': '#1565C0',
    '関東':         '#E65100',
    '中部':         '#2E7D32',
    '近畿':         '#6A1B9A',
    '中国・四国':   '#795548',
    '九州・沖縄':   '#00695C',
}
df_all['地方'] = df_all['都道府県'].map(REGION_MAP)
df_2022['地方'] = df_2022['都道府県'].map(REGION_MAP)

# ================================================================
# ■ 図1: 死亡率の時系列推移（地域別平均, 2012-2023）
# ================================================================
print("\n" + "=" * 65)
print("■ 図1: 死亡率の時系列推移（地域別平均）")
print("=" * 65)

df_region_ts = (
    df_all.groupby(['年度', '地方'])['粗死亡率']
    .mean()
    .reset_index()
)

fig1, ax1 = plt.subplots(figsize=(11, 6))

for region in REGION_ORDER:
    sub = df_region_ts[df_region_ts['地方'] == region].sort_values('年度')
    if len(sub) == 0:
        continue
    ax1.plot(sub['年度'], sub['粗死亡率'],
             marker='o', linewidth=2, markersize=5,
             color=REGION_COLORS[region], label=region)

# 2022年に垂直線（分析基準年）
ax1.axvline(2022, color='#C62828', linestyle='--', linewidth=1.5, alpha=0.8)
ax1.text(2022.05, ax1.get_ylim()[1] * 0.98, '2022年\n（分析基準年）',
         ha='left', va='top', fontsize=9, color='#C62828', fontweight='bold')

ax1.set_xlabel('年度', fontsize=12)
ax1.set_ylabel('粗死亡率（人口千人あたり）', fontsize=12)
ax1.set_title('粗死亡率の時系列推移（地方別平均, 2012〜2023年）\n〜高齢化の進展に伴い全地域で上昇傾向〜',
              fontsize=13, fontweight='bold')
ax1.set_xticks(sorted(df_all['年度'].unique()))
ax1.xaxis.set_tick_params(rotation=45)
ax1.legend(loc='upper left', fontsize=9.5, framealpha=0.88)
ax1.grid(True, alpha=0.3)
plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2022_H5_12_fig1_timeseries.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig1)
print("図1保存: 2022_H5_12_fig1_timeseries.png")

# ================================================================
# ■ 図2: 高齢化率 vs 死亡率 散布図（2022年, 都道府県ラベル付き）
# ================================================================
print("\n■ 図2: 高齢化率 vs 死亡率 散布図（2022年）")

fig2, ax2 = plt.subplots(figsize=(11, 7))

x2    = df_2022['高齢化率'].values
y2    = df_2022['粗死亡率'].values
temp2 = df_2022['B4101'].values
pref2 = df_2022['都道府県'].values

sc2 = ax2.scatter(x2, y2, c=temp2,
                  cmap='RdYlBu_r', s=75, alpha=0.85,
                  edgecolors='#333', linewidth=0.5, zorder=3,
                  vmin=10, vmax=24)
cbar2 = fig2.colorbar(sc2, ax=ax2, label='年平均気温（℃）', shrink=0.8)
cbar2.ax.tick_params(labelsize=9)

# 回帰直線
valid2 = np.isfinite(x2) & np.isfinite(y2)
z2 = np.polyfit(x2[valid2], y2[valid2], 1)
xs2 = np.linspace(x2[valid2].min(), x2[valid2].max(), 100)
ax2.plot(xs2, np.poly1d(z2)(xs2), 'b--', linewidth=1.8, alpha=0.7,
         label=f'回帰直線 (r={r_aging:.3f}, p={p_aging:.4f})')

# 都道府県ラベル（注目都市）
LABEL_PREFS = {
    '東京都', '北海道', '沖縄県', '秋田県', '山形県', '青森県',
    '高知県', '島根県', '愛知県', '神奈川県', '埼玉県',
    '鹿児島県', '長崎県', '和歌山県', '奈良県',
}
for xi, yi, pref in zip(x2, y2, pref2):
    if pref in LABEL_PREFS:
        short = pref.replace('県', '').replace('都', '').replace('道', '').replace('府', '')
        ax2.annotate(short, (xi, yi),
                     xytext=(4, 3), textcoords='offset points',
                     fontsize=7.5, color='#1A237E', fontweight='bold')

ax2.set_xlabel('高齢化率（65歳以上人口割合 %）', fontsize=12)
ax2.set_ylabel('粗死亡率（人口千人あたり）', fontsize=12)
ax2.set_title('高齢化率と粗死亡率の関係（2022年, 都道府県別）\n〜色：年平均気温。高齢化率が粗死亡率を強く規定する〜',
              fontsize=13, fontweight='bold')
ax2.legend(fontsize=9.5)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2022_H5_12_fig2_scatter.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig2)
print("図2保存: 2022_H5_12_fig2_scatter.png")

# ================================================================
# ■ 図3: OLS回帰係数プロット（粗死亡率の決定要因）
# ================================================================
print("\n■ 図3: OLS回帰係数プロット")

coefs3 = np.asarray(ols_model.params[1:])
ses3   = np.asarray(ols_model.bse[1:])
pvals3 = np.asarray(ols_model.pvalues[1:])

COEF_COLORS = []
for p in pvals3:
    if p < 0.01:
        COEF_COLORS.append('#C62828')
    elif p < 0.05:
        COEF_COLORS.append('#FF8F00')
    elif p < 0.10:
        COEF_COLORS.append('#FDD835')
    else:
        COEF_COLORS.append('#9E9E9E')

fig3, ax3 = plt.subplots(figsize=(9, 5.5))
y_pos3 = np.arange(len(REG_LABELS))

ax3.barh(y_pos3, coefs3, color=COEF_COLORS, alpha=0.82,
         edgecolor='white', height=0.55)
ax3.errorbar(coefs3, y_pos3, xerr=1.96 * ses3,
             fmt='none', color='#222', capsize=5, linewidth=1.5)
ax3.axvline(0, color='gray', linestyle='--', linewidth=1.0)
ax3.set_yticks(y_pos3)
ax3.set_yticklabels(REG_LABELS, fontsize=11)
ax3.set_xlabel('標準化回帰係数（±1.96SE）', fontsize=11)
ax3.set_title(
    f'粗死亡率の決定要因 — OLS重回帰係数（2022年, N={len(df_reg)}都道府県）\n'
    f'R²={ols_model.rsquared:.3f}（adj. R²={ols_model.rsquared_adj:.3f}）',
    fontsize=12, fontweight='bold'
)
ax3.invert_yaxis()
ax3.grid(axis='x', alpha=0.3)

# p値ラベル
for i, (c, p) in enumerate(zip(coefs3, pvals3)):
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'n.s.'
    offset = 0.015
    ha = 'left' if c >= 0 else 'right'
    ax3.text(c + (offset if c >= 0 else -offset), i,
             f' {c:.3f} {sig}',
             va='center', ha=ha, fontsize=8.5)

from matplotlib.patches import Patch
ax3.legend(handles=[
    Patch(color='#C62828', alpha=0.85, label='p<0.01'),
    Patch(color='#FF8F00', alpha=0.85, label='p<0.05'),
    Patch(color='#FDD835', alpha=0.85, label='p<0.10'),
    Patch(color='#9E9E9E', alpha=0.85, label='n.s.'),
], fontsize=9, loc='lower right')
plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2022_H5_12_fig3_coef.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig3)
print("図3保存: 2022_H5_12_fig3_coef.png")

# ================================================================
# ■ 図4: 気温帯別 死亡率箱ひげ図（年平均気温を3グループに分けた比較）
# ================================================================
print("\n■ 図4: 気温帯別 死亡率箱ひげ図")

# 2022年の全都道府県を年平均気温で3グループに分類（qcut）
df_box = df_2022[['都道府県', 'B4101', '粗死亡率', '高齢化率']].dropna().copy()
df_box['気温帯'] = pd.qcut(df_box['B4101'], q=3,
                            labels=['低温帯\n（寒冷地）', '中温帯\n（温暖地）', '高温帯\n（温暖地）'])

# グループ統計
print("\n  気温帯別 統計:")
grp_stat = df_box.groupby('気温帯', observed=True).agg(
    n=('粗死亡率', 'count'),
    気温平均=('B4101', 'mean'),
    死亡率平均=('粗死亡率', 'mean'),
    死亡率中央値=('粗死亡率', 'median'),
    高齢化率平均=('高齢化率', 'mean'),
)
print(grp_stat.to_string())

# ANOVA検定
groups_by_temp = [g['粗死亡率'].values for _, g in df_box.groupby('気温帯', observed=True)]
f_stat, p_anova = stats.f_oneway(*groups_by_temp)
print(f"\n  一元配置ANOVA: F={f_stat:.3f}, p={p_anova:.4f}")

# 気温帯ラベルと統計情報
temp_labels = df_box.groupby('気温帯', observed=True)['B4101'].mean()
box_colors = ['#1565C0', '#43A047', '#E65100']

fig4, ax4 = plt.subplots(figsize=(9, 6))

# 気温帯ごとにデータ収集
categories = df_box['気温帯'].cat.categories
data_groups = [df_box[df_box['気温帯'] == cat]['粗死亡率'].values for cat in categories]
n_groups    = [len(g) for g in data_groups]
mean_vals   = [g.mean() for g in data_groups]
temp_means  = [df_box[df_box['気温帯'] == cat]['B4101'].mean() for cat in categories]

bp = ax4.boxplot(data_groups,
                 labels=[str(c) for c in categories],
                 patch_artist=True,
                 widths=0.55,
                 medianprops=dict(color='white', linewidth=2.5),
                 boxprops=dict(linewidth=1.5),
                 whiskerprops=dict(linewidth=1.5),
                 capprops=dict(linewidth=1.5),
                 flierprops=dict(marker='o', markersize=5, alpha=0.6))

for patch, color in zip(bp['boxes'], box_colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.78)

# 平均値をダイヤモンドでプロット
for i, mv in enumerate(mean_vals):
    ax4.scatter(i + 1, mv, marker='D', color='white', s=55, zorder=5,
                edgecolors=box_colors[i], linewidths=1.5)

# 各グループの n数・平均気温を注記
for i, (n, tm) in enumerate(zip(n_groups, temp_means)):
    ax4.text(i + 1, ax4.get_ylim()[0] + 0.15,
             f'n={n}\n平均気温\n{tm:.1f}℃',
             ha='center', va='bottom', fontsize=8.5, color='#333')

# ANOVA p値
sig_txt = f'一元配置ANOVA: F={f_stat:.2f}, p={p_anova:.4f}'
ax4.set_xlabel('気温帯（年平均気温の三分位）', fontsize=12)
ax4.set_ylabel('粗死亡率（人口千人あたり）', fontsize=12)
ax4.set_title(
    f'気温帯別 粗死亡率の分布（2022年, N=47都道府県）\n〜{sig_txt}〜',
    fontsize=12, fontweight='bold'
)
ax4.grid(axis='y', alpha=0.3)

# 凡例（菱形=平均値）
from matplotlib.lines import Line2D
legend_el = [Line2D([0], [0], marker='D', color='w',
                     markerfacecolor='gray', markeredgecolor='gray',
                     markersize=8, label='平均値')]
ax4.legend(handles=legend_el, fontsize=9, loc='upper right')
plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2022_H5_12_fig4_boxplot.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig4)
print("図4保存: 2022_H5_12_fig4_boxplot.png")

# ================================================================
# ■ 完了メッセージ
# ================================================================
print("\n" + "=" * 65)
print("■ 全図生成完了（4枚）")
print(f"  fig1_timeseries.png : 死亡率の時系列推移（地方別平均）")
print(f"  fig2_scatter.png    : 高齢化率 vs 死亡率 散布図（2022年）")
print(f"  fig3_coef.png       : OLS重回帰係数プロット")
print(f"  fig4_boxplot.png    : 気温帯別 死亡率箱ひげ図")
print("=" * 65)
print(f"\n  OLS結果サマリ（2022年, N={len(df_reg)}）:")
print(f"    R²={ols_model.rsquared:.3f}, adj.R²={ols_model.rsquared_adj:.3f}")
print(f"    高齢化率 β={ols_model.params[1]:.3f} (p={ols_model.pvalues[1]:.4f})")
print(f"    年平均気温 β={ols_model.params[2]:.3f} (p={ols_model.pvalues[2]:.4f})")
print(f"    最高気温   β={ols_model.params[3]:.3f} (p={ols_model.pvalues[3]:.4f})")
print(f"    降水日数   β={ols_model.params[4]:.3f} (p={ols_model.pvalues[4]:.4f})")
