greedy_scheduler.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # schedulers/greedy_scheduler.py
  2. from .lonline_nb import lonline_init, lonline_continue
  3. def greedy_budget_scheduler(
  4. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  5. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ
  6. bounces, # 例: [1,2,3,4](重複なし)
  7. C_total, # 総予算
  8. network_generator, # callable: (path_num, pair_idx) -> network
  9. min_sets_per_link=4, # 互換用(lonline 側で min=4 を保証)
  10. return_details=False,
  11. # 既定値は現状コードと同じ:C_const=0.01, delta=0.1
  12. C_const=0.01,
  13. delta=0.1,
  14. ):
  15. """
  16. Two-Phase Greedy スケジューラ(1-origin対応)
  17. - 入出力キーは常に 1..L
  18. """
  19. # 前処理
  20. N_pairs = len(node_path_list)
  21. networks = [None] * N_pairs
  22. states = [None] * N_pairs
  23. per_pair_results = [(False, 0, None)] * N_pairs
  24. per_pair_details = [dict(alloc_by_path={}, est_fid_by_path={}) for _ in range(N_pairs)]
  25. init_costs = [0] * N_pairs
  26. f_init = [0.0] * N_pairs
  27. consumed_total = 0
  28. # -----------------------
  29. # フェーズ1: 広域探索
  30. # -----------------------
  31. for pair_idx, path_num in enumerate(node_path_list):
  32. if consumed_total >= C_total or path_num <= 0:
  33. per_pair_results[pair_idx] = (False, 0, None)
  34. continue
  35. print(f"[INIT] pair={pair_idx} remain={int(C_total)-int(consumed_total)} "
  36. f"paths={path_num} bounces={bounces}")
  37. remaining = int(C_total) - int(consumed_total)
  38. if remaining <= 0:
  39. break
  40. # ★ 1-origin の path_list
  41. path_list = list(range(1, int(path_num) + 1))
  42. network = network_generator(int(path_num), pair_idx)
  43. networks[pair_idx] = network
  44. if return_details:
  45. correctness, cost, best_path_fid, alloc0, est0, state = lonline_init(
  46. network, path_list, list(bounces), int(remaining),
  47. return_details=True, C_const=C_const, delta=delta, min_sets=4
  48. )
  49. for l, b in alloc0.items():
  50. per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
  51. per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
  52. per_pair_details[pair_idx]["est_fid_by_path"].update(
  53. {int(k): float(v) for k, v in est0.items()}
  54. )
  55. print(f"[INIT<-] pair={pair_idx} cost={int(cost)} best_path_fid={best_path_fid} "
  56. f"s={state.get('s') if state else None} "
  57. f"k={len(state.get('candidate_set', [])) if state else None}")
  58. _sum_after_init = sum(per_pair_details[pair_idx]["alloc_by_path"].values())
  59. print(f"[CHECK:init] pair={pair_idx} sum_alloc_by_path={_sum_after_init} "
  60. f"(should equal init cost={int(cost)})")
  61. else:
  62. correctness, cost, best_path_fid, state = lonline_init(
  63. network, path_list, list(bounces), int(remaining),
  64. return_details=False, C_const=C_const, delta=delta, min_sets=4
  65. )
  66. print(f"[INIT<-] pair={pair_idx} cost={int(cost)} best_path_fid={best_path_fid} "
  67. f"s={state.get('s') if state else None} "
  68. f"k={len(state.get('candidate_set', [])) if state else None}")
  69. init_costs[pair_idx] = int(cost)
  70. consumed_total += int(cost)
  71. states[pair_idx] = state
  72. f_init[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0
  73. per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid)
  74. if consumed_total >= C_total:
  75. break
  76. print(f"[CHECK:after-init] sum_init={sum(init_costs)} consumed_total_after_init={consumed_total}")
  77. # V_n = I_n * f_n^(init)
  78. def _score(idx):
  79. imp = importance_list[idx] if importance_list is not None else 1.0
  80. return float(imp) * float(f_init[idx])
  81. order = sorted(
  82. [i for i in range(N_pairs) if (states[i] is not None and node_path_list[i] > 0)],
  83. key=_score,
  84. reverse=True
  85. )
  86. debug_scores = [(i, _score(i)) for i in range(N_pairs)]
  87. print(f"[ORDER] by importance*init_fid desc: {sorted(debug_scores, key=lambda x: x[1], reverse=True)}")
  88. # -----------------------
  89. # フェーズ2: 集中的活用
  90. # -----------------------
  91. for pair_idx in order:
  92. print(f"[GREEDY] pre-loop pair={pair_idx} consumed_total={consumed_total}")
  93. if consumed_total >= C_total:
  94. break
  95. if states[pair_idx] is None:
  96. continue
  97. network = networks[pair_idx]
  98. state = states[pair_idx]
  99. while consumed_total < C_total:
  100. remaining = int(C_total) - int(consumed_total)
  101. if remaining <= 0:
  102. break
  103. print(f"[GREEDY] pair={pair_idx} remain={remaining} "
  104. f"s={state.get('s') if state else None} "
  105. f"k={len(state.get('candidate_set', [])) if state else None}")
  106. if return_details:
  107. correctness, cost, best_path_fid, alloc_more, est_more, new_state, insufficient = lonline_continue(
  108. network, int(remaining), state=state, return_details=True
  109. )
  110. print(f"[GREEDY<-] pair={pair_idx} cost={int(cost)} insufficient={bool(insufficient)} "
  111. f"best_path_fid={best_path_fid} "
  112. f"s'={new_state.get('s') if new_state else None} "
  113. f"k'={len(new_state.get('candidate_set', [])) if new_state else None} "
  114. f"consumed_total→{consumed_total + int(cost)}")
  115. for l, b in alloc_more.items():
  116. per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
  117. per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
  118. per_pair_details[pair_idx]["est_fid_by_path"].update(
  119. {int(k): float(v) for k, v in est_more.items()}
  120. )
  121. _sum_after_round = sum(per_pair_details[pair_idx]["alloc_by_path"].values())
  122. print(f"[CHECK:round] pair={pair_idx} add={int(cost)} sum_alloc_by_path={_sum_after_round}")
  123. else:
  124. correctness, cost, best_path_fid, new_state, insufficient = lonline_continue(
  125. network, int(remaining), state=state, return_details=False
  126. )
  127. print(f"[GREEDY<-] pair={pair_idx} cost={int(cost)} insufficient={bool(insufficient)} "
  128. f"best_path_fid={best_path_fid} "
  129. f"s'={new_state.get('s') if new_state else None} "
  130. f"k'={len(new_state.get('candidate_set', [])) if new_state else None} "
  131. f"consumed_total→{consumed_total + int(cost)}")
  132. consumed_total += int(cost)
  133. print(f"[GREEDY] post-accum pair={pair_idx} consumed_total={consumed_total}")
  134. state = new_state
  135. states[pair_idx] = new_state
  136. prev_correctness, prev_cost, prev_best = per_pair_results[pair_idx]
  137. per_pair_results[pair_idx] = (
  138. bool(prev_correctness and correctness),
  139. int(prev_cost) + int(cost),
  140. best_path_fid,
  141. )
  142. if bool(insufficient):
  143. print(f"[GREEDY] break(insufficient) pair={pair_idx}")
  144. break
  145. cand = list(new_state.get("candidate_set", []))
  146. if len(cand) <= 1:
  147. print(f"[GREEDY] converged pair={pair_idx} "
  148. f"s={new_state.get('s')} k={len(cand)} consumed_total={consumed_total}")
  149. break
  150. if consumed_total >= C_total:
  151. break
  152. print(f"[CHECK:return] consumed_total={consumed_total}")
  153. return (per_pair_results, int(consumed_total), per_pair_details) if return_details \
  154. else (per_pair_results, int(consumed_total))