| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # schedulers/greedy_scheduler.py
- from .lonline_nb import lonline_network_benchmarking
- def greedy_budget_scheduler(
- node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数
- importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ
- bounces, # 例: [1,2,3,4](重複なし)
- C_total, # 総予算
- network_generator, # callable: (path_num, pair_idx) -> network
- C_initial_per_pair=40, # 各ペアの初期プローブ予算
- ):
- """
- Greedy スケジューラ(最小構成)
- 手順:
- 1) 各ペアに小さな初期予算 C_initial_per_pair を配って lonline を一度だけ実行し、
- おおまかなベスト経路忠実度を得る。
- 2) importance * estimated_fidelity のスコアでペアを降順ソート。
- 3) 残余予算をスコア上位のペアにまとめて与え、lonline をもう一度実行。
- 予算が尽きるまで繰り返す(この実装では「まとめて全部」与える)。
- 返り値:
- per_pair_results: List[ (correct: bool, cost: int, best_path_fidelity: float|None) ] (ペア順)
- total_cost: int
- """
- num_pairs = len(node_path_list)
- assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list"
- if num_pairs == 0:
- return [], 0
- # bounces は重複なし・正の整数を仮定
- assert len(bounces) == len(set(bounces)), "bounces must be unique"
- assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints"
- # --- Step 1: 各ペアを軽くプローブ(初期推定) ---
- initial_est_fids = [0.0] * num_pairs
- initial_costs = [0] * num_pairs
- per_pair_results = [(False, 0, None)] * num_pairs # プレースホルダ
- consumed_total = 0
- for pair_idx, path_num in enumerate(node_path_list):
- if consumed_total >= C_total or path_num <= 0:
- continue
- C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0))
- if C_probe <= 0:
- break
- network = network_generator(path_num, pair_idx)
- path_list = list(range(1, path_num + 1))
- correctness, cost, best_path_fid = lonline_network_benchmarking(
- network, path_list, list(bounces), int(C_probe)
- )
- consumed_total += int(cost)
- initial_costs[pair_idx] = int(cost)
- initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0
- per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid)
- remaining = max(int(C_total) - int(consumed_total), 0)
- # --- Step 2: importance * estimated_fidelity で優先度付け ---
- scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)]
- scores.sort(key=lambda x: x[1], reverse=True)
- # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) ---
- for pair_idx, _score in scores:
- if remaining <= 0:
- break
- path_num = node_path_list[pair_idx]
- if path_num <= 0:
- continue
- network = network_generator(path_num, pair_idx)
- path_list = list(range(1, path_num + 1))
- # 残りを丸ごと与える(シンプル設計)
- correctness, cost, best_path_fid = lonline_network_benchmarking(
- network, path_list, list(bounces), int(remaining)
- )
- per_pair_results[pair_idx] = (
- bool(correctness),
- int(initial_costs[pair_idx] + int(cost)),
- best_path_fid,
- )
- remaining -= int(cost)
- consumed_total += int(cost)
- return per_pair_results, int(consumed_total)
|