"""
2024_U1_daijin.py
COVID-19の5類感染症移行後における宿泊者数損失の要因分析
総務大臣賞 [大学生・一般の部]
中江芙佳、緒方奏士、山本真大、佐々木大地（同志社大学文化情報学部）

実データ（SSDSE-B-2026）による教育用再現コード
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル:
#     ・SSDSE-B-2026.csv
#       → data/raw/SSDSE-B-2026.csv に配置
#     ・SSDSE-E-2026.csv
#       → data/raw/SSDSE-E-2026.csv に配置
#
#   ダウンロード先:
#     https://www.nstac.go.jp/use/literacy/ssdse/
#     （SSDSE-B（社会・人口統計体系 都道府県データ） の CSV をダウンロード）
#     （SSDSE-E（社会・人口統計体系 都道府県の指標2） の CSV をダウンロード）
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2024_U1_daijin.py
# ============================================================


import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from scipy import stats
from numpy.linalg import lstsq
from sklearn.preprocessing import StandardScaler

plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['axes.unicode_minus'] = False

import os
FIGDIR = os.path.normpath('html/figures') + os.sep
DATA_B = 'data/raw/SSDSE-B-2026.csv'
DATA_E = 'data/raw/SSDSE-E-2026.csv'
os.makedirs(FIGDIR, exist_ok=True)

# ----------------------------------------------------------------
# データ読み込み: SSDSE-B 2014-2023, 47都道府県
# ----------------------------------------------------------------
df_b = pd.read_csv(DATA_B, encoding='cp932', header=1)
mask = df_b['地域コード'].str.match(r'^R\d{5}$', na=False) & (df_b['地域コード'] != 'R00000')
df_b = df_b[mask].copy()
df_b = df_b[df_b['年度'].between(2014, 2023)].copy()

# 都道府県リスト（北海道順）
pref_order = (df_b[df_b['年度'] == 2023]
              .sort_values('地域コード')['都道府県'].tolist())

# 延べ宿泊者数のパネルデータ（都道府県×年度）
pivot = df_b.pivot_table(index='都道府県', columns='年度',
                         values='延べ宿泊者数', aggfunc='first')
pivot = pivot.loc[pref_order]

years = list(range(2014, 2024))

# ================================================================
# Holt-Winters 指数平滑化 (加法モデル, 季節なし年次データ)
# 2014-2019を学習 → 2020-2023を反事実推定
# ================================================================
def holt_winters_additive(data, alpha=0.4, beta=0.2):
    """Simple Holt linear (double exponential smoothing)"""
    n = len(data)
    L = np.zeros(n)
    B = np.zeros(n)
    L[0] = data[0]
    B[0] = data[1] - data[0]
    for t in range(1, n):
        L[t] = alpha * data[t] + (1 - alpha) * (L[t-1] + B[t-1])
        B[t] = beta * (L[t] - L[t-1]) + (1 - beta) * B[t-1]
    return L, B

train_years = list(range(2014, 2020))   # 6 years
predict_years = list(range(2020, 2024)) # 4 years

actual_all = pivot[years].values          # (47, 10)
cf_all = np.zeros_like(actual_all, dtype=float)

for i, pref in enumerate(pref_order):
    train_vals = pivot.loc[pref, train_years].values.astype(float)
    L, B = holt_winters_additive(train_vals)
    # training fit
    for j, yr in enumerate(train_years):
        cf_all[i, j] = train_vals[j]   # actual = counterfactual for pre-COVID
    # forecast 4 steps ahead
    L_last, B_last = L[-1], B[-1]
    for h, yr in enumerate(predict_years, start=1):
        cf_all[i, len(train_years) + h - 1] = L_last + h * B_last

# actual stays as real data
actual_total = actual_all.sum(axis=0)   # sum over 47 prefs
cf_total = cf_all.sum(axis=0)

# Loss: counterfactual - actual for 2020-2023 (indices 6-9)
loss_by_pref = (cf_all[:, 6:] - actual_all[:, 6:]).sum(axis=1)

# ================================================================
# 図1: 全国合計 実績 vs カウンターファクチャル（年次棒グラフ+線）
# ================================================================
fig, ax = plt.subplots(figsize=(10, 5))

years_arr = np.array(years)
ax.bar(years_arr, actual_total / 1e6, color='steelblue', alpha=0.6, label='実績値（延べ宿泊者数）')
ax.plot(years_arr, cf_total / 1e6, '--o', color='darkred', lw=2,
        label='カウンターファクチャル（Holt-Winters予測）')

# Loss shading
covid_idx = 6
for j in range(covid_idx, len(years)):
    ax.bar(years_arr[j], (cf_total[j] - actual_total[j]) / 1e6,
           bottom=actual_total[j] / 1e6, color='red', alpha=0.3)

ax.axvline(2019.5, color='red', lw=1.5, linestyle=':')
ax.axvline(2022.5, color='orange', lw=1.5, linestyle=':')
ax.text(2020.0, ax.get_ylim()[1] * 0.95, 'COVID-19\n感染拡大', fontsize=9, color='red', ha='center')
ax.text(2023.0, ax.get_ylim()[1] * 0.95, '5類移行\n(2023)', fontsize=9, color='orange', ha='center')

ax.set_title("図1: 全国延べ宿泊者数 実績 vs カウンターファクチャル（Holt-Winters予測）", fontsize=13)
ax.set_xlabel("年度")
ax.set_ylabel("延べ宿泊者数（百万人泊）")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U1_fig1_trend.png", dpi=150)
plt.close()
print("fig1 saved")

# ================================================================
# 図2: 都道府県別宿泊者数損失（上位・下位）
# ================================================================
high_tourism = ['沖縄県', '東京都', '大阪府', '北海道', '京都府']

fig, ax = plt.subplots(figsize=(12, 8))
loss_series = pd.Series(loss_by_pref / 1e6, index=pref_order).sort_values(ascending=True)
colors = ['#C62828' if pref in high_tourism else '#1565C0' for pref in loss_series.index]

ax.barh(loss_series.index, loss_series.values, color=colors, alpha=0.8, edgecolor='white')
ax.axvline(loss_series.mean(), color='gray', lw=1.5, linestyle='--')
ax.text(loss_series.mean() + 0.5, 2, f'平均={loss_series.mean():.1f}', fontsize=9, color='gray')

red_patch = mpatches.Patch(color='#C62828', alpha=0.8, label='主要観光都道府県')
blue_patch = mpatches.Patch(color='#1565C0', alpha=0.8, label='その他')
ax.legend(handles=[red_patch, blue_patch], fontsize=10)

ax.set_title("図2: 都道府県別 COVID-19 延べ宿泊者数損失累計（2020-2023年）", fontsize=13)
ax.set_xlabel("損失宿泊者数（百万人泊）")
ax.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U1_fig2_loss.png", dpi=150)
plt.close()
print("fig2 saved")

# ================================================================
# 図3: 相関ヒートマップ（説明変数は実SSDSE-B/E変数）
# ================================================================
# SSDSE-E から都道府県特性を取得
df_e_raw = pd.read_csv(DATA_E, encoding='cp932', header=0)
df_e = df_e_raw.iloc[2:].copy()
df_e.columns = df_e_raw.iloc[1].values
df_e = df_e[df_e['都道府県'] != '全国'].reset_index(drop=True)

# SSDSE-B 2019年断面（コロナ前）
df_2019 = df_b[df_b['年度'] == 2019].set_index('都道府県')

# 外国人依存度 = 外国人延べ宿泊者数 / 延べ宿泊者数（2019）
df_2019_use = df_2019.reindex(pref_order)
foreign_ratio = (df_2019_use['外国人延べ宿泊者数'].astype(float) /
                 df_2019_use['延べ宿泊者数'].astype(float).replace(0, np.nan)).fillna(0)
hotel_count = df_2019_use['旅館営業施設数（ホテルを含む）'].astype(float)
hotel_rooms = df_2019_use['旅館営業施設客室数（ホテルを含む）'].astype(float)

# 人口規模
population_2019 = df_2019_use['総人口'].astype(float)

# SSDSE-E: 1人当たり県民所得
df_e_indexed = df_e.set_index('都道府県')
income_col = '1人当たり県民所得（平成27年基準）'

# align prefecture names
def align_pref(series, target_list):
    """pref_order と SSDSE-E の都道府県名を合わせる"""
    out = pd.Series(np.nan, index=target_list)
    for p in target_list:
        # exact match
        if p in series.index:
            out[p] = series[p]
        else:
            # try without 県/府/都/道 suffix
            short = p.rstrip('県府都道')
            matches = [k for k in series.index if k.startswith(short) or short in k]
            if matches:
                out[p] = series[matches[0]]
    return out

income_raw = pd.to_numeric(df_e_indexed[income_col], errors='coerce')
income = align_pref(income_raw, pref_order).fillna(income_raw.mean())

df_reg = pd.DataFrame({
    '宿泊者数損失（百万）': loss_by_pref / 1e6,
    '外国人依存度': foreign_ratio.values,
    '旅館施設数（千）': hotel_count.values / 1000,
    '旅館客室数（万）': hotel_rooms.values / 10000,
    '1人当たり所得（万）': income.values / 10000,
}, index=pref_order)

corr = df_reg.corr()

fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(corr.values, cmap='RdYlBu_r', vmin=-1, vmax=1, aspect='auto')
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

labels = corr.columns.tolist()
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=35, ha='right', fontsize=9)
ax.set_yticklabels(labels, fontsize=9)

for i in range(len(labels)):
    for j in range(len(labels)):
        ax.text(j, i, f"{corr.values[i, j]:.2f}", ha='center', va='center',
                fontsize=9, color='black' if abs(corr.values[i, j]) < 0.7 else 'white')

ax.set_title("図3: 変数間の相関ヒートマップ", fontsize=13)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U1_fig3_corr.png", dpi=150)
plt.close()
print("fig3 saved")

# ================================================================
# 図4: 重回帰係数（標準化）
# ================================================================
feature_names = ['外国人依存度', '旅館施設数（千）', '旅館客室数（万）', '1人当たり所得（万）']
X = df_reg[feature_names].values
y = df_reg['宿泊者数損失（百万）'].values
n = len(y)

scaler = StandardScaler()
X_std = scaler.fit_transform(X)

X_with_const = np.column_stack([np.ones(n), X_std])
coef, _, _, _ = lstsq(X_with_const, y, rcond=None)
coefs = coef[1:]

residuals = y - X_with_const @ coef
sigma2 = residuals @ residuals / (n - X_with_const.shape[1])
cov = sigma2 * np.linalg.inv(X_with_const.T @ X_with_const)
se = np.sqrt(np.diag(cov))[1:]
t_stats = coefs / se
p_vals = 2 * (1 - stats.t.cdf(np.abs(t_stats), df=n - X_with_const.shape[1]))

fig, ax = plt.subplots(figsize=(8, 5))
colors_coef = ['#C62828' if c > 0 else '#1565C0' for c in coefs]
ax.barh(feature_names, coefs, color=colors_coef, alpha=0.8)

for i, (c, p, se_i) in enumerate(zip(coefs, p_vals, se)):
    sig = "***" if p < 0.001 else ("**" if p < 0.01 else ("*" if p < 0.05 else ""))
    ax.text(c + (0.05 if c >= 0 else -0.05), i, sig, va='center',
            ha='left' if c >= 0 else 'right', fontsize=12)
    ax.errorbar(c, i, xerr=1.96 * se_i, fmt='none', color='black', capsize=4)

ax.axvline(0, color='black', lw=1)
ax.set_title("図4: 重回帰分析 標準化係数（宿泊者数損失を従属変数）", fontsize=13)
ax.set_xlabel("標準化回帰係数（95% CI）")
ax.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(FIGDIR + "2024_U1_fig4_coef.png", dpi=150)
plt.close()
print("fig4 saved")
print("All figures saved.")
