"""
教育用再現コード: 2024年 統計データ分析コンペティション 統計活用奨励賞（高校生）
=================================================================
論文タイトル：医療費削減に向けたスポーツ時間増加策のデータ分析
著者：過目今（法政大学国際高等学校）

【分析概要】
  データ：SSDSE-B-2026.csv（社会・人口統計体系）
          SSDSE-D-2023.csv（社会生活基本調査2021）

  Step1. 相関行列分析（生活時間変数と スポーツ行動者率・平均時間）
  Step2. 重回帰分析（年齢階層×性別：6モデル）
         年齢階層は都道府県の高齢化率三分位で代理設定
           青少年環境（高齢化率 低位 ）→ 若い住民構成の都道府県
           現役環境  （高齢化率 中位 ）→ 現役世代中心の都道府県
           シニア環境（高齢化率 高位 ）→ 高齢化が進んだ都道府県
         × 男性 / 女性（SSDSE-D の男女別データ）
  Step3. 標準化係数による変数間影響力比較（ヒートマップ）
  Step4. 睡眠時間 vs スポーツ時間 散布図（男女別）

  Key findings（実データから得られた傾向）:
    - 睡眠時間が長い都道府県ほど男性のスポーツ時間が長い傾向
    - 仕事時間が長い都道府県では男女ともスポーツ時間が短い傾向
    - 高齢化率が高い都道府県ほどスポーツ行動者率が低い傾向（男女共通）
    - 保健医療費が高い都道府県ではスポーツ行動者率が低い傾向

【データサイエンス学習ポイント】
  1. 年齢階層×性別の「格子状分析」デザイン
  2. 標準化係数による変数間影響力比較
  3. 都道府県の人口構成を年齢層代理変数として使う手法
  4. 実公的統計データ（SSDSE）の活用方法

【データ】SSDSE-B-2026.csv、SSDSE-D-2023.csv（実公的統計データ）
         合成データ・np.random は一切使用しない
=================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-D-2023.csv
#       → data/raw/SSDSE-D-2023.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-D（社会・人口統計体系 都道府県の指標） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_H4_katsuyo.py
# ============================================================


import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from scipy import stats
import warnings
import os

warnings.filterwarnings('ignore')

# ──────────────────────────────────────────────────────────────
# 共通設定
# ──────────────────────────────────────────────────────────────
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

DATA_DIR = 'data/raw'
FIG_DIR = 'html/figures'
os.makedirs(FIG_DIR, exist_ok=True)

# カラーパレット（年齢階層×性別）
COLORS = {
    '青少年_男': '#1565C0',
    '青少年_女': '#90CAF9',
    '現役_男':   '#2E7D32',
    '現役_女':   '#A5D6A7',
    'シニア_男': '#E65100',
    'シニア_女': '#FFCC80',
}

# ================================================================
# ■ Step 0. データ読み込みとマージ
# ================================================================

print("=" * 65)
print("■ Step 0. データ読み込み")
print("=" * 65)

# ─── SSDSE-D-2023（社会生活基本調査2021）────────────────────────
df_d = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-D-2023.csv'),
    encoding='cp932',
    header=1,
)
# 全国集計行を除外し、都道府県（47）のみ残す
df_d = df_d[df_d['地域コード'] != 'R00000'].copy()

# 数値列を変換（エラーは NaN）
time_use_cols = [
    '睡眠', '仕事', '学業', '家事',
    'スポーツ',                        # 平均時間（分/週全体の1日平均）
    '趣味・娯楽',
    'テレビ・ラジオ・新聞・雑誌',
    '休養・くつろぎ',
    '学習・自己啓発・訓練(学業以外)',
    '通勤・通学',
    '買い物',
    '交際・付き合い',
    'スポーツの総数',                  # 行動者率（%）
]
for col in time_use_cols:
    df_d[col] = pd.to_numeric(df_d[col], errors='coerce')

# 男女別に分割
df_d_male   = df_d[df_d['男女の別'] == '1_男'].reset_index(drop=True)
df_d_female = df_d[df_d['男女の別'] == '2_女'].reset_index(drop=True)
df_d_total  = df_d[df_d['男女の別'] == '0_総数'].reset_index(drop=True)

print(f"SSDSE-D: 男性 {len(df_d_male)} 都道府県, 女性 {len(df_d_female)} 都道府県")

# ─── SSDSE-B-2026（社会・人口統計体系）────────────────────────
df_b = pd.read_csv(
    os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'),
    encoding='cp932',
    header=1,
)
# 2022年度データを使用
df_b = df_b[df_b['年度'] == 2022].copy()

# 必要な数値列を変換
b_numeric = [
    '総人口', '65歳以上人口', '15歳未満人口', '15～64歳人口',
    '保健医療費（二人以上の世帯）',
    '消費支出（二人以上の世帯）',
    '降水日数（年間）',
    '年平均気温',
]
for col in b_numeric:
    df_b[col] = pd.to_numeric(df_b[col], errors='coerce')

# 年齢構成比（年齢階層代理変数）
df_b['高齢化率']       = df_b['65歳以上人口']  / df_b['総人口'] * 100  # 65歳以上割合
df_b['若年人口率']     = df_b['15歳未満人口']  / df_b['総人口'] * 100  # 15歳未満割合
df_b['生産年齢人口率'] = df_b['15～64歳人口']  / df_b['総人口'] * 100  # 15-64歳割合
# 保健医療費を一人当たりに換算（千円）
df_b['保健医療費_千円'] = df_b['保健医療費（二人以上の世帯）'] / 1000

b_use_cols = [
    '地域コード',
    '高齢化率', '若年人口率', '生産年齢人口率',
    '保健医療費_千円',
    '降水日数（年間）',
    '年平均気温',
]
for col in ['高齢化率', '若年人口率', '生産年齢人口率', '保健医療費_千円',
            '降水日数（年間）', '年平均気温']:
    df_b[col] = pd.to_numeric(df_b[col], errors='coerce')

print(f"SSDSE-B: {len(df_b)} 都道府県（2022年度）")

# ─── マージ（男女各データ ← SSDSE-B）──────────────────────────
merged_m = df_d_male.merge(df_b[b_use_cols], on='地域コード', how='inner')
merged_f = df_d_female.merge(df_b[b_use_cols], on='地域コード', how='inner')

print(f"マージ後: 男性 {len(merged_m)} 都道府県, 女性 {len(merged_f)} 都道府県")

# ─── 年齢階層代理グループ（高齢化率三分位）──────────────────
q33 = merged_m['高齢化率'].quantile(1/3)
q67 = merged_m['高齢化率'].quantile(2/3)

def assign_age_group(df: pd.DataFrame, col: str = '高齢化率') -> pd.Series:
    """高齢化率の三分位でグループを割り当てる"""
    low  = df[col] <= q33
    high = df[col] >  q67
    mid  = ~low & ~high
    groups = pd.Series('現役', index=df.index)
    groups[low]  = '青少年'
    groups[high] = 'シニア'
    return groups

merged_m['年齢グループ'] = assign_age_group(merged_m)
merged_f['年齢グループ'] = assign_age_group(merged_f)

for g in ['青少年', '現役', 'シニア']:
    n = (merged_m['年齢グループ'] == g).sum()
    rng = merged_m.loc[merged_m['年齢グループ'] == g, '高齢化率']
    print(f"  {g}: {n}都道府県, 高齢化率 {rng.min():.1f}%–{rng.max():.1f}%")

# ─── 説明変数・目的変数の定義 ─────────────────────────────────
# 目的変数
Y_COL_RATE = 'スポーツの総数'   # 行動者率（%）
Y_COL_TIME = 'スポーツ'         # 平均時間（分/日）

# 説明変数（SSDSE-D 生活時間 + SSDSE-B 社会指標）
PRED_COLS = [
    '睡眠',                         # 睡眠時間（分/日）
    '仕事',                         # 仕事時間（分/日）
    'テレビ・ラジオ・新聞・雑誌',   # テレビ等（分/日）
    '趣味・娯楽',                   # 趣味・娯楽（分/日）
    '高齢化率',                     # 65歳以上割合（%）
    '保健医療費_千円',              # 保健医療費（千円）
    '降水日数（年間）',             # 降水日数（日）
]
PRED_LABELS = {
    '睡眠':                       '睡眠時間',
    '仕事':                       '仕事時間',
    'テレビ・ラジオ・新聞・雑誌': 'テレビ時間',
    '趣味・娯楽':                 '趣味・娯楽時間',
    '高齢化率':                   '高齢化率',
    '保健医療費_千円':            '保健医療費',
    '降水日数（年間）':           '降水日数',
}

# ================================================================
# ■ Step 1. 記述統計
# ================================================================

print("\n" + "=" * 65)
print("■ Step 1. 記述統計")
print("=" * 65)

desc_cols = [Y_COL_RATE, Y_COL_TIME] + PRED_COLS
print("\n男性（全47都道府県）:")
print(merged_m[desc_cols].describe().round(2).to_string())
print("\n女性（全47都道府県）:")
print(merged_f[desc_cols].describe().round(2).to_string())

# ================================================================
# ■ Step 2. VIF 確認
# ================================================================

print("\n" + "=" * 65)
print("■ Step 2. VIF（多重共線性確認）")
print("=" * 65)

def compute_vif(df: pd.DataFrame, cols: list) -> dict:
    """標準化後の VIF を計算する"""
    sub = df[cols].dropna()
    X_std = (sub - sub.mean()) / sub.std()
    X_arr = X_std.values
    vif = {}
    for i, c in enumerate(cols):
        vif[c] = variance_inflation_factor(X_arr, i)
    return vif

vif_m = compute_vif(merged_m, PRED_COLS)
print("\n男性 VIF:")
for col, v in vif_m.items():
    flag = ' ★多重共線性の可能性' if v > 5 else ''
    print(f"  {PRED_LABELS[col]:<18} VIF={v:.2f}{flag}")

# ================================================================
# ■ Step 3. 重回帰分析（6モデル：年齢層×性別）
# ================================================================

print("\n" + "=" * 65)
print("■ Step 3. 重回帰分析（6モデル）")
print("=" * 65)

AGE_GROUPS  = ['青少年', '現役', 'シニア']
GENDERS     = ['男性', '女性']
GENDER_KEYS = {'男性': merged_m, '女性': merged_f}
AGE_COLORS  = {'青少年': '#1565C0', '現役': '#2E7D32', 'シニア': '#E65100'}

reg_results: dict = {}   # key = (age_group, gender)

def run_ols(df_sub: pd.DataFrame, y_col: str, x_cols: list):
    """標準化 OLS を実行し statsmodels の結果を返す"""
    sub = df_sub[[y_col] + x_cols].dropna()
    y = sub[y_col]
    X_raw = sub[x_cols]
    # 標準化
    X_std = (X_raw - X_raw.mean()) / X_raw.std()
    X_fit = sm.add_constant(X_std)
    model = sm.OLS(y, X_fit).fit(cov_type='HC1')
    return model, sub

for age in AGE_GROUPS:
    for gender, merged_df in GENDER_KEYS.items():
        sub = merged_df[merged_df['年齢グループ'] == age].copy()
        model, sub_used = run_ols(sub, Y_COL_RATE, PRED_COLS)
        reg_results[(age, gender)] = {
            'model':   model,
            'sub':     sub_used,
            'n':       len(sub_used),
            'age':     age,
            'gender':  gender,
        }
        sig_vars = [PRED_LABELS[c] for c in PRED_COLS
                    if model.pvalues.get(c, 1) < 0.05]
        print(f"\n  {age}×{gender}: n={len(sub_used)}, R²={model.rsquared:.3f}, "
              f"adj.R²={model.rsquared_adj:.3f}")
        print(f"    有意な変数（p<0.05）: {sig_vars if sig_vars else 'なし'}")

# ================================================================
# ■ 図の生成（4枚）
# ================================================================

print("\n" + "=" * 65)
print("■ 図の生成（4枚）")
print("=" * 65)

# ─── 図1: 生活時間変数の相関行列（SSDSE-D 全体）──────────────
print("図1: 相関行列（生活時間変数）を作成中...")

# 相関を計算する変数（全47都道府県・男女総数）
corr_cols = [
    'スポーツの総数',
    'スポーツ',
    '睡眠',
    '仕事',
    'テレビ・ラジオ・新聞・雑誌',
    '趣味・娯楽',
    '休養・くつろぎ',
    '買い物',
    '交際・付き合い',
]
corr_labels = {
    'スポーツの総数':              'スポーツ\n行動者率',
    'スポーツ':                    'スポーツ\n平均時間',
    '睡眠':                        '睡眠',
    '仕事':                        '仕事',
    'テレビ・ラジオ・新聞・雑誌': 'テレビ等',
    '趣味・娯楽':                  '趣味\n娯楽',
    '休養・くつろぎ':              '休養',
    '買い物':                      '買い物',
    '交際・付き合い':              '交際',
}

# 男女総数データで相関を計算
corr_data = df_d_total[corr_cols].apply(pd.to_numeric, errors='coerce').dropna()
corr_mat = corr_data.corr()

fig1, ax1 = plt.subplots(figsize=(10, 8))
im1 = ax1.imshow(corr_mat.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
tick_labels = [corr_labels[c] for c in corr_cols]
ax1.set_xticks(range(len(corr_cols)))
ax1.set_xticklabels(tick_labels, fontsize=9)
ax1.set_yticks(range(len(corr_cols)))
ax1.set_yticklabels(tick_labels, fontsize=9)
for i in range(len(corr_cols)):
    for j in range(len(corr_cols)):
        val = corr_mat.values[i, j]
        txt_color = 'white' if abs(val) > 0.5 else 'black'
        ax1.text(j, i, f'{val:.2f}', ha='center', va='center',
                 fontsize=7.5, color=txt_color, fontweight='bold')
plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04,
             label='ピアソン相関係数')
ax1.set_title(
    '生活時間変数の相関行列\n'
    '（SSDSE-D 2023 社会生活基本調査2021 男女総数・47都道府県）',
    fontsize=11, fontweight='bold', pad=12,
)
plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2024_H4_fig1_corr_matrix.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  -> 2024_H4_fig1_corr_matrix.png 保存完了")

# ─── 図2: 6モデルの標準化回帰係数（3行×2列）─────────────────
print("図2: 6モデル回帰係数パネルを作成中...")

fig2, axes2 = plt.subplots(3, 2, figsize=(14, 13))
fig2.suptitle(
    '重回帰分析：年齢階層×性別の規定要因（標準化回帰係数）\n'
    '目的変数：スポーツ行動者率（%）  説明変数：標準化済み\n'
    '（出典：SSDSE-D 2023・SSDSE-B 2026 実データ）',
    fontsize=11, fontweight='bold',
)

gender_palette = {'男性': '#1565C0', '女性': '#C62828'}

for row_i, age in enumerate(AGE_GROUPS):
    for col_i, gender in enumerate(GENDERS):
        ax = axes2[row_i, col_i]
        res = reg_results[(age, gender)]
        model = res['model']
        n = res['n']

        coefs  = [model.params.get(c, 0)   for c in PRED_COLS]
        ses    = [model.bse.get(c, 0)       for c in PRED_COLS]
        pvals  = [model.pvalues.get(c, 1)   for c in PRED_COLS]
        labels = [PRED_LABELS[c]            for c in PRED_COLS]

        # 係数の大きさ順にソート
        order = sorted(range(len(coefs)), key=lambda i: coefs[i])
        c_sorted = [coefs[i]  for i in order]
        s_sorted = [ses[i]    for i in order]
        p_sorted = [pvals[i]  for i in order]
        l_sorted = [labels[i] for i in order]

        age_color = AGE_COLORS[age]
        bar_colors = [age_color if p < 0.05 else '#BDBDBD' for p in p_sorted]

        y_pos = range(len(PRED_COLS))
        ax.barh(y_pos, c_sorted,
                xerr=[1.96 * s for s in s_sorted],
                color=bar_colors, alpha=0.85, edgecolor='white',
                capsize=3, error_kw={'elinewidth': 1.2, 'ecolor': '#444'})
        ax.set_yticks(y_pos)
        ax.set_yticklabels(l_sorted, fontsize=8.5)
        ax.axvline(0, color='black', linewidth=0.9)
        ax.set_xlabel('標準化回帰係数（±95%CI）', fontsize=8)
        ax.set_title(
            f'{age}環境×{gender}  '
            f'n={n}  R²={model.rsquared:.3f}',
            fontsize=9, fontweight='bold', color=age_color,
        )
        ax.grid(axis='x', alpha=0.25)
        # 有意マーカー
        for yi, (c, p) in enumerate(zip(c_sorted, p_sorted)):
            if p < 0.05:
                ax.text(c + (0.05 if c >= 0 else -0.05), yi,
                        '*', ha='center', va='center',
                        fontsize=11, color=age_color, fontweight='bold')

# 凡例用パッチ
from matplotlib.patches import Patch
legend_elems = [
    Patch(facecolor='steelblue', label='有意（p<0.05）'),
    Patch(facecolor='#BDBDBD',   label='非有意（p≥0.05）'),
]
fig2.legend(handles=legend_elems, loc='lower center', ncol=2,
            fontsize=9, framealpha=0.9,
            bbox_to_anchor=(0.5, -0.01))

plt.tight_layout(rect=[0, 0.03, 1, 0.97])
fig2.savefig(os.path.join(FIG_DIR, '2024_H4_fig2_coef_panel.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig2)
print("  -> 2024_H4_fig2_coef_panel.png 保存完了")

# ─── 図3: 睡眠時間 vs スポーツ時間（男女別散布図）──────────
print("図3: 睡眠時間 vs スポーツ時間 散布図を作成中...")

fig3, (ax3a, ax3b) = plt.subplots(1, 2, figsize=(13, 6))
fig3.suptitle(
    '睡眠時間 vs スポーツ時間（47都道府県）\n'
    '（出典：SSDSE-D 2023 社会生活基本調査2021）',
    fontsize=11, fontweight='bold',
)

def scatter_with_regression(ax, df, x_col, y_col, color, label, age_group_col=None):
    """散布図 + 回帰直線を描画し相関係数を表示する"""
    x = pd.to_numeric(df[x_col], errors='coerce')
    y = pd.to_numeric(df[y_col], errors='coerce')
    mask = x.notna() & y.notna()
    x, y = x[mask].values, y[mask].values

    # 年齢グループ別に色付け（都道府県ラベルなし）
    if age_group_col is not None and age_group_col in df.columns:
        age_colors_map = {'青少年': '#1565C0', '現役': '#2E7D32', 'シニア': '#E65100'}
        pt_colors = df.loc[mask, age_group_col].map(age_colors_map).fillna(color)
        ax.scatter(x, y, c=pt_colors, alpha=0.75, s=55, edgecolors='white', linewidth=0.5)
    else:
        ax.scatter(x, y, color=color, alpha=0.7, s=55, edgecolors='white', linewidth=0.5)

    # 都道府県名ラベル（一部のみ）
    pref_col = '都道府県' if '都道府県' in df.columns else None
    if pref_col:
        pref_vals = df.loc[mask, pref_col].reset_index(drop=True)
        for xi, yi, pref in zip(x, y, pref_vals):
            if pref in ['東京都', '大阪府', '愛知県', '北海道', '福岡県', '沖縄県',
                         '秋田県', '高知県', '山形県', '島根県']:
                ax.annotate(pref, (xi, yi), fontsize=7, color='#333',
                            xytext=(3, 3), textcoords='offset points')

    # 回帰直線
    slope, intercept, r, p_val, _ = stats.linregress(x, y)
    x_line = np.linspace(x.min(), x.max(), 100)
    ax.plot(x_line, intercept + slope * x_line, color='#333',
            linewidth=1.5, linestyle='--', alpha=0.8)

    r2 = r ** 2
    p_str = f'p={p_val:.3f}' if p_val >= 0.001 else 'p<0.001'
    ax.set_title(f'{label}\nr={r:.3f}, R²={r2:.3f}, {p_str}',
                 fontsize=10, fontweight='bold')
    ax.set_xlabel('睡眠時間（分/日）', fontsize=10)
    ax.set_ylabel('スポーツ時間（分/日）', fontsize=10)
    ax.grid(alpha=0.25)

scatter_with_regression(
    ax3a, merged_m, '睡眠', 'スポーツ',
    color='#1565C0', label='男性（全47都道府県）',
    age_group_col='年齢グループ',
)
scatter_with_regression(
    ax3b, merged_f, '睡眠', 'スポーツ',
    color='#C62828', label='女性（全47都道府県）',
    age_group_col='年齢グループ',
)

# 年齢グループ凡例
from matplotlib.patches import Patch as MPatch
legend_age = [
    MPatch(color='#1565C0', label='青少年環境（高齢化率 低位）'),
    MPatch(color='#2E7D32', label='現役環境（高齢化率 中位）'),
    MPatch(color='#E65100', label='シニア環境（高齢化率 高位）'),
]
ax3a.legend(handles=legend_age, fontsize=7.5, loc='upper left',
            title='年齢グループ代理', title_fontsize=7.5)
ax3b.legend(handles=legend_age, fontsize=7.5, loc='upper left',
            title='年齢グループ代理', title_fontsize=7.5)

plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2024_H4_fig3_scatter_sleep_sport.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  -> 2024_H4_fig3_scatter_sleep_sport.png 保存完了")

# ─── 図4: 標準化係数ヒートマップ（6グループ×7変数）───────────
print("図4: 標準化係数ヒートマップを作成中...")

# 標準化係数行列（rows=6グループ, cols=説明変数）
group_labels_ordered = [
    '青少年×男性', '青少年×女性',
    '現役×男性',   '現役×女性',
    'シニア×男性', 'シニア×女性',
]
group_keys_ordered = [
    ('青少年', '男性'), ('青少年', '女性'),
    ('現役',   '男性'), ('現役',   '女性'),
    ('シニア', '男性'), ('シニア', '女性'),
]

coef_matrix = np.zeros((len(group_keys_ordered), len(PRED_COLS)))
pval_matrix = np.ones((len(group_keys_ordered), len(PRED_COLS)))
r2_vals     = []

for gi, (age, gender) in enumerate(group_keys_ordered):
    res   = reg_results[(age, gender)]
    model = res['model']
    r2_vals.append(model.rsquared)
    for ci, col in enumerate(PRED_COLS):
        coef_matrix[gi, ci] = model.params.get(col, 0)
        pval_matrix[gi, ci] = model.pvalues.get(col, 1)

fig4, axes4 = plt.subplots(1, 2, figsize=(16, 5.5),
                            gridspec_kw={'width_ratios': [6, 1]})
fig4.suptitle(
    '標準化回帰係数ヒートマップ（6モデル×7説明変数）\n'
    '目的変数：スポーツ行動者率（%）\n'
    '（出典：SSDSE-D 2023・SSDSE-B 2026 実データ）',
    fontsize=11, fontweight='bold',
)

ax4 = axes4[0]
im4 = ax4.imshow(coef_matrix, cmap='RdBu_r', vmin=-1.5, vmax=1.5, aspect='auto')
ax4.set_xticks(range(len(PRED_COLS)))
ax4.set_xticklabels([PRED_LABELS[c] for c in PRED_COLS],
                    fontsize=9, rotation=25, ha='right')
ax4.set_yticks(range(len(group_labels_ordered)))
ax4.set_yticklabels(group_labels_ordered, fontsize=9)

for gi in range(len(group_keys_ordered)):
    for ci in range(len(PRED_COLS)):
        val = coef_matrix[gi, ci]
        p   = pval_matrix[gi, ci]
        txt_col = 'white' if abs(val) > 0.8 else 'black'
        cell_txt = f'{val:.2f}'
        if p < 0.05:
            cell_txt += '\n*'
        ax4.text(ci, gi, cell_txt, ha='center', va='center',
                 fontsize=7.5, color=txt_col, fontweight='bold')

# 横線で年齢グループを分ける
ax4.axhline(1.5, color='white', linewidth=2.5)
ax4.axhline(3.5, color='white', linewidth=2.5)

# 年齢グループラベル（左側）
age_group_labels = ['青少年\n環境', '現役\n環境', 'シニア\n環境']
for i, (row, lbl) in enumerate(zip([0.5, 2.5, 4.5], age_group_labels)):
    ax4.text(-0.75, row, lbl, ha='center', va='center',
             fontsize=8, color=list(AGE_COLORS.values())[i],
             fontweight='bold', transform=ax4.get_yaxis_transform())

plt.colorbar(im4, ax=ax4, fraction=0.04, pad=0.02,
             label='標準化回帰係数')
ax4.set_title('標準化係数（* p<0.05）', fontsize=10)

# R² バーグラフ（右パネル）
ax4r = axes4[1]
bar_colors_r2 = []
for age, gender in group_keys_ordered:
    base = AGE_COLORS[age]
    bar_colors_r2.append(base if gender == '男性' else base + '99')

y_pos = range(len(group_labels_ordered))
bars4 = []
for yi, (r2, (age, gender)) in enumerate(zip(r2_vals, group_keys_ordered)):
    alpha_val = 0.9 if gender == '男性' else 0.5
    b = ax4r.barh(yi, r2, color=AGE_COLORS[age], alpha=alpha_val, edgecolor='white')
    bars4.append(b[0])
ax4r.set_xlim(0, max(r2_vals) * 1.4 + 0.05)
ax4r.set_yticks(list(y_pos))
ax4r.set_yticklabels([''] * len(group_labels_ordered))
ax4r.set_xlabel('R²', fontsize=9)
ax4r.set_title('R²', fontsize=10)
ax4r.grid(axis='x', alpha=0.3)
for i, (bar, r2) in enumerate(zip(bars4, r2_vals)):
    ax4r.text(r2 + 0.01, bar.get_y() + bar.get_height() / 2,
              f'{r2:.2f}', va='center', fontsize=8)

plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2024_H4_fig4_heatmap_coef.png'),
             bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  -> 2024_H4_fig4_heatmap_coef.png 保存完了")

# ================================================================
# ■ 完了サマリ
# ================================================================

print("\n" + "=" * 65)
print("完了: 全図の生成完了（4枚）")
print("=" * 65)
print(f"\n保存先: {os.path.abspath(FIG_DIR)}")
print("  2024_H4_fig1_corr_matrix.png       - 生活時間変数の相関行列")
print("  2024_H4_fig2_coef_panel.png        - 6モデル回帰係数パネル")
print("  2024_H4_fig3_scatter_sleep_sport.png - 睡眠 vs スポーツ散布図")
print("  2024_H4_fig4_heatmap_coef.png      - 標準化係数ヒートマップ")
print()
print(f"使用データ: SSDSE-D-2023.csv（社会生活基本調査2021）")
print(f"            SSDSE-B-2026.csv（社会・人口統計体系 2022年度）")
print(f"合成データ: なし（np.random 未使用）")
