"""
2018_H3_suri.py
都道府県別高齢化率の時系列分析と将来予測：指数平滑法・Holt線形トレンドモデル
統計数理賞（高校生部門） 2018年度
教育用再現コード — 実データ: SSDSE-B-2026.csv
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2018_H3_suri.py
# ============================================================


import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import statsmodels.api as sm
from statsmodels.tsa.holtwinters import ExponentialSmoothing, SimpleExpSmoothing
from statsmodels.graphics.tsaplots import plot_acf
from sklearn.metrics import mean_squared_error, mean_absolute_error

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150

FIG_DIR = 'html/figures'
DATA_B  = 'data/raw/SSDSE-B-2026.csv'
os.makedirs(FIG_DIR, exist_ok=True)

# ── データ読み込み ──────────────────────────────────────
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
df_b = df_b[df_b['地域コード'].str.match(r'^R\d{5}', na=False)].copy()
df_b['年度'] = df_b['年度'].astype(int)

print("Columns:", df_b.columns.tolist())
print("Years:", sorted(df_b['年度'].unique()))
print("Shape:", df_b.shape)

# 高齢化率を計算（65歳以上人口 / 総人口 × 100）
df_b['高齢化率'] = df_b['65歳以上人口'] / df_b['総人口'] * 100
aging_col = '高齢化率'

print(f"\n高齢化率 統計: \n{df_b[aging_col].describe()}")

# 都道府県名の正規化（「県」「都」「道」「府」を除く短縮名）
def shorten_pref(name):
    return name.replace('県','').replace('都','').replace('道','').replace('府','')

df_b['都道府県短'] = df_b['都道府県'].apply(shorten_pref)

# 地域マッピング（短縮名ベース）
region_map = {
    '北海道': '北海道・東北', '青森': '北海道・東北', '岩手': '北海道・東北', '宮城': '北海道・東北',
    '秋田': '北海道・東北', '山形': '北海道・東北', '福島': '北海道・東北',
    '茨城': '関東', '栃木': '関東', '群馬': '関東', '埼玉': '関東', '千葉': '関東', '東京': '関東', '神奈川': '関東',
    '新潟': '中部', '富山': '中部', '石川': '中部', '福井': '中部', '山梨': '中部',
    '長野': '中部', '岐阜': '中部', '静岡': '中部', '愛知': '中部',
    '三重': '近畿', '滋賀': '近畿', '京都': '近畿', '大阪': '近畿', '兵庫': '近畿', '奈良': '近畿', '和歌山': '近畿',
    '鳥取': '中国・四国', '島根': '中国・四国', '岡山': '中国・四国', '広島': '中国・四国',
    '山口': '中国・四国', '徳島': '中国・四国', '香川': '中国・四国', '愛媛': '中国・四国', '高知': '中国・四国',
    '福岡': '九州・沖縄', '佐賀': '九州・沖縄', '長崎': '九州・沖縄', '熊本': '九州・沖縄',
    '大分': '九州・沖縄', '宮崎': '九州・沖縄', '鹿児島': '九州・沖縄', '沖縄': '九州・沖縄'
}
region_colors = {
    '北海道・東北': '#4e9af1',
    '関東':         '#e05c5c',
    '中部':         '#f0a500',
    '近畿':         '#5cb85c',
    '中国・四国':   '#9b59b6',
    '九州・沖縄':   '#f39c12'
}

df_b['地域'] = df_b['都道府県短'].map(region_map).fillna('その他')

# ── 全国平均時系列 ──────────────────────────────────────
nat_avg = df_b.groupby('年度')[aging_col].mean().sort_index()
years = nat_avg.index.values
print(f"\n全国平均高齢化率（年度別）:\n{nat_avg}")

# ── Simple Exponential Smoothing (SES) ──────────────────
ses = SimpleExpSmoothing(nat_avg.values).fit(optimized=True)
ses_fitted   = ses.fittedvalues
ses_forecast = ses.forecast(3)
print(f"\nSES alpha={ses.params['smoothing_level']:.4f}")

# ── Holt 線形トレンドモデル ──────────────────────────────
holt = ExponentialSmoothing(nat_avg.values, trend='add').fit(optimized=True)
holt_fitted   = holt.fittedvalues
holt_forecast = holt.forecast(3)
print(f"Holt alpha={holt.params['smoothing_level']:.4f}, beta={holt.params['smoothing_trend']:.4f}")

# ── 予測精度 ────────────────────────────────────────────
rmse_ses  = np.sqrt(mean_squared_error(nat_avg.values[1:], ses_fitted[1:]))
rmse_holt = np.sqrt(mean_squared_error(nat_avg.values[1:], holt_fitted[1:]))
mae_ses   = mean_absolute_error(nat_avg.values[1:], ses_fitted[1:])
mae_holt  = mean_absolute_error(nat_avg.values[1:], holt_fitted[1:])
print(f"SES  RMSE={rmse_ses:.4f}, MAE={mae_ses:.4f}")
print(f"Holt RMSE={rmse_holt:.4f}, MAE={mae_holt:.4f}")

forecast_years = np.array([years[-1]+1, years[-1]+2, years[-1]+3])

# ══════════════════════════════════════════════════════════
# 図1：都道府県別高齢化率ランキング棒グラフ（最新年）
# ══════════════════════════════════════════════════════════
latest_year = int(nat_avg.index.max())
df_latest = df_b[df_b['年度'] == latest_year].copy()
df_latest = df_latest.sort_values(aging_col, ascending=True).reset_index(drop=True)
nat_mean_latest = nat_avg[latest_year]

fig1, ax1 = plt.subplots(figsize=(10, 12))

bar_colors = [region_colors.get(r, '#aaaaaa') for r in df_latest['地域']]
bars = ax1.barh(df_latest['都道府県短'], df_latest[aging_col],
                color=bar_colors, edgecolor='white', linewidth=0.4, height=0.75)

ax1.axvline(nat_mean_latest, color='black', linestyle='--', linewidth=1.8,
            label=f'全国平均 {nat_mean_latest:.1f}%')

# 凡例（地域色）
import matplotlib.patches as mpatches
legend_patches = [mpatches.Patch(color=c, label=r) for r, c in region_colors.items()]
legend_patches.append(mpatches.Patch(color='black', label=f'全国平均 {nat_mean_latest:.1f}%',
                                     fill=False, linestyle='--'))
ax1.legend(handles=legend_patches, loc='lower right', fontsize=9, framealpha=0.9)

ax1.set_xlabel('高齢化率（65歳以上人口比率）[%]', fontsize=11)
ax1.set_title(f'都道府県別高齢化率ランキング（{latest_year}年度）\n地域色分け・全国平均点線', fontsize=13, fontweight='bold')
ax1.set_xlim(0, max(df_latest[aging_col]) * 1.08)
ax1.xaxis.set_major_formatter(mticker.FormatStrFormatter('%.0f%%'))
ax1.grid(axis='x', linestyle=':', alpha=0.5)
fig1.tight_layout()
fig1.savefig(os.path.join(FIG_DIR, '2018_H3_fig1.png'), bbox_inches='tight')
plt.close(fig1)
print("Saved fig1")

# ══════════════════════════════════════════════════════════
# 図2：全国平均高齢化率 + SES / Holt フィット + 3年先予測
# ══════════════════════════════════════════════════════════
fig2, ax2 = plt.subplots(figsize=(11, 6))

ax2.plot(years, nat_avg.values, 'ko-', linewidth=2, markersize=6, label='実績値', zorder=5)

ax2.plot(years, ses_fitted, 'b-', linewidth=1.8, alpha=0.85,
         label=f'SES（α={ses.params["smoothing_level"]:.3f}）RMSE={rmse_ses:.3f}')
ax2.plot(forecast_years, ses_forecast, 'b--', linewidth=1.8, alpha=0.85)
ax2.plot([years[-1], forecast_years[0]], [ses_fitted[-1], ses_forecast[0]], 'b--', linewidth=1.8, alpha=0.85)

ax2.plot(years, holt_fitted, 'r-', linewidth=1.8, alpha=0.85,
         label=f'Holt線形（α={holt.params["smoothing_level"]:.3f}, β={holt.params["smoothing_trend"]:.3f}）RMSE={rmse_holt:.3f}')
ax2.plot(forecast_years, holt_forecast, 'r--', linewidth=1.8, alpha=0.85)
ax2.plot([years[-1], forecast_years[0]], [holt_fitted[-1], holt_forecast[0]], 'r--', linewidth=1.8, alpha=0.85)

# 予測区間帯
ax2.axvspan(years[-1], forecast_years[-1], alpha=0.06, color='gray', label='予測期間')
ax2.axvline(years[-1], color='gray', linestyle=':', linewidth=1)
ax2.text(years[-1]+0.1, ax2.get_ylim()[0] if ax2.get_ylim()[0] > 0 else 20,
         '← 実績  予測 →', fontsize=9, color='gray', va='bottom')

ax2.set_xlabel('年度', fontsize=11)
ax2.set_ylabel('高齢化率（全国平均）[%]', fontsize=11)
ax2.set_title('全国平均高齢化率の時系列分析\n指数平滑法（SES）・Holt線形トレンドモデル + 3年先予測', fontsize=13, fontweight='bold')
ax2.legend(fontsize=10, loc='upper left')
ax2.set_xticks(list(years) + list(forecast_years))
ax2.set_xticklabels([str(y) for y in years] + [f'{y}(予)' for y in forecast_years],
                     rotation=45, ha='right', fontsize=9)
ax2.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f%%'))
ax2.grid(linestyle=':', alpha=0.5)
fig2.tight_layout()
fig2.savefig(os.path.join(FIG_DIR, '2018_H3_fig2.png'), bbox_inches='tight')
plt.close(fig2)
print("Saved fig2")

# ══════════════════════════════════════════════════════════
# 図3：ACF プロット（全国平均高齢化率）
# ══════════════════════════════════════════════════════════
fig3, ax3 = plt.subplots(figsize=(10, 4))
max_lags = min(8, len(nat_avg) - 2)
plot_acf(nat_avg.values, lags=max_lags, ax=ax3,
         title='自己相関関数 (ACF)：全国平均高齢化率', alpha=0.05)
ax3.set_xlabel('ラグ（年）', fontsize=11)
ax3.set_ylabel('自己相関係数', fontsize=11)
ax3.set_title('自己相関関数 (ACF)：全国平均高齢化率\n（高い自己相関は強いトレンドの存在を示す）',
              fontsize=12, fontweight='bold')

# 解説注釈
ax3.text(0.02, 0.05,
         '青い影：95%信頼区間\nバーが区間外 → 有意な自己相関あり',
         transform=ax3.transAxes, fontsize=9, color='#1565C0',
         va='bottom', bbox=dict(boxstyle='round,pad=0.3', facecolor='#EFF3FF', alpha=0.8))

fig3.tight_layout()
fig3.savefig(os.path.join(FIG_DIR, '2018_H3_fig3.png'), bbox_inches='tight')
plt.close(fig3)
print("Saved fig3")

# ══════════════════════════════════════════════════════════
# 図4：地域別高齢化率の時系列推移（6地域・折れ線グラフ）
# ══════════════════════════════════════════════════════════
df_region = df_b.copy()
region_ts = df_region.groupby(['年度', '地域'])[aging_col].mean().reset_index()

fig4, axes = plt.subplots(1, 2, figsize=(14, 6))
ax4L, ax4R = axes

# 左：時系列折れ線
for region, color in region_colors.items():
    sub = region_ts[region_ts['地域'] == region].sort_values('年度')
    if len(sub) == 0:
        continue
    ax4L.plot(sub['年度'], sub[aging_col], '-o', color=color, linewidth=2,
              markersize=5, label=region)

ax4L.set_xlabel('年度', fontsize=11)
ax4L.set_ylabel('高齢化率（地域平均）[%]', fontsize=11)
ax4L.set_title('6地域別 高齢化率の時系列推移\n（2012〜2023年）', fontsize=12, fontweight='bold')
ax4L.legend(fontsize=9, loc='upper left')
ax4L.set_xticks(sorted(df_b['年度'].unique()))
ax4L.set_xticklabels([str(y) for y in sorted(df_b['年度'].unique())],
                      rotation=45, ha='right', fontsize=8)
ax4L.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.0f%%'))
ax4L.grid(linestyle=':', alpha=0.5)

# 右：最新年の地域格差（棒グラフ）
latest_region = region_ts[region_ts['年度'] == latest_year].copy()
latest_region = latest_region.sort_values(aging_col, ascending=False)
bar_cols_r = [region_colors.get(r, '#aaaaaa') for r in latest_region['地域']]

ax4R.bar(latest_region['地域'], latest_region[aging_col],
         color=bar_cols_r, edgecolor='white', linewidth=0.4)
ax4R.axhline(nat_mean_latest, color='black', linestyle='--', linewidth=1.8,
             label=f'全国平均 {nat_mean_latest:.1f}%')
ax4R.set_xlabel('地域', fontsize=11)
ax4R.set_ylabel('高齢化率（地域平均）[%]', fontsize=11)
ax4R.set_title(f'{latest_year}年度の地域間格差\n（6地域別平均 vs 全国平均）', fontsize=12, fontweight='bold')
ax4R.legend(fontsize=10)
ax4R.set_xticklabels(latest_region['地域'], rotation=25, ha='right', fontsize=9)
ax4R.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.0f%%'))
ax4R.grid(axis='y', linestyle=':', alpha=0.5)

# 値ラベル
for bar, val in zip(ax4R.patches, latest_region[aging_col]):
    ax4R.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.15,
              f'{val:.1f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')

fig4.suptitle('地域別高齢化率の時系列分析と地域格差', fontsize=14, fontweight='bold', y=1.01)
fig4.tight_layout()
fig4.savefig(os.path.join(FIG_DIR, '2018_H3_fig4.png'), bbox_inches='tight')
plt.close(fig4)
print("Saved fig4")

print("\nDONE: 2018_H3_suri")
print(f"SES  alpha={ses.params['smoothing_level']:.4f}  RMSE={rmse_ses:.4f}  MAE={mae_ses:.4f}")
print(f"Holt alpha={holt.params['smoothing_level']:.4f}, beta={holt.params['smoothing_trend']:.4f}  RMSE={rmse_holt:.4f}  MAE={mae_holt:.4f}")
print(f"SES  3年先予測: {ses_forecast}")
print(f"Holt 3年先予測: {holt_forecast}")
