evaluation.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. # evaluation.py — Run shared sweep once; all plots aggregate from cache (Py3.8-safe)
  2. import math
  3. import os
  4. import pickle
  5. import time
  6. import shutil
  7. import json
  8. import hashlib
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. from cycler import cycler
  12. # metrics / viz を外出し(UNIX的分離)
  13. from metrics.widths import (
  14. ci_radius_hoeffding,
  15. sum_weighted_widths_all_links,
  16. sum_weighted_min_widths_perpair,
  17. sum_widths_all_links,
  18. sum_minwidths_perpair,
  19. )
  20. from viz.plots import mean_ci95, plot_with_ci_band
  21. from network import QuantumNetwork
  22. from schedulers import run_scheduler # スケジューラ呼び出し
  23. # ---- Matplotlib style(互換性重視: hex色 & 無難な記号類)----
  24. plt.rc("font", family="Times New Roman")
  25. plt.rc("font", size=20)
  26. default_cycler = (
  27. cycler(color=["#4daf4a", "#377eb8", "#e41a1c", "#984ea3", "#ff7f00", "#a65628"])
  28. + cycler(marker=["s", "v", "o", "x", "*", "+"])
  29. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  30. )
  31. plt.rc("axes", prop_cycle=default_cycler)
  32. # =========================
  33. # Fidelity generators
  34. # =========================
  35. def generate_fidelity_list_avg_gap(path_num):
  36. result = []
  37. fidelity_max = 1
  38. fidelity_min = 0.9
  39. gap = (fidelity_max - fidelity_min) / path_num
  40. fidelity = fidelity_max
  41. for _ in range(path_num):
  42. result.append(fidelity)
  43. fidelity -= gap
  44. assert len(result) == path_num
  45. return result
  46. def generate_fidelity_list_fix_gap(path_num, gap, fidelity_max=1):
  47. result = []
  48. fidelity = fidelity_max
  49. for _ in range(path_num):
  50. result.append(fidelity)
  51. fidelity -= gap
  52. assert len(result) == path_num
  53. return result
  54. def generate_fidelity_list_random(path_num, alpha=0.95, beta=0.85, variance=0.1):
  55. """Generate `path_num` links with a guaranteed top-1 gap."""
  56. while True:
  57. mean = [alpha] + [beta] * (path_num - 1)
  58. result = []
  59. for i in range(path_num):
  60. mu = mean[i]
  61. # [0.8, 1.0] の範囲に入るまでサンプリング
  62. while True:
  63. r = np.random.normal(mu, variance)
  64. if 0.8 <= r <= 1.0:
  65. break
  66. result.append(r)
  67. assert len(result) == path_num
  68. sorted_res = sorted(result, reverse=True)
  69. if sorted_res[0] - sorted_res[1] > 0.02:
  70. return result
  71. # =========================
  72. # Progress helpers
  73. # =========================
  74. def _start_timer():
  75. return {"t0": time.time(), "last": time.time()}
  76. def _tick(timer):
  77. now = time.time()
  78. dt_total = now - timer["t0"]
  79. dt_step = now - timer["last"]
  80. timer["last"] = now
  81. return dt_total, dt_step
  82. def _log(msg):
  83. print(msg, flush=True)
  84. # =========================
  85. # Shared sweep (cache) helpers with file lock
  86. # =========================
  87. def _sweep_signature(budget_list, scheduler_names, noise_model,
  88. node_path_list, importance_list, bounces, repeat):
  89. payload = {
  90. "budget_list": list(budget_list),
  91. "scheduler_names": list(scheduler_names),
  92. "noise_model": str(noise_model),
  93. "node_path_list": list(node_path_list),
  94. "importance_list": list(importance_list),
  95. "bounces": list(bounces),
  96. "repeat": int(repeat),
  97. "version": 1,
  98. }
  99. sig = hashlib.md5(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:10]
  100. return payload, sig
  101. def _shared_sweep_path(noise_model, sig):
  102. root_dir = os.path.dirname(os.path.abspath(__file__))
  103. outdir = os.path.join(root_dir, "outputs")
  104. os.makedirs(outdir, exist_ok=True)
  105. return os.path.join(outdir, f"shared_sweep_{noise_model}_{sig}.pickle")
  106. def _run_or_load_shared_sweep(
  107. budget_list, scheduler_names, noise_model,
  108. node_path_list, importance_list,
  109. bounces=(1,2,3,4), repeat=10,
  110. verbose=True, print_every=1,
  111. ):
  112. config, sig = _sweep_signature(budget_list, scheduler_names, noise_model,
  113. node_path_list, importance_list, bounces, repeat)
  114. cache_path = _shared_sweep_path(noise_model, sig)
  115. lock_path = cache_path + ".lock"
  116. STALE_LOCK_SECS = 6 * 60 * 60 # 6時間無更新ならロック回収
  117. HEARTBEAT_EVERY = 5.0 # 生成側のロック更新間隔(秒)
  118. # 既存キャッシュがあれば即ロード
  119. if os.path.exists(cache_path):
  120. if verbose: _log(f"[shared] Load cached sweep: {os.path.basename(cache_path)}")
  121. with open(cache_path, "rb") as f:
  122. return pickle.load(f)
  123. # --- ロック獲得(初回生成は1プロセスのみ)---
  124. got_lock = False
  125. while True:
  126. try:
  127. fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  128. os.close(fd)
  129. got_lock = True
  130. break
  131. except FileExistsError:
  132. # 他プロセスが生成中:完成を待つ(タイムアウトなし)
  133. if os.path.exists(cache_path):
  134. with open(cache_path, "rb") as f:
  135. return pickle.load(f)
  136. # スタックロック検出:長時間 mtime 更新がない場合は回収
  137. try:
  138. age = time.time() - os.path.getmtime(lock_path)
  139. except OSError:
  140. age = 0
  141. if age > STALE_LOCK_SECS:
  142. if verbose: _log("[shared] Stale lock detected. Removing...")
  143. try: os.remove(lock_path)
  144. except FileNotFoundError: pass
  145. continue
  146. # 進捗待ち
  147. if verbose: _log("[shared] Waiting for cache to be ready...")
  148. time.sleep(1.0)
  149. try:
  150. if verbose: _log(f"[shared] Run sweep and cache to: {os.path.basename(cache_path)}")
  151. data = {name: {k: [] for k in range(len(budget_list))} for name in scheduler_names}
  152. last_hb = time.time()
  153. for k, C_total in enumerate(budget_list):
  154. if verbose: _log(f"=== [SHARED {noise_model}] Budget={C_total} ({k+1}/{len(budget_list)}) ===")
  155. for r in range(repeat):
  156. if verbose and ((r + 1) % print_every == 0 or r == 0):
  157. _log(f" [repeat {r+1}/{repeat}]")
  158. # ハートビート(ロックの mtime を更新して“生存”を伝える)
  159. now = time.time()
  160. if now - last_hb >= HEARTBEAT_EVERY:
  161. try: os.utime(lock_path, None)
  162. except FileNotFoundError: pass
  163. last_hb = now
  164. # 1リピート = 1トポロジ
  165. fidelity_bank = [generate_fidelity_list_random(n) for n in node_path_list]
  166. def network_generator(path_num, pair_idx):
  167. return QuantumNetwork(path_num, fidelity_bank[pair_idx], noise_model)
  168. for name in scheduler_names:
  169. per_pair_results, total_cost, per_pair_details = run_scheduler(
  170. node_path_list=node_path_list,
  171. importance_list=importance_list,
  172. scheduler_name=name,
  173. bounces=list(bounces),
  174. C_total=int(C_total),
  175. network_generator=network_generator,
  176. return_details=True,
  177. )
  178. data[name][k].append({
  179. "per_pair_results": per_pair_results,
  180. "per_pair_details": per_pair_details,
  181. "total_cost": total_cost,
  182. })
  183. payload = {"config": config, "budget_list": list(budget_list), "data": data}
  184. # アトミック書き込み
  185. tmp = cache_path + ".tmp"
  186. with open(tmp, "wb") as f:
  187. pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
  188. os.replace(tmp, cache_path)
  189. return payload
  190. finally:
  191. if got_lock:
  192. try: os.remove(lock_path)
  193. except FileNotFoundError: pass
  194. # =========================
  195. # 1) Accuracy: 平均のみ(CIなし)
  196. # =========================
  197. def plot_accuracy_vs_budget(
  198. budget_list, scheduler_names, noise_model,
  199. node_path_list, importance_list,
  200. bounces=(1,2,3,4), repeat=10,
  201. verbose=True, print_every=1,
  202. ):
  203. file_name = f"plot_accuracy_vs_budget_{noise_model}"
  204. root_dir = os.path.dirname(os.path.abspath(__file__))
  205. outdir = os.path.join(root_dir, "outputs")
  206. os.makedirs(outdir, exist_ok=True)
  207. payload = _run_or_load_shared_sweep(
  208. budget_list, scheduler_names, noise_model,
  209. node_path_list, importance_list,
  210. bounces=bounces, repeat=repeat,
  211. verbose=verbose, print_every=print_every,
  212. )
  213. results = {name: {"accs": [[] for _ in budget_list]} for name in scheduler_names}
  214. for name in scheduler_names:
  215. for k in range(len(budget_list)):
  216. for run in payload["data"][name][k]:
  217. per_pair_results = run["per_pair_results"]
  218. acc = float(np.mean([1.0 if c else 0.0 for (c, _cost, _bf) in per_pair_results])) if per_pair_results else 0.0
  219. results[name]["accs"][k].append(acc)
  220. # plot
  221. plt.rc("axes", prop_cycle=default_cycler)
  222. fig, ax = plt.subplots()
  223. xs = list(budget_list)
  224. for name, data in results.items():
  225. avg_accs = [float(np.mean(v)) if v else 0.0 for v in data["accs"]]
  226. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  227. ax.plot(xs, avg_accs, linewidth=2.0, label=label)
  228. ax.set_xlabel("Total Budget (C)")
  229. ax.set_ylabel("Average Correctness")
  230. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  231. plt.tight_layout()
  232. pdf = f"{file_name}.pdf"
  233. plt.savefig(pdf);
  234. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  235. _log(f"Saved: {pdf}")
  236. # =========================
  237. # 2) Value vs Used(x=実コスト平均)
  238. # =========================
  239. def plot_value_vs_used(
  240. budget_list, scheduler_names, noise_model,
  241. node_path_list, importance_list,
  242. bounces=(1,2,3,4), repeat=10,
  243. verbose=True, print_every=1,
  244. ):
  245. file_name = f"plot_value_vs_used_{noise_model}"
  246. root_dir = os.path.dirname(os.path.abspath(__file__))
  247. outdir = os.path.join(root_dir, "outputs")
  248. os.makedirs(outdir, exist_ok=True)
  249. payload = _run_or_load_shared_sweep(
  250. budget_list, scheduler_names, noise_model,
  251. node_path_list, importance_list,
  252. bounces=bounces, repeat=repeat,
  253. verbose=verbose, print_every=print_every,
  254. )
  255. results = {name: {"values": [[] for _ in budget_list], "costs": [[] for _ in budget_list]} for name in scheduler_names}
  256. for name in scheduler_names:
  257. for k in range(len(budget_list)):
  258. for run in payload["data"][name][k]:
  259. per_pair_details = run["per_pair_details"]
  260. total_cost = int(run["total_cost"])
  261. # value = Σ_d I_d Σ_l est(d,l) * alloc(d,l)
  262. value = 0.0
  263. for d, det in enumerate(per_pair_details):
  264. alloc = det.get("alloc_by_path", {})
  265. est = det.get("est_fid_by_path", {})
  266. inner = sum(float(est.get(l, 0.0)) * int(b) for l, b in alloc.items())
  267. I = float(importance_list[d]) if d < len(importance_list) else 1.0
  268. value += I * inner
  269. results[name]["values"][k].append(float(value))
  270. results[name]["costs"][k].append(total_cost)
  271. # plot
  272. plt.rc("axes", prop_cycle=default_cycler)
  273. fig, ax = plt.subplots()
  274. for name, dat in results.items():
  275. xs = [float(np.mean(v)) if v else 0.0 for v in dat["costs"]]
  276. ys = [float(np.mean(v)) if v else 0.0 for v in dat["values"]]
  277. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  278. ax.plot(xs, ys, linewidth=2.0, marker="o", label=label)
  279. ax.set_xlabel("Total Measured Cost (used)")
  280. ax.set_ylabel("Total Value (Σ I_d Σ f̂_{d,l}·B_{d,l})")
  281. ax.grid(True); ax.legend(title="Scheduler")
  282. plt.tight_layout()
  283. pdf = f"{file_name}.pdf"
  284. plt.savefig(pdf);
  285. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  286. _log(f"Saved: {pdf}")
  287. # =========================
  288. # 3) Value vs Budget target(x=目標予算)
  289. # =========================
  290. def plot_value_vs_budget_target(
  291. budget_list, scheduler_names, noise_model,
  292. node_path_list, importance_list,
  293. bounces=(1,2,3,4), repeat=10,
  294. verbose=True, print_every=1,
  295. ):
  296. file_name = f"plot_value_vs_budget_target_{noise_model}"
  297. root_dir = os.path.dirname(os.path.abspath(__file__))
  298. outdir = os.path.join(root_dir, "outputs")
  299. os.makedirs(outdir, exist_ok=True)
  300. payload = _run_or_load_shared_sweep(
  301. budget_list, scheduler_names, noise_model,
  302. node_path_list, importance_list,
  303. bounces=bounces, repeat=repeat,
  304. verbose=verbose, print_every=print_every,
  305. )
  306. results = {name: {"values": [[] for _ in budget_list]} for name in scheduler_names}
  307. for name in scheduler_names:
  308. for k in range(len(budget_list)):
  309. for run in payload["data"][name][k]:
  310. per_pair_details = run["per_pair_details"]
  311. value = 0.0
  312. for d, det in enumerate(per_pair_details):
  313. alloc = det.get("alloc_by_path", {})
  314. est = det.get("est_fid_by_path", {})
  315. inner = sum(float(est.get(l, 0.0)) * int(b) for l, b in alloc.items())
  316. I = float(importance_list[d]) if d < len(importance_list) else 1.0
  317. value += I * inner
  318. results[name]["values"][k].append(float(value))
  319. # plot
  320. plt.rc("axes", prop_cycle=default_cycler)
  321. fig, ax = plt.subplots()
  322. xs = list(budget_list)
  323. for name, dat in results.items():
  324. ys = [float(np.mean(v)) if v else 0.0 for v in dat["values"]]
  325. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  326. ax.plot(xs, ys, linewidth=2.0, marker="o", label=label)
  327. ax.set_xlabel("Budget (target)")
  328. ax.set_ylabel("Total Value (Σ I_d Σ f̂_{d,l}·B_{d,l})")
  329. ax.grid(True); ax.legend(title="Scheduler")
  330. plt.tight_layout()
  331. pdf = f"{file_name}.pdf"
  332. plt.savefig(pdf);
  333. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  334. _log(f"Saved: {pdf}")
  335. # =========================
  336. # 4) 幅(UB-LB)Unweighted: 全リンク総和
  337. # =========================
  338. def plot_widthsum_alllinks_vs_budget(
  339. budget_list, scheduler_names, noise_model,
  340. node_path_list, importance_list,
  341. bounces=(1,2,3,4), repeat=10, delta=0.1,
  342. verbose=True, print_every=1,
  343. ):
  344. file_name = f"plot_widthsum_alllinks_vs_budget_{noise_model}"
  345. root_dir = os.path.dirname(os.path.abspath(__file__))
  346. outdir = os.path.join(root_dir, "outputs")
  347. os.makedirs(outdir, exist_ok=True)
  348. payload = _run_or_load_shared_sweep(
  349. budget_list, scheduler_names, noise_model,
  350. node_path_list, importance_list,
  351. bounces=bounces, repeat=repeat,
  352. verbose=verbose, print_every=print_every,
  353. )
  354. results = {name: {"sums": [[] for _ in budget_list]} for name in scheduler_names}
  355. for name in scheduler_names:
  356. for k in range(len(budget_list)):
  357. for run in payload["data"][name][k]:
  358. per_pair_details = run["per_pair_details"]
  359. v = sum_widths_all_links(per_pair_details, delta=delta)
  360. results[name]["sums"][k].append(v)
  361. # plot (mean ± 95%CI)
  362. plt.rc("axes", prop_cycle=default_cycler)
  363. fig, ax = plt.subplots()
  364. xs = list(budget_list)
  365. for name, dat in results.items():
  366. means, halfs = [], []
  367. for vals in dat["sums"]:
  368. m, h = mean_ci95(vals); means.append(m); halfs.append(h)
  369. means = np.asarray(means); halfs = np.asarray(halfs)
  370. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  371. ax.plot(xs, means, linewidth=2.0, marker="o", label=label)
  372. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  373. ax.set_xlabel("Budget (target)")
  374. ax.set_ylabel("Sum of (UB - LB) over all links")
  375. ax.grid(True); ax.legend(title="Scheduler")
  376. plt.tight_layout()
  377. pdf = f"{file_name}.pdf"
  378. plt.savefig(pdf);
  379. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  380. _log(f"Saved: {pdf}")
  381. # =========================
  382. # 5) 幅(UB-LB)Unweighted: ペア最小幅の総和
  383. # =========================
  384. def plot_minwidthsum_perpair_vs_budget(
  385. budget_list, scheduler_names, noise_model,
  386. node_path_list, importance_list,
  387. bounces=(1,2,3,4), repeat=10, delta=0.1,
  388. verbose=True, print_every=1,
  389. ):
  390. file_name = f"plot_minwidthsum_perpair_vs_budget_{noise_model}"
  391. root_dir = os.path.dirname(os.path.abspath(__file__))
  392. outdir = os.path.join(root_dir, "outputs")
  393. os.makedirs(outdir, exist_ok=True)
  394. payload = _run_or_load_shared_sweep(
  395. budget_list, scheduler_names, noise_model,
  396. node_path_list, importance_list,
  397. bounces=bounces, repeat=repeat,
  398. verbose=verbose, print_every=print_every,
  399. )
  400. results = {name: {"sums": [[] for _ in budget_list]} for name in scheduler_names}
  401. for name in scheduler_names:
  402. for k in range(len(budget_list)):
  403. for run in payload["data"][name][k]:
  404. per_pair_details = run["per_pair_details"]
  405. v = sum_minwidths_perpair(per_pair_details, delta=delta)
  406. results[name]["sums"][k].append(v)
  407. # plot (mean ± 95%CI)
  408. plt.rc("axes", prop_cycle=default_cycler)
  409. fig, ax = plt.subplots()
  410. xs = list(budget_list)
  411. for name, dat in results.items():
  412. means, halfs = [], []
  413. for vals in dat["sums"]:
  414. m, h = mean_ci95(vals); means.append(m); halfs.append(h)
  415. means = np.asarray(means); halfs = np.asarray(halfs)
  416. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  417. ax.plot(xs, means, linewidth=2.0, marker="o", label=label)
  418. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  419. ax.set_xlabel("Budget (target)")
  420. ax.set_ylabel("Sum over pairs of min (UB - LB)")
  421. ax.grid(True); ax.legend(title="Scheduler")
  422. plt.tight_layout()
  423. pdf = f"{file_name}.pdf"
  424. plt.savefig(pdf);
  425. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  426. _log(f"Saved: {pdf}")
  427. # =========================
  428. # 6) 幅(UB-LB)Weighted: 全リンク I_d·幅 総和
  429. # =========================
  430. def plot_widthsum_alllinks_weighted_vs_budget(
  431. budget_list, scheduler_names, noise_model,
  432. node_path_list, importance_list,
  433. bounces=(1,2,3,4), repeat=10, delta=0.1,
  434. verbose=True, print_every=1,
  435. ):
  436. file_name = f"plot_widthsum_alllinks_weighted_vs_budget_{noise_model}"
  437. root_dir = os.path.dirname(os.path.abspath(__file__))
  438. outdir = os.path.join(root_dir, "outputs")
  439. os.makedirs(outdir, exist_ok=True)
  440. payload = _run_or_load_shared_sweep(
  441. budget_list, scheduler_names, noise_model,
  442. node_path_list, importance_list,
  443. bounces=bounces, repeat=repeat,
  444. verbose=verbose, print_every=print_every,
  445. )
  446. results = {name: {"sums": [[] for _ in budget_list]} for name in scheduler_names}
  447. for name in scheduler_names:
  448. for k in range(len(budget_list)):
  449. for run in payload["data"][name][k]:
  450. per_pair_details = run["per_pair_details"]
  451. v = sum_weighted_widths_all_links(per_pair_details, importance_list, delta=delta)
  452. results[name]["sums"][k].append(v)
  453. # plot (mean ± 95%CI)
  454. plt.rc("axes", prop_cycle=default_cycler)
  455. fig, ax = plt.subplots()
  456. xs = list(budget_list)
  457. for name, dat in results.items():
  458. means, halfs = [], []
  459. for vals in dat["sums"]:
  460. m, h = mean_ci95(vals); means.append(m); halfs.append(h)
  461. means = np.asarray(means); halfs = np.asarray(halfs)
  462. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  463. ax.plot(xs, means, linewidth=2.0, marker="o", label=label)
  464. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  465. ax.set_xlabel("Budget (target)")
  466. ax.set_ylabel("Weighted Sum of Widths Σ_d Σ_l I_d (UB - LB)")
  467. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  468. plt.tight_layout()
  469. pdf = f"{file_name}.pdf"
  470. plt.savefig(pdf);
  471. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  472. _log(f"Saved: {pdf}")
  473. # =========================
  474. # 7) 幅(UB-LB)Weighted: ペアごとの I_d·最小幅 総和
  475. # =========================
  476. def plot_minwidthsum_perpair_weighted_vs_budget(
  477. budget_list, scheduler_names, noise_model,
  478. node_path_list, importance_list,
  479. bounces=(1,2,3,4), repeat=10, delta=0.1,
  480. verbose=True, print_every=1,
  481. ):
  482. file_name = f"plot_minwidthsum_perpair_weighted_vs_budget_{noise_model}"
  483. root_dir = os.path.dirname(os.path.abspath(__file__))
  484. outdir = os.path.join(root_dir, "outputs")
  485. os.makedirs(outdir, exist_ok=True)
  486. payload = _run_or_load_shared_sweep(
  487. budget_list, scheduler_names, noise_model,
  488. node_path_list, importance_list,
  489. bounces=bounces, repeat=repeat,
  490. verbose=verbose, print_every=print_every,
  491. )
  492. results = {name: {"sums": [[] for _ in budget_list]} for name in scheduler_names}
  493. for name in scheduler_names:
  494. for k in range(len(budget_list)):
  495. for run in payload["data"][name][k]:
  496. per_pair_details = run["per_pair_details"]
  497. v = sum_weighted_min_widths_perpair(per_pair_details, importance_list, delta=delta)
  498. results[name]["sums"][k].append(v)
  499. # plot (mean ± 95%CI)
  500. plt.rc("axes", prop_cycle=default_cycler)
  501. fig, ax = plt.subplots()
  502. xs = list(budget_list)
  503. for name, dat in results.items():
  504. means, halfs = [], []
  505. for vals in dat["sums"]:
  506. m, h = mean_ci95(vals); means.append(m); halfs.append(h)
  507. means = np.asarray(means); halfs = np.asarray(halfs)
  508. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  509. ax.plot(xs, means, linewidth=2.0, marker="o", label=label)
  510. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  511. ax.set_xlabel("Budget (target)")
  512. ax.set_ylabel("Weighted sum over pairs of min (UB - LB) (× I_d)")
  513. ax.grid(True); ax.legend(title="Scheduler")
  514. plt.tight_layout()
  515. pdf = f"{file_name}.pdf"
  516. plt.savefig(pdf);
  517. if shutil.which("pdfcrop"): os.system(f"pdfcrop {pdf} {pdf}")
  518. _log(f"Saved: {pdf}")