evaluationpair.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # evaluationpair.py — Sweep "number of destination pairs" (x) vs Accuracy (y)
  2. # Designed to align with evaluation.py pipeline (1-origin keys, utils.ids normalization).
  3. #
  4. # Produces: outputs/plot_accuracy_vs_pairs_<noise_model>.pdf
  5. import os
  6. import time
  7. import json
  8. import pickle
  9. import hashlib
  10. import shutil
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from cycler import cycler
  15. from network import QuantumNetwork
  16. from schedulers import run_scheduler
  17. from viz.plots import mean_ci95
  18. from utils.ids import to_idx0, normalize_to_1origin, is_keys_1origin
  19. from utils.fidelity import (
  20. generate_fidelity_list_avg_gap,
  21. generate_fidelity_list_fix_gap,
  22. generate_fidelity_list_random,
  23. _generate_fidelity_list_random_rng,
  24. )
  25. # ---- Matplotlib style (match evaluation.py) ----
  26. mpl.rcParams["figure.constrained_layout.use"] = True
  27. mpl.rcParams["savefig.bbox"] = "tight"
  28. mpl.rcParams["font.family"] = "serif"
  29. mpl.rcParams["font.serif"] = [
  30. "TeX Gyre Termes",
  31. "Nimbus Roman",
  32. "Liberation Serif",
  33. "DejaVu Serif",
  34. ]
  35. mpl.rcParams["font.size"] = 20
  36. _default_cycler = (
  37. cycler(color=["#4daf4a", "#377eb8", "#e41a1c", "#984ea3", "#ff7f00", "#a65628"])
  38. + cycler(marker=["s", "v", "o", "x", "*", "+"])
  39. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  40. )
  41. plt.rc("axes", prop_cycle=_default_cycler)
  42. # =========================
  43. # Utilities
  44. # =========================
  45. def _log(msg: str):
  46. print(msg, flush=True)
  47. def _generate_fidelity_list_random_rng(rng: np.random.Generator, path_num: int,
  48. alpha: float = 0.90, beta: float = 0.85, variance: float = 0.1):
  49. """Generate `path_num` link fidelities in [0.8, 1.0], ensuring a small top-1 gap."""
  50. while True:
  51. mean = [alpha] + [beta] * (path_num - 1)
  52. res = []
  53. for mu in mean:
  54. while True:
  55. r = rng.normal(mu, variance)
  56. if 0.8 <= r <= 1.0:
  57. break
  58. res.append(float(r))
  59. sorted_res = sorted(res, reverse=True)
  60. if sorted_res[0] - sorted_res[1] > 0.02:
  61. return res
  62. # =========================
  63. # Pair-sweep cache helpers
  64. # =========================
  65. def _sweep_signature_pairs(pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  66. bounces, repeat, importance_mode="fixed", importance_uniform=(0.0,1.0), seed=None):
  67. payload = {
  68. "pairs_list": list(pairs_list),
  69. "paths_per_pair": int(paths_per_pair),
  70. "C_total": int(C_total),
  71. "scheduler_names": list(scheduler_names),
  72. "noise_model": str(noise_model),
  73. "bounces": list(bounces),
  74. "repeat": int(repeat),
  75. "importance_mode": str(importance_mode),
  76. "importance_uniform": list(importance_uniform) if importance_uniform is not None else None,
  77. "seed": int(seed) if seed is not None else None,
  78. "version": 2, # ★ schema: per_pair_details の est/true_fid_by_path を 1-origin へ統一
  79. }
  80. sig = hashlib.md5(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:10]
  81. return payload, sig
  82. def _shared_pair_sweep_path(noise_model: str, sig: str):
  83. root_dir = os.path.dirname(os.path.abspath(__file__))
  84. outdir = os.path.join(root_dir, "outputs")
  85. os.makedirs(outdir, exist_ok=True)
  86. return os.path.join(outdir, f"pair_sweep_{noise_model}_{sig}.pickle")
  87. def _run_or_load_pair_sweep(
  88. pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  89. bounces=(1,2,3,4), repeat=10,
  90. importance_mode="fixed", importance_uniform=(0.0,1.0),
  91. seed=None,
  92. verbose=True, print_every=1,
  93. ):
  94. config, sig = _sweep_signature_pairs(
  95. pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  96. bounces, repeat, importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed
  97. )
  98. cache_path = _shared_pair_sweep_path(noise_model, sig)
  99. lock_path = cache_path + ".lock"
  100. STALE_LOCK_SECS = 6 * 60 * 60
  101. HEARTBEAT_EVERY = 5.0
  102. rng = np.random.default_rng(seed)
  103. # Quick load if exists
  104. if os.path.exists(cache_path):
  105. if verbose: _log(f"[pair-sweep] Load cached: {os.path.basename(cache_path)}")
  106. with open(cache_path, "rb") as f:
  107. return pickle.load(f)
  108. # Acquire lock (single producer; others wait)
  109. got_lock = False
  110. while True:
  111. try:
  112. fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  113. os.close(fd)
  114. got_lock = True
  115. break
  116. except FileExistsError:
  117. if os.path.exists(cache_path):
  118. with open(cache_path, "rb") as f:
  119. return pickle.load(f)
  120. try:
  121. age = time.time() - os.path.getmtime(lock_path)
  122. except OSError:
  123. age = 0
  124. if age > STALE_LOCK_SECS:
  125. if verbose: _log("[pair-sweep] Stale lock detected. Removing...")
  126. try: os.remove(lock_path)
  127. except FileNotFoundError: pass
  128. continue
  129. if verbose: _log("[pair-sweep] Waiting for cache to be ready...")
  130. time.sleep(1.0)
  131. try:
  132. if verbose: _log(f"[pair-sweep] Run sweep and cache to: {os.path.basename(cache_path)}")
  133. data = {name: {k: [] for k in range(len(pairs_list))} for name in scheduler_names}
  134. last_hb = time.time()
  135. for r in range(repeat):
  136. if verbose and ((r + 1) % print_every == 0 or r == 0):
  137. _log(f"[pair-sweep] Repeat {r+1}/{repeat}")
  138. # For each N (number of destination pairs), build one fixed topology per repeat
  139. for k, N_pairs in enumerate(pairs_list):
  140. # Heartbeat
  141. now = time.time()
  142. if now - last_hb >= HEARTBEAT_EVERY:
  143. try: os.utime(lock_path, None)
  144. except FileNotFoundError: pass
  145. last_hb = now
  146. node_path_list = [int(paths_per_pair)] * int(N_pairs)
  147. # Fidelity bank for this N (used consistently across schedulers)
  148. fidelity_bank = [_generate_fidelity_list_random_rng(rng, paths_per_pair) for _ in node_path_list]
  149. # Importance list for this N
  150. if str(importance_mode).lower() == "uniform":
  151. a, b = map(float, importance_uniform)
  152. importance_list = [float(rng.uniform(a, b)) for _ in node_path_list]
  153. else:
  154. # fixed mode: default all ones
  155. importance_list = [1.0 for _ in node_path_list]
  156. def network_generator(path_num, pair_idx):
  157. return QuantumNetwork(path_num, fidelity_bank[pair_idx], noise_model)
  158. for name in scheduler_names:
  159. per_pair_results, total_cost, per_pair_details = run_scheduler(
  160. node_path_list=node_path_list,
  161. importance_list=importance_list,
  162. scheduler_name=name,
  163. bounces=list(bounces),
  164. C_total=int(C_total),
  165. network_generator=network_generator,
  166. return_details=True,
  167. )
  168. # ★ evaluation.py と同様に、真値辞書(1..L)を注入し推定辞書を 1-origin に正規化
  169. for d, det in enumerate(per_pair_details):
  170. L = node_path_list[d]
  171. est_map = det.get("est_fid_by_path", {})
  172. if est_map:
  173. est_map_norm = normalize_to_1origin({int(k): float(v) for k, v in est_map.items()}, L)
  174. else:
  175. est_map_norm = {}
  176. true_map = {pid: float(fidelity_bank[d][to_idx0(pid)]) for pid in range(1, L + 1)}
  177. if est_map_norm and not is_keys_1origin(est_map_norm.keys(), L):
  178. raise RuntimeError(f"[inject] est_fid_by_path keys not 1..{L} (pair={d})")
  179. det["est_fid_by_path"] = est_map_norm
  180. det["true_fid_by_path"] = true_map
  181. data[name][k].append({
  182. "per_pair_results": per_pair_results,
  183. "per_pair_details": per_pair_details,
  184. "total_cost": total_cost,
  185. "importance_list": importance_list,
  186. "node_path_list": node_path_list,
  187. })
  188. payload = {"config": config, "pairs_list": list(pairs_list), "data": data}
  189. tmp = cache_path + ".tmp"
  190. with open(tmp, "wb") as f:
  191. pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
  192. os.replace(tmp, cache_path)
  193. return payload
  194. finally:
  195. if got_lock:
  196. try: os.remove(lock_path)
  197. except FileNotFoundError: pass
  198. # =========================
  199. # Plot: Accuracy (mean ± 95% CI) vs #Destination Pairs
  200. # =========================
  201. def plot_accuracy_vs_pairs(
  202. pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  203. bounces=(1,2,3,4), repeat=10,
  204. importance_mode="fixed", importance_uniform=(0.0,1.0),
  205. seed=None,
  206. verbose=True, print_every=1,
  207. ):
  208. """
  209. pairs_list: list[int] # x-axis = number of destination pairs (N)
  210. paths_per_pair: int # number of candidate links per pair (each L_n = paths_per_pair)
  211. C_total: int # total budget for the whole experiment (fixed while N varies)
  212. scheduler_names: list[str]
  213. noise_model: str
  214. bounces: tuple/list[int] # NB bounce vector
  215. repeat: int # repeats per N
  216. importance_mode: "fixed" or "uniform"
  217. importance_uniform: (a,b) # when uniform, sample I_n ~ U[a,b]
  218. seed: int
  219. """
  220. file_name = f"plot_accuracy_vs_pairs_{noise_model}"
  221. root_dir = os.path.dirname(os.path.abspath(__file__))
  222. outdir = os.path.join(root_dir, "outputs")
  223. os.makedirs(outdir, exist_ok=True)
  224. payload = _run_or_load_pair_sweep(
  225. pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  226. bounces=bounces, repeat=repeat,
  227. importance_mode=importance_mode, importance_uniform=importance_uniform,
  228. seed=seed, verbose=verbose, print_every=print_every
  229. )
  230. results = {name: {"accs": [[] for _ in pairs_list]} for name in scheduler_names}
  231. for name in scheduler_names:
  232. for k in range(len(pairs_list)):
  233. for run in payload["data"][name][k]:
  234. per_pair_results = run["per_pair_results"]
  235. # Normalize elements to bool → 0/1
  236. vals = []
  237. for r in per_pair_results:
  238. if isinstance(r, tuple):
  239. c = r[0]
  240. elif isinstance(r, (int, float, bool)):
  241. c = bool(r)
  242. else:
  243. raise TypeError(f"Unexpected per_pair_results element: {type(r)} -> {r}")
  244. vals.append(1.0 if c else 0.0)
  245. acc = float(np.mean(vals)) if vals else 0.0
  246. results[name]["accs"][k].append(acc)
  247. # Plot
  248. plt.rc("axes", prop_cycle=_default_cycler)
  249. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  250. xs = list(pairs_list)
  251. for name, data in results.items():
  252. means, halfs = [], []
  253. for vals in data["accs"]:
  254. m, h = mean_ci95(vals)
  255. means.append(m); halfs.append(h)
  256. means = np.asarray(means); halfs = np.asarray(halfs)
  257. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  258. ax.plot(xs, means, linewidth=2.0, label=label)
  259. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  260. ax.set_xlabel("Number of Destination Pairs (N)")
  261. ax.set_ylabel("Average Correctness (mean ± 95% CI)")
  262. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  263. pdf = os.path.join(outdir, f"{file_name}.pdf")
  264. plt.savefig(pdf)
  265. if shutil.which("pdfcrop"):
  266. os.system(f'pdfcrop --margins "8 8 8 8" "{pdf}" "{pdf}"')
  267. _log(f"Saved: {pdf}")
  268. return {
  269. "pdf": pdf,
  270. "pairs_list": list(pairs_list),
  271. "config": payload["config"],
  272. }
  273. if __name__ == "__main__":
  274. # Minimal CLI for quick testing
  275. pairs_list = [1, 2, 3, 4, 5, 6]
  276. paths_per_pair = 5
  277. C_total = 6000
  278. scheduler_names = ["Greedy", "LNaive"]
  279. noise_model = "Depolar"
  280. bounces = (1,2,3,4)
  281. repeat = 10
  282. importance_mode = "uniform"
  283. importance_uniform = (0.0, 1.0)
  284. seed = 12
  285. plot_accuracy_vs_pairs(
  286. pairs_list, paths_per_pair, C_total, scheduler_names, noise_model,
  287. bounces=bounces, repeat=repeat,
  288. importance_mode=importance_mode, importance_uniform=importance_uniform,
  289. seed=seed, verbose=True
  290. )