greedy_scheduler.py 4.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # schedulers/greedy_scheduler.py
  2. from .lonline_nb import lonline_network_benchmarking
  3. def greedy_budget_scheduler(
  4. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  5. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ
  6. bounces, # 例: [1,2,3,4](重複なし)
  7. C_total, # 総予算
  8. network_generator, # callable: (path_num, pair_idx) -> network
  9. C_initial_per_pair=40, # 各ペアの初期プローブ予算
  10. return_details=False,
  11. ):
  12. num_pairs = len(node_path_list)
  13. assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
  14. if num_pairs == 0:
  15. return ([], 0, []) if return_details else ([], 0)
  16. assert len(bounces) == len(set(bounces)), "bounces must be unique"
  17. assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
  18. # --- Step 1: 各ペアを軽くプローブ(初期推定) ---
  19. initial_est_fids = [0.0] * num_pairs
  20. initial_costs = [0] * num_pairs
  21. per_pair_results = [(False, 0, None)] * num_pairs
  22. per_pair_details = [{"alloc_by_path": {}, "est_fid_by_path": {}} for _ in range(num_pairs)]
  23. consumed_total = 0
  24. for pair_idx, path_num in enumerate(node_path_list):
  25. if consumed_total >= C_total or path_num <= 0:
  26. continue
  27. C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0))
  28. if C_probe <= 0:
  29. break
  30. network = network_generator(path_num, pair_idx)
  31. path_list = list(range(1, path_num + 1))
  32. if return_details:
  33. correctness, cost, best_path_fid, alloc0, est0 = lonline_network_benchmarking(
  34. network, path_list, list(bounces), int(C_probe), return_details=True
  35. )
  36. # 詳細をマージ(配分は加算・推定は後勝ち)
  37. for l, b in alloc0.items():
  38. per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
  39. per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
  40. per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est0.items()})
  41. else:
  42. correctness, cost, best_path_fid = lonline_network_benchmarking(
  43. network, path_list, list(bounces), int(C_probe)
  44. )
  45. consumed_total += int(cost)
  46. initial_costs[pair_idx] = int(cost)
  47. initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0
  48. per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid)
  49. remaining = max(int(C_total) - int(consumed_total), 0)
  50. # --- Step 2: importance * estimated_fidelity で優先度付け ---
  51. scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)]
  52. scores.sort(key=lambda x: x[1], reverse=True)
  53. # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) ---
  54. for pair_idx, _score in scores:
  55. if remaining <= 0:
  56. break
  57. path_num = node_path_list[pair_idx]
  58. if path_num <= 0:
  59. continue
  60. network = network_generator(path_num, pair_idx)
  61. path_list = list(range(1, path_num + 1))
  62. if return_details:
  63. correctness, cost, best_path_fid, alloc1, est1 = lonline_network_benchmarking(
  64. network, path_list, list(bounces), int(remaining), return_details=True
  65. )
  66. for l, b in alloc1.items():
  67. per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
  68. per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
  69. per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est1.items()})
  70. else:
  71. correctness, cost, best_path_fid = lonline_network_benchmarking(
  72. network, path_list, list(bounces), int(remaining)
  73. )
  74. per_pair_results[pair_idx] = (
  75. bool(correctness),
  76. int(initial_costs[pair_idx] + int(cost)),
  77. best_path_fid,
  78. )
  79. remaining -= int(cost)
  80. consumed_total += int(cost)
  81. return (per_pair_results, int(consumed_total), per_pair_details) if return_details \
  82. else (per_pair_results, int(consumed_total))