evaluationgap.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # evaluationgap.py — Gap sweep: x = gap, y = accuracy (mean ± 95% CI)
  2. # Supports:
  3. # (2a) Random gap mode : alpha = alpha_base, beta = alpha - gap, then random sampling (utils.fidelity)
  4. # (2b) Fixed gap mode : deterministic arithmetic sequence with gap (utils.fidelity)
  5. #
  6. # Both modes inject true_fid_by_path with 1-origin keys and normalize est_fid_by_path to 1-origin.
  7. import os
  8. import json
  9. import time
  10. import pickle
  11. import hashlib
  12. import shutil
  13. from typing import List, Sequence, Dict, Any, Tuple
  14. import numpy as np
  15. import matplotlib as mpl
  16. import matplotlib.pyplot as plt
  17. from cycler import cycler
  18. from network import QuantumNetwork
  19. from schedulers import run_scheduler
  20. from viz.plots import mean_ci95
  21. from utils.ids import to_idx0, normalize_to_1origin, is_keys_1origin
  22. from utils.fidelity import (
  23. generate_fidelity_list_fix_gap,
  24. _generate_fidelity_list_random_rng,
  25. )
  26. # ---- Matplotlib style (align with evaluation.py) ----
  27. mpl.rcParams["figure.constrained_layout.use"] = True
  28. mpl.rcParams["savefig.bbox"] = "tight"
  29. mpl.rcParams["font.family"] = "serif"
  30. mpl.rcParams["font.serif"] = [
  31. "TeX Gyre Termes",
  32. "Nimbus Roman",
  33. "Liberation Serif",
  34. "DejaVu Serif",
  35. ]
  36. mpl.rcParams["font.size"] = 20
  37. _default_cycler = (
  38. cycler(color=["#4daf4a", "#377eb8", "#e41a1c", "#984ea3", "#ff7f00", "#a65628"])
  39. + cycler(marker=["s", "v", "o", "x", "*", "+"])
  40. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  41. )
  42. plt.rc("axes", prop_cycle=_default_cycler)
  43. # -----------------------------
  44. # Cache helpers (gap sweep)
  45. # -----------------------------
  46. def _gap_sweep_signature(gap_list: Sequence[float], scheduler_names: Sequence[str], noise_model: str,
  47. node_path_list: Sequence[int], importance_list: Sequence[float],
  48. bounces: Sequence[int], repeat: int,
  49. mode: str, # "random" or "fixed"
  50. importance_mode: str = "fixed", importance_uniform: Tuple[float, float] = (0.0, 1.0),
  51. seed: int = None, alpha_base: float = 0.95, variance: float = 0.10,
  52. C_total: int = 5000) -> Tuple[Dict[str, Any], str]:
  53. payload = {
  54. "gap_list": list(map(float, gap_list)),
  55. "scheduler_names": list(scheduler_names),
  56. "noise_model": str(noise_model),
  57. "node_path_list": list(map(int, node_path_list)),
  58. "importance_list": list(importance_list) if importance_list is not None else None,
  59. "importance_mode": str(importance_mode),
  60. "importance_uniform": list(importance_uniform) if importance_uniform is not None else None,
  61. "bounces": list(map(int, bounces)),
  62. "repeat": int(repeat),
  63. "seed": int(seed) if seed is not None else None,
  64. "mode": str(mode), # "random" / "fixed"
  65. "alpha_base": float(alpha_base),
  66. "variance": float(variance),
  67. "C_total": int(C_total),
  68. "version": 4, # schema: 1-origin injection & normalized est keys; fidelity_bank per gap stored
  69. }
  70. sig = hashlib.md5(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:10]
  71. return payload, sig
  72. def _shared_gap_path(noise_model: str, sig: str) -> str:
  73. root_dir = os.path.dirname(os.path.abspath(__file__))
  74. outdir = os.path.join(root_dir, "outputs")
  75. os.makedirs(outdir, exist_ok=True)
  76. return os.path.join(outdir, f"shared_gap_{noise_model}_{sig}.pickle")
  77. def _run_or_load_shared_gap_sweep(
  78. gap_list: Sequence[float], scheduler_names: Sequence[str], noise_model: str,
  79. node_path_list: Sequence[int], importance_list: Sequence[float],
  80. bounces=(1, 2, 3, 4), repeat: int = 10,
  81. importance_mode: str = "fixed", importance_uniform: Tuple[float, float] = (0.0, 1.0),
  82. seed: int = None, alpha_base: float = 0.95, variance: float = 0.10,
  83. C_total: int = 5000, mode: str = "random",
  84. verbose: bool = True, print_every: int = 1,
  85. ) -> Dict[str, Any]:
  86. """
  87. For each gap in gap_list, run `repeat` times. For each (gap, repeat) we create ONE fidelity_bank
  88. and reuse it for:
  89. - network generation (per pair)
  90. - true_fid_by_path injection
  91. so that there is no re-sampling mismatch.
  92. """
  93. config, sig = _gap_sweep_signature(
  94. gap_list, scheduler_names, noise_model,
  95. node_path_list, importance_list, bounces, repeat,
  96. mode=mode,
  97. importance_mode=importance_mode, importance_uniform=importance_uniform,
  98. seed=seed, alpha_base=alpha_base, variance=variance, C_total=C_total
  99. )
  100. cache_path = _shared_gap_path(noise_model, sig)
  101. lock_path = cache_path + ".lock"
  102. STALE_LOCK_SECS = 6 * 60 * 60
  103. HEARTBEAT_EVERY = 5.0
  104. rng = np.random.default_rng(seed)
  105. # Fast path: cached
  106. if os.path.exists(cache_path):
  107. if verbose:
  108. print(f"[gap-shared] Load cached: {os.path.basename(cache_path)}", flush=True)
  109. with open(cache_path, "rb") as f:
  110. return pickle.load(f)
  111. # Lock acquisition (single writer)
  112. got_lock = False
  113. while True:
  114. try:
  115. fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  116. os.close(fd)
  117. got_lock = True
  118. break
  119. except FileExistsError:
  120. if os.path.exists(cache_path):
  121. with open(cache_path, "rb") as f:
  122. return pickle.load(f)
  123. try:
  124. age = time.time() - os.path.getmtime(lock_path)
  125. except OSError:
  126. age = 0
  127. if age > STALE_LOCK_SECS:
  128. try: os.remove(lock_path)
  129. except FileNotFoundError: pass
  130. continue
  131. if verbose:
  132. print("[gap-shared] Waiting for cache to be ready...", flush=True)
  133. time.sleep(1.0)
  134. try:
  135. if verbose:
  136. print(f"[gap-shared] Run gap sweep and cache to: {os.path.basename(cache_path)}", flush=True)
  137. data = {name: {k: [] for k in range(len(gap_list))} for name in scheduler_names}
  138. last_hb = time.time()
  139. for r in range(repeat):
  140. if verbose and ((r + 1) % print_every == 0 or r == 0):
  141. print(f"[gap-shared] Repeat {r+1}/{repeat}", flush=True)
  142. # Importance per repeat
  143. if str(importance_mode).lower() == "uniform":
  144. a, b = map(float, importance_uniform)
  145. imp_list_r = [float(rng.uniform(a, b)) for _ in node_path_list]
  146. else:
  147. imp_list_r = list(importance_list)
  148. # Sweep gaps
  149. for k, gap in enumerate(gap_list):
  150. if verbose:
  151. print(f"=== [GAP {noise_model}] gap={gap} ({k+1}/{len(gap_list)}), mode={mode} ===", flush=True)
  152. # Heartbeat
  153. now = time.time()
  154. if now - last_hb >= HEARTBEAT_EVERY:
  155. try: os.utime(lock_path, None)
  156. except FileNotFoundError: pass
  157. last_hb = now
  158. # (重要) gap×repeat ごとに fidelity_bank を先に作って保存 → 再利用
  159. fidelity_bank: List[List[float]] = []
  160. for pair_idx, path_num in enumerate(node_path_list):
  161. if mode == "fixed":
  162. # 等差列: fidelity_max から gap ずつ下げる
  163. fids = generate_fidelity_list_fix_gap(
  164. path_num=int(path_num), gap=float(gap), fidelity_max=1.0
  165. )
  166. else:
  167. # ランダム: alpha=alpha_base, beta=alpha_base-gap
  168. alpha = float(alpha_base)
  169. beta = float(alpha_base) - float(gap)
  170. fids = _generate_fidelity_list_random_rng(
  171. rng=rng, path_num=int(path_num),
  172. alpha=alpha, beta=beta, variance=float(variance)
  173. )
  174. fidelity_bank.append(fids)
  175. # network generator uses the saved bank
  176. def network_generator(path_num: int, pair_idx: int):
  177. return QuantumNetwork(path_num, fidelity_bank[pair_idx], noise_model)
  178. for name in scheduler_names:
  179. per_pair_results, total_cost, per_pair_details = run_scheduler(
  180. node_path_list=node_path_list, importance_list=imp_list_r,
  181. scheduler_name=name,
  182. bounces=list(bounces),
  183. C_total=int(C_total),
  184. network_generator=network_generator,
  185. return_details=True,
  186. )
  187. # Inject truth (1..L) and normalize estimated map (to 1..L)
  188. for d, det in enumerate(per_pair_details):
  189. L = int(node_path_list[d])
  190. est_map = det.get("est_fid_by_path", {})
  191. if est_map:
  192. est_map_norm = normalize_to_1origin({int(k): float(v) for k, v in est_map.items()}, L)
  193. else:
  194. est_map_norm = {}
  195. # true map from the saved fidelity_bank (no re-sampling)
  196. true_list = fidelity_bank[d] # 0-origin
  197. true_map = {pid: float(true_list[to_idx0(pid)]) for pid in range(1, L + 1)}
  198. if est_map_norm and not is_keys_1origin(est_map_norm.keys(), L):
  199. raise RuntimeError(f"[inject] est_fid_by_path keys not 1..{L} (pair={d})")
  200. det["est_fid_by_path"] = est_map_norm
  201. det["true_fid_by_path"] = true_map
  202. data[name][k].append({
  203. "per_pair_results": per_pair_results,
  204. "per_pair_details": per_pair_details,
  205. "total_cost": total_cost,
  206. "importance_list": imp_list_r,
  207. "gap": float(gap),
  208. "C_total": int(C_total),
  209. "alpha_base": float(alpha_base),
  210. "variance": float(variance),
  211. "mode": str(mode),
  212. "node_path_list": list(map(int, node_path_list)),
  213. })
  214. payload = {
  215. "config": config,
  216. "gap_list": list(map(float, gap_list)),
  217. "data": data,
  218. }
  219. # atomic write
  220. tmp = cache_path + ".tmp"
  221. with open(tmp, "wb") as f:
  222. pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
  223. os.replace(tmp, cache_path)
  224. return payload
  225. finally:
  226. if got_lock:
  227. try: os.remove(lock_path)
  228. except FileNotFoundError: pass
  229. # -----------------------------
  230. # Public APIs
  231. # -----------------------------
  232. def plot_accuracy_vs_gap(
  233. gap_list: Sequence[float], scheduler_names: Sequence[str], noise_model: str,
  234. node_path_list: Sequence[int], importance_list: Sequence[float],
  235. bounces=(1, 2, 3, 4), repeat: int = 10,
  236. importance_mode: str = "fixed", importance_uniform: Tuple[float, float] = (0.0, 1.0),
  237. seed: int = None, alpha_base: float = 0.95, variance: float = 0.10,
  238. C_total_override: int = None,
  239. verbose: bool = True, print_every: int = 1,
  240. ) -> str:
  241. """
  242. (2a) Gap vs Accuracy — Random mode (utils.fidelity)
  243. """
  244. file_name = f"plot_accuracy_vs_gap_random_{noise_model}"
  245. root_dir = os.path.dirname(os.path.abspath(__file__))
  246. outdir = os.path.join(root_dir, "outputs")
  247. os.makedirs(outdir, exist_ok=True)
  248. C_total = int(C_total_override) if C_total_override is not None else 5000
  249. payload = _run_or_load_shared_gap_sweep(
  250. gap_list, scheduler_names, noise_model,
  251. node_path_list, importance_list,
  252. bounces=bounces, repeat=repeat,
  253. importance_mode=importance_mode, importance_uniform=importance_uniform,
  254. seed=seed, alpha_base=alpha_base, variance=variance,
  255. C_total=C_total, mode="random",
  256. verbose=verbose, print_every=print_every,
  257. )
  258. # Collect accuracy arrays per gap
  259. results = {name: {"accs": [[] for _ in gap_list]} for name in scheduler_names}
  260. for name in scheduler_names:
  261. for k in range(len(gap_list)):
  262. for run in payload["data"][name][k]:
  263. per_pair_results = run["per_pair_results"]
  264. vals = []
  265. for r in per_pair_results:
  266. if isinstance(r, tuple):
  267. c = r[0]
  268. elif isinstance(r, (int, float, bool)):
  269. c = bool(r)
  270. else:
  271. raise TypeError(f"per_pair_results element has unexpected type: {type(r)} -> {r}")
  272. vals.append(1.0 if c else 0.0)
  273. acc = float(np.mean(vals)) if vals else 0.0
  274. results[name]["accs"][k].append(acc)
  275. # Plot
  276. plt.rc("axes", prop_cycle=_default_cycler)
  277. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  278. xs = list(map(float, gap_list))
  279. for name, data in results.items():
  280. means, halfs = [], []
  281. for vals in data["accs"]:
  282. m, h = mean_ci95(vals)
  283. means.append(m); halfs.append(h)
  284. means = np.asarray(means); halfs = np.asarray(halfs)
  285. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  286. ax.plot(xs, means, linewidth=2.0, label=label)
  287. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  288. ax.set_xlabel("Gap (alpha - beta)")
  289. ax.set_ylabel("Average Correctness (mean ± 95% CI)")
  290. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  291. pdf = os.path.join(outdir, f"{file_name}.pdf")
  292. plt.savefig(pdf)
  293. if shutil.which("pdfcrop"):
  294. os.system(f'pdfcrop --margins "8 8 8 8" "{pdf}" "{pdf}"')
  295. print(f"Saved: {pdf}", flush=True)
  296. return pdf
  297. def plot_accuracy_vs_gap_fixgap(
  298. gap_list: Sequence[float], scheduler_names: Sequence[str], noise_model: str,
  299. node_path_list: Sequence[int], importance_list: Sequence[float],
  300. bounces=(1, 2, 3, 4), repeat: int = 10,
  301. importance_mode: str = "fixed", importance_uniform: Tuple[float, float] = (0.0, 1.0),
  302. seed: int = None, fidelity_max: float = 1.0,
  303. C_total_override: int = None,
  304. verbose: bool = True, print_every: int = 1,
  305. ) -> str:
  306. """
  307. (2b) Gap vs Accuracy — Fixed arithmetic sequence mode (utils.fidelity)
  308. """
  309. # 固定列では rng は使わないが、署名の再現性のため seed を渡しておく
  310. file_name = f"plot_accuracy_vs_gap_fixed_{noise_model}"
  311. root_dir = os.path.dirname(os.path.abspath(__file__))
  312. outdir = os.path.join(root_dir, "outputs")
  313. os.makedirs(outdir, exist_ok=True)
  314. # alpha_base/variance は未使用だが、シグネチャ整合のためデフォルト値を渡す
  315. C_total = int(C_total_override) if C_total_override is not None else 5000
  316. payload = _run_or_load_shared_gap_sweep(
  317. gap_list, scheduler_names, noise_model,
  318. node_path_list, importance_list,
  319. bounces=bounces, repeat=repeat,
  320. importance_mode=importance_mode, importance_uniform=importance_uniform,
  321. seed=seed, alpha_base=0.95, variance=0.10,
  322. C_total=C_total, mode="fixed",
  323. verbose=verbose, print_every=print_every,
  324. )
  325. # Collect accuracy arrays per gap
  326. results = {name: {"accs": [[] for _ in gap_list]} for name in scheduler_names}
  327. for name in scheduler_names:
  328. for k in range(len(gap_list)):
  329. for run in payload["data"][name][k]:
  330. per_pair_results = run["per_pair_results"]
  331. vals = []
  332. for r in per_pair_results:
  333. if isinstance(r, tuple):
  334. c = r[0]
  335. elif isinstance(r, (int, float, bool)):
  336. c = bool(r)
  337. else:
  338. raise TypeError(f"per_pair_results element has unexpected type: {type(r)} -> {r}")
  339. vals.append(1.0 if c else 0.0)
  340. acc = float(np.mean(vals)) if vals else 0.0
  341. results[name]["accs"][k].append(acc)
  342. # Plot
  343. plt.rc("axes", prop_cycle=_default_cycler)
  344. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  345. xs = list(map(float, gap_list))
  346. for name, data in results.items():
  347. means, halfs = [], []
  348. for vals in data["accs"]:
  349. m, h = mean_ci95(vals)
  350. means.append(m)
  351. halfs.append(h)
  352. means = np.asarray(means)
  353. halfs = np.asarray(halfs)
  354. label = name.replace("Vanilla NB", "VanillaNB").replace("Succ. Elim. NB", "SuccElimNB")
  355. ax.plot(xs, means, linewidth=2.0, label=label)
  356. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  357. ax.set_xlabel("Gap (arithmetic sequence)")
  358. ax.set_ylabel("Average Correctness (mean ± 95% CI)")
  359. ax.grid(True)
  360. ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  361. pdf = os.path.join(outdir, f"{file_name}.pdf")
  362. plt.savefig(pdf)
  363. if shutil.which("pdfcrop"):
  364. os.system(f'pdfcrop --margins "8 8 8 8" "{pdf}" "{pdf}"')
  365. print(f"Saved: {pdf}", flush=True)
  366. return pdf