w_naive_scheduler.py~ 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # schedulers/lnaive_scheduler.py
  2. from .lnaive_nb import naive_network_benchmarking_with_budget
  3. # 追加:重要度に比例して C_total を N ペアへ割り当てる
  4. def _allocate_budget_by_importance(weights, C_total: int):
  5. # クリップ&型
  6. w = [max(0.0, float(x)) for x in weights]
  7. W = sum(w)
  8. if C_total <= 0:
  9. return [0] * len(w)
  10. if W <= 0.0:
  11. # 全て0なら均等割
  12. base = C_total // max(1, len(w))
  13. rem = C_total - base * len(w)
  14. alloc = [base] * len(w)
  15. for i in range(rem):
  16. alloc[i] += 1
  17. return alloc
  18. # 連続値の割当 → 切り捨て → 余りを小数部の大きい順で配分
  19. quotas = [C_total * wi / W for wi in w]
  20. floors = [int(q) for q in quotas]
  21. rem = C_total - sum(floors)
  22. frac = [(q - f, idx) for idx, (q, f) in enumerate(zip(quotas, floors))]
  23. frac.sort(reverse=True) # 小数部の大きい順
  24. for k in range(rem):
  25. floors[frac[k][1]] += 1
  26. return floors
  27. def lnaive_budget_scheduler(
  28. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  29. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ(ここでは未使用)
  30. bounces, # 例: [1,2,3,4](重複なし)
  31. C_total, # 総予算(切り捨て配分、超過しない)
  32. network_generator, # callable: (path_num, pair_idx) -> network
  33. return_details=False,
  34. ):
  35. num_pairs = len(node_path_list)
  36. assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
  37. if num_pairs == 0:
  38. return ([], 0, []) if return_details else ([], 0)
  39. assert len(bounces) == len(set(bounces)), "bounces must be unique"
  40. assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
  41. # 均等配分:1ペアあたりの割当
  42. C_per_pair_list = _allocate_budget_by_importance(importance_list, int(C_total))
  43. per_pair_results = []
  44. per_pair_details = []
  45. total_cost = 0
  46. for pair_idx, path_num in enumerate(node_path_list):
  47. if path_num <= 0:
  48. per_pair_results.append((False, 0, None))
  49. if return_details:
  50. per_pair_details.append({"alloc_by_path": {}, "est_fid_by_path": {}})
  51. continue
  52. network = network_generator(path_num, pair_idx)
  53. path_list = list(range(1, path_num + 1))
  54. if return_details:
  55. correctness, cost, best_path_fidelity, alloc_by_path, est_fid_by_path = \
  56. naive_network_benchmarking_with_budget(
  57. network, path_list, list(bounces), C_pair, return_details=True
  58. )
  59. per_pair_details.append({
  60. "alloc_by_path": {int(k): int(v) for k, v in alloc_by_path.items()},
  61. "est_fid_by_path": {int(k): float(v) for k, v in est_fid_by_path.items()},
  62. })
  63. else:
  64. correctness, cost, best_path_fidelity = naive_network_benchmarking_with_budget(
  65. network, path_list, list(bounces), C_pair
  66. )
  67. per_pair_results.append((bool(correctness), int(cost), best_path_fidelity))
  68. total_cost += int(cost)
  69. return (per_pair_results, total_cost, per_pair_details) if return_details \
  70. else (per_pair_results, total_cost)