simulation.py~ 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # simulation.py
  2. # Produces ONE CSV of raw data with a "scheduler" and "used" (actual spent cost) columns.
  3. from dataclasses import dataclass
  4. from typing import List, Dict, Tuple
  5. import math
  6. import csv
  7. import random
  8. from fidelity import generate_fidelity_list_random, generate_importance_list_random
  9. @dataclass
  10. class SimConfig:
  11. n_pairs: int = 3
  12. links_per_pair: int = 5
  13. budgets: List[int] = None
  14. trials: int = 10
  15. seed: int = 42
  16. init_samples_per_link: int = 4
  17. delta: float = 0.1
  18. cost_per_sample: int = 1
  19. schedulers: List[str] = None # names for series
  20. def __post_init__(self):
  21. if self.budgets is None:
  22. self.budgets = [500, 1000, 2000, 5000]
  23. if self.schedulers is None:
  24. self.schedulers = ["GreedySimple"]
  25. def _radius(n: int, delta: float) -> float:
  26. if n <= 0:
  27. return float("inf")
  28. return math.sqrt(0.5 * math.log(2.0 / max(1e-12, delta)) / n)
  29. def _argmax(d: Dict[int, float]) -> int:
  30. best_k = None
  31. best_v = -1e9
  32. for k, v in d.items():
  33. if v > best_v or (v == best_v and (best_k is None or k < best_k)):
  34. best_k, best_v = k, v
  35. return best_k if best_k is not None else -1
  36. def _run_scheduler_greedy_simple(n_pairs, links_per_pair, budgets_sorted, delta, init_samples_per_link, cost_per_sample, true_fids_per_pair, importances, writer, trial, scheduler_name):
  37. # Per-budget deterministic re-run for snapshots
  38. rng_state0 = random.getstate()
  39. for b in budgets_sorted:
  40. random.setstate(rng_state0)
  41. est: Dict[Tuple[int,int], float] = {}
  42. ns: Dict[Tuple[int,int], int] = {}
  43. for p in range(n_pairs):
  44. for l in range(links_per_pair):
  45. est[(p,l)] = 0.0
  46. ns[(p,l)] = 0
  47. used = 0
  48. # phase-1: uniform
  49. stop = False
  50. for p in range(n_pairs):
  51. for l in range(links_per_pair):
  52. for _ in range(init_samples_per_link):
  53. x = 1 if random.random() < true_fids_per_pair[p][l] else 0
  54. ns[(p,l)] += 1
  55. est[(p,l)] += (x - est[(p,l)]) / ns[(p,l)]
  56. used += cost_per_sample
  57. if used >= b:
  58. stop = True
  59. break
  60. if stop: break
  61. if stop: break
  62. # phase-2: greedy
  63. while used < b:
  64. for p in range(n_pairs):
  65. cur = {l: est[(p,l)] for l in range(links_per_pair)}
  66. best_l = max(cur.keys(), key=lambda kk: cur[kk])
  67. x = 1 if random.random() < true_fids_per_pair[p][best_l] else 0
  68. ns[(p,best_l)] += 1
  69. est[(p,best_l)] += (x - est[(p,best_l)]) / ns[(p,best_l)]
  70. used += cost_per_sample
  71. if used >= b:
  72. break
  73. # emit rows (same 'used' for all rows in this (trial,budget,scheduler))
  74. for p in range(n_pairs):
  75. tb = max(range(links_per_pair), key=lambda l: true_fids_per_pair[p][l])
  76. cur_map = {l: est[(p,l)] for l in range(links_per_pair)}
  77. ca = max(range(links_per_pair), key=lambda l: cur_map[l])
  78. for l in range(links_per_pair):
  79. k = (p,l)
  80. r = _radius(ns[k], delta)
  81. lb = max(0.0, est[k] - r if ns[k] > 0 else 0.0)
  82. ub = min(1.0, est[k] + r if ns[k] > 0 else 1.0)
  83. writer.writerow([
  84. scheduler_name, trial, b, used, p, l, importances[p], true_fids_per_pair[p][l],
  85. est[k], ns[k], lb, ub,
  86. 1 if l == tb else 0,
  87. 1 if l == ca else 0,
  88. ])
  89. random.setstate(rng_state0)
  90. def run_simulation(csv_path: str, cfg: SimConfig) -> None:
  91. random.seed(cfg.seed)
  92. budgets_sorted = sorted(set(cfg.budgets))
  93. with open(csv_path, "w", newline="") as f:
  94. w = csv.writer(f)
  95. w.writerow([
  96. "scheduler","trial","budget_target","used","pair_id","link_id","importance","true_fid",
  97. "est_mean","n_samples","lb","ub","is_true_best","is_pair_current_argmax"
  98. ])
  99. for trial in range(cfg.trials):
  100. # Ground truth per pair
  101. true_fids_per_pair: Dict[int, Dict[int, float]] = {}
  102. importances: Dict[int, float] = {}
  103. for p in range(cfg.n_pairs):
  104. true_list = generate_fidelity_list_random(cfg.links_per_pair)
  105. true_fids_per_pair[p] = {i: true_list[i] for i in range(cfg.links_per_pair)}
  106. importances[p] = generate_importance_list_random(1)[0]
  107. for sched_name in cfg.schedulers:
  108. _run_scheduler_greedy_simple(
  109. cfg.n_pairs, cfg.links_per_pair, budgets_sorted, cfg.delta,
  110. cfg.init_samples_per_link, cfg.cost_per_sample,
  111. true_fids_per_pair, importances, w, trial, sched_name
  112. )