"""
教育用再現コード: 2021年度 統計データ分析コンペティション 統計活用奨励賞 [大学生・一般の部]
=================================================================================
論文タイトル：医療資源の地域格差：医療機関・医師数の分布と決定要因分析
受賞：統計活用奨励賞（大学生・一般の部）

【分析概要】
  データ：SSDSE-B-2026.csv（都道府県別パネルデータ, 2012〜2023年度）
  対象：全47都道府県 × 最大12年（2012〜2023）

  医療機関密度の3指標を構築し、都道府県間の地域格差を可視化する。
  さらに高齢化率・消費支出・保健医療費を説明変数とするOLS回帰で決定要因を分析する。

【計算指標】
  病院密度         = 一般病院数 / 総人口 × 10,000（人口1万人当たり）
  診療所密度       = 一般診療所数 / 総人口 × 10,000
  歯科密度         = 歯科診療所数 / 総人口 × 10,000
  医療機関密度_総合 = (一般病院数 + 一般診療所数) / 総人口 × 10,000
  高齢化率         = 65歳以上人口 / 総人口 × 100
  消費支出_log     = log(消費支出（二人以上の世帯）)
  保健医療費_千円  = 保健医療費（二人以上の世帯）/ 1,000

【分析の流れ】
  Fig1: 医療機関密度の都道府県別ランキング（2022年、積み上げ棒グラフ）
  Fig2: 高齢化率 vs 医療機関密度_総合 散布図（2022年、都道府県ラベル付き）
  Fig3: OLS回帰係数プロット（標準化係数、95%信頼区間付き）
  Fig4: 地域別・医療機関密度_総合の時系列推移（2012〜2023）

【データ出典】
  SSDSE-B-2026.csv: 社会・人口統計体系（都道府県データ）
  変数コード:
    I510120 → 一般病院数
    I5102   → 一般診療所数
    I5103   → 歯科診療所数
    A1101   → 総人口
    A1303   → 65歳以上人口
    L3221   → 消費支出（二人以上の世帯）
    L322106 → 保健医療費（二人以上の世帯）

【データサイエンス学習ポイント】
  1. 人口当たり医療機関密度の計算（標準化指標の必要性）
  2. ローレンツ曲線とジニ係数（医療格差の定量化コード）
  3. 医療資源の決定要因（需要・供給両面の経済学的解釈）
  4. 医療政策の地域格差対策（統計から見える問題点）
=================================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2021_U4_katsuyo.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from scipy import stats

warnings.filterwarnings('ignore')

# ── パス設定 ──────────────────────────────────────────────────────────────────
FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family': 'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi': 150,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# ── データ読み込み ────────────────────────────────────────────────────────────
raw = pd.read_csv(DATA_B, encoding='cp932', header=0)
# 行0はJapanese label row; 実データは行1以降
df_all = raw.iloc[1:].copy()
df_all.columns = raw.columns

YEAR_COL = df_all.columns[0]  # 'SSDSE-B-2026'

# 都道府県コードのみ（R01000〜R47000）
df_all = df_all[df_all['Code'].str.match(r'^R\d{5}$')].copy()

# 数値変換
NUM_COLS = ['I510120', 'I5102', 'I5103', 'A1101', 'A1303', 'L3221', 'L322106']
for col in NUM_COLS:
    df_all[col] = pd.to_numeric(df_all[col], errors='coerce')

df_all['year'] = df_all[YEAR_COL].astype(int)

# ── 派生指標の計算 ────────────────────────────────────────────────────────────
df_all['病院密度']          = df_all['I510120'] / df_all['A1101'] * 10000
df_all['診療所密度']        = df_all['I5102']   / df_all['A1101'] * 10000
df_all['歯科密度']          = df_all['I5103']   / df_all['A1101'] * 10000
df_all['医療機関密度_総合'] = (df_all['I510120'] + df_all['I5102']) / df_all['A1101'] * 10000
df_all['高齢化率']          = df_all['A1303']   / df_all['A1101'] * 100
df_all['消費支出_log']      = np.log(df_all['L3221'].clip(lower=1))
df_all['保健医療費_千円']   = df_all['L322106'] / 1000

print(f"読み込み完了: {len(df_all)}レコード, {df_all['year'].nunique()}年度, {df_all['Prefecture'].nunique()}都道府県")

# 2022年断面データ
d22 = df_all[df_all['year'] == 2022].copy()
d22 = d22.dropna(subset=['病院密度', '診療所密度', '歯科密度', '高齢化率']).sort_values('医療機関密度_総合', ascending=False)

# ────────────────────────────────────────────────────────────────────────────
# ジニ係数計算関数（参考値として表示）
# ────────────────────────────────────────────────────────────────────────────
def gini_coefficient(values):
    """ローレンツ曲線に基づくジニ係数を計算する。"""
    v = np.sort(np.array(values, dtype=float))
    n = len(v)
    idx = np.arange(1, n + 1)
    return (2 * np.dot(idx, v)) / (n * v.sum()) - (n + 1) / n

g_hosp  = gini_coefficient(d22['病院密度'].dropna())
g_clin  = gini_coefficient(d22['診療所密度'].dropna())
g_dent  = gini_coefficient(d22['歯科密度'].dropna())
g_total = gini_coefficient(d22['医療機関密度_総合'].dropna())
print(f"\nジニ係数 (2022年, 47都道府県):")
print(f"  病院密度:         {g_hosp:.4f}")
print(f"  診療所密度:       {g_clin:.4f}")
print(f"  歯科密度:         {g_dent:.4f}")
print(f"  医療機関密度_総合: {g_total:.4f}")

# ────────────────────────────────────────────────────────────────────────────
# Fig 1: 都道府県別ランキング（積み上げ棒グラフ）
# ────────────────────────────────────────────────────────────────────────────
fig1, ax1 = plt.subplots(figsize=(14, 8))

prefs    = d22['Prefecture'].values
hosp_v   = d22['病院密度'].values
clin_v   = d22['診療所密度'].values
dent_v   = d22['歯科密度'].values
x        = np.arange(len(prefs))
bar_w    = 0.72

b1 = ax1.bar(x, hosp_v, bar_w, label='一般病院', color='#1565C0', alpha=0.9)
b2 = ax1.bar(x, clin_v, bar_w, bottom=hosp_v, label='一般診療所', color='#42A5F5', alpha=0.9)
b3 = ax1.bar(x, dent_v, bar_w, bottom=hosp_v + clin_v, label='歯科診療所', color='#E65100', alpha=0.85)

ax1.set_xticks(x)
ax1.set_xticklabels(prefs, rotation=90, fontsize=8.5)
ax1.set_ylabel('人口1万人当たり医療機関数（機関）', fontsize=11)
ax1.set_title('Fig 1｜都道府県別 医療機関密度（2022年）\n病院・診療所・歯科の積み上げ',
              fontsize=13, fontweight='bold', pad=14)
ax1.legend(loc='upper right', fontsize=10)
ax1.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))

# 全国平均線
nat_avg = (d22['医療機関密度_総合'] + d22['歯科密度']).mean()
ax1.axhline(nat_avg, color='gray', linestyle='--', linewidth=1.2, alpha=0.8)
ax1.text(len(prefs) - 0.5, nat_avg + 0.08, f'全国平均\n{nat_avg:.2f}', fontsize=8.5,
         color='gray', ha='right')

fig1.tight_layout()
fig1_path = os.path.join(FIG_DIR, '2021_U4_fig1_ranking.png')
fig1.savefig(fig1_path, dpi=150, bbox_inches='tight')
plt.close(fig1)
print(f"\nFig1 保存: {fig1_path}")

# ────────────────────────────────────────────────────────────────────────────
# Fig 2: 高齢化率 vs 医療機関密度 散布図
# ────────────────────────────────────────────────────────────────────────────
d22_sc = d22.dropna(subset=['高齢化率', '医療機関密度_総合'])

fig2, ax2 = plt.subplots(figsize=(10, 7))

# 地域区分（総務省8地域）
region_map = {
    '北海道': '北海道',
    '青森県': '東北', '岩手県': '東北', '宮城県': '東北', '秋田県': '東北',
    '山形県': '東北', '福島県': '東北',
    '茨城県': '関東', '栃木県': '関東', '群馬県': '関東', '埼玉県': '関東',
    '千葉県': '関東', '東京都': '関東', '神奈川県': '関東',
    '新潟県': '中部', '富山県': '中部', '石川県': '中部', '福井県': '中部',
    '山梨県': '中部', '長野県': '中部', '岐阜県': '中部', '静岡県': '中部', '愛知県': '中部',
    '三重県': '近畿', '滋賀県': '近畿', '京都府': '近畿', '大阪府': '近畿',
    '兵庫県': '近畿', '奈良県': '近畿', '和歌山県': '近畿',
    '鳥取県': '中国', '島根県': '中国', '岡山県': '中国', '広島県': '中国', '山口県': '中国',
    '徳島県': '四国', '香川県': '四国', '愛媛県': '四国', '高知県': '四国',
    '福岡県': '九州', '佐賀県': '九州', '長崎県': '九州', '熊本県': '九州',
    '大分県': '九州', '宮崎県': '九州', '鹿児島県': '九州', '沖縄県': '九州',
}
region_colors = {
    '北海道': '#1565C0', '東北': '#2E7D32', '関東': '#F57F17',
    '中部': '#6A1B9A', '近畿': '#C62828', '中国': '#00838F',
    '四国': '#E65100', '九州': '#37474F',
}

d22_sc = d22_sc.copy()
d22_sc['region'] = d22_sc['Prefecture'].map(region_map).fillna('その他')

plotted_regions = set()
for _, row in d22_sc.iterrows():
    reg = row['region']
    col = region_colors.get(reg, '#888888')
    label = reg if reg not in plotted_regions else '_nolegend_'
    ax2.scatter(row['高齢化率'], row['医療機関密度_総合'],
                color=col, s=60, alpha=0.85, label=label, zorder=3)
    plotted_regions.add(reg)
    ax2.annotate(row['Prefecture'].replace('県', '').replace('都', '').replace('府', '').replace('道', ''),
                 (row['高齢化率'], row['医療機関密度_総合']),
                 textcoords='offset points', xytext=(5, 3), fontsize=7.5, color='#333')

# 回帰直線
x_sc = d22_sc['高齢化率'].values
y_sc = d22_sc['医療機関密度_総合'].values
slope, intercept, r, p, _ = stats.linregress(x_sc, y_sc)
x_line = np.linspace(x_sc.min(), x_sc.max(), 100)
ax2.plot(x_line, slope * x_line + intercept, 'k--', linewidth=1.5, alpha=0.6,
         label=f'回帰直線 (r={r:.3f}, p={p:.3f})')

ax2.set_xlabel('高齢化率（65歳以上人口比率, %）', fontsize=11)
ax2.set_ylabel('医療機関密度_総合（人口1万人当たり）', fontsize=11)
ax2.set_title('Fig 2｜高齢化率 × 医療機関密度（2022年・47都道府県）',
              fontsize=13, fontweight='bold', pad=14)
ax2.legend(fontsize=9, loc='upper left', ncol=2)

fig2.tight_layout()
fig2_path = os.path.join(FIG_DIR, '2021_U4_fig2_scatter.png')
fig2.savefig(fig2_path, dpi=150, bbox_inches='tight')
plt.close(fig2)
print(f"Fig2 保存: {fig2_path}")

# ────────────────────────────────────────────────────────────────────────────
# Fig 3: OLS回帰係数プロット（標準化係数 + 95%CI）
# ────────────────────────────────────────────────────────────────────────────
# 2022年断面 OLS：被説明変数 = 医療機関密度_総合
d_ols = d22.dropna(subset=['医療機関密度_総合', '高齢化率', '消費支出_log', '保健医療費_千円']).copy()

def standardize(s):
    return (s - s.mean()) / s.std()

y_std = standardize(d_ols['医療機関密度_総合'])
X_dict = {
    '高齢化率': standardize(d_ols['高齢化率']),
    '消費支出(log)': standardize(d_ols['消費支出_log']),
    '保健医療費(千円)': standardize(d_ols['保健医療費_千円']),
}

coef_results = {}
for name, xvar in X_dict.items():
    X_mat = np.column_stack([np.ones(len(xvar)), xvar.values])
    beta, res, rank, sv = np.linalg.lstsq(X_mat, y_std.values, rcond=None)
    # 標準誤差
    n_, k_ = X_mat.shape
    y_hat = X_mat @ beta
    sigma2 = np.sum((y_std.values - y_hat)**2) / (n_ - k_)
    cov_mat = sigma2 * np.linalg.inv(X_mat.T @ X_mat)
    se = np.sqrt(np.diag(cov_mat))
    t_stat = beta[1] / se[1]
    p_val = 2 * stats.t.sf(abs(t_stat), df=n_ - k_)
    ci95 = 1.96 * se[1]
    coef_results[name] = {'coef': beta[1], 'ci': ci95, 'p': p_val}

# 複合OLS（全変数投入）
X_all = np.column_stack([
    np.ones(len(d_ols)),
    X_dict['高齢化率'].values,
    X_dict['消費支出(log)'].values,
    X_dict['保健医療費(千円)'].values,
])
beta_all, _, _, _ = np.linalg.lstsq(X_all, y_std.values, rcond=None)
n_a, k_a = X_all.shape
y_hat_a = X_all @ beta_all
sigma2_a = np.sum((y_std.values - y_hat_a)**2) / (n_a - k_a)
cov_all = sigma2_a * np.linalg.inv(X_all.T @ X_all)
se_all = np.sqrt(np.diag(cov_all))

multi_coefs = {
    '高齢化率': (beta_all[1], 1.96 * se_all[1],
                 2 * stats.t.sf(abs(beta_all[1] / se_all[1]), df=n_a - k_a)),
    '消費支出(log)': (beta_all[2], 1.96 * se_all[2],
                     2 * stats.t.sf(abs(beta_all[2] / se_all[2]), df=n_a - k_a)),
    '保健医療費(千円)': (beta_all[3], 1.96 * se_all[3],
                       2 * stats.t.sf(abs(beta_all[3] / se_all[3]), df=n_a - k_a)),
}

print("\n--- OLS結果 (標準化係数, N=47都道府県, 2022年) ---")
for vname, (c, ci, p) in multi_coefs.items():
    sig = '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else 'n.s.'))
    print(f"  {vname:18s}: β={c:+.4f}  95%CI[{c-ci:+.4f}, {c+ci:+.4f}]  p={p:.4f} {sig}")

fig3, ax3 = plt.subplots(figsize=(8, 5))

var_labels = list(multi_coefs.keys())
coefs  = [multi_coefs[v][0] for v in var_labels]
cis    = [multi_coefs[v][1] for v in var_labels]
pvals  = [multi_coefs[v][2] for v in var_labels]
colors_ols = ['#1565C0' if c > 0 else '#C62828' for c in coefs]

y_pos = np.arange(len(var_labels))
ax3.barh(y_pos, coefs, xerr=cis, color=colors_ols, alpha=0.80,
         height=0.5, error_kw={'elinewidth': 2, 'capsize': 5, 'ecolor': '#333'})

for i, (c, p) in enumerate(zip(coefs, pvals)):
    sig_str = '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else 'n.s.'))
    ax3.text(c + (cis[i] + 0.02) * np.sign(c), i, sig_str,
             va='center', fontsize=11, fontweight='bold',
             color='#333' if sig_str != 'n.s.' else '#999')

ax3.axvline(0, color='black', linewidth=1.0, linestyle='-')
ax3.set_yticks(y_pos)
ax3.set_yticklabels(var_labels, fontsize=11)
ax3.set_xlabel('標準化回帰係数（β）± 95%信頼区間', fontsize=11)
ax3.set_title('Fig 3｜医療機関密度の決定要因\nOLS標準化係数（2022年断面, N=47）', fontsize=13, fontweight='bold', pad=14)
ax3.text(0.99, 0.03, '***p<0.001  **p<0.01  *p<0.05  n.s.=有意でない',
         transform=ax3.transAxes, ha='right', va='bottom', fontsize=9, color='#555')

fig3.tight_layout()
fig3_path = os.path.join(FIG_DIR, '2021_U4_fig3_ols.png')
fig3.savefig(fig3_path, dpi=150, bbox_inches='tight')
plt.close(fig3)
print(f"Fig3 保存: {fig3_path}")

# ────────────────────────────────────────────────────────────────────────────
# Fig 4: 地域別 時系列推移（2012〜2023）
# ────────────────────────────────────────────────────────────────────────────
# 8地域集計
df_ts = df_all[df_all['year'].between(2012, 2023)].copy()
df_ts['region'] = df_ts['Prefecture'].map(region_map).fillna('その他')
df_ts = df_ts.dropna(subset=['医療機関密度_総合'])

region_ts = (
    df_ts.groupby(['year', 'region'])['医療機関密度_総合']
    .mean()
    .reset_index()
)

region_order = ['北海道', '東北', '関東', '中部', '近畿', '中国', '四国', '九州']
region_ls = ['-', '--', '-.', ':', '-', '--', '-.', ':']

fig4, ax4 = plt.subplots(figsize=(11, 6))

for reg, ls_ in zip(region_order, region_ls):
    sub = region_ts[region_ts['region'] == reg].sort_values('year')
    col = region_colors.get(reg, '#888')
    ax4.plot(sub['year'], sub['医療機関密度_総合'],
             marker='o', markersize=5, linewidth=2.2, linestyle=ls_,
             color=col, label=reg, alpha=0.9)

ax4.set_xlabel('年度', fontsize=11)
ax4.set_ylabel('医療機関密度_総合（人口1万人当たり）', fontsize=11)
ax4.set_title('Fig 4｜地域別 医療機関密度の時系列推移（2012〜2023年）\n8地域平均（一般病院＋一般診療所）',
              fontsize=13, fontweight='bold', pad=14)
ax4.set_xticks(range(2012, 2024))
ax4.xaxis.set_tick_params(rotation=45)
ax4.legend(loc='upper right', fontsize=10, ncol=2)
ax4.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))

fig4.tight_layout()
fig4_path = os.path.join(FIG_DIR, '2021_U4_fig4_timeseries.png')
fig4.savefig(fig4_path, dpi=150, bbox_inches='tight')
plt.close(fig4)
print(f"Fig4 保存: {fig4_path}")

print("\n=== 全図の出力が完了しました ===")
print(f"出力先: {FIG_DIR}")
