greedy_scheduler.py~ 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # schedulers/greedy_scheduler.py
  2. from .lonline_nb import lonline_network_benchmarking
  3. def greedy_budget_scheduler(
  4. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  5. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ
  6. bounces, # 例: [1,2,3,4](重複なし)
  7. C_total, # 総予算
  8. network_generator, # callable: (path_num, pair_idx) -> network
  9. C_initial_per_pair=40, # 各ペアの初期プローブ予算
  10. ):
  11. """
  12. Greedy スケジューラ(最小構成)
  13. 手順:
  14. 1) 各ペアに小さな初期予算 C_initial_per_pair を配って lonline を一度だけ実行し、
  15. おおまかなベスト経路忠実度を得る。
  16. 2) importance * estimated_fidelity のスコアでペアを降順ソート。
  17. 3) 残余予算をスコア上位のペアにまとめて与え、lonline をもう一度実行。
  18. 予算が尽きるまで繰り返す(この実装では「まとめて全部」与える)。
  19. 返り値:
  20. per_pair_results: List[ (correct: bool, cost: int, best_path_fidelity: float|None) ] (ペア順)
  21. total_cost: int
  22. """
  23. num_pairs = len(node_path_list)
  24. assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
  25. if num_pairs == 0:
  26. return [], 0
  27. # bounces は重複なし・正の整数を仮定
  28. assert len(bounces) == len(set(bounces)), "bounces must be unique"
  29. assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
  30. # --- Step 1: 各ペアを軽くプローブ(初期推定) ---
  31. initial_est_fids = [0.0] * num_pairs
  32. initial_costs = [0] * num_pairs
  33. per_pair_results = [(False, 0, None)] * num_pairs # プレースホルダ
  34. consumed_total = 0
  35. for pair_idx, path_num in enumerate(node_path_list):
  36. if consumed_total >= C_total or path_num <= 0:
  37. continue
  38. C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0))
  39. if C_probe <= 0:
  40. break
  41. network = network_generator(path_num, pair_idx)
  42. path_list = list(range(1, path_num + 1))
  43. correctness, cost, best_path_fid = lonline_network_benchmarking(
  44. network, path_list, list(bounces), int(C_probe)
  45. )
  46. consumed_total += int(cost)
  47. initial_costs[pair_idx] = int(cost)
  48. initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0
  49. per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid)
  50. remaining = max(int(C_total) - int(consumed_total), 0)
  51. # --- Step 2: importance * estimated_fidelity で優先度付け ---
  52. scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)]
  53. scores.sort(key=lambda x: x[1], reverse=True)
  54. # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) ---
  55. for pair_idx, _score in scores:
  56. if remaining <= 0:
  57. break
  58. path_num = node_path_list[pair_idx]
  59. if path_num <= 0:
  60. continue
  61. network = network_generator(path_num, pair_idx)
  62. path_list = list(range(1, path_num + 1))
  63. # 残りを丸ごと与える(シンプル設計)
  64. correctness, cost, best_path_fid = lonline_network_benchmarking(
  65. network, path_list, list(bounces), int(remaining)
  66. )
  67. per_pair_results[pair_idx] = (
  68. bool(correctness),
  69. int(initial_costs[pair_idx] + int(cost)),
  70. best_path_fid,
  71. )
  72. remaining -= int(cost)
  73. consumed_total += int(cost)
  74. return per_pair_results, int(consumed_total)