| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import math
- def lonline_network_benchmarking(network, path_list, bounces, C_budget):
- candidate_set = list(path_list)
- s = 0
- C = 0.01
- delta = 0.1
- cost = 0
- estimated_fidelities = {}
- # 1経路を1サンプル測る想定コスト(list 前提)
- cost_per_sample_unit = sum(bounces) if sum(bounces) > 0 else 1
- while cost < C_budget and len(candidate_set) > 1:
- s += 1
- Ns = math.ceil(C * (2 ** (2 * s)) * math.log2(max((2 ** s) * len(candidate_set) / delta, 2)))
- if Ns < 4:
- Ns = 4
- # --- 事前コスト見積り(lonline準拠) ---
- cost_needed_for_one_path = Ns * cost_per_sample_unit
- # 2ラウンド目以降で1経路ぶんすら入らないなら終了
- if cost + cost_needed_for_one_path > C_budget and s > 1:
- break
- # 1ラウンド目だけは Ns を縮退してでも1回は回す(入らなければ中止)
- """
- if cost + cost_needed_for_one_path > C_budget and s == 1:
- Ns_fit = (C_budget - cost) // max(cost_per_sample_unit, 1)
- if Ns_fit <= 0:
- break
- Ns = int(Ns_fit)
- cost_needed_for_one_path = Ns * cost_per_sample_unit
- """
- # ---------------------------------------
- sample_times = {i: Ns for i in bounces}
- p_s = {}
- measured_paths = [] # このラウンドで実際に測れた経路だけを削除判定に使う
- for path in list(candidate_set):
- if cost + cost_needed_for_one_path > C_budget:
- continue # 予算に入らない経路はこのラウンドはスキップ
- p, bounces_num = network.benchmark_path(path, bounces, sample_times)
- cost += bounces_num
- estimated_fidelities[path] = p + (1 - p) / 2 # ★ 既存式&変数名を踏襲
- p_s[path] = p
- measured_paths.append(path)
- if not p_s:
- break # このラウンドで1つも測れなかった
- # online_nb.py と同じ 2^{-s} 幅の連続削除
- p_max = max(p_s.values())
- new_candidate_set = []
- for path in measured_paths: # 測れたものだけで判定(KeyError防止)
- if p_s[path] + 2**(-s) > p_max - 2**(-s):
- new_candidate_set.append(path)
- # もし全消しになったら、保険として現集合を維持
- candidate_set = new_candidate_set or candidate_set
- if not estimated_fidelities:
- return None, cost, None
- best_path = max(estimated_fidelities, key=estimated_fidelities.get)
- best_path_fidelity = estimated_fidelities[best_path]
- correctness = (best_path == getattr(network, "best_path", None))
- return correctness, cost, best_path_fidelity
|