lnaive_scheduler.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # schedulers/lnaive_scheduler.py
  2. from .lnaive_nb import naive_network_benchmarking_with_budget
  3. def lnaive_budget_scheduler(
  4. node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
  5. importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ(ここでは未使用)
  6. bounces, # 例: [1,2,3,4](重複なし)
  7. C_total, # 総予算(切り捨て配分、超過しない)
  8. network_generator, # callable: (path_num, pair_idx) -> network
  9. return_details=False,
  10. ):
  11. num_pairs = len(node_path_list)
  12. assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
  13. if num_pairs == 0:
  14. return ([], 0, []) if return_details else ([], 0)
  15. assert len(bounces) == len(set(bounces)), "bounces must be unique"
  16. assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
  17. # 均等配分:1ペアあたりの割当
  18. C_per_pair = int(C_total // max(num_pairs, 1))
  19. per_pair_results = []
  20. per_pair_details = []
  21. total_cost = 0
  22. for pair_idx, path_num in enumerate(node_path_list):
  23. if path_num <= 0:
  24. per_pair_results.append((False, 0, None))
  25. if return_details:
  26. per_pair_details.append({"alloc_by_path": {}, "est_fid_by_path": {}})
  27. continue
  28. network = network_generator(path_num, pair_idx)
  29. path_list = list(range(1, path_num + 1))
  30. if return_details:
  31. correctness, cost, best_path_fidelity, alloc_by_path, est_fid_by_path = \
  32. naive_network_benchmarking_with_budget(
  33. network, path_list, list(bounces), C_per_pair, return_details=True
  34. )
  35. per_pair_details.append({
  36. "alloc_by_path": {int(k): int(v) for k, v in alloc_by_path.items()},
  37. "est_fid_by_path": {int(k): float(v) for k, v in est_fid_by_path.items()},
  38. })
  39. else:
  40. correctness, cost, best_path_fidelity = naive_network_benchmarking_with_budget(
  41. network, path_list, list(bounces), C_per_pair
  42. )
  43. per_pair_results.append((bool(correctness), int(cost), best_path_fidelity))
  44. total_cost += int(cost)
  45. return (per_pair_results, total_cost, per_pair_details) if return_details \
  46. else (per_pair_results, total_cost)