simulation.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # simulation.py
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. from dataclasses import dataclass
  5. from typing import Dict, Any, List
  6. import csv, os, math, random
  7. # ===== 既存ネットワークAPIに合わせたアダプタ =====
  8. class Adapter:
  9. """
  10. あなたの network.py / nb_protocol.py を変更せずに使うための薄いラッパ。
  11. - QuantumNetwork(path_num, fidelity_list, noise_model) を自前で構築
  12. - 単一ペア 'Alice-Bob' に path_id=1..path_num を割当
  13. - スケジューラが期待する nb_protocol 互換API(sample_path / true_fidelity)を Shim で提供
  14. """
  15. def __init__(self, noise_model: str, path_num: int, fidelity_list: List[float], seed: int | None = None):
  16. if seed is not None:
  17. random.seed(seed)
  18. import network as qnet
  19. # QuantumNetwork を直接構築(network.py を変更しない)
  20. self.net = qnet.QuantumNetwork(path_num=path_num, fidelity_list=fidelity_list, noise_model=noise_model)
  21. self.pairs = ["Alice-Bob"]
  22. self.paths_map = {"Alice-Bob": list(range(1, path_num + 1))}
  23. # nb_protocol 互換 Shim
  24. self.nbp = _NBPShim(self.net)
  25. # ---- ヘルパ ----
  26. def true_fidelity(self, path_id: Any) -> float:
  27. return self.nbp.true_fidelity(self.net, path_id)
  28. def list_pairs(self) -> List[Any]:
  29. return list(self.pairs)
  30. def list_paths_of(self, pair_id: Any) -> List[Any]:
  31. return list(self.paths_map.get(pair_id, []))
  32. # ---- スケジューラ呼び出し ----
  33. def run_scheduler(self, scheduler_name: str, budget_target: int,
  34. importance: Dict[Any, float]) -> Dict[str, Any]:
  35. """
  36. スケジューラに共通IFで実行要求する。
  37. 返り値の想定(辞書):
  38. {
  39. 'used_cost_total': int,
  40. 'per_pair_details': [
  41. {
  42. 'pair_id': pair_id,
  43. 'alloc_by_path': {path_id: sample_count, ...},
  44. 'est_fid_by_path': {path_id: mean_estimate, ...},
  45. 'best_pred_path': path_id,
  46. }, ...
  47. ]
  48. }
  49. """
  50. if scheduler_name == "greedy":
  51. from schedulers.greedy_scheduler import run as greedy_run
  52. return greedy_run(self.net, self.pairs, self.paths_map, budget_target, importance, self.nbp)
  53. elif scheduler_name == "naive":
  54. from schedulers.lnaive_scheduler import run as naive_run
  55. return naive_run(self.net, self.pairs, self.paths_map, budget_target, importance, self.nbp)
  56. elif scheduler_name == "online_nb":
  57. from schedulers.lonline_nb import run as onb_run
  58. return onb_run(self.net, self.pairs, self.paths_map, budget_target, importance, self.nbp)
  59. else:
  60. raise ValueError(f"unknown scheduler: {scheduler_name}")
  61. # ===== 便利関数 =====
  62. def hoeffding_radius(n: int, delta: float = 0.05) -> float:
  63. if n <= 0:
  64. return 1.0
  65. return math.sqrt(0.5 * math.log(2.0 / delta) / n)
  66. def clamp01(x: float) -> float:
  67. return 0.0 if x < 0.0 else (1.0 if x > 1.0 else x)
  68. # ===== CSV I/O =====
  69. CSV_HEADER = [
  70. "run_id", "noise", "scheduler", "budget_target",
  71. "used_cost_total",
  72. "pair_id", "path_id",
  73. "importance", # I_d
  74. "samples", # B_{d,l}
  75. "est_mean", "lb", "ub", "width",
  76. "is_best_true", "is_best_pred"
  77. ]
  78. def open_csv(path: str):
  79. os.makedirs(os.path.dirname(path), exist_ok=True)
  80. exists = os.path.exists(path)
  81. f = open(path, "a", newline="")
  82. w = csv.writer(f)
  83. if not exists:
  84. w.writerow(CSV_HEADER)
  85. return f, w
  86. # ===== メインシミュレーション =====
  87. @dataclass
  88. class ExperimentConfig:
  89. noise_model: str
  90. budgets: List[int]
  91. schedulers: List[str] # ["greedy", "naive", "online_nb", ...]
  92. repeats: int
  93. importance_mode: str = "both" # "both" / "weighted_only" / "unweighted_only"
  94. delta_ci: float = 0.05 # 95%CI相当
  95. out_csv: str = "outputs/raw_simulation_data.csv"
  96. seed: int | None = None
  97. # QuantumNetwork 構築用
  98. path_num: int = 5
  99. fidelity_list: List[float] | None = None
  100. def _importance_for_pairs(pairs: List[Any], mode: str) -> Dict[str, Dict[Any, float]]:
  101. res: Dict[str, Dict[Any, float]] = {}
  102. if mode in ("both", "unweighted_only"):
  103. res["unweighted"] = {p: 1.0 for p in pairs}
  104. if mode in ("both", "weighted_only"):
  105. # 重要度は例として一様乱数(必要なら差替え)
  106. res["weighted"] = {p: 0.5 + random.random() for p in pairs}
  107. return res
  108. def run_and_append_csv(cfg: ExperimentConfig) -> str:
  109. fid = cfg.fidelity_list or _default_fidelities(cfg.path_num)
  110. adp = Adapter(cfg.noise_model, cfg.path_num, fid, seed=cfg.seed)
  111. pairs = adp.list_pairs()
  112. importance_sets = _importance_for_pairs(pairs, cfg.importance_mode)
  113. f, w = open_csv(cfg.out_csv)
  114. try:
  115. run_id = 0
  116. for _ in range(cfg.repeats):
  117. run_id += 1
  118. for budget in cfg.budgets:
  119. for sched in cfg.schedulers:
  120. for imp_tag, I in importance_sets.items():
  121. # スケジューラ実行
  122. result = adp.run_scheduler(sched, budget, I)
  123. used_cost_total = int(result.get("used_cost_total", budget))
  124. per_pair_details: List[Dict[str, Any]] = result.get("per_pair_details", [])
  125. # 真の最良パス(正答率判定用)
  126. true_best_by_pair = {}
  127. for pair in pairs:
  128. paths = adp.list_paths_of(pair)
  129. best = None
  130. bestv = -1.0
  131. for pid in paths:
  132. tf = adp.true_fidelity(pid)
  133. if tf > bestv:
  134. bestv, best = tf, pid
  135. true_best_by_pair[pair] = best
  136. # CSV行を形成
  137. for det in per_pair_details:
  138. pair_id = det["pair_id"]
  139. alloc = det.get("alloc_by_path", {}) or {}
  140. est = det.get("est_fid_by_path", {}) or {}
  141. pred = det.get("best_pred_path")
  142. for path_id, samples in alloc.items():
  143. m = float(est.get(path_id, 0.5))
  144. r = hoeffding_radius(int(samples), cfg.delta_ci)
  145. lb = clamp01(m - r)
  146. ub = clamp01(m + r)
  147. width = ub - lb
  148. is_true_best = (true_best_by_pair.get(pair_id) == path_id)
  149. is_best_pred = (pred == path_id)
  150. w.writerow([
  151. f"{run_id}-{imp_tag}",
  152. cfg.noise_model,
  153. sched,
  154. budget,
  155. used_cost_total,
  156. pair_id,
  157. path_id,
  158. I.get(pair_id, 1.0),
  159. int(samples),
  160. m, lb, ub, width,
  161. int(is_true_best), int(is_best_pred),
  162. ])
  163. finally:
  164. f.close()
  165. return cfg.out_csv
  166. # ===== nb_protocol 互換 Shim =====
  167. class _NBPShim:
  168. """
  169. スケジューラが期待する nb_protocol 風のAPIを提供:
  170. - sample_path(net, path_id, n): QuantumNetwork.benchmark_path を呼ぶ
  171. - true_fidelity(net, path_id): 量子チャネルの ground truth を返す
  172. """
  173. def __init__(self, net):
  174. self.net = net
  175. def sample_path(self, net, path_id: int, n: int) -> float:
  176. # 1-bounce を n 回の測定にマップ(nb_protocol.NBProtocolAlice の設計に整合)
  177. p, _cost = self.net.benchmark_path(path_id, bounces=[1], sample_times={1: int(n)})
  178. return float(p)
  179. def true_fidelity(self, net, path_id: int) -> float:
  180. return float(self.net.quantum_channels[path_id - 1].fidelity)
  181. # ===== デフォルト忠実度の簡易生成(必要なら差替え) =====
  182. def _default_fidelities(path_num: int) -> List[float]:
  183. alpha, beta, var = 0.93, 0.85, 0.02
  184. res = [max(0.8, min(0.999, random.gauss(beta, var))) for _ in range(path_num)]
  185. best_idx = random.randrange(path_num)
  186. res[best_idx] = max(0.85, min(0.999, random.gauss(alpha, var)))
  187. return res