evaluationgap.py~ 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # evaluationgap.py — Sweep x-axis over "gap", y-axis = accuracy (mean ± 95% CI)
  2. # Random fidelity generator version where alpha - beta = gap.
  3. import os
  4. import json
  5. import time
  6. import pickle
  7. import hashlib
  8. import shutil
  9. import numpy as np
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. from cycler import cycler
  13. from network import QuantumNetwork
  14. from schedulers import run_scheduler
  15. from viz.plots import mean_ci95
  16. # ---- Matplotlib global style (match evaluation.py) ----
  17. mpl.rcParams["figure.constrained_layout.use"] = True
  18. mpl.rcParams["savefig.bbox"] = "tight"
  19. mpl.rcParams["font.family"] = "serif"
  20. mpl.rcParams["font.serif"] = [
  21. "TeX Gyre Termes",
  22. "Nimbus Roman",
  23. "Liberation Serif",
  24. "DejaVu Serif",
  25. ]
  26. mpl.rcParams["font.size"] = 20
  27. default_cycler = (
  28. cycler(color=["#4daf4a", "#377eb8", "#e41a1c", "#984ea3", "#ff7f00", "#a65628"])
  29. + cycler(marker=["s", "v", "o", "x", "*", "+"])
  30. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  31. )
  32. plt.rc("axes", prop_cycle=default_cycler)
  33. # -----------------------------
  34. # Random fidelity generators (alpha - beta = gap)
  35. # -----------------------------
  36. def _generate_fidelity_list_random_rng(rng, path_num, alpha=0.95, beta=0.85, variance=0.1):
  37. """
  38. Generate `path_num` fidelities with top-1 mean alpha and others mean beta,
  39. each sampled from Normal(mu, variance), clamped to [0.8, 1.0].
  40. Ensures a visible top-1 gap (>0.02) in the sorted list.
  41. """
  42. while True:
  43. mean = [alpha] + [beta] * (path_num - 1)
  44. res = []
  45. for mu in mean:
  46. # Rejection sample into [0.8, 1.0]
  47. while True:
  48. r = rng.normal(mu, variance)
  49. if 0.8 <= r <= 1.0:
  50. break
  51. res.append(float(r))
  52. sorted_res = sorted(res, reverse=True)
  53. if len(sorted_res) >= 2 and (sorted_res[0] - sorted_res[1]) > 0.02:
  54. return res
  55. def _fidelity_list_gap_random(path_num, gap, rng,
  56. alpha_base=0.95, variance=0.1):
  57. """
  58. Build a fidelity list of length `path_num` using:
  59. alpha = alpha_base
  60. beta = alpha - gap
  61. With random jitter via Normal(mu, variance), clamped to [0.8, 1.0].
  62. """
  63. alpha = float(alpha_base)
  64. beta = float(alpha_base - gap)
  65. # keep beta within [0.8, alpha)
  66. beta = min(max(beta, 0.8), max(alpha - 1e-6, 0.8))
  67. return _generate_fidelity_list_random_rng(rng, path_num, alpha=alpha, beta=beta, variance=variance)
  68. # -----------------------------
  69. # Cache helpers (gap sweep)
  70. # -----------------------------
  71. def _gap_sweep_signature(gap_list, scheduler_names, noise_model,
  72. node_path_list, importance_list, bounces, repeat,
  73. importance_mode="fixed", importance_uniform=(0.0, 1.0), seed=None,
  74. alpha_base=0.95, variance=0.10):
  75. payload = {
  76. "gap_list": list(map(float, gap_list)),
  77. "scheduler_names": list(scheduler_names),
  78. "noise_model": str(noise_model),
  79. "node_path_list": list(node_path_list),
  80. "importance_list": list(importance_list) if importance_list is not None else None,
  81. "importance_mode": str(importance_mode),
  82. "importance_uniform": list(importance_uniform) if importance_uniform is not None else None,
  83. "bounces": list(bounces),
  84. "repeat": int(repeat),
  85. "seed": int(seed) if seed is not None else None,
  86. # fidelity-generation mode & params
  87. "fidelity_mode": "random_gap_alpha_beta",
  88. "alpha_base": float(alpha_base),
  89. "variance": float(variance),
  90. "version": 2,
  91. }
  92. sig = hashlib.md5(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:10]
  93. return payload, sig
  94. def _shared_gap_path(noise_model, sig):
  95. root_dir = os.path.dirname(os.path.abspath(__file__))
  96. outdir = os.path.join(root_dir, "outputs")
  97. os.makedirs(outdir, exist_ok=True)
  98. return os.path.join(outdir, f"shared_gap_{noise_model}_{sig}.pickle")
  99. def _run_or_load_shared_gap_sweep(
  100. gap_list, scheduler_names, noise_model,
  101. node_path_list, importance_list,
  102. bounces=(1, 2, 3, 4), repeat=10,
  103. importance_mode="fixed", importance_uniform=(0.0, 1.0),
  104. seed=None, alpha_base=0.95, variance=0.10,
  105. C_total=5000,
  106. verbose=True, print_every=1,
  107. ):
  108. """
  109. For each gap in gap_list, run `repeat` times over the same topology generator (per-repeat),
  110. and evaluate every scheduler. Cache the whole sweep with a single file lock.
  111. """
  112. config, sig = _gap_sweep_signature(
  113. gap_list, scheduler_names, noise_model,
  114. node_path_list, importance_list, bounces, repeat,
  115. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed,
  116. alpha_base=alpha_base, variance=variance,
  117. )
  118. cache_path = _shared_gap_path(noise_model, sig)
  119. lock_path = cache_path + ".lock"
  120. STALE_LOCK_SECS = 6 * 60 * 60
  121. HEARTBEAT_EVERY = 5.0
  122. rng = np.random.default_rng(seed)
  123. # Fast path: cached
  124. if os.path.exists(cache_path):
  125. with open(cache_path, "rb") as f:
  126. return pickle.load(f)
  127. # Lock acquisition loop
  128. got_lock = False
  129. while True:
  130. try:
  131. fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  132. os.close(fd)
  133. got_lock = True
  134. break
  135. except FileExistsError:
  136. # If cache appeared while waiting, load immediately.
  137. if os.path.exists(cache_path):
  138. with open(cache_path, "rb") as f:
  139. return pickle.load(f)
  140. try:
  141. age = time.time() - os.path.getmtime(lock_path)
  142. except OSError:
  143. age = 0
  144. if age > STALE_LOCK_SECS:
  145. try:
  146. os.remove(lock_path)
  147. except FileNotFoundError:
  148. pass
  149. continue
  150. time.sleep(1.0)
  151. try:
  152. if verbose:
  153. print(f"[gap-shared] Run gap sweep and cache to: {os.path.basename(cache_path)}", flush=True)
  154. data = {name: {k: [] for k in range(len(gap_list))} for name in scheduler_names}
  155. last_hb = time.time()
  156. # Repeat loop: per-repeat we will re-sample importance (if requested)
  157. for r in range(repeat):
  158. if verbose and ((r + 1) % print_every == 0 or r == 0):
  159. print(f"[gap-shared] Repeat {r+1}/{repeat}", flush=True)
  160. # Importance list per repeat
  161. if str(importance_mode).lower() == "uniform":
  162. a, b = map(float, importance_uniform)
  163. imp_list_r = [float(rng.uniform(a, b)) for _ in node_path_list]
  164. else:
  165. imp_list_r = list(importance_list)
  166. # Sweep over gaps
  167. for k, gap in enumerate(gap_list):
  168. if verbose:
  169. print(f"=== [GAP {noise_model}] gap={gap} ({k+1}/{len(gap_list)}) ===", flush=True)
  170. # Heartbeat
  171. now = time.time()
  172. if now - last_hb >= HEARTBEAT_EVERY:
  173. try:
  174. os.utime(lock_path, None)
  175. except FileNotFoundError:
  176. pass
  177. last_hb = now
  178. # Network generator for this 'gap' (fresh fidelities for each pair)
  179. def network_generator(path_num, pair_idx):
  180. fids = _fidelity_list_gap_random(
  181. path_num=path_num,
  182. gap=float(gap),
  183. rng=rng,
  184. alpha_base=alpha_base,
  185. variance=variance,
  186. )
  187. return QuantumNetwork(path_num, fids, noise_model)
  188. for name in scheduler_names:
  189. per_pair_results, total_cost, per_pair_details = run_scheduler(
  190. node_path_list=node_path_list, importance_list=imp_list_r,
  191. scheduler_name=name,
  192. bounces=list(bounces),
  193. C_total=int(C_total),
  194. network_generator=network_generator,
  195. return_details=True,
  196. )
  197. data[name][k].append({
  198. "per_pair_results": per_pair_results,
  199. "per_pair_details": per_pair_details,
  200. "total_cost": total_cost,
  201. "importance_list": imp_list_r,
  202. "gap": float(gap),
  203. "C_total": int(C_total),
  204. "alpha_base": float(alpha_base),
  205. "variance": float(variance),
  206. })
  207. payload = {
  208. "config": config,
  209. "gap_list": list(map(float, gap_list)),
  210. "data": data,
  211. }
  212. tmp = cache_path + ".tmp"
  213. with open(tmp, "wb") as f:
  214. pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
  215. os.replace(tmp, cache_path)
  216. return payload
  217. finally:
  218. if got_lock:
  219. try:
  220. os.remove(lock_path)
  221. except FileNotFoundError:
  222. pass
  223. # -----------------------------
  224. # Public API: plot (mean ± 95% CI)
  225. # -----------------------------
  226. def plot_accuracy_vs_gap(
  227. gap_list, scheduler_names, noise_model,
  228. node_path_list, importance_list,
  229. bounces=(1, 2, 3, 4), repeat=10,
  230. importance_mode="fixed", importance_uniform=(0.0, 1.0),
  231. seed=None,
  232. alpha_base=0.95, variance=0.10,
  233. C_total_override=None,
  234. verbose=True, print_every=1,
  235. ):
  236. """
  237. Make a plot with x = gap, y = accuracy (mean ± 95% CI).
  238. Uses alpha - beta = gap; fidelities are sampled per pair from Normal(mu, variance) clamped to [0.8,1.0].
  239. """
  240. file_name = f"plot_accuracy_vs_gap_{noise_model}"
  241. root_dir = os.path.dirname(os.path.abspath(__file__))
  242. outdir = os.path.join(root_dir, "outputs")
  243. os.makedirs(outdir, exist_ok=True)
  244. C_total = int(C_total_override) if C_total_override is not None else 5000
  245. payload = _run_or_load_shared_gap_sweep(
  246. gap_list, scheduler_names, noise_model,
  247. node_path_list, importance_list,
  248. bounces=bounces, repeat=repeat,
  249. importance_mode=importance_mode, importance_uniform=importance_uniform, seed=seed,
  250. alpha_base=alpha_base, variance=variance,
  251. C_total=C_total,
  252. verbose=verbose, print_every=print_every,
  253. )
  254. # Collect accuracy arrays per gap
  255. results = {name: {"accs": [[] for _ in gap_list]} for name in scheduler_names}
  256. for name in scheduler_names:
  257. for k in range(len(gap_list)):
  258. for run in payload["data"][name][k]:
  259. per_pair_results = run["per_pair_results"]
  260. vals = []
  261. for r in per_pair_results:
  262. if isinstance(r, tuple):
  263. c = r[0]
  264. elif isinstance(r, (int, float, bool)):
  265. c = bool(r)
  266. else:
  267. raise TypeError(f"per_pair_results element has unexpected type: {type(r)} -> {r}")
  268. vals.append(1.0 if c else 0.0)
  269. acc = float(np.mean(vals)) if vals else 0.0
  270. results[name]["accs"][k].append(acc)
  271. # Plot
  272. plt.rc("axes", prop_cycle=default_cycler)
  273. fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)
  274. xs = list(map(float, gap_list))
  275. for name, data in results.items():
  276. means, halfs = [], []
  277. for vals in data["accs"]:
  278. m, h = mean_ci95(vals)
  279. means.append(m); halfs.append(h)
  280. means = np.asarray(means); halfs = np.asarray(halfs)
  281. label = name.replace("Vanilla NB","VanillaNB").replace("Succ. Elim. NB","SuccElimNB")
  282. ax.plot(xs, means, linewidth=2.0, label=label)
  283. ax.fill_between(xs, means - halfs, means + halfs, alpha=0.25)
  284. ax.set_xlabel("Gap (alpha - beta)")
  285. ax.set_ylabel("Average Correctness (mean ± 95% CI)")
  286. ax.grid(True); ax.legend(title="Scheduler", fontsize=14, title_fontsize=18)
  287. pdf = os.path.join(outdir, f"{file_name}.pdf")
  288. plt.savefig(pdf)
  289. if shutil.which("pdfcrop"):
  290. os.system(f'pdfcrop --margins "8 8 8 8" "{pdf}" "{pdf}"')
  291. print(f"Saved: {pdf}", flush=True)
  292. return pdf
  293. if __name__ == "__main__":
  294. # Minimal example (safe defaults). Adjust as needed.
  295. gaps = [0.005, 0.01, 0.02, 0.03]
  296. scheds = ["Vanilla NB", "Succ. Elim. NB", "Greedy Two-Phase"]
  297. noise = "Depolar"
  298. node_paths = [5, 5, 5] # 3 pairs, each with 5 candidate links
  299. importances = [1.0, 1.0, 1.0]
  300. plot_accuracy_vs_gap(
  301. gap_list=gaps,
  302. scheduler_names=scheds,
  303. noise_model=noise,
  304. node_path_list=node_paths,
  305. importance_list=importances,
  306. bounces=(1,2,3,4),
  307. repeat=5,
  308. importance_mode="fixed",
  309. seed=42,
  310. alpha_base=0.95,
  311. variance=0.10,
  312. C_total_override=5000,
  313. )