"""
教育用再現コード: 2024年 統計データ分析コンペティション 審査員奨励賞（大学生・一般）
=================================================================================
論文タイトル：データ駆動型因果探索による投票率変動要因の解明
              ―都道府県データを用いたNOTEARSモデルの適用―
受賞区分  ：審査員奨励賞 ［大学生・一般の部］
著者      ：関屋百々花ほか（一橋大学社会学部・商学部）

【分析概要】
  データ：SSDSE-B-2026（都道府県別パネルデータ 2022年）+
          SSDSE-E-2026（都道府県別クロスセクション）
  目的   ：NOTEARSアルゴリズムで因果DAGを推定し、
            地域指標間の因果方向を解明する

  Step1. 実データ読み込み（SSDSE-B・E 2022年 47都道府県）
  Step2. 派生変数の作成（転入率・高齢化率・大学進学率・保育所千対など）
  Step3. NOTEARS（簡略版）でDAGを推定
  Step4. 重回帰分析との比較

【NOTEARSの簡略実装】
  勾配降下法でDAG制約（自己ループなし）を満たす隣接行列Wを推定。
  実際のNOTEARS（Zheng et al. 2018）は指数行列によるDAG制約を使用するが、
  本コードでは教育目的のため簡略版（L1正則化付き勾配降下）を使用。

【データ出典】
  独立行政法人統計センター「SSDSE（教育用標準データセット）」
  https://www.nstac.go.jp/use/literacy/ssdse/

【図の出力】
  html/figures/2024_U5_4_fig1_dag.png          ... DAGの可視化
  html/figures/2024_U5_4_fig2_adj_matrix.png   ... 隣接行列ヒートマップ
  html/figures/2024_U5_4_fig3_compare.png      ... 重回帰 vs 因果探索係数比較
  html/figures/2024_U5_4_fig4_scatter.png      ... 高齢化率vs転入率の散布図
=================================================================================
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_U5_4_shorei.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import warnings
warnings.filterwarnings('ignore')

try:
    import networkx as nx
    HAS_NX = True
except ImportError:
    HAS_NX = False

import statsmodels.api as sm
from scipy import stats as scipy_stats

# ── パス設定 ─────────────────────────────────────────────────────────────────
FIG_DIR = 'html/figures'
DATA_DIR = 'data/raw'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family':        'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi':         150,
    'axes.spines.top':    False,
    'axes.spines.right':  False,
})

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step1. 実データ読み込み（SSDSE-B 2022年 47都道府県）
# ═══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("■ 実データ読み込み（SSDSE-B-2026 / SSDSE-E-2026）")
print("=" * 60)

YEAR = 2022

# SSDSE-B
df_b_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-B-2026.csv'),
                        encoding='cp932', header=1)
df_b = df_b_raw[
    (df_b_raw['年度'] == YEAR) &
    df_b_raw['地域コード'].str.match(r'^R\d{5}$', na=False)
].copy()
df_b = df_b[df_b['地域コード'] != 'R00000'].reset_index(drop=True)

# SSDSE-E（クロスセクション）
df_e_raw = pd.read_csv(os.path.join(DATA_DIR, 'SSDSE-E-2026.csv'),
                        encoding='cp932', header=0)
df_e = df_e_raw.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[0].values
df_e = df_e.iloc[1:].copy()
df_e.columns = df_e_raw.iloc[1].values
df_e = df_e[df_e['都道府県'] != '全国'].reset_index(drop=True)

print(f"SSDSE-B: {len(df_b)}都道府県 (年度={YEAR})")
print(f"SSDSE-E: {len(df_e)}都道府県")

# ── 数値変換 ──
num_cols_b = ['総人口', '65歳以上人口', '合計特殊出生率',
              '転入者数（日本人移動者）', '転出者数（日本人移動者）',
              '保育所等数', '年平均気温',
              '高等学校卒業者数', '高等学校卒業者のうち進学者数']
for c in num_cols_b:
    df_b[c] = pd.to_numeric(df_b[c], errors='coerce')

num_cols_e = ['1人当たり県民所得（平成27年基準）', '医師数',
              '従業者数（民営）（医療、福祉）']
for c in num_cols_e:
    df_e[c] = pd.to_numeric(df_e[c], errors='coerce')

# ─ 派生変数の作成 ─
df_b['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口']
df_b['転入率'] = (df_b['転入者数（日本人移動者）'] - df_b['転出者数（日本人移動者）']) / df_b['総人口'] * 1000
df_b['保育所千対'] = df_b['保育所等数'] / df_b['総人口'] * 10000
df_b['大学進学率'] = (df_b['高等学校卒業者のうち進学者数'] /
                    df_b['高等学校卒業者数'].replace(0, np.nan))

# SSDSE-B と SSDSE-E を都道府県でマージ
df_merged = df_b[['都道府県', '高齢化率', '転入率', '合計特殊出生率',
                   '保育所千対', '大学進学率', '年平均気温']].copy()

# 県民所得をEから追加
df_e_sub = df_e[['都道府県', '1人当たり県民所得（平成27年基準）']].copy()
df_e_sub.columns = ['都道府県', '県民所得']
df_merged = df_merged.merge(df_e_sub, on='都道府県', how='left')
df_merged['県民所得'] = pd.to_numeric(df_merged['県民所得'], errors='coerce')

# 変数選定（6変数）
VAR_NAMES = ['転入率', '高齢化率', '合計特殊出生率', '県民所得', '大学進学率', '保育所千対']
df_analysis = df_merged[VAR_NAMES].dropna()
PREFS = df_merged.loc[df_analysis.index, '都道府県'].values

print(f"\n分析対象: {len(df_analysis)}都道府県")
print(df_analysis.describe().round(3))

# 標準化
X = df_analysis.values.copy()
X_std = (X - X.mean(axis=0)) / X.std(axis=0)
d = X_std.shape[1]

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step2. NOTEARS簡略実装
# ═══════════════════════════════════════════════════════════════════════════════
def notears_linear(X, lambda1=0.05, max_iter=500, lr=0.005):
    """
    NOTEARSアルゴリズム簡略版（教育目的）
    X: 標準化済みデータ (N, d)
    returns: W (d, d) 隣接行列（W[i,j]はi→jの効果）
    論文: Zheng et al. 2018 "DAGs with NO TEARS"
    """
    n, d = X.shape
    W = np.zeros((d, d))
    for iteration in range(max_iter):
        residuals = X - X @ W
        grad = -X.T @ residuals / n
        W_new = W - lr * grad
        threshold = lambda1 * lr
        W_new = np.sign(W_new) * np.maximum(np.abs(W_new) - threshold, 0)
        np.fill_diagonal(W_new, 0)
        W = W_new
    W_sparse = W.copy()
    W_sparse[np.abs(W_sparse) < 0.03] = 0.0
    return W_sparse

print("\n■ NOTEARS（簡略版）実行中...")
W_est = notears_linear(X_std, lambda1=0.05, max_iter=500)

print("\n【推定された隣接行列】")
W_df = pd.DataFrame(W_est, index=VAR_NAMES, columns=VAR_NAMES)
print(W_df.round(3))

# ═══════════════════════════════════════════════════════════════════════════════
# ■ Step3. 重回帰分析（比較用）
# ═══════════════════════════════════════════════════════════════════════════════
X_reg = sm.add_constant(X_std[:, 1:])  # 転入率以外を説明変数
ols = sm.OLS(X_std[:, 0], X_reg).fit()  # 転入率を目的変数

print("\n【重回帰分析結果（目的変数：転入率）】")
print(ols.summary().tables[1])

ols_coefs = ols.params[1:]
notears_coefs = W_est[1:, 0]  # 各変数→転入率の効果

# ═══════════════════════════════════════════════════════════════════════════════
# ■ 図の生成（4枚）
# ═══════════════════════════════════════════════════════════════════════════════

COLORS = {'positive': '#1565C0', 'negative': '#C62828', 'neutral': '#90CAF9'}

# ────────────────────────────────────────────────────────────────────────────
# 図1：DAGの可視化
# ────────────────────────────────────────────────────────────────────────────
print("\n図1: DAGを作成中...")

fig1, ax1 = plt.subplots(figsize=(10, 8))
ax1.set_xlim(-1.5, 1.5)
ax1.set_ylim(-1.5, 1.5)
ax1.set_aspect('equal')
ax1.axis('off')

n_vars = len(VAR_NAMES)
angles = np.linspace(0, 2 * np.pi, n_vars, endpoint=False) + np.pi / 2
positions = {var: (np.cos(a), np.sin(a)) for var, a in zip(VAR_NAMES, angles)}

node_colors = {
    '転入率': '#E65100', '高齢化率': '#1565C0', '合計特殊出生率': '#2E7D32',
    '県民所得': '#6A1B9A', '大学進学率': '#00695C', '保育所千対': '#795548'
}
for var, (x, y) in positions.items():
    color = node_colors[var]
    circle = plt.Circle((x, y), 0.22, color=color, alpha=0.85, zorder=4)
    ax1.add_patch(circle)
    ax1.text(x, y, var, ha='center', va='center', fontsize=9, fontweight='bold',
             color='white', zorder=5)

threshold = 0.08
for i, src in enumerate(VAR_NAMES):
    for j, tgt in enumerate(VAR_NAMES):
        if abs(W_est[i, j]) > threshold:
            x1, y1 = positions[src]
            x2, y2 = positions[tgt]
            dx, dy = x2 - x1, y2 - y1
            dist = np.sqrt(dx**2 + dy**2)
            r = 0.22
            sx = x1 + r * dx / dist
            sy = y1 + r * dy / dist
            ex = x2 - r * dx / dist
            ey = y2 - r * dy / dist
            color = '#1565C0' if W_est[i, j] > 0 else '#C62828'
            lw = 1.5 + 2 * abs(W_est[i, j])
            ax1.annotate('', xy=(ex, ey), xytext=(sx, sy),
                         arrowprops=dict(arrowstyle='->', color=color,
                                          lw=lw, mutation_scale=15))
            mid_x = (sx + ex) / 2 + 0.05 * (-dy / dist)
            mid_y = (sy + ey) / 2 + 0.05 * (dx / dist)
            ax1.text(mid_x, mid_y, f'{W_est[i, j]:.2f}', fontsize=8,
                     color=color, ha='center', zorder=6)

ax1.set_title('NOTEARS推定による有向非巡回グラフ（DAG）\n都道府県指標の因果構造（SSDSE 2022年）',
              fontsize=13, fontweight='bold')
blue_patch = mpatches.Patch(color='#1565C0', label='正の因果効果')
red_patch = mpatches.Patch(color='#C62828', label='負の因果効果')
ax1.legend(handles=[blue_patch, red_patch], fontsize=10, loc='lower right')

plt.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2024_U5_4_fig1_dag.png'), bbox_inches='tight', dpi=150)
plt.close(fig1)
print("  → 2024_U5_4_fig1_dag.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図2：隣接行列ヒートマップ
# ────────────────────────────────────────────────────────────────────────────
print("図2: 隣接行列ヒートマップを作成中...")

fig2, ax2 = plt.subplots(figsize=(8, 7))
im = ax2.imshow(W_est, cmap='RdBu_r', vmin=-0.5, vmax=0.5, aspect='auto')
plt.colorbar(im, ax=ax2, label='因果効果の強さ')
ax2.set_xticks(range(d))
ax2.set_yticks(range(d))
ax2.set_xticklabels(VAR_NAMES, fontsize=9, rotation=20, ha='right')
ax2.set_yticklabels(VAR_NAMES, fontsize=9)
ax2.set_xlabel('効果を受ける変数（子ノード）', fontsize=11)
ax2.set_ylabel('効果を与える変数（親ノード）', fontsize=11)
ax2.set_title('NOTEARS推定隣接行列\nW[i,j]: i → j の因果効果', fontsize=12, fontweight='bold')
for i in range(d):
    for j in range(d):
        val = W_est[i, j]
        if abs(val) > 0.05:
            ax2.text(j, i, f'{val:.2f}', ha='center', va='center',
                     fontsize=9, fontweight='bold',
                     color='white' if abs(val) > 0.3 else 'black')
        else:
            ax2.text(j, i, '0', ha='center', va='center', fontsize=8, color='#aaa')

plt.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2024_U5_4_fig2_adj_matrix.png'), bbox_inches='tight', dpi=150)
plt.close(fig2)
print("  → 2024_U5_4_fig2_adj_matrix.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図3：重回帰 vs 因果探索 係数比較
# ────────────────────────────────────────────────────────────────────────────
print("図3: 係数比較を作成中...")

var_labels = VAR_NAMES[1:]
x = np.arange(len(var_labels))
width = 0.35

fig3, ax3 = plt.subplots(figsize=(10, 6))
bars_ols = ax3.bar(x - width/2, ols_coefs, width, label='重回帰（OLS）',
                    color='#1565C0', alpha=0.82, edgecolor='white')
bars_nt = ax3.bar(x + width/2, notears_coefs, width, label='NOTEARS（因果探索）',
                   color='#E65100', alpha=0.82, edgecolor='white')
ax3.axhline(0, color='black', linewidth=0.8)
ax3.set_xticks(x)
ax3.set_xticklabels(var_labels, fontsize=10, rotation=10, ha='right')
ax3.set_ylabel('標準化係数', fontsize=12)
ax3.set_title('重回帰 vs NOTEARS（因果探索）\n転入率への効果の比較（SSDSE実データ）',
              fontsize=13, fontweight='bold')
ax3.legend(fontsize=11)
ax3.grid(axis='y', alpha=0.3)

for bar in bars_ols:
    h = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2, h + (0.01 if h >= 0 else -0.02),
             f'{h:.2f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=8.5, color='#1565C0')
for bar in bars_nt:
    h = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2, h + (0.01 if h >= 0 else -0.02),
             f'{h:.2f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=8.5, color='#E65100')

plt.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2024_U5_4_fig3_compare.png'), bbox_inches='tight', dpi=150)
plt.close(fig3)
print("  → 2024_U5_4_fig3_compare.png 保存完了")

# ────────────────────────────────────────────────────────────────────────────
# 図4：高齢化率 vs 転入率 散布図
# ────────────────────────────────────────────────────────────────────────────
print("図4: 高齢化率 vs 転入率散布図を作成中...")

aging_rate = df_analysis['高齢化率'].values * 100   # %表示
net_inflow = df_analysis['転入率'].values

r_val, p_val = scipy_stats.pearsonr(aging_rate, net_inflow)

# 県民所得でカテゴリ分け
income_vals = df_analysis['県民所得'].values
income_cat = pd.qcut(income_vals, q=3, labels=['低所得圏', '中所得圏', '高所得圏'])
cat_colors = {'低所得圏': '#43A047', '中所得圏': '#FB8C00', '高所得圏': '#E53935'}

fig4, axes4 = plt.subplots(1, 2, figsize=(13, 5))
fig4.suptitle('高齢化率と転入率の関係（都道府県別実データ）', fontsize=13, fontweight='bold')

ax4a = axes4[0]
for cat, col in cat_colors.items():
    mask = income_cat == cat
    ax4a.scatter(aging_rate[mask], net_inflow[mask], c=col, alpha=0.75, s=40, label=cat)
coef4 = np.polyfit(aging_rate, net_inflow, 1)
x_fit = np.linspace(aging_rate.min(), aging_rate.max(), 100)
ax4a.plot(x_fit, np.polyval(coef4, x_fit), 'k-', linewidth=2,
          label=f'回帰直線 (β={coef4[0]:+.3f})')
# 代表的な都道府県にラベル
for i, pref in enumerate(PREFS):
    if pref in ['東京都', '沖縄県', '秋田県', '神奈川県', '大阪府']:
        ax4a.annotate(pref.replace('県','').replace('府','').replace('都','').replace('道',''),
                      (aging_rate[i], net_inflow[i]),
                      textcoords='offset points', xytext=(5, 3), fontsize=8, color='#333')
ax4a.set_xlabel('高齢化率（65歳以上割合 %）', fontsize=11)
ax4a.set_ylabel('転入率（転入超過 /千人）', fontsize=11)
ax4a.set_title(f'高齢化率 → 転入率\nr = {r_val:.3f}', fontsize=11, fontweight='bold')
ax4a.legend(fontsize=8, markerscale=2)
ax4a.grid(True, alpha=0.2)
ax4a.text(0.05, 0.95, f'r = {r_val:.3f}\np = {p_val:.3f}', transform=ax4a.transAxes,
          fontsize=10, va='top', bbox=dict(boxstyle='round', facecolor='#E3F2FD', alpha=0.8))

ax4b = axes4[1]
for cat, col in cat_colors.items():
    mask = income_cat == cat
    mean_aging = aging_rate[mask].mean()
    mean_inflow = net_inflow[mask].mean()
    ax4b.scatter(mean_aging, mean_inflow, c=col, s=200, alpha=0.9, zorder=5, label=cat)
    ax4b.annotate(f'{cat}\n高齢化率={mean_aging:.1f}%\n転入率={mean_inflow:.2f}‰',
                  (mean_aging, mean_inflow), textcoords='offset points', xytext=(10, 5),
                  fontsize=9, color=col, fontweight='bold')

ax4b.plot(x_fit, np.polyval(coef4, x_fit), 'k--', linewidth=1.5, alpha=0.5)
ax4b.set_xlabel('高齢化率（平均）', fontsize=11)
ax4b.set_ylabel('転入率（平均）', fontsize=11)
ax4b.set_title('県民所得カテゴリ別の平均値\n（グループ間比較）', fontsize=11, fontweight='bold')
ax4b.legend(fontsize=9)
ax4b.grid(True, alpha=0.2)

plt.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2024_U5_4_fig4_scatter.png'), bbox_inches='tight', dpi=150)
plt.close(fig4)
print("  → 2024_U5_4_fig4_scatter.png 保存完了")

print("\n" + "=" * 60)
print("✓ 全図の生成完了（4枚）")
print("=" * 60)
print("\n【主要知見】")
print(f"  高齢化率→転入率: NOTEARS係数={W_est[1,0]:.3f}, OLS係数={ols_coefs[0]:.3f}")
print(f"  合計特殊出生率→転入率: NOTEARS係数={W_est[2,0]:.3f}, OLS係数={ols_coefs[1]:.3f}")
print(f"  県民所得→転入率: NOTEARS係数={W_est[3,0]:.3f}")
print(f"  相関係数（高齢化率×転入率）: r = {r_val:.3f}")
print(f"  使用データ: SSDSE-B-2026, SSDSE-E-2026 ({YEAR}年, {len(df_analysis)}都道府県)")
