evaluation.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # evaluation.py — Run shared sweep once; all plots aggregate from cache (reproducible with seed)
  2. import math
  3. import os
  4. import pickle
  5. import time
  6. import shutil
  7. import json
  8. import hashlib
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. from cycler import cycler
  12. # metrics / viz を外出し(UNIX的分離)
  13. from metrics.widths import (
  14. ci_radius_hoeffding,
  15. sum_weighted_widths_all_links,
  16. sum_weighted_min_widths_perpair,
  17. sum_widths_all_links,
  18. sum_minwidths_perpair,
  19. )
  20. from viz.plots import mean_ci95, plot_with_ci_band
  21. from network import QuantumNetwork
  22. from schedulers import run_scheduler # スケジューラ呼び出し
  23. from utils.ids import to_idx0, normalize_to_1origin, is_keys_1origin
  24. from utils.fidelity import (
  25. generate_fidelity_list_avg_gap,
  26. generate_fidelity_list_fix_gap,
  27. generate_fidelity_list_random,
  28. _generate_fidelity_list_random_rng,
  29. )
  30. import matplotlib as mpl
  31. mpl.rcParams["figure.constrained_layout.use"] = True
  32. mpl.rcParams["savefig.bbox"] = "tight" # すべての savefig に適用
  33. # ---- Matplotlib style(互換性重視: hex色 & 無難な記号類)----
  34. mpl.rcParams["font.family"] = "serif"
  35. mpl.rcParams["font.serif"] = [
  36. "TeX Gyre Termes",
  37. "Nimbus Roman",
  38. "Liberation Serif",
  39. "DejaVu Serif",
  40. ]
  41. mpl.rcParams["font.size"] = 20
  42. default_cycler = (
  43. cycler(color=["#4daf4a", "#377eb8", "#e41a1c", "#984ea3", "#ff7f00", "#a65628"])
  44. + cycler(marker=["s", "v", "o", "x", "*", "+"])
  45. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  46. )
  47. plt.rc("axes", prop_cycle=default_cycler)
  48. # =========================
  49. # Progress helpers
  50. # =========================
  51. def _start_timer():
  52. return {"t0": time.time(), "last": time.time()}
  53. def _tick(timer):
  54. now = time.time()
  55. dt_total = now - timer["t0"]
  56. dt_step = now - timer["last"]
  57. timer["last"] = now
  58. return dt_total, dt_step
  59. def _log(msg):
  60. print(msg, flush=True)
  61. # =========================
  62. # Shared sweep (cache) helpers with file lock
  63. # =========================
  64. def _sweep_signature(budget_list, scheduler_names, noise_model,
  65. node_path_list, importance_list, bounces, repeat,
  66. importance_mode="fixed", importance_uniform=(0.0, 1.0), seed=None):
  67. payload = {
  68. "budget_list": list(budget_list),
  69. "scheduler_names": list(scheduler_names),
  70. "noise_model": str(noise_model),
  71. "node_path_list": list(node_path_list),
  72. "importance_list": list(importance_list) if importance_list is not None else None,
  73. "importance_mode": str(importance_mode),
  74. "importance_uniform": list(importance_uniform) if importance_uniform is not None else None,
  75. "bounces": list(bounces),
  76. "repeat": int(repeat),
  77. "seed": int(seed) if seed is not None else None,
  78. "version": 5, # schema: true_fid_by_path を 1-origin に統一
  79. }
  80. sig = hashlib.md5(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:10]
  81. return payload, sig
  82. def _shared_sweep_path(noise_model, sig):
  83. root_dir = os.path.dirname(os.path.abspath(__file__))
  84. outdir = os.path.join(root_dir, "outputs")
  85. os.makedirs(outdir, exist_ok=True)
  86. return os.path.join(outdir, f"shared_sweep_{noise_model}_{sig}.pickle")
  87. def _run_or_load_shared_sweep(
  88. budget_list, scheduler_names, noise_model,
  89. node_path_list, importance_list,
  90. bounces=(1,2,3,4), repeat=10,
  91. importance_mode="fixed", importance_uniform=(0.0, 1.0),
  92. seed=None,
  93. verbose=True, print_every=1,
  94. ):
  95. config, sig = _sweep_signature(
  96. budget_list, scheduler_names, noise_model,
  97. node_path_list, importance_list, bounces, repeat,
  98. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed
  99. )
  100. cache_path = _shared_sweep_path(noise_model, sig)
  101. lock_path = cache_path + ".lock"
  102. STALE_LOCK_SECS = 6 * 60 * 60 # 6時間無更新ならロック回収
  103. HEARTBEAT_EVERY = 5.0 # 生成側のロック更新間隔(秒)
  104. rng = np.random.default_rng(seed) # 乱数生成器(再現性の核)
  105. # 既存キャッシュがあれば即ロード
  106. if os.path.exists(cache_path):
  107. if verbose: _log(f"[shared] Load cached sweep: {os.path.basename(cache_path)}")
  108. with open(cache_path, "rb") as f:
  109. return pickle.load(f)
  110. # --- ロック獲得(初回生成は1プロセスのみ)---
  111. got_lock = False
  112. while True:
  113. try:
  114. fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  115. os.close(fd)
  116. got_lock = True
  117. break
  118. except FileExistsError:
  119. # 他プロセスが生成中:完成を待つ(タイムアウトなし)
  120. if os.path.exists(cache_path):
  121. with open(cache_path, "rb") as f:
  122. return pickle.load(f)
  123. # スタックロック検出:長時間 mtime 更新がない場合は回収
  124. try:
  125. age = time.time() - os.path.getmtime(lock_path)
  126. except OSError:
  127. age = 0
  128. if age > STALE_LOCK_SECS:
  129. if verbose: _log("[shared] Stale lock detected. Removing...")
  130. try: os.remove(lock_path)
  131. except FileNotFoundError:
  132. pass
  133. continue
  134. # 進捗待ち
  135. if verbose: _log("[shared] Waiting for cache to be ready...")
  136. time.sleep(1.0)
  137. try:
  138. if verbose: _log(f"[shared] Run sweep and cache to: {os.path.basename(cache_path)}")
  139. data = {name: {k: [] for k in range(len(budget_list))} for name in scheduler_names}
  140. last_hb = time.time()
  141. # === 1リピート=1トポロジを固定し、そのまま全ての budget を評価 ===
  142. for r in range(repeat):
  143. if verbose and ((r + 1) % print_every == 0 or r == 0):
  144. _log(f"[shared] Repeat {r+1}/{repeat} (fixed topology)")
  145. # この repeat 内で使い回す固定トポロジ(rng版)
  146. fidelity_bank = [_generate_fidelity_list_random_rng(rng, n) for n in node_path_list]
  147. # importance per repeat (fixed or uniform sample; rng使用)
  148. if str(importance_mode).lower() == "uniform":
  149. a, b = map(float, importance_uniform)
  150. imp_list_r = [float(rng.uniform(a, b)) for _ in node_path_list]
  151. else:
  152. imp_list_r = list(importance_list)
  153. def network_generator(path_num, pair_idx):
  154. return QuantumNetwork(path_num, fidelity_bank[pair_idx], noise_model)
  155. # 同一トポロジのまま、予算だけを変えて実行
  156. for k, C_total in enumerate(budget_list):
  157. if verbose:
  158. _log(f"=== [SHARED {noise_model}] Budget={C_total} ({k+1}/{len(budget_list)}) ===")
  159. # ハートビート(ロックの mtime を更新)
  160. now = time.time()
  161. if now - last_hb >= HEARTBEAT_EVERY:
  162. try:
  163. os.utime(lock_path, None)
  164. except FileNotFoundError:
  165. pass
  166. last_hb = now
  167. for name in scheduler_names:
  168. per_pair_results, total_cost, per_pair_details = run_scheduler(
  169. node_path_list=node_path_list, importance_list=imp_list_r,
  170. scheduler_name=name,
  171. bounces=list(bounces),
  172. C_total=int(C_total),
  173. network_generator=network_generator,
  174. return_details=True,
  175. )
  176. # --- 真の忠実度 true_fid_by_path を per_pair_details に注入 ---
  177. # キーは est_fid_by_path のキー体系(整数1..Lに正規化)に合わせる。無ければ 1..L。
  178. for d, det in enumerate(per_pair_details):
  179. true_list = fidelity_bank[d] # 0-origin list of true fidelities
  180. est_map = det.get("est_fid_by_path", {}) # 本来 {1..L} を想定
  181. L = len(true_list)
  182. # 1) 推定辞書を 1-origin に正規化(0-originで来た場合でも吸収)
  183. if est_map:
  184. est_map_norm = normalize_to_1origin(
  185. {int(k): float(v) for k, v in est_map.items()}, L
  186. )
  187. else:
  188. est_map_norm = {} # 未測定なら空のまま(値計算側で0寄与にする)
  189. # 2) 真値辞書を 1-origin で構築(内部 true_list は 0-origin なので to_idx0)
  190. true_map = {pid: float(true_list[to_idx0(pid)]) for pid in range(1, L + 1)}
  191. # 3) 厳格検査(任意だが、デバッグの早期検出に有用)
  192. if est_map_norm and not is_keys_1origin(est_map_norm.keys(), L):
  193. raise RuntimeError(f"[inject] est_fid_by_path keys not 1..{L} (pair={d})")
  194. det["est_fid_by_path"] = est_map_norm
  195. det["true_fid_by_path"] = true_map
  196. data[name][k].append({
  197. "per_pair_results": per_pair_results,
  198. "per_pair_details": per_pair_details,
  199. "total_cost": total_cost,
  200. "importance_list": imp_list_r
  201. })
  202. payload = {"config": config, "budget_list": list(budget_list), "data": data}
  203. # アトミック書き込み
  204. tmp = cache_path + ".tmp"
  205. with open(tmp, "wb") as f:
  206. pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
  207. os.replace(tmp, cache_path)
  208. return payload
  209. finally:
  210. if got_lock:
  211. try:
  212. os.remove(lock_path)
  213. except FileNotFoundError:
  214. pass
  215. # =========================
  216. # 1) Accuracy: 平均 ± 95%CI
  217. # =========================
  218. def plot_accuracy_vs_budget(
  219. budget_list, scheduler_names, noise_model,
  220. node_path_list, importance_list,
  221. bounces=(1,2,3,4), repeat=10,
  222. importance_mode="fixed", importance_uniform=(0.0,1.0), seed=None,
  223. verbose=True, print_every=1,
  224. ):
  225. file_name = f"plot_accuracy_vs_budget_{noise_model}"
  226. root_dir = os.path.dirname(os.path.abspath(__file__))
  227. outdir = os.path.join(root_dir, "outputs")
  228. os.makedirs(outdir, exist_ok=True)
  229. payload = _run_or_load_shared_sweep(
  230. budget_list, scheduler_names, noise_model,
  231. node_path_list, importance_list,
  232. bounces=bounces, repeat=repeat,
  233. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed,
  234. verbose=verbose, print_every=print_every,
  235. )
  236. # 収集: 予算ごとの正解率(0/1)の配列を溜める
  237. results = {name: {"accs": [[] for _ in budget_list]} for name in scheduler_names}
  238. for name in scheduler_names:
  239. for k in range(len(budget_list)):
  240. for run in payload["data"][name][k]:
  241. per_pair_results = run["per_pair_results"]
  242. # per_pair_results の要素を bool に正規化して 0/1 に変換
  243. vals = []
  244. for r in per_pair_results:
  245. if isinstance(r, tuple):
  246. c = r[0]
  247. elif isinstance(r, (int, float, bool)):
  248. c = bool(r)
  249. else:
  250. raise TypeError(
  251. f"per_pair_results element has unexpected type: {type(r)} -> {r}"
  252. )
  253. vals.append(1.0 if c else 0.0)
  254. acc = float(np.mean(vals)) if vals else 0.0
  255. results[name]["accs"][k].append(acc)
  256. # plot (mean ± 95%CI)
  257. plt.rc("axes", prop_cycle=default_cycler)
  258. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  259. xs = list(budget_list)
  260. for name, data in results.items():
  261. means, halfs = [], []
  262. for vals in data["accs"]:
  263. m, h = mean_ci95(vals) # viz.plots.mean_ci95 を使用
  264. means.append(m); halfs.append(h)
  265. means = np.asarray(means); halfs = np.asarray(halfs)
  266. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  267. ax.plot(xs, means, linewidth=2.0, label=label)
  268. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  269. ax.set_xlabel("Total Budget (C)")
  270. ax.set_ylabel("Average Correctness (mean ± 95% CI)")
  271. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  272. pdf = os.path.join(outdir, f"{file_name}.pdf")
  273. plt.savefig(pdf)
  274. if shutil.which("pdfcrop"):
  275. os.system(f'pdfcrop --margins "8 8 8 8" {pdf} {pdf}')
  276. _log(f"Saved: {pdf}")
  277. # =========================
  278. # 2) Value vs Used(x=実コスト平均, y=Σ_d I_d * true_fid(j*_d) の平均±95%CI)
  279. # ※ j*_d は宛先 d における「推定忠実度が最大」のリンク(path_id は 1..L)
  280. # =========================
  281. def plot_value_vs_used(
  282. budget_list, scheduler_names, noise_model,
  283. node_path_list, importance_list,
  284. bounces=(1,2,3,4), repeat=10, importance_mode="fixed", importance_uniform=(0.0,1.0), seed=None,
  285. verbose=True, print_every=1,
  286. ):
  287. file_name = f"plot_value_vs_used_{noise_model}"
  288. root_dir = os.path.dirname(os.path.abspath(__file__))
  289. outdir = os.path.join(root_dir, "outputs")
  290. os.makedirs(outdir, exist_ok=True)
  291. payload = _run_or_load_shared_sweep(
  292. budget_list, scheduler_names, noise_model,
  293. node_path_list, importance_list,
  294. bounces=bounces, repeat=repeat,
  295. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed,
  296. verbose=verbose, print_every=print_every,
  297. )
  298. results = {name: {"values": [[] for _ in budget_list], "costs": [[] for _ in budget_list]} for name in scheduler_names}
  299. for name in scheduler_names:
  300. for k in range(len(budget_list)):
  301. for run in payload["data"][name][k]:
  302. per_pair_details = run["per_pair_details"]
  303. total_cost = int(run["total_cost"])
  304. # y: value = Σ_d I_d * true_fid(j*_d)
  305. # where j*_d = argmax_l est_fid_by_path[d][l]
  306. value = 0.0
  307. I_used = run.get("importance_list", importance_list)
  308. for d, det in enumerate(per_pair_details):
  309. est = det.get("est_fid_by_path", {}) # {path_id(1..L): estimated_fidelity}
  310. true_ = det.get("true_fid_by_path", {}) # {path_id(1..L): true_fidelity}
  311. # 1) 真値辞書が無いのは設定不整合 → 例外で明示
  312. if not true_:
  313. raise RuntimeError(f"[value] true_fid_by_path missing for pair {d}")
  314. # 2') 1本でも推定があれば、その時点の推定最大 j* を選び、その『真の忠実度』を使う
  315. if est:
  316. j_star = max(est, key=lambda l: float(est.get(l, 0.0)))
  317. if j_star not in true_:
  318. raise RuntimeError(
  319. f"[value] true_fid_by_path lacks j* (pair={d}, j*={j_star})."
  320. )
  321. best_true = float(true_[j_star])
  322. else:
  323. # 推定が全く無ければ 0 寄与(従来どおり)
  324. best_true = 0.0
  325. I = float(I_used[d]) if d < len(I_used) else 1.0
  326. value += I * best_true
  327. results[name]["values"][k].append(float(value))
  328. results[name]["costs"][k].append(total_cost)
  329. # plot (y に 95%CI の帯を表示)
  330. plt.rc("axes", prop_cycle=default_cycler)
  331. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  332. for name, dat in results.items():
  333. # x は各予算での使用コストの平均
  334. x_means = [float(np.mean(v)) if v else 0.0 for v in dat["costs"]]
  335. # y は各予算での value(上で定義)の平均 ± 95%CI
  336. y_means, y_halfs = [], []
  337. for vals in dat["values"]:
  338. m, h = mean_ci95(vals) # viz.plots.mean_ci95
  339. y_means.append(float(m))
  340. y_halfs.append(float(h))
  341. x_means = np.asarray(x_means)
  342. y_means = np.asarray(y_means)
  343. y_halfs = np.asarray(y_halfs)
  344. label = name.replace("Vanilla NB", "VanillaNB").replace("Succ. Elim. NB", "SuccElimNB")
  345. ax.plot(x_means, y_means, linewidth=2.0, marker="o", label=label)
  346. ax.fill_between(x_means, y_means - y_halfs, y_means + y_halfs, alpha=0.25)
  347. ax.set_xlabel("Total Measured Cost (used)")
  348. ax.set_ylabel("Σ_d I_d · true_fid(j*_d) (mean ± 95% CI)")
  349. ax.grid(True); ax.legend(title="Scheduler")
  350. pdf = os.path.join(outdir, f"{file_name}.pdf")
  351. plt.savefig(pdf)
  352. if shutil.which("pdfcrop"):
  353. os.system(f'pdfcrop --margins "8 8 8 8" {pdf} {pdf}')
  354. _log(f"Saved: {pdf}")
  355. def plot_value_vs_budget(
  356. budget_list, scheduler_names, noise_model,
  357. node_path_list, importance_list,
  358. bounces=(1,2,3,4), repeat=10, importance_mode="fixed", importance_uniform=(0.0,1.0), seed=None,
  359. verbose=True, print_every=1,
  360. ):
  361. """
  362. x=割り当て予算(budget_list)、y=Σ_d I_d * true_fid(j*_d) の平均±95%CI を描画する。
  363. ※ j*_d は「その時点の推定最大リンク」。全リンク未測定でも、推定が1本でもあればその j* を使う。
  364. 出力: outputs/plot_value_vs_budget_{noise_model}.pdf
  365. """
  366. file_name = f"plot_value_vs_budget_{noise_model}"
  367. root_dir = os.path.dirname(os.path.abspath(__file__))
  368. outdir = os.path.join(root_dir, "outputs")
  369. os.makedirs(outdir, exist_ok=True)
  370. # 共有スイープ(キャッシュ)を実行/読込
  371. payload = _run_or_load_shared_sweep(
  372. budget_list, scheduler_names, noise_model,
  373. node_path_list, importance_list,
  374. bounces=bounces, repeat=repeat,
  375. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed,
  376. verbose=verbose, print_every=print_every,
  377. )
  378. # スケジューラごと・予算ごとに value と(参考)used コストを蓄積
  379. results = {name: {"values": [[] for _ in budget_list], "costs": [[] for _ in budget_list]} for name in scheduler_names}
  380. for name in scheduler_names:
  381. for k in range(len(budget_list)):
  382. for run in payload["data"][name][k]:
  383. per_pair_details = run["per_pair_details"]
  384. total_cost = int(run["total_cost"]) # 参考(今回はxに使わない)
  385. I_used = run.get("importance_list", importance_list)
  386. # y: value = Σ_d I_d * true_fid(j*_d)
  387. # j*_d = argmax_l est_fid_by_path[d][l](1本でも推定があればその時点のj*を採用)
  388. value = 0.0
  389. for d, det in enumerate(per_pair_details):
  390. est = det.get("est_fid_by_path", {}) # {path_id(1..L): est}
  391. true_ = det.get("true_fid_by_path", {}) # {path_id(1..L): true}
  392. # 真値辞書が無いのは設定不整合
  393. if not true_:
  394. raise RuntimeError(f"[value] true_fid_by_path missing for pair {d}")
  395. # 推定が1本でもあれば、その時点の j* の『真の忠実度』を使う
  396. if est:
  397. j_star = max(est, key=lambda l: float(est.get(l, 0.0)))
  398. if j_star not in true_:
  399. raise RuntimeError(f"[value] true_fid_by_path lacks j* (pair={d}, j*={j_star}).")
  400. best_true = float(true_[j_star])
  401. else:
  402. # 推定が全く無ければ寄与0
  403. best_true = 0.0
  404. I = float(I_used[d]) if d < len(I_used) else 1.0
  405. value += I * best_true
  406. results[name]["values"][k].append(float(value))
  407. results[name]["costs"][k].append(total_cost) # y軸には使わないが保持
  408. # === プロット(x: 割り当て予算 = budget_list, y: value 平均±95%CI) ===
  409. plt.rc("axes", prop_cycle=default_cycler)
  410. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  411. x_vals = np.asarray(list(budget_list), dtype=float) # 横軸は割り当て予算
  412. for name, dat in results.items():
  413. # y は各予算での value の平均 ± 95%CI
  414. y_means, y_halfs = [], []
  415. for vals in dat["values"]:
  416. m, h = mean_ci95(vals)
  417. y_means.append(float(m))
  418. y_halfs.append(float(h))
  419. y_means = np.asarray(y_means)
  420. y_halfs = np.asarray(y_halfs)
  421. label = name.replace("Vanilla NB", "VanillaNB").replace("Succ. Elim. NB", "SuccElimNB")
  422. ax.plot(x_vals, y_means, linewidth=2.0, marker="o", label=label)
  423. ax.fill_between(x_vals, y_means - y_halfs, y_means + y_halfs, alpha=0.25)
  424. ax.set_xlabel("Total Budget (C)")
  425. ax.set_ylabel("Σ_d I_d · true_fid(j*_d) (mean ± 95% CI)")
  426. ax.grid(True); ax.legend(title="Scheduler")
  427. pdf = os.path.join(outdir, f"{file_name}.pdf")
  428. plt.savefig(pdf)
  429. if shutil.which("pdfcrop"):
  430. os.system(f'pdfcrop --margins "8 8 8 8" {pdf} {pdf}')
  431. _log(f"Saved: {pdf}")