groups_scheduler.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # schedulers/groups_scheduler.py
  2. from .groups_nb import groups_network_benchmarking_with_budget
  3. # 追加:重要度→グループ化
  4. def _group_paths_by_importance(path_num: int, importance: float):
  5. """
  6. 重要度に応じて実リンクID(1..L)をグループ化した二次元配列を返す。
  7. 0.75 < I <= 1.00 : まとめない([ [1], [2], ..., [L] ])
  8. 0.50 < I <= 0.75 : 2本ずつ(余りは小グループでOK)
  9. 0.25 < I <= 0.50 : 3本ずつ(余りは小グループでOK)
  10. 0.00 < I <= 0.25 : 全部まとめる([ [1,2,...,L] ])
  11. """
  12. L = int(path_num)
  13. ids = list(range(1, L + 1))
  14. I = float(importance)
  15. if I > 0.75:
  16. return [[i] for i in ids]
  17. if I > 0.50:
  18. size = 2
  19. elif I > 0.25:
  20. size = 3
  21. else:
  22. return [ids] # 全まとめ
  23. groups = []
  24. for i in range(0, L, size):
  25. groups.append(ids[i:i+size])
  26. return groups
  27. def groups_budget_scheduler(
  28. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  29. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ(ここでは未使用)
  30. bounces, # 例: [1,2,3,4](重複なし)
  31. C_total, # 総予算(切り捨て配分、超過しない)
  32. network_generator, # callable: (path_num, pair_idx) -> network
  33. return_details=False,
  34. ):
  35. num_pairs = len(node_path_list)
  36. assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
  37. if num_pairs == 0:
  38. return ([], 0, []) if return_details else ([], 0)
  39. assert len(bounces) == len(set(bounces)), "bounces must be unique"
  40. assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
  41. # 均等配分:1ペアあたりの割当
  42. C_per_pair = int(C_total // max(num_pairs, 1))
  43. per_pair_results = []
  44. per_pair_details = []
  45. total_cost = 0
  46. for pair_idx, path_num in enumerate(node_path_list):
  47. if path_num <= 0:
  48. per_pair_results.append((False, 0, None))
  49. if return_details:
  50. per_pair_details.append({"alloc_by_path": {}, "est_fid_by_path": {}})
  51. continue
  52. network = network_generator(path_num, pair_idx)
  53. # 追加:重要度に応じたグループ
  54. groups = _group_paths_by_importance(path_num, importance_list[pair_idx])
  55. path_list = list(range(1, path_num + 1))
  56. if return_details:
  57. correctness, cost, best_path_fidelity, alloc_by_path, est_fid_by_path = \
  58. groups_network_benchmarking_with_budget(
  59. network, path_list, list(bounces), C_per_pair, groups=groups, return_details=True
  60. )
  61. per_pair_details.append({
  62. "alloc_by_path": {int(k): int(v) for k, v in alloc_by_path.items()},
  63. "est_fid_by_path": {int(k): float(v) for k, v in est_fid_by_path.items()},
  64. })
  65. else:
  66. correctness, cost, best_path_fidelity = groups_network_benchmarking_with_budget(
  67. network, path_list, list(bounces), C_per_pair, groups=groups,
  68. )
  69. per_pair_results.append((bool(correctness), int(cost), best_path_fidelity))
  70. total_cost += int(cost)
  71. return (per_pair_results, total_cost, per_pair_details) if return_details \
  72. else (per_pair_results, total_cost)