"""
2025_U4 分析スクリプト: Bresnahan-Reiss エントリーモデル
==========================================================
【論文】日本におけるMRI設置の現状と過剰導入の実証的検証
        川村結愛ら（大阪経済大学）統計活用奨励賞

【手法】
  - 順序プロビットモデル（Ordered Probit）でMRI台数ビンを推定
  - Bresnahan & Reiss (1991) 型必要人口閾値の推定
  - 競争効果（参入者数が増えると利潤が減少するか）の検証

【入力】
  data/2025_U4/2025_U4_panel.csv

【出力】
  html/figures/2025_U4_*.png
"""

# ============================================================
# 【データの準備】実行前に以下のデータファイルを用意してください
#
#   必要ファイル: data/raw/ 以下に SSDSE CSVを配置
#   ダウンロード先: https://www.nstac.go.jp/use/literacy/ssdse/
#
#   フォルダ配置（プロジェクトルートからの相対パス）:
#     code/                ← このスクリプトの場所
#     data/raw/            ← CSV ファイルをここに配置
#     html/figures/        ← 図の出力先（自動生成）
#
#   実行方法（ファイルを一切編集せず実行可能）:
#     python3 code/2025_U4_katsuyo.py
# ============================================================


import os
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from scipy import stats
from scipy.optimize import minimize
from scipy.special import ndtr   # Φ(x)

warnings.filterwarnings('ignore')

# ── パス設定 ──────────────────────────────────────────────────────────
DATA_DIR = 'data/2025_U4'
FIG_DIR  = 'html/figures'
os.makedirs(FIG_DIR, exist_ok=True)

plt.rcParams.update({
    'font.family': 'Hiragino Sans',
    'axes.unicode_minus': False,
    'figure.dpi': 150,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# ── データ読み込み ─────────────────────────────────────────────────────
df = pd.read_csv(os.path.join(DATA_DIR, '2025_U4_panel.csv'))
df = df.dropna(subset=['population', 'mri_total'])
df['ln_pop'] = np.log(df['population'].clip(lower=1))
df['ln_pop65'] = np.log((df['population'] * 0.28).clip(lower=1))  # 高齢化率約28%を推計に使用

print(f"データ: {len(df)}レコード, {df['year'].nunique()}年度")
print(f"MRI台数統計:\n{df['mri_total'].describe()}\n")

# ── Figure 1: MRI台数分布（年度別） ──────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True)
YEARS = [2017, 2018, 2020]
colors = ['#2196F3', '#4CAF50', '#FF9800']
bins_label = ['0', '1-2', '3-5', '6-10', '11+']

for ax, year, col in zip(axes, YEARS, colors):
    d = df[df['year'] == year]
    counts = [
        (d['mri_total'] == 0).sum(),
        ((d['mri_total'] >= 1) & (d['mri_total'] <= 2)).sum(),
        ((d['mri_total'] >= 3) & (d['mri_total'] <= 5)).sum(),
        ((d['mri_total'] >= 6) & (d['mri_total'] <= 10)).sum(),
        (d['mri_total'] >= 11).sum(),
    ]
    bars = ax.bar(bins_label, counts, color=col, alpha=0.85, edgecolor='white', linewidth=0.8)
    for bar, cnt in zip(bars, counts):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                str(cnt), ha='center', va='bottom', fontsize=9)
    ax.set_title(f'{year}年度', fontsize=12, fontweight='bold')
    ax.set_xlabel('MRI台数（二次医療圏）', fontsize=10)
    ax.set_ylim(0, max(counts) * 1.18)

axes[0].set_ylabel('二次医療圏数', fontsize=10)
fig.suptitle('二次医療圏別MRI台数分布（2017〜2020年度）', fontsize=13, fontweight='bold', y=1.01)
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, '2025_U4_dist.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
print("Figure 1 保存: 2025_U4_dist.png")

# ── Figure 2: 人口 vs MRI台数 散布図（2020年） ───────────────────────
d20 = df[df['year'] == 2020].copy()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# 左: 散布図（生値）
ax = axes[0]
sc = ax.scatter(d20['population'] / 1e4, d20['mri_total'],
                alpha=0.5, s=30, c=d20['mri_total'], cmap='YlOrRd',
                vmin=0, vmax=40, edgecolors='none')
plt.colorbar(sc, ax=ax, label='MRI台数')
ax.set_xlabel('人口（万人）', fontsize=11)
ax.set_ylabel('MRI台数', fontsize=11)
ax.set_title('人口規模とMRI台数（2020年）', fontsize=12, fontweight='bold')
ax.set_xlim(0, d20['population'].max() / 1e4 * 1.05)

# 右: 対数スケール
ax = axes[1]
ax.scatter(d20['ln_pop'], d20['mri_total'],
           alpha=0.5, s=30, c='#1976D2', edgecolors='none')
z = np.polyfit(d20['ln_pop'].dropna(), d20.loc[d20['ln_pop'].notna(), 'mri_total'], 1)
p = np.poly1d(z)
xr = np.linspace(d20['ln_pop'].min(), d20['ln_pop'].max(), 100)
ax.plot(xr, p(xr), 'r-', linewidth=1.5, label=f'線形近似 (β={z[0]:.2f})')
ax.set_xlabel('ln(人口)', fontsize=11)
ax.set_ylabel('MRI台数', fontsize=11)
ax.set_title('対数人口とMRI台数（2020年）', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)

fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, '2025_U4_scatter.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
print("Figure 2 保存: 2025_U4_scatter.png")

# ── 順序プロビット推定（Ordered Probit）──────────────────────────────
print("\n" + "=" * 55)
print("■ 順序プロビット推定（Bresnahan-Reiss型）")
print("=" * 55)

# 2020年データで推定
d = d20.copy()
d = d.dropna(subset=['mri_bin', 'ln_pop'])
y = d['mri_bin'].values.astype(int)   # 0,1,2,3,4
X = d['ln_pop'].values

# K=5カテゴリのordered probit
# P(Y=k) = Φ(μ_k - β*X) - Φ(μ_{k-1} - β*X)
# パラメータ: β, μ_1, δ_2=μ_2-μ_1>0, δ_3>0, δ_4>0

def op_loglik(params, y, X):
    beta = params[0]
    mu1  = params[1]
    d2   = np.exp(params[2])
    d3   = np.exp(params[3])
    d4   = np.exp(params[4])
    mu2  = mu1 + d2
    mu3  = mu2 + d3
    mu4  = mu3 + d4
    thresholds = [-np.inf, mu1, mu2, mu3, mu4, np.inf]
    lp = beta * X
    ll = 0.0
    for k in range(5):
        idx = y == k
        if idx.sum() == 0:
            continue
        p = ndtr(thresholds[k+1] - lp[idx]) - ndtr(thresholds[k] - lp[idx])
        p = np.clip(p, 1e-15, 1.0)
        ll += np.log(p).sum()
    return -ll

# 初期値
p0 = [1.0, 5.0, np.log(1.5), np.log(1.5), np.log(1.5)]
result = minimize(op_loglik, p0, args=(y, X),
                  method='Nelder-Mead',
                  options={'maxiter': 20000, 'xatol': 1e-8, 'fatol': 1e-8})

beta_hat = result.x[0]
mu1_hat  = result.x[1]
mu2_hat  = mu1_hat + np.exp(result.x[2])
mu3_hat  = mu2_hat + np.exp(result.x[3])
mu4_hat  = mu3_hat + np.exp(result.x[4])

print(f"\n  係数 β (ln_pop): {beta_hat:.4f}")
print(f"  閾値 μ₁: {mu1_hat:.4f}")
print(f"  閾値 μ₂: {mu2_hat:.4f}")
print(f"  閾値 μ₃: {mu3_hat:.4f}")
print(f"  閾値 μ₄: {mu4_hat:.4f}")
print(f"  対数尤度: {-result.fun:.2f}")

# 必要人口閾値（βが正なら μ_k/β で求まる）
# P(Y>=k) >= 0.5 となる X = μ_k / β
if beta_hat > 0:
    thresholds_pop = {
        1: np.exp(mu1_hat / beta_hat),
        2: np.exp(mu2_hat / beta_hat),
        3: np.exp(mu3_hat / beta_hat),
        4: np.exp(mu4_hat / beta_hat),
    }
    print(f"\n  【必要人口閾値（P(MRI≥N)=0.5）】")
    for n, pop in thresholds_pop.items():
        print(f"  MRI {n}台以上: {pop/1e4:.1f}万人")

    # 競争効果: 1台目→2台目の必要人口比 (BR比)
    # 1台あたりの必要人口増加率が1を超えると競争で利潤が圧縮されている
    ratios = {}
    pops = list(thresholds_pop.values())
    for i in range(1, 4):
        ratios[i+1] = pops[i] / pops[i-1]
    print(f"\n  【競争効果（BR比）: N台→N+1台の必要人口比】")
    for n, r in ratios.items():
        print(f"  {n-1}台→{n}台: {r:.3f} (>1.0 → 競争で利潤圧縮)")

# ── Figure 3: 必要人口閾値と予測確率 ─────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# 左: 予測確率曲線
ax = axes[0]
ln_range = np.linspace(8, 14, 200)
pop_range = np.exp(ln_range) / 1e4
mus = [mu1_hat, mu2_hat, mu3_hat, mu4_hat]
mu_ext = [-np.inf] + mus + [np.inf]
labels_p = ['0台', '1-2台', '3-5台', '6-10台', '11台+']
cm = plt.cm.get_cmap('tab10')
for k in range(5):
    p_k = ndtr(mu_ext[k+1] - beta_hat * ln_range) - ndtr(mu_ext[k] - beta_hat * ln_range)
    ax.plot(pop_range, p_k, label=labels_p[k], linewidth=2, color=cm(k))
ax.set_xlabel('人口（万人）', fontsize=11)
ax.set_ylabel('確率', fontsize=11)
ax.set_title('MRI台数カテゴリの予測確率\n（順序プロビット）', fontsize=11, fontweight='bold')
ax.legend(fontsize=9, loc='center right')
ax.set_xlim(0, 250)
ax.set_ylim(0, 1)

# 右: 必要人口閾値の棒グラフ（競争効果可視化）
ax = axes[1]
if beta_hat > 0:
    ns = list(thresholds_pop.keys())
    pops_wan = [v / 1e4 for v in thresholds_pop.values()]
    bars = ax.bar([f'{n}台以上' for n in ns], pops_wan,
                  color=['#42A5F5', '#26C6DA', '#66BB6A', '#FFA726'],
                  edgecolor='white', linewidth=0.8, alpha=0.9)
    for bar, v in zip(bars, pops_wan):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{v:.1f}万人', ha='center', va='bottom', fontsize=9)
    ax.set_ylabel('必要人口（万人）', fontsize=11)
    ax.set_title('MRI設置に必要な人口閾値\n（Bresnahan-Reiss推定値）',
                 fontsize=11, fontweight='bold')

fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, '2025_U4_br.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
print("\nFigure 3 保存: 2025_U4_br.png")

# ── Figure 4: 地域別過剰設置マップ（散布図で代替）──────────────────
fig, ax = plt.subplots(figsize=(10, 6))

d20 = d20.copy()
d20['expected_bin'] = (beta_hat * d20['ln_pop'] - mu1_hat) / max(abs(beta_hat), 0.01)
d20['excess'] = d20['mri_total'] - d20['mri_bin']   # 実際 - ビン中央

# 人口 vs mri_per_100k（人口100万人あたりMRI台数）で過剰を可視化
sc = ax.scatter(d20['population'] / 1e4, d20['mri_per_100k'],
                c=d20['mri_per_100k'], cmap='RdYlGn_r',
                s=50, alpha=0.7, edgecolors='none',
                vmin=0, vmax=d20['mri_per_100k'].quantile(0.95))

plt.colorbar(sc, ax=ax, label='MRI台数/人口10万人')

# 日本平均線
avg_per100k = df[df['year'] == 2020]['mri_per_100k'].mean()
ax.axhline(avg_per100k, color='red', linewidth=1.5, linestyle='--',
           label=f'全国平均 {avg_per100k:.1f}台/10万人')

ax.set_xlabel('人口（万人）', fontsize=11)
ax.set_ylabel('MRI台数（人口10万人あたり）', fontsize=11)
ax.set_title('人口規模と人口当たりMRI密度（2020年）\n色が赤いほどMRI密度が高い',
             fontsize=11, fontweight='bold')
ax.legend(fontsize=10)
ax.set_xlim(0, d20['population'].max() / 1e4 * 1.05)

# 上位二次医療圏にラベル
top = d20.nlargest(5, 'mri_per_100k')
for _, row in top.iterrows():
    ax.annotate(row['sec_name'], (row['population']/1e4, row['mri_per_100k']),
                fontsize=7, alpha=0.8,
                xytext=(5, 2), textcoords='offset points')

fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, '2025_U4_excess.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
print("Figure 4 保存: 2025_U4_excess.png")

# ── Figure 5: 年度変化（MRI台数の推移） ──────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# 共通の二次医療圏のみ追跡（2017〜2020で重複する sec_code）
common_secs = set(df[df['year']==2017]['sec_code']) & \
              set(df[df['year']==2018]['sec_code']) & \
              set(df[df['year']==2020]['sec_code'])
df_common = df[df['sec_code'].isin(common_secs)].copy()
pivot = df_common.pivot_table(index='sec_code', columns='year', values='mri_total')

ax = axes[0]
yr_totals = df.groupby('year')['mri_total'].sum()
ax.bar([str(y) for y in yr_totals.index], yr_totals.values,
       color=['#2196F3', '#4CAF50', '#FF9800'], alpha=0.85, edgecolor='white')
for i, (yr, v) in enumerate(yr_totals.items()):
    ax.text(i, v + 20, f'{v:,}台', ha='center', va='bottom', fontsize=11)
ax.set_ylabel('MRI台数合計', fontsize=11)
ax.set_title('全国MRI台数の推移（二次医療圏集計）', fontsize=11, fontweight='bold')

ax = axes[1]
change_17_20 = pivot[2020] - pivot[2017]
ax.hist(change_17_20.dropna(), bins=20, color='#7B1FA2', alpha=0.75, edgecolor='white')
ax.axvline(0, color='red', linewidth=1.5, linestyle='--')
ax.set_xlabel('MRI台数変化（2017→2020）', fontsize=11)
ax.set_ylabel('二次医療圏数', fontsize=11)
ax.set_title('MRI台数変化の分布（2017→2020）', fontsize=11, fontweight='bold')
mean_change = change_17_20.mean()
ax.axvline(mean_change, color='blue', linewidth=1.5, linestyle=':',
           label=f'平均変化 {mean_change:.1f}台')
ax.legend(fontsize=10)

fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, '2025_U4_trend.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
print("Figure 5 保存: 2025_U4_trend.png")

# ── 記述統計サマリー ──────────────────────────────────────────────────
print("\n" + "=" * 55)
print("■ 記述統計（2020年）")
print("=" * 55)
desc_cols = ['mri_total', 'mri_3t', 'mri_15t', 'population', 'mri_per_100k']
print(d20[desc_cols].describe().round(2).to_string())

print(f"\n  [MRI密度] 全国平均: {avg_per100k:.2f}台/10万人")
print(f"  [最大]  {d20.loc[d20['mri_per_100k'].idxmax(), 'sec_name']}: "
      f"{d20['mri_per_100k'].max():.1f}台/10万人")
print(f"  [台数最大] {d20.loc[d20['mri_total'].idxmax(), 'sec_name']}: "
      f"{d20['mri_total'].max()}台")

print("\n■ 完了。全figureを html/figures/ に保存しました。")
