| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- # simulation.py
- # Produces ONE CSV of raw data with a "scheduler" and "used" (actual spent cost) columns.
- from dataclasses import dataclass
- from typing import List, Dict, Tuple
- import math
- import csv
- import random
- from fidelity import generate_fidelity_list_random, generate_importance_list_random
- @dataclass
- class SimConfig:
- n_pairs: int = 3
- links_per_pair: int = 5
- budgets: List[int] = None
- trials: int = 10
- seed: int = 42
- init_samples_per_link: int = 4
- delta: float = 0.1
- cost_per_sample: int = 1
- schedulers: List[str] = None # names for series
- def __post_init__(self):
- if self.budgets is None:
- self.budgets = [500, 1000, 2000, 5000]
- if self.schedulers is None:
- self.schedulers = ["GreedySimple"]
- def _radius(n: int, delta: float) -> float:
- if n <= 0:
- return float("inf")
- return math.sqrt(0.5 * math.log(2.0 / max(1e-12, delta)) / n)
- def _argmax(d: Dict[int, float]) -> int:
- best_k = None
- best_v = -1e9
- for k, v in d.items():
- if v > best_v or (v == best_v and (best_k is None or k < best_k)):
- best_k, best_v = k, v
- return best_k if best_k is not None else -1
- def _run_scheduler_greedy_simple(n_pairs, links_per_pair, budgets_sorted, delta, init_samples_per_link, cost_per_sample, true_fids_per_pair, importances, writer, trial, scheduler_name):
- # Per-budget deterministic re-run for snapshots
- rng_state0 = random.getstate()
- for b in budgets_sorted:
- random.setstate(rng_state0)
- est: Dict[Tuple[int,int], float] = {}
- ns: Dict[Tuple[int,int], int] = {}
- for p in range(n_pairs):
- for l in range(links_per_pair):
- est[(p,l)] = 0.0
- ns[(p,l)] = 0
- used = 0
- # phase-1: uniform
- stop = False
- for p in range(n_pairs):
- for l in range(links_per_pair):
- for _ in range(init_samples_per_link):
- x = 1 if random.random() < true_fids_per_pair[p][l] else 0
- ns[(p,l)] += 1
- est[(p,l)] += (x - est[(p,l)]) / ns[(p,l)]
- used += cost_per_sample
- if used >= b:
- stop = True
- break
- if stop: break
- if stop: break
- # phase-2: greedy
- while used < b:
- for p in range(n_pairs):
- cur = {l: est[(p,l)] for l in range(links_per_pair)}
- best_l = max(cur.keys(), key=lambda kk: cur[kk])
- x = 1 if random.random() < true_fids_per_pair[p][best_l] else 0
- ns[(p,best_l)] += 1
- est[(p,best_l)] += (x - est[(p,best_l)]) / ns[(p,best_l)]
- used += cost_per_sample
- if used >= b:
- break
- # emit rows (same 'used' for all rows in this (trial,budget,scheduler))
- for p in range(n_pairs):
- tb = max(range(links_per_pair), key=lambda l: true_fids_per_pair[p][l])
- cur_map = {l: est[(p,l)] for l in range(links_per_pair)}
- ca = max(range(links_per_pair), key=lambda l: cur_map[l])
- for l in range(links_per_pair):
- k = (p,l)
- r = _radius(ns[k], delta)
- lb = max(0.0, est[k] - r if ns[k] > 0 else 0.0)
- ub = min(1.0, est[k] + r if ns[k] > 0 else 1.0)
- writer.writerow([
- scheduler_name, trial, b, used, p, l, importances[p], true_fids_per_pair[p][l],
- est[k], ns[k], lb, ub,
- 1 if l == tb else 0,
- 1 if l == ca else 0,
- ])
- random.setstate(rng_state0)
- def run_simulation(csv_path: str, cfg: SimConfig) -> None:
- random.seed(cfg.seed)
- budgets_sorted = sorted(set(cfg.budgets))
- with open(csv_path, "w", newline="") as f:
- w = csv.writer(f)
- w.writerow([
- "scheduler","trial","budget_target","used","pair_id","link_id","importance","true_fid",
- "est_mean","n_samples","lb","ub","is_true_best","is_pair_current_argmax"
- ])
- for trial in range(cfg.trials):
- # Ground truth per pair
- true_fids_per_pair: Dict[int, Dict[int, float]] = {}
- importances: Dict[int, float] = {}
- for p in range(cfg.n_pairs):
- true_list = generate_fidelity_list_random(cfg.links_per_pair)
- true_fids_per_pair[p] = {i: true_list[i] for i in range(cfg.links_per_pair)}
- importances[p] = generate_importance_list_random(1)[0]
- for sched_name in cfg.schedulers:
- _run_scheduler_greedy_simple(
- cfg.n_pairs, cfg.links_per_pair, budgets_sorted, cfg.delta,
- cfg.init_samples_per_link, cfg.cost_per_sample,
- true_fids_per_pair, importances, w, trial, sched_name
- )
|