| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- # schedulers/greedy_scheduler.py
- from .lonline_nb import lonline_network_benchmarking
- def greedy_budget_scheduler(
- node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
- importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ
- bounces, # 例: [1,2,3,4](重複なし)
- C_total, # 総予算
- network_generator, # callable: (path_num, pair_idx) -> network
- C_initial_per_pair=40, # 各ペアの初期プローブ予算
- return_details=False,
- ):
- num_pairs = len(node_path_list)
- assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
- if num_pairs == 0:
- return ([], 0, []) if return_details else ([], 0)
- assert len(bounces) == len(set(bounces)), "bounces must be unique"
- assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
- # --- Step 1: 各ペアを軽くプローブ(初期推定) ---
- initial_est_fids = [0.0] * num_pairs
- initial_costs = [0] * num_pairs
- per_pair_results = [(False, 0, None)] * num_pairs
- per_pair_details = [{"alloc_by_path": {}, "est_fid_by_path": {}} for _ in range(num_pairs)]
- consumed_total = 0
- for pair_idx, path_num in enumerate(node_path_list):
- if consumed_total >= C_total or path_num <= 0:
- continue
- C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0))
- if C_probe <= 0:
- break
- network = network_generator(path_num, pair_idx)
- path_list = list(range(1, path_num + 1))
- if return_details:
- correctness, cost, best_path_fid, alloc0, est0 = lonline_network_benchmarking(
- network, path_list, list(bounces), int(C_probe), return_details=True
- )
- # 詳細をマージ(配分は加算・推定は後勝ち)
- for l, b in alloc0.items():
- per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
- per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
- per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est0.items()})
- else:
- correctness, cost, best_path_fid = lonline_network_benchmarking(
- network, path_list, list(bounces), int(C_probe)
- )
- consumed_total += int(cost)
- initial_costs[pair_idx] = int(cost)
- initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0
- per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid)
- remaining = max(int(C_total) - int(consumed_total), 0)
- # --- Step 2: importance * estimated_fidelity で優先度付け ---
- scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)]
- scores.sort(key=lambda x: x[1], reverse=True)
- # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) ---
- for pair_idx, _score in scores:
- if remaining <= 0:
- break
- path_num = node_path_list[pair_idx]
- if path_num <= 0:
- continue
- network = network_generator(path_num, pair_idx)
- path_list = list(range(1, path_num + 1))
- if return_details:
- correctness, cost, best_path_fid, alloc1, est1 = lonline_network_benchmarking(
- network, path_list, list(bounces), int(remaining), return_details=True
- )
- for l, b in alloc1.items():
- per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \
- per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b)
- per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est1.items()})
- else:
- correctness, cost, best_path_fid = lonline_network_benchmarking(
- network, path_list, list(bounces), int(remaining)
- )
- per_pair_results[pair_idx] = (
- bool(correctness),
- int(initial_costs[pair_idx] + int(cost)),
- best_path_fid,
- )
- remaining -= int(cost)
- consumed_total += int(cost)
- return (per_pair_results, int(consumed_total), per_pair_details) if return_details \
- else (per_pair_results, int(consumed_total))
|