main.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. main.py — evaluation.py の各種プロットを一括実行
  5. """
  6. from multiprocessing.pool import Pool
  7. import os
  8. import random
  9. try:
  10. from utils import set_random_seed
  11. except Exception:
  12. def set_random_seed(seed: int = 12):
  13. random.seed(seed)
  14. try:
  15. import numpy as np
  16. np.random.seed(seed)
  17. except Exception:
  18. pass
  19. from evaluation import (
  20. plot_accuracy_vs_budget,
  21. plot_value_vs_used,
  22. plot_value_vs_budget_target,
  23. plot_widthsum_alllinks_vs_budget,
  24. plot_minwidthsum_perpair_vs_budget,
  25. plot_widthsum_alllinks_weighted_vs_budget,
  26. plot_minwidthsum_perpair_weighted_vs_budget,
  27. )
  28. def main():
  29. set_random_seed(12)
  30. num_workers = max(1, (os.cpu_count() or 4) // 2)
  31. noise_model_list = ["Depolar"]
  32. scheduler_names = ["LNaive", "Greedy"]
  33. node_path_list = [5, 5, 5]
  34. importance_list = [0.3, 0.6, 0.9]
  35. budget_list = [1000,2000,3000,4000,5000,6000,7000,8000,9000,10000]
  36. bounces = (1, 2, 3, 4)
  37. repeat = 10
  38. delta = 0.1
  39. print("=== Config ===")
  40. print(f"workers={num_workers}, noise_models={noise_model_list}")
  41. print(f"schedulers={scheduler_names}")
  42. print(f"node_path_list={node_path_list}, importance_list={importance_list}")
  43. print(f"budgets={budget_list}, bounces={bounces}, repeat={repeat}, delta={delta}")
  44. print("================\n")
  45. p = Pool(processes=num_workers)
  46. jobs = []
  47. for noise_model in noise_model_list:
  48. jobs.append(p.apply_async(
  49. plot_accuracy_vs_budget,
  50. args=(budget_list, scheduler_names, noise_model,
  51. node_path_list, importance_list, bounces, repeat),
  52. kwds={"verbose": True}
  53. ))
  54. jobs.append(p.apply_async(
  55. plot_value_vs_used,
  56. args=(budget_list, scheduler_names, noise_model,
  57. node_path_list, importance_list, bounces, repeat),
  58. kwds={"verbose": True}
  59. ))
  60. jobs.append(p.apply_async(
  61. plot_value_vs_budget_target,
  62. args=(budget_list, scheduler_names, noise_model,
  63. node_path_list, importance_list, bounces, repeat),
  64. kwds={"verbose": True}
  65. ))
  66. jobs.append(p.apply_async(
  67. plot_widthsum_alllinks_vs_budget,
  68. args=(budget_list, scheduler_names, noise_model,
  69. node_path_list, importance_list, bounces, repeat),
  70. kwds={"delta": delta, "verbose": True}
  71. ))
  72. jobs.append(p.apply_async(
  73. plot_minwidthsum_perpair_vs_budget,
  74. args=(budget_list, scheduler_names, noise_model,
  75. node_path_list, importance_list, bounces, repeat),
  76. kwds={"delta": delta, "verbose": True}
  77. ))
  78. jobs.append(p.apply_async(
  79. plot_widthsum_alllinks_weighted_vs_budget,
  80. args=(budget_list, scheduler_names, noise_model,
  81. node_path_list, importance_list, bounces, repeat),
  82. kwds={"delta": delta, "verbose": True}
  83. ))
  84. jobs.append(p.apply_async(
  85. plot_minwidthsum_perpair_weighted_vs_budget,
  86. args=(budget_list, scheduler_names, noise_model,
  87. node_path_list, importance_list, bounces, repeat),
  88. kwds={"delta": delta, "verbose": True}
  89. ))
  90. p.close(); p.join()
  91. for j in jobs: j.get()
  92. print("\nAll jobs finished.")
  93. print("Pickles -> ./outputs/, PDF -> カレントディレクトリ に保存されます。")
  94. if __name__ == "__main__":
  95. main()