# schedulers/greedy_scheduler.py from .lonline_nb import lonline_network_benchmarking def greedy_budget_scheduler( node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数 importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ bounces, # 例: [1,2,3,4](重複なし) C_total, # 総予算 network_generator, # callable: (path_num, pair_idx) -> network C_initial_per_pair=40, # 各ペアの初期プローブ予算 ): """ Greedy スケジューラ(最小構成) 手順: 1) 各ペアに小さな初期予算 C_initial_per_pair を配って lonline を一度だけ実行し、 おおまかなベスト経路忠実度を得る。 2) importance * estimated_fidelity のスコアでペアを降順ソート。 3) 残余予算をスコア上位のペアにまとめて与え、lonline をもう一度実行。 予算が尽きるまで繰り返す(この実装では「まとめて全部」与える)。 返り値: per_pair_results: List[ (correct: bool, cost: int, best_path_fidelity: float|None) ] (ペア順) total_cost: int """ num_pairs = len(node_path_list) assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list" if num_pairs == 0: return [], 0 # bounces は重複なし・正の整数を仮定 assert len(bounces) == len(set(bounces)), "bounces must be unique" assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints" # --- Step 1: 各ペアを軽くプローブ(初期推定) --- initial_est_fids = [0.0] * num_pairs initial_costs = [0] * num_pairs per_pair_results = [(False, 0, None)] * num_pairs # プレースホルダ consumed_total = 0 for pair_idx, path_num in enumerate(node_path_list): if consumed_total >= C_total or path_num <= 0: continue C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0)) if C_probe <= 0: break network = network_generator(path_num, pair_idx) path_list = list(range(1, path_num + 1)) correctness, cost, best_path_fid = lonline_network_benchmarking( network, path_list, list(bounces), int(C_probe) ) consumed_total += int(cost) initial_costs[pair_idx] = int(cost) initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0 per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid) remaining = max(int(C_total) - int(consumed_total), 0) # --- Step 2: importance * estimated_fidelity で優先度付け --- scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)] scores.sort(key=lambda x: x[1], reverse=True) # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) --- for pair_idx, _score in scores: if remaining <= 0: break path_num = node_path_list[pair_idx] if path_num <= 0: continue network = network_generator(path_num, pair_idx) path_list = list(range(1, path_num + 1)) # 残りを丸ごと与える(シンプル設計) correctness, cost, best_path_fid = lonline_network_benchmarking( network, path_list, list(bounces), int(remaining) ) per_pair_results[pair_idx] = ( bool(correctness), int(initial_costs[pair_idx] + int(cost)), best_path_fid, ) remaining -= int(cost) consumed_total += int(cost) return per_pair_results, int(consumed_total)