"""
2025_H4 分析スクリプト: 介護負担の地域要因分析
==============================================================
【論文】介護職の離職率に影響する地域要因の分析と介護人材確保策の推進
        勝田花梨（江戸川学園取手高等学校）統計活用奨励賞（高校生）

【本スクリプトの方針】
  介護職離職率は公的な都道府県別機械可読データが存在しないため、
  SSDSE の実公的データのみで再構成する。
  目的変数として「介護・看護時間（分/日 per 人口）」(SSDSE-D-2023) を使用。
  これは住民が実際に介護・看護に費やす時間であり、地域の介護負担度を示す。

【手法】
  1. 相関分析（ヒートマップ）
  2. 重回帰分析（OLS）, VIF確認
  3. 箱ひげ図（分布確認）
  4. k-meansクラスタリング（k=3）

【変数】
  目的変数: 介護・看護時間（分/日）       ← SSDSE-D-2023
  説明変数:
    - 高齢化率（%）         = 65歳以上人口 / 総人口 × 100    (SSDSE-B-2026)
    - 医師数_10万対         = 医師数 / 総人口 × 100000       (SSDSE-E-2026)
    - 一般病院数_10万対     = 一般病院数 / 総人口 × 100000   (SSDSE-E-2026)
    - 1人当たり県民所得     （SSDSE-E-2026）
    - 消費支出              （SSDSE-E-2026）

【入力】
  data/raw/SSDSE-B-2026.csv（2022年度, 47都道府県）
  data/raw/SSDSE-E-2026.csv（横断面, 47都道府県）
  data/raw/SSDSE-D-2023.csv（生活時間調査）

【出力】
  html/figures/2025_H4_fig1_dist.png    ... 介護・看護時間の分布
  html/figures/2025_H4_fig2_heatmap.png ... 相関ヒートマップ
  html/figures/2025_H4_fig3_reg.png     ... 回帰係数プロット
  html/figures/2025_H4_fig4_cluster.png ... クラスター分析結果
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#     ・SSDSE-D-2023.csv
#       → data/raw/SSDSE-D-2023.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#     （SSDSE-D（社会・人口統計体系 都道府県の指標） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2025_H4_katsuyo.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import statsmodels.api as sm
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from statsmodels.stats.outliers_influence import variance_inflation_factor

warnings.filterwarnings('ignore')

# ── パス設定 ──────────────────────────────────────────────────────────
DATA_DIR = 'data/raw'
FIG_DIR  = 'html/figures'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family': 'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi': 150,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# ── SSDSE-B-2026.csv 読み込み（2022年度, 47都道府県）────────────────
df_b_raw = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'),
    encoding='cp932', header=1
)
df_b_raw['年度'] = pd.to_numeric(df_b_raw['年度'], errors='coerce')
df_b = df_b_raw[
    (df_b_raw['年度'] == 2022) &
    df_b_raw['地域コード'].str.match(r'^R\d{5}$', na=False)
].copy().set_index('都道府県')

for col in ['総人口', '65歳以上人口', '15～64歳人口']:
    df_b[col] = pd.to_numeric(df_b[col], errors='coerce')

# ── SSDSE-E-2026.csv 読み込み（横断面）──────────────────────────────
df_e_raw = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-E-2026.csv'),
    encoding='cp932', header=1
)
df_e = df_e_raw.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[0].values
df_e = df_e[df_e['都道府県'] != '全国'].reset_index(drop=True).set_index('都道府県')

for col in ['医師数', '総人口', '65歳以上人口', '一般病院数', '一般診療所数',
            '1人当たり県民所得（平成27年基準）', '消費支出（二人以上の世帯）',
            '総面積（北方地域及び竹島を除く）']:
    df_e[col] = pd.to_numeric(df_e[col], errors='coerce')

# ── SSDSE-D-2023.csv 読み込み（生活時間, 都道府県別 総数）──────────
df_d_raw = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-D-2023.csv'),
    encoding='cp932', header=1
)
df_d = df_d_raw[
    (df_d_raw['男女の別'] == '0_総数') &
    (df_d_raw['地域コード'] != 'R00000')
].copy().set_index('都道府県')
df_d['介護・看護'] = pd.to_numeric(df_d['介護・看護'], errors='coerce')

# ── 共通都道府県リスト（47都道府県）────────────────────────────────
PREFS = df_e.index.tolist()

# ── 変数の計算 ──────────────────────────────────────────────────────
df = pd.DataFrame(index=PREFS)
df.index.name = 'pref'

# 目的変数: 介護・看護時間（分/日）from SSDSE-D
df['介護看護時間'] = df_d['介護・看護']

# 説明変数
# 高齢化率 (SSDSE-B 2022)
df['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口'] * 100

# 医師数_10万対 (SSDSE-E)
df['医師数_10万対'] = df_e['医師数'] / df_e['総人口'] * 100000

# 一般病院数_10万対 (SSDSE-E)
df['病院数_10万対'] = df_e['一般病院数'] / df_e['総人口'] * 100000

# 1人当たり県民所得 (SSDSE-E)
df['1人当たり県民所得'] = df_e['1人当たり県民所得（平成27年基準）']

# 消費支出 (SSDSE-E)
df['消費支出'] = df_e['消費支出（二人以上の世帯）']

df = df.dropna().reset_index()

PREFS_USED = df['pref'].tolist()
print(f"分析対象: {len(df)}都道府県")
print("=== 記述統計 ===")
ANALYSIS_COLS = ['介護看護時間', '高齢化率', '医師数_10万対', '病院数_10万対',
                 '1人当たり県民所得', '消費支出']
print(df[ANALYSIS_COLS].describe().round(2))

# ── 標準化 ──────────────────────────────────────────────────────────
FEAT_COLS = ['高齢化率', '医師数_10万対', '病院数_10万対', '1人当たり県民所得', '消費支出']
scaler = StandardScaler()
X_scaled_arr = scaler.fit_transform(df[FEAT_COLS].values)
df_scaled = pd.DataFrame(X_scaled_arr, columns=[c + '_z' for c in FEAT_COLS])
df = pd.concat([df.reset_index(drop=True), df_scaled], axis=1)

# ── 重回帰分析 ──────────────────────────────────────────────────────
Z_COLS = [c + '_z' for c in FEAT_COLS]
X_reg = sm.add_constant(df[Z_COLS])
y_reg = df['介護看護時間']
res_ols = sm.OLS(y_reg, X_reg).fit()

print("\n=== 重回帰分析結果 ===")
print(res_ols.summary2().tables[1].to_string())
print(f"R²={res_ols.rsquared:.3f}, adj.R²={res_ols.rsquared_adj:.3f}")

# VIF
vif_vals = [variance_inflation_factor(X_reg.values, i + 1) for i in range(len(FEAT_COLS))]
print("\n=== VIF ===")
for n, v in zip(FEAT_COLS, vif_vals):
    print(f"  {n}: {v:.2f}")

# ── Figure 1: 介護・看護時間の分布（箱ひげ図＋ヒストグラム）──────
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))

ax = axes[0]
bp = ax.boxplot(df['介護看護時間'], patch_artist=True, vert=True,
                boxprops=dict(facecolor='#BBDEFB', color='#1565C0'),
                medianprops=dict(color='#E53935', lw=2.5),
                whiskerprops=dict(color='#1565C0'),
                capprops=dict(color='#1565C0'),
                flierprops=dict(marker='o', color='#E53935', alpha=0.7))
ax.set_xticklabels(['47都道府県'])
ax.set_ylabel('介護・看護時間（分/日）')
ax.set_title('介護・看護時間の分布（箱ひげ図）', fontsize=12, fontweight='bold')
med = df['介護看護時間'].median()
ax.text(1.12, med, f'中央値\n{med:.1f}分', va='center', fontsize=9, color='#E53935')

ax2 = axes[1]
ax2.hist(df['介護看護時間'], bins=12, color='#1565C0', alpha=0.75, edgecolor='white')
ax2.axvline(df['介護看護時間'].mean(), color='#E53935', ls='--', lw=2,
            label=f'平均 {df["介護看護時間"].mean():.1f}分')
ax2.axvline(df['介護看護時間'].median(), color='#FB8C00', ls='-', lw=2,
            label=f'中央値 {df["介護看護時間"].median():.1f}分')
ax2.set_xlabel('介護・看護時間（分/日）')
ax2.set_ylabel('都道府県数')
ax2.set_title('介護・看護時間の頻度分布', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2025_H4_fig1_dist.png'), bbox_inches='tight')
plt.close()
print("Figure 1 saved.")

# ── Figure 2: 相関ヒートマップ ──────────────────────────────────────
fig, ax = plt.subplots(figsize=(7, 6))

corr_cols = ['介護看護時間'] + FEAT_COLS
corr_labels = ['介護・看護\n時間', '高齢化率', '医師数\n10万対', '病院数\n10万対',
               '県民所得', '消費支出']
corr_mat = df[corr_cols].corr()

im = ax.imshow(corr_mat.values, cmap='RdBu_r', vmin=-1, vmax=1)
n_vars = len(corr_cols)
ax.set_xticks(range(n_vars))
ax.set_yticks(range(n_vars))
ax.set_xticklabels(corr_labels, fontsize=9)
ax.set_yticklabels(corr_labels, fontsize=9)

for i in range(n_vars):
    for j in range(n_vars):
        val = corr_mat.values[i, j]
        color = 'white' if abs(val) > 0.5 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center',
                fontsize=10, fontweight='bold', color=color)

plt.colorbar(im, ax=ax, label='相関係数', shrink=0.7)
ax.set_title('相関係数ヒートマップ\n（介護・看護時間と説明変数）',
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2025_H4_fig2_heatmap.png'), bbox_inches='tight')
plt.close()
print("Figure 2 saved.")

# ── Figure 3: 回帰係数と95%信頼区間 ────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

var_labels_reg = ['高齢化率', '医師数\n10万対', '病院数\n10万対', '1人当たり\n県民所得', '消費支出']
coefs = res_ols.params[1:].values
cis   = res_ols.conf_int().values[1:]
pvals = res_ols.pvalues[1:].values

ax = axes[0]
bar_colors = ['#E53935' if c < 0 else '#1E88E5' for c in coefs]
ax.barh(range(len(FEAT_COLS)), coefs, color=bar_colors, alpha=0.85, height=0.5)
ax.errorbar(coefs, range(len(FEAT_COLS)),
            xerr=[coefs - cis[:, 0], cis[:, 1] - coefs],
            fmt='none', color='black', capsize=6, lw=2)
ax.axvline(0, color='black', lw=1.2)
ax.set_yticks(range(len(FEAT_COLS)))
sig_marks = ['***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else ''))
             for p in pvals]
ax.set_yticklabels([f'{l} {m}' for l, m in zip(var_labels_reg, sig_marks)], fontsize=11)
ax.set_xlabel('標準化回帰係数')
ax.set_title('重回帰分析結果\n（標準化係数・95%CI）', fontsize=12, fontweight='bold')

neg_patch = mpatches.Patch(color='#E53935', alpha=0.85, label='負の効果')
pos_patch = mpatches.Patch(color='#1E88E5', alpha=0.85, label='正の効果')
ax.legend(handles=[neg_patch, pos_patch], fontsize=9, loc='lower right')

# VIF プロット
ax2 = axes[1]
colors_vif = ['#E53935' if v >= 10 else ('#FB8C00' if v >= 5 else '#43A047')
              for v in vif_vals]
ax2.barh(range(len(FEAT_COLS)), vif_vals, color=colors_vif, alpha=0.85, height=0.5)
ax2.axvline(10, color='#E53935', ls='--', lw=1.5, label='VIF=10 (要注意)')
ax2.axvline(5,  color='#FB8C00', ls='--', lw=1.5, label='VIF=5 (注意)')
ax2.set_yticks(range(len(FEAT_COLS)))
ax2.set_yticklabels(var_labels_reg, fontsize=11)
ax2.set_xlabel('VIF値')
ax2.set_title('多重共線性チェック（VIF）', fontsize=12, fontweight='bold')
ax2.legend(fontsize=9)

for i, v in enumerate(vif_vals):
    ax2.text(v + 0.02, i, f'{v:.2f}', va='center', fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2025_H4_fig3_reg.png'), bbox_inches='tight')
plt.close()
print("Figure 3 saved.")

# ── k-meansクラスタリング（k=3, 政策グループ）──────────────────────
X_cluster = df[Z_COLS].values
km = KMeans(n_clusters=3, n_init=50, max_iter=500)
km.fit(X_cluster)
df['cluster_raw'] = km.labels_

# 介護・看護時間の平均でクラスターを命名（高い順）
cl_mean = df.groupby('cluster_raw')['介護看護時間'].mean().sort_values(ascending=False)
rank_map = {cl_mean.index[0]: 0,   # 介護負担大
            cl_mean.index[1]: 1,   # 中
            cl_mean.index[2]: 2}   # 小
df['cluster'] = df['cluster_raw'].map(rank_map)
CLUSTER_NAMES  = {0: '介護負担大群\n（高齢化先行型）',
                  1: '介護負担中群\n（移行型）',
                  2: '介護負担小群\n（都市・若年型）'}
CLUSTER_COLORS = {0: '#E53935', 1: '#FB8C00', 2: '#43A047'}

print("\n=== クラスター別プロファイル ===")
print(df.groupby('cluster')[['介護看護時間', '高齢化率', '医師数_10万対']].mean().round(2))

# ── Figure 4: クラスター分析結果 ────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 5.5))

# 左: 散布図（高齢化率 vs 介護・看護時間）
ax = axes[0]
for cl, name in CLUSTER_NAMES.items():
    mask = df['cluster'] == cl
    ax.scatter(df[mask]['高齢化率'], df[mask]['介護看護時間'],
               color=CLUSTER_COLORS[cl], label=name.replace('\n', ' '),
               s=80, alpha=0.8, edgecolors='white', lw=0.5)

ax.set_xlabel('高齢化率（%）')
ax.set_ylabel('介護・看護時間（分/日）')
ax.set_title('高齢化率と介護・看護時間の関係\n（k-meansクラスタリング）',
             fontsize=12, fontweight='bold')
ax.legend(fontsize=9, title='クラスター')

# 全体回帰線
z = np.polyfit(df['高齢化率'], df['介護看護時間'], 1)
p_line = np.poly1d(z)
x_line = np.linspace(df['高齢化率'].min(), df['高齢化率'].max(), 100)
ax.plot(x_line, p_line(x_line), '--', color='gray', lw=1.5, alpha=0.7,
        label='全体回帰線')
ax.legend(fontsize=9, title='クラスター')

# 注目都道府県
for _, row in df.iterrows():
    if row['pref'] in ['東京都', '秋田県', '島根県', '沖縄県']:
        ax.annotate(row['pref'].replace('県', '').replace('都', '').replace('府', ''),
                   (row['高齢化率'], row['介護看護時間']),
                   fontsize=8, fontweight='bold', color='#333',
                   xytext=(5, 4), textcoords='offset points',
                   arrowprops=dict(arrowstyle='->', color='#666', lw=0.7))

# 右: クラスター別変数プロファイル（棒グラフ）
ax2 = axes[1]
cl_profile = df.groupby('cluster')[
    ['介護看護時間'] + [c + '_z' for c in FEAT_COLS[:3]]
].mean()

x = np.arange(3)
width = 0.22
plot_vars = ['介護看護時間', '高齢化率_z', '医師数_10万対_z', '病院数_10万対_z']
pv_labels = ['介護・看護時間\n（実値/分）', '高齢化率\n（標準化）',
             '医師数10万対\n（標準化）', '病院数10万対\n（標準化）']
pv_colors = ['#1565C0', '#43A047', '#E53935', '#FB8C00']

for i, (var, lab, col) in enumerate(zip(plot_vars, pv_labels, pv_colors)):
    vals = [cl_profile.loc[c, var] if c in cl_profile.index else 0 for c in range(3)]
    ax2.bar(x + (i - 1.5) * width, vals, width=width,
            label=lab, color=col, alpha=0.8)

ax2.set_xticks(x)
ax2.set_xticklabels([CLUSTER_NAMES[c].replace('\n', ' ') for c in range(3)], fontsize=9)
ax2.axhline(0, color='black', lw=0.8)
ax2.set_title('クラスター別変数プロファイル', fontsize=12, fontweight='bold')
ax2.legend(fontsize=8, ncol=2)
ax2.set_ylabel('値（介護・看護時間: 分、その他: 標準化値）')

plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, '2025_H4_fig4_cluster.png'), bbox_inches='tight')
plt.close()
print("Figure 4 saved.")

print("\n=== クラスター別都道府県 ===")
for cl in range(3):
    prefs_in_cl = df[df['cluster'] == cl]['pref'].tolist()
    print(f"  {CLUSTER_NAMES[cl].replace(chr(10), '')}: {', '.join(prefs_in_cl)}")

print("\n分析完了。html/figures/ に図を保存しました。")
