lonline_nb.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # lonline_nb.py
  2. import math
  3. def lonline_network_benchmarking(network, path_list, bounces, C_budget, return_details=False):
  4. """
  5. L-Online 風の逐次削除型 NB。
  6. 返り値(常に一貫):
  7. return_details=False:
  8. (correctness: bool, cost: int, best_path_fidelity: float|None)
  9. return_details=True:
  10. (correctness: bool, cost: int, best_path_fidelity: float|None,
  11. alloc_by_path: dict[int,int], est_fid_by_path: dict[int,float])
  12. 想定 I/F:
  13. network.benchmark_path(path, bounces, sample_times) -> (p, used_cost)
  14. 忠実度変換: fidelity = p + (1 - p)/2
  15. """
  16. candidate_set = list(path_list)
  17. # 既存コード由来のパラメータ(必要に応じて合わせてください)
  18. s = 0
  19. C = 0.01
  20. delta = 0.1
  21. # 集計器
  22. cost = 0
  23. estimated_fidelities = {}
  24. # 詳細返却用の器(return_details に関わらず初期化:どの分岐でも形を揃える)
  25. alloc_by_path = {int(p): 0 for p in path_list}
  26. est_fid_by_path = {}
  27. if not candidate_set or C_budget <= 0:
  28. # 何も測れないケースでも形は揃えて返す
  29. if return_details:
  30. return False, int(cost), None, alloc_by_path, est_fid_by_path
  31. return False, int(cost), None
  32. # 1 経路を 1 サンプル測るコストの近似(ここでは hop 重みの和)
  33. cost_per_sample_unit = sum(bounces) if sum(bounces) > 0 else 1
  34. # ---- メインループ ----
  35. while cost < C_budget and len(candidate_set) > 1:
  36. s += 1
  37. # ラウンド s のサンプル回数(既存式)
  38. Ns = math.ceil(C * (2 ** (2 * s)) * math.log2(max((2 ** s) * len(candidate_set) / delta, 2)))
  39. if Ns < 4:
  40. Ns = 4
  41. # このラウンドで 1 経路に必要なコスト目安
  42. cost_needed_for_one_path = Ns * cost_per_sample_unit
  43. # 2 ラウンド目以降で 1 経路すら回せないなら終了
  44. if cost + cost_needed_for_one_path > C_budget and s > 1:
  45. break
  46. # hop ごとに同じ Ns を配る(network 側の想定 I/F に合わせる)
  47. sample_times = {h: int(Ns) for h in bounces}
  48. # ラウンド内の観測
  49. p_s = {}
  50. measured_paths = []
  51. for path in list(candidate_set):
  52. if cost + cost_needed_for_one_path > C_budget:
  53. continue # 予算が入らない経路はこのラウンドでは測らない
  54. # 実測
  55. p, used = network.benchmark_path(path, bounces, sample_times)
  56. cost += int(used)
  57. # 忠実度推定を更新(既存式)
  58. fidelity = p + (1 - p) / 2.0
  59. estimated_fidelities[path] = fidelity
  60. p_s[path] = p
  61. measured_paths.append(path)
  62. # 詳細集計
  63. alloc_by_path[int(path)] = alloc_by_path.get(int(path), 0) + int(used)
  64. est_fid_by_path[int(path)] = float(fidelity)
  65. # このラウンドで 1 本も測れなかったら終了
  66. if not p_s:
  67. break
  68. # 連続削除(幅 2^{-s})
  69. p_max = max(p_s.values())
  70. new_candidate_set = []
  71. for path in measured_paths:
  72. if p_s[path] + 2 ** (-s) > p_max - 2 ** (-s):
  73. new_candidate_set.append(path)
  74. # 全消し回避:空になったら据え置き
  75. candidate_set = new_candidate_set or candidate_set
  76. # 1 本も推定できなかった場合
  77. if not estimated_fidelities:
  78. if return_details:
  79. return False, int(cost), None, alloc_by_path, est_fid_by_path
  80. return False, int(cost), None
  81. # 最良推定パスと正解判定
  82. best_path = max(estimated_fidelities, key=estimated_fidelities.get)
  83. best_path_fidelity = estimated_fidelities[best_path]
  84. correctness = (best_path == getattr(network, "best_path", None))
  85. if return_details:
  86. return bool(correctness), int(cost), best_path_fidelity, alloc_by_path, est_fid_by_path
  87. return bool(correctness), int(cost), best_path_fidelity
  88. # 互換用エイリアス(古い呼び名を使っているコード向け)
  89. lonline_network_benchmarking_with_budget = lonline_network_benchmarking