groups_nb.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # groups_nb.py
  2. import random # グループ内から測定リンクをランダム選択
  3. def groups_network_benchmarking_with_budget(network, path_list, bounces, C_budget, groups, return_details=True):
  4. fidelity = {}
  5. cost = 0
  6. # groups バリデーション(1-origin想定)
  7. ids_set = set(int(p) for p in path_list)
  8. flat = [int(x) for g in groups for x in g]
  9. if not flat:
  10. raise ValueError("groups must not be empty")
  11. if any((pid not in ids_set) for pid in flat):
  12. raise ValueError("groups contains invalid path id(s)")
  13. # 全被覆を強制したいなら以下を有効化(任意)
  14. # if set(flat) != ids_set:
  15. # raise ValueError("groups must cover all path ids exactly")
  16. per_sample_cost = sum(bounces) or 1
  17. # 等分配は「グループ本数」基準
  18. n_groups = max(1, len(groups))
  19. per_group_budget = int(C_budget) // n_groups
  20. Ns = per_group_budget // per_sample_cost
  21. if Ns <= 0:
  22. if return_details:
  23. return False, 0, None, {}, {}
  24. return False, 0, None
  25. # 追加: 詳細記録用
  26. alloc_by_path = {int(p): 0 for p in path_list}
  27. est_fid_by_path = {}
  28. # (変更後)
  29. # 各グループについて Ns 回まわし、毎回ランダムに 1 本を選んで測定
  30. sample_times_one = {h: 1 for h in bounces} # ★ 1回分のNBセット
  31. for grp in groups:
  32. f_sum = 0.0
  33. for _ in range(int(Ns)): # ★ Ns 回くり返す
  34. chosen = int(random.choice(grp))
  35. p, used = network.benchmark_path(chosen, bounces, sample_times_one)
  36. f = p + (1 - p) / 2.0 # 既存の忠実度変換式
  37. f_sum += f
  38. cost += int(used)
  39. alloc_by_path[chosen] = alloc_by_path.get(chosen, 0) + int(used)
  40. # ★ グループの推定値は Ns 回の平均を全リンクにコピー
  41. f_group = f_sum / float(Ns)
  42. for pid in grp:
  43. pid = int(pid)
  44. fidelity[pid] = f_group
  45. est_fid_by_path[pid] = float(f_group)
  46. if not fidelity:
  47. if return_details:
  48. return False, int(cost), None, alloc_by_path, est_fid_by_path
  49. return False, int(cost), None
  50. best_path = max(fidelity, key=fidelity.get)
  51. correctness = (best_path == getattr(network, "best_path", None))
  52. best_path_fidelity = fidelity[best_path]
  53. if return_details:
  54. return bool(correctness), int(cost), best_path_fidelity, alloc_by_path, est_fid_by_path
  55. return bool(correctness), int(cost), best_path_fidelity