lonline_nb.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # lonline_nb.py
  2. import math
  3. def _ns_for_round(s, k, C_const, delta, min_sets):
  4. """LinkSelFiE 由来の Ns(s,k) を計算(k=len(candidate_set))。"""
  5. Ns = math.ceil(C_const * (2 ** (2 * s)) * math.log2(max((2 ** s) * k / delta, 2)))
  6. return max(Ns, min_sets)
  7. def lonline_init(
  8. network, path_list, bounces, C_budget,
  9. *, return_details=False, C_const=0.01, delta=0.1, min_sets=4
  10. ):
  11. """
  12. 広域探索フェーズ(s=1 の 1 ラウンドのみ)。候補全リンクに一律 Ns セットを投入できる場合に限り実行。
  13. 出力:
  14. correctness, cost, best_path_fidelity, [alloc_by_path, est_fid_by_path,] state
  15. """
  16. # ★ 受け取った path_list を“公開ID(=1..L)”として、そのまま使う
  17. candidate_set = list(path_list)
  18. alloc_by_path = {int(p): 0 for p in path_list}
  19. est_fid_by_path = {}
  20. estimated_fidelities = {}
  21. cost = 0
  22. if not candidate_set or C_budget <= 0:
  23. base = (False, 0, None)
  24. if return_details:
  25. base += (alloc_by_path, est_fid_by_path)
  26. state = {
  27. "s": 1, "candidate_set": candidate_set, "estimated_fidelities": estimated_fidelities,
  28. "alloc_by_path": alloc_by_path, "est_fid_by_path": est_fid_by_path, "bounces": list(bounces),
  29. "C_const": C_const, "delta": delta, "min_sets": min_sets
  30. }
  31. return (*base, state)
  32. c_B = sum(bounces) if sum(bounces) > 0 else 1
  33. s = 1
  34. Ns = _ns_for_round(s, len(candidate_set), C_const, delta, min_sets)
  35. # 候補全リンクに一律 Ns セットを入れられるかを事前判定(途中打ち切りなしを保証)
  36. round_cost_all = len(candidate_set) * Ns * c_B
  37. if round_cost_all > C_budget:
  38. base = (False, 0, None)
  39. if return_details:
  40. base += (alloc_by_path, est_fid_by_path)
  41. state = {
  42. "s": s, "candidate_set": candidate_set, "estimated_fidelities": estimated_fidelities,
  43. "alloc_by_path": alloc_by_path, "est_fid_by_path": est_fid_by_path, "bounces": list(bounces),
  44. "C_const": C_const, "delta": delta, "min_sets": min_sets
  45. }
  46. return (*base, state)
  47. # ここから候補全リンクに一律 Ns セット投入(途中打ち切りなし)
  48. sample_times = {h: int(Ns) for h in bounces}
  49. p_s, measured = {}, []
  50. for path in list(candidate_set):
  51. p, used = network.benchmark_path(path, bounces, sample_times)
  52. cost += int(used)
  53. fidelity = p + (1 - p) / 2.0
  54. estimated_fidelities[path] = fidelity
  55. p_s[path] = p
  56. measured.append(path)
  57. alloc_by_path[int(path)] = alloc_by_path.get(int(path), 0) + int(used)
  58. est_fid_by_path[int(path)] = float(fidelity)
  59. # 逐次除去(現時点では 2^{-s} 閾値ルール)
  60. if p_s:
  61. p_max = max(p_s.values())
  62. new_cand = [path for path in measured if (p_s[path] + 2 ** (-s) > p_max - 2 ** (-s))]
  63. candidate_set = new_cand or candidate_set
  64. best_path_fid = None
  65. if estimated_fidelities:
  66. best_path = max(estimated_fidelities, key=estimated_fidelities.get)
  67. best_path_fid = estimated_fidelities[best_path]
  68. correctness = (best_path == getattr(network, "best_path", None))
  69. else:
  70. correctness = False
  71. state = {
  72. "s": s, "candidate_set": candidate_set, "estimated_fidelities": estimated_fidelities,
  73. "alloc_by_path": alloc_by_path, "est_fid_by_path": est_fid_by_path, "bounces": list(bounces),
  74. "C_const": C_const, "delta": delta, "min_sets": min_sets
  75. }
  76. base = (bool(correctness), int(cost), best_path_fid)
  77. if return_details:
  78. base += (alloc_by_path, est_fid_by_path)
  79. return (*base, state)
  80. def lonline_continue(
  81. network, C_budget, *, state, return_details=False,
  82. C_const=None, delta=None, min_sets=None
  83. ):
  84. """
  85. 集中的活用フェーズ。state を引き継ぎ s=2 以降を「フェーズ単位」で実施。
  86. 1フェーズ分の均等投入コストが入らない場合は何も実行せず insufficient_budget=True を返す。
  87. 出力:
  88. correctness, cost, best_path_fidelity, [alloc_by_path, est_fid_by_path,] new_state, insufficient_budget
  89. """
  90. # 引き継ぎ
  91. s = int(state.get("s", 1))
  92. candidate_set = list(state.get("candidate_set", []))
  93. estimated_fidelities = dict(state.get("estimated_fidelities", {}))
  94. alloc_by_path = dict(state.get("alloc_by_path", {}))
  95. est_fid_by_path = dict(state.get("est_fid_by_path", {}))
  96. bounces = list(state.get("bounces", []))
  97. C_const = state.get("C_const", 0.01) if C_const is None else C_const
  98. delta = state.get("delta", 0.1) if delta is None else delta
  99. min_sets = state.get("min_sets", 4) if min_sets is None else min_sets
  100. cost = 0
  101. if not candidate_set or C_budget <= 0 or len(candidate_set) <= 1:
  102. best_path_fid = None
  103. if estimated_fidelities:
  104. best_path = max(estimated_fidelities, key=estimated_fidelities.get)
  105. best_path_fid = estimated_fidelities[best_path]
  106. correctness = (best_path == getattr(network, "best_path", None))
  107. else:
  108. correctness = False
  109. base = (bool(correctness), int(cost), best_path_fid)
  110. if return_details:
  111. base += (alloc_by_path, est_fid_by_path)
  112. return (*base, {**state, "s": s}, False)
  113. c_B = sum(bounces) if sum(bounces) > 0 else 1
  114. insufficient_budget = False
  115. while cost < C_budget and len(candidate_set) > 1:
  116. s += 1
  117. Ns = _ns_for_round(s, len(candidate_set), C_const, delta, min_sets)
  118. # 候補全リンクに一律 Ns セット(途中打ち切りなし)
  119. round_cost_all = len(candidate_set) * Ns * c_B
  120. if cost + round_cost_all > C_budget:
  121. insufficient_budget = True
  122. s -= 1
  123. break
  124. sample_times = {h: int(Ns) for h in bounces}
  125. p_s, measured = {}, []
  126. for path in list(candidate_set):
  127. p, used = network.benchmark_path(path, bounces, sample_times)
  128. cost += int(used)
  129. fidelity = p + (1 - p) / 2.0
  130. estimated_fidelities[path] = fidelity
  131. p_s[path] = p
  132. measured.append(path)
  133. alloc_by_path[int(path)] = alloc_by_path.get(int(path), 0) + int(used)
  134. est_fid_by_path[int(path)] = float(fidelity)
  135. if not p_s:
  136. break
  137. p_max = max(p_s.values())
  138. new_cand = [path for path in measured if (p_s[path] + 2 ** (-s) > p_max - 2 ** (-s))]
  139. candidate_set = new_cand or candidate_set
  140. best_path_fid = None
  141. if estimated_fidelities:
  142. best_path = max(estimated_fidelities, key=estimated_fidelities.get)
  143. best_path_fid = estimated_fidelities[best_path]
  144. correctness = (best_path == getattr(network, "best_path", None))
  145. else:
  146. correctness = False
  147. new_state = {
  148. "s": s, "candidate_set": candidate_set, "estimated_fidelities": estimated_fidelities,
  149. "alloc_by_path": alloc_by_path, "est_fid_by_path": est_fid_by_path, "bounces": bounces,
  150. "C_const": C_const, "delta": delta, "min_sets": min_sets
  151. }
  152. base = (bool(correctness), int(cost), best_path_fid)
  153. if return_details:
  154. base += (alloc_by_path, est_fid_by_path)
  155. return (*base, new_state, insufficient_budget)
  156. def _dry_phase_cost(state, C_budget, C_const=None, delta=None, min_sets=None):
  157. s = int(state.get("s", 1))
  158. k = len(state.get("candidate_set", []))
  159. bounces = state.get("bounces", [])
  160. c_B = sum(bounces) if bounces else 1
  161. C_const = state.get("C_const", 0.01) if C_const is None else C_const
  162. delta = state.get("delta", 0.1) if delta is None else delta
  163. min_sets = state.get("min_sets", 4) if min_sets is None else min_sets
  164. def Ns(s, k):
  165. val = math.ceil(C_const * (2**(2*s)) * math.log2(max((2**s)*k/delta, 2)))
  166. return max(val, min_sets)
  167. need_s2 = k * Ns(2, k) * c_B
  168. need_s3 = k * Ns(3, k) * c_B
  169. return dict(s=s, k=k, c_B=c_B, need_s2=need_s2, need_s3=need_s3, C_budget=int(C_budget))