"""
2022_U4 分析スクリプト: 生活系ごみ排出量の決定要因分析
==========================================================
【論文】生活系ごみ排出量と事業系ごみ排出量による回帰分析
        統計活用奨励賞（大学生・一般の部）

【手法】
  - 重回帰分析（OLS）: 1人1日当たりのごみ排出量を目的変数
  - 説明変数: 食料費割合、教養娯楽費割合、65歳以上割合、年平均気温、ごみのリサイクル率
  - VIF（分散拡大因子）による多重共線性チェック
  - 標準化回帰係数（β係数）による変数間の相対的重要度比較

【入力】
  data/raw/SSDSE-B-2026.csv  （都道府県別統計、2022年度 N=47）

【出力】
  html/figures/2022_U4_fig1_corr.png   相関ヒートマップ
  html/figures/2022_U4_fig2_vif.png    VIF棒グラフ
  html/figures/2022_U4_fig3_coef.png   標準化回帰係数プロット（95%CI付き）
  html/figures/2022_U4_fig4_scatter.png 消費支出 vs 1人1日ごみ排出量 散布図
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2022_U4_katsuyo.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from scipy import stats

warnings.filterwarnings('ignore')

# ── パス設定 ──────────────────────────────────────────────────────────
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
FIG_DIR = 'html/figures'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family': 'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi': 150,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# ── データ読み込み ─────────────────────────────────────────────────────
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}$', na=False)].copy()
df_2022 = df_b[df_b['年度'] == 2022].copy()
df_2022 = df_2022.reset_index(drop=True)

print(f"データ読み込み完了: {len(df_2022)}都道府県（2022年度）")

# ── 変数作成 ──────────────────────────────────────────────────────────
# 消費支出に占める構成比（%）
df_2022['食料費割合']   = (df_2022['食料費（二人以上の世帯）']   / df_2022['消費支出（二人以上の世帯）']) * 100
df_2022['教養娯楽費割合'] = (df_2022['教養娯楽費（二人以上の世帯）'] / df_2022['消費支出（二人以上の世帯）']) * 100
df_2022['65歳以上割合']  = (df_2022['65歳以上人口'] / df_2022['総人口']) * 100

# 変数名リスト
TARGET  = '1人1日当たりの排出量'   # g/人/日
PRED_NAMES = {
    'food_ratio':    '食料費割合（%）',
    'culture_ratio': '教養娯楽費割合（%）',
    'elder_ratio':   '65歳以上割合（%）',
    'temp':          '年平均気温（℃）',
    'recycle':       'リサイクル率（%）',
}

df_2022['food_ratio']    = df_2022['食料費割合']
df_2022['culture_ratio'] = df_2022['教養娯楽費割合']
df_2022['elder_ratio']   = df_2022['65歳以上割合']
df_2022['temp']          = df_2022['年平均気温']
df_2022['recycle']       = df_2022['ごみのリサイクル率']
df_2022['y']             = df_2022[TARGET]

PREDS = list(PRED_NAMES.keys())
analysis_cols = ['都道府県', 'y'] + PREDS
df_ana = df_2022[analysis_cols].dropna().copy()

print(f"分析対象: {len(df_ana)}都道府県（欠損除外後）")
print(f"\n目的変数 ({TARGET}) 記述統計:")
print(df_ana['y'].describe().round(1))

# ── 記述統計 ──────────────────────────────────────────────────────────
print(f"\n説明変数 記述統計:")
for p, name in PRED_NAMES.items():
    s = df_ana[p]
    print(f"  {name}: mean={s.mean():.2f}, std={s.std():.2f}, "
          f"min={s.min():.2f}, max={s.max():.2f}")

# ── OLS 重回帰分析 ──────────────────────────────────────────────────
y  = df_ana['y'].values
X_raw = df_ana[PREDS].values

X_sm = sm.add_constant(X_raw)
model = sm.OLS(y, X_sm).fit()

print("\n" + "=" * 60)
print("■ OLS 重回帰分析結果")
print("=" * 60)
print(model.summary())

# ── VIF 計算（定数項を含めた行列で各説明変数のVIFを計算）──────────
# statsmodels の variance_inflation_factor は X の i番目の列を他の列で回帰する
# 正しい使い方: sm.add_constant した行列を渡し、定数列(index=0)を除く各列を計算
X_vif = sm.add_constant(X_raw)
vif_vals = []
for i in range(1, X_vif.shape[1]):   # index 0 は定数列
    vif_vals.append(variance_inflation_factor(X_vif, i))

print("\n■ VIF（分散拡大因子）")
for name, vif in zip(PRED_NAMES.values(), vif_vals):
    flag = "  ← 問題なし" if vif < 5 else ("  ← 要注意" if vif < 10 else "  ← 問題あり")
    print(f"  {name}: VIF = {vif:.2f}{flag}")

# ── 標準化回帰係数 ─────────────────────────────────────────────────
y_std  = (y - y.mean()) / y.std()
X_std  = (X_raw - X_raw.mean(axis=0)) / X_raw.std(axis=0)
X_std_sm = sm.add_constant(X_std)
model_std = sm.OLS(y_std, X_std_sm).fit()

beta_coef  = model_std.params[1:]   # 定数項を除く
_ci_all    = model_std.conf_int(alpha=0.05)       # ndarray shape (k+1, 2)
beta_ci    = _ci_all[1:]                           # 定数項を除く 95%CI
beta_pvals = model_std.pvalues[1:]

print("\n■ 標準化回帰係数（β係数）")
for name, beta, pval in zip(PRED_NAMES.values(), beta_coef, beta_pvals):
    sig = "**" if pval < 0.01 else ("*" if pval < 0.05 else "")
    print(f"  {name}: β = {beta:.4f}  (p={pval:.4f}) {sig}")

print(f"\n  R² = {model.rsquared:.4f},  Adj. R² = {model.rsquared_adj:.4f}")
print(f"  F統計量 = {model.fvalue:.2f},  p(F) = {model.f_pvalue:.4f}")

# ════════════════════════════════════════════════════════════════════
# Figure 1: 相関ヒートマップ
# ════════════════════════════════════════════════════════════════════
label_map = {'y': '1人1日\nごみ排出量'}
label_map.update({p: n.replace('（%）','').replace('（℃）','') for p, n in PRED_NAMES.items()})

corr_cols   = ['y'] + PREDS
corr_labels = [label_map[c] for c in corr_cols]
corr_mat    = df_ana[corr_cols].corr()

fig, ax = plt.subplots(figsize=(8, 6.5))
im = ax.imshow(corr_mat.values, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')

n = len(corr_cols)
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(corr_labels, fontsize=10)
ax.set_yticklabels(corr_labels, fontsize=10)
plt.colorbar(im, ax=ax, shrink=0.85, label='相関係数 r')

for i in range(n):
    for j in range(n):
        v = corr_mat.values[i, j]
        c = 'white' if abs(v) > 0.6 else 'black'
        ax.text(j, i, f'{v:.2f}', ha='center', va='center', fontsize=10,
                color=c, fontweight='bold')

ax.set_title('図1：相関ヒートマップ\n（目的変数＋説明変数、2022年度 N=47都道府県）',
             fontsize=12, fontweight='bold', pad=12)
fig.tight_layout()
out1 = os.path.join(FIG_DIR, '2022_U4_fig1_corr.png')
fig.savefig(out1, bbox_inches='tight', dpi=150)
plt.close(fig)
print(f"\n[保存] {out1}")

# ════════════════════════════════════════════════════════════════════
# Figure 2: VIF 棒グラフ
# ════════════════════════════════════════════════════════════════════
short_names = [n.replace('（%）','').replace('（℃）','') for n in PRED_NAMES.values()]

colors_vif = ['#E53935' if v >= 10 else ('#FFA726' if v >= 5 else '#42A5F5')
              for v in vif_vals]

fig, ax = plt.subplots(figsize=(8, 4.5))
bars = ax.barh(short_names, vif_vals, color=colors_vif, edgecolor='white',
               linewidth=0.8, alpha=0.9)
ax.axvline(5,  color='#FFA726', linewidth=1.5, linestyle='--', label='VIF=5（注意水準）')
ax.axvline(10, color='#E53935', linewidth=1.5, linestyle='--', label='VIF=10（問題水準）')

for bar, v in zip(bars, vif_vals):
    ax.text(v + 0.05, bar.get_y() + bar.get_height()/2,
            f'{v:.2f}', va='center', fontsize=11, fontweight='bold')

ax.set_xlabel('VIF（分散拡大因子）', fontsize=11)
ax.set_title('図2：VIF による多重共線性チェック\n（VIF < 5: 問題なし、5–10: 要注意、≥ 10: 問題あり）',
             fontsize=12, fontweight='bold', pad=10)
ax.legend(fontsize=10, loc='lower right')
ax.set_xlim(0, max(vif_vals) * 1.2 + 0.5)
fig.tight_layout()
out2 = os.path.join(FIG_DIR, '2022_U4_fig2_vif.png')
fig.savefig(out2, bbox_inches='tight', dpi=150)
plt.close(fig)
print(f"[保存] {out2}")

# ════════════════════════════════════════════════════════════════════
# Figure 3: 標準化回帰係数プロット（95% CI付き）
# ════════════════════════════════════════════════════════════════════
# beta_ci は ndarray shape (n_vars, 2)  or DataFrame — 両対応で numpy化
beta_ci_arr = np.array(beta_ci)
ci_lo = beta_ci_arr[:, 0]
ci_hi = beta_ci_arr[:, 1]
err_lo = beta_coef - ci_lo
err_hi = ci_hi - beta_coef

# 係数の大きい順にソート
order   = np.argsort(np.abs(beta_coef))[::-1]
b_sorted   = beta_coef[order]
lo_sorted  = err_lo[order]
hi_sorted  = err_hi[order]
names_sorted = [short_names[i] for i in order]
sig_sorted   = np.array(beta_pvals)[order]

colors_coef = ['#C62828' if b > 0 else '#1565C0' for b in b_sorted]
alpha_vals  = [1.0 if p < 0.05 else 0.45 for p in sig_sorted]

fig, ax = plt.subplots(figsize=(8, 5))
for i, (name, b, lo, hi, col, al, pv) in enumerate(
        zip(names_sorted, b_sorted, lo_sorted, hi_sorted, colors_coef, alpha_vals, sig_sorted)):
    ax.barh(i, b, xerr=[[lo], [hi]], color=col, alpha=al, edgecolor='white',
            linewidth=0.8, capsize=5, error_kw={'elinewidth': 1.5, 'ecolor': '#555'})
    sig_mark = '**' if pv < 0.01 else ('*' if pv < 0.05 else 'n.s.')
    x_pos = b + (hi if b >= 0 else -lo) * 1.05
    ax.text(x_pos, i, f' {sig_mark}  β={b:.3f}',
            va='center', fontsize=9, color='#333')

ax.axvline(0, color='black', linewidth=0.8, linestyle='-')
ax.set_yticks(range(len(names_sorted)))
ax.set_yticklabels(names_sorted, fontsize=11)
ax.set_xlabel('標準化回帰係数（β）', fontsize=11)
ax.set_title('図3：標準化回帰係数プロット（95%信頼区間付き）\n'
             '赤: 正の効果, 青: 負の効果, 透明: 非有意（p≥0.05）',
             fontsize=12, fontweight='bold', pad=10)

red_patch  = mpatches.Patch(color='#C62828', label='正の効果（β > 0）')
blue_patch = mpatches.Patch(color='#1565C0', label='負の効果（β < 0）')
ax.legend(handles=[red_patch, blue_patch], fontsize=10, loc='lower right')

x_vals = np.concatenate([b_sorted - lo_sorted, b_sorted + hi_sorted])
x_margin = max(abs(x_vals)) * 0.25
ax.set_xlim(min(x_vals) - x_margin, max(x_vals) + x_margin)

fig.tight_layout()
out3 = os.path.join(FIG_DIR, '2022_U4_fig3_coef.png')
fig.savefig(out3, bbox_inches='tight', dpi=150)
plt.close(fig)
print(f"[保存] {out3}")

# ════════════════════════════════════════════════════════════════════
# Figure 4: 消費支出構成 vs 1人1日ごみ排出量 散布図（都道府県ラベル）
# ════════════════════════════════════════════════════════════════════
pref_labels = df_ana['都道府県'].values
y_vals      = df_ana['y'].values
food_vals   = df_ana['food_ratio'].values
culture_vals = df_ana['culture_ratio'].values

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, x_vals, xlabel, x_col, label_idx in [
    (axes[0], food_vals,    '食料費割合（%）',   'food',    True),
    (axes[1], culture_vals, '教養娯楽費割合（%）', 'culture', True),
]:
    sc = ax.scatter(x_vals, y_vals, s=55, c=y_vals, cmap='RdYlGn_r',
                    vmin=y_vals.min() - 20, vmax=y_vals.max() + 20,
                    edgecolors='#555', linewidths=0.5, zorder=3, alpha=0.9)

    # 都道府県ラベル（主要都道府県のみ）
    # 上位5位・下位5位 + 外れ値候補
    y_rank = pd.Series(y_vals)
    top_idx = set(y_rank.nlargest(5).index) | set(y_rank.nsmallest(5).index)
    for i, (px, py, lbl) in enumerate(zip(x_vals, y_vals, pref_labels)):
        if i in top_idx:
            ax.annotate(lbl, (px, py), fontsize=7.5, alpha=0.9,
                        xytext=(4, 2), textcoords='offset points',
                        color='#222')

    # 回帰直線
    slope, intercept, r_val, p_val, se = stats.linregress(x_vals, y_vals)
    x_line = np.linspace(x_vals.min(), x_vals.max(), 100)
    ax.plot(x_line, slope * x_line + intercept,
            color='#E53935', linewidth=1.8, linestyle='--',
            label=f'回帰直線 (r={r_val:.3f}, p={p_val:.3f})')

    ax.set_xlabel(xlabel, fontsize=11)
    ax.set_ylabel('1人1日当たりの排出量（g/人/日）', fontsize=11)
    ax.legend(fontsize=10, loc='best')
    ax.grid(axis='y', alpha=0.3)

plt.colorbar(sc, ax=axes[1], label='1人1日排出量（g/人/日）', shrink=0.9)

fig.suptitle('図4：消費支出構成割合 と 1人1日ごみ排出量 の関係\n'
             '（2022年度 都道府県別 N=47、ラベル: 排出量上位・下位5府県）',
             fontsize=12, fontweight='bold', y=1.01)
fig.tight_layout()
out4 = os.path.join(FIG_DIR, '2022_U4_fig4_scatter.png')
fig.savefig(out4, bbox_inches='tight', dpi=150)
plt.close(fig)
print(f"[保存] {out4}")

# ── 結果サマリー ──────────────────────────────────────────────────────
print("\n" + "=" * 60)
print("■ 分析サマリー")
print("=" * 60)
print(f"  目的変数: {TARGET}")
print(f"  サンプル数: N = {len(df_ana)}")
print(f"  R² = {model.rsquared:.4f}  ({model.rsquared:.1%}の変動を説明)")
print(f"  Adj. R² = {model.rsquared_adj:.4f}")
print(f"  F統計量 = {model.fvalue:.2f},  p = {model.f_pvalue:.4f}")
print()
print("  【標準化回帰係数 上位変数】")
sorted_abs = sorted(zip(PRED_NAMES.values(), beta_coef, np.array(beta_pvals)),
                    key=lambda x: abs(x[1]), reverse=True)
for name, beta, pv in sorted_abs:
    direction = "↑ 増加" if beta > 0 else "↓ 減少"
    sig = "**" if pv < 0.01 else ("*" if pv < 0.05 else "(n.s.)")
    print(f"  {name}: β={beta:+.4f} → 排出量{direction}  {sig}")
print()
print("  【多重共線性】")
for name, vif in zip(PRED_NAMES.values(), vif_vals):
    status = "問題なし" if vif < 5 else ("要注意" if vif < 10 else "問題あり")
    print(f"  {name}: VIF={vif:.2f} ({status})")

print("\n■ 完了。全4図を html/figures/ に保存しました。")
