# schedulers/greedy_scheduler.py from .lonline_nb import lonline_network_benchmarking def greedy_budget_scheduler( node_path_list, # 例: [2, 2, 2] … 各ペアのパス本数 importance_list, # 例: [0.3, 0.5, 0.7] … 長さは node_path_list と同じ bounces, # 例: [1,2,3,4](重複なし) C_total, # 総予算 network_generator, # callable: (path_num, pair_idx) -> network C_initial_per_pair=40, # 各ペアの初期プローブ予算 return_details=False, ): num_pairs = len(node_path_list) assert num_pairs == len(importance_list), "length mismatch: node_path_list vs importance_list" if num_pairs == 0: return ([], 0, []) if return_details else ([], 0) assert len(bounces) == len(set(bounces)), "bounces must be unique" assert all(isinstance(w, int) and w > 0 for w in bounces), "bounces must be positive ints" # --- Step 1: 各ペアを軽くプローブ(初期推定) --- initial_est_fids = [0.0] * num_pairs initial_costs = [0] * num_pairs per_pair_results = [(False, 0, None)] * num_pairs per_pair_details = [{"alloc_by_path": {}, "est_fid_by_path": {}} for _ in range(num_pairs)] consumed_total = 0 for pair_idx, path_num in enumerate(node_path_list): if consumed_total >= C_total or path_num <= 0: continue C_probe = min(int(C_initial_per_pair), max(int(C_total) - int(consumed_total), 0)) if C_probe <= 0: break network = network_generator(path_num, pair_idx) path_list = list(range(1, path_num + 1)) if return_details: correctness, cost, best_path_fid, alloc0, est0 = lonline_network_benchmarking( network, path_list, list(bounces), int(C_probe), return_details=True ) # 詳細をマージ(配分は加算・推定は後勝ち) for l, b in alloc0.items(): per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \ per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b) per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est0.items()}) else: correctness, cost, best_path_fid = lonline_network_benchmarking( network, path_list, list(bounces), int(C_probe) ) consumed_total += int(cost) initial_costs[pair_idx] = int(cost) initial_est_fids[pair_idx] = float(best_path_fid) if best_path_fid is not None else 0.0 per_pair_results[pair_idx] = (bool(correctness), int(cost), best_path_fid) remaining = max(int(C_total) - int(consumed_total), 0) # --- Step 2: importance * estimated_fidelity で優先度付け --- scores = [(idx, importance_list[idx] * initial_est_fids[idx]) for idx in range(num_pairs)] scores.sort(key=lambda x: x[1], reverse=True) # --- Step 3: 残余予算を Greedy に配分(上位にまとめて) --- for pair_idx, _score in scores: if remaining <= 0: break path_num = node_path_list[pair_idx] if path_num <= 0: continue network = network_generator(path_num, pair_idx) path_list = list(range(1, path_num + 1)) if return_details: correctness, cost, best_path_fid, alloc1, est1 = lonline_network_benchmarking( network, path_list, list(bounces), int(remaining), return_details=True ) for l, b in alloc1.items(): per_pair_details[pair_idx]["alloc_by_path"][int(l)] = \ per_pair_details[pair_idx]["alloc_by_path"].get(int(l), 0) + int(b) per_pair_details[pair_idx]["est_fid_by_path"].update({int(k): float(v) for k, v in est1.items()}) else: correctness, cost, best_path_fid = lonline_network_benchmarking( network, path_list, list(bounces), int(remaining) ) per_pair_results[pair_idx] = ( bool(correctness), int(initial_costs[pair_idx] + int(cost)), best_path_fid, ) remaining -= int(cost) consumed_total += int(cost) return (per_pair_results, int(consumed_total), per_pair_details) if return_details \ else (per_pair_results, int(consumed_total))