plots.py~ 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # viz/plots.py — Python 3.8 compatible plotting utilities
  2. from __future__ import annotations
  3. import math
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from cycler import cycler
  7. # ----- Unified style (kept lightweight) -----
  8. default_cycler = (
  9. cycler(color=["C0", "C1", "C2", "C3", "C4", "C5"])
  10. + cycler(marker=["s", "D", "^", "v", "o", "x"])
  11. + cycler(linestyle=[":", "--", "-", "-.", "--", ":"])
  12. )
  13. plt.rc("axes", prop_cycle=default_cycler)
  14. # ----- 95% CI (t critical) -----
  15. def tcrit_95(n: int) -> float:
  16. """Return ~95% t critical value for sample size n (simple, conservative)."""
  17. if n <= 1:
  18. return float("inf")
  19. if n < 30:
  20. # Conservative constant (close to df=9..29 range)
  21. return 2.262
  22. return 1.96
  23. def mean_ci95(vals):
  24. """Return (mean, half_width) for 95% CI."""
  25. arr = np.array(list(vals), dtype=float)
  26. n = len(arr)
  27. if n == 0:
  28. return 0.0, 0.0
  29. if n == 1:
  30. return float(arr[0]), 0.0
  31. m = float(arr.mean())
  32. s = float(arr.std(ddof=1))
  33. half = tcrit_95(n) * (s / math.sqrt(n))
  34. return m, half
  35. def plot_with_ci_band(ax, xs, mean, half, *, label, line_kwargs=None, band_kwargs=None):
  36. """Plot mean line and shaded CI band (Python 3.8 safe)."""
  37. line_kwargs = {} if line_kwargs is None else dict(line_kwargs)
  38. band = {"alpha": 0.25}
  39. if band_kwargs is not None:
  40. band.update(dict(band_kwargs))
  41. line, = ax.plot(xs, mean, label=label, **line_kwargs)
  42. ax.fill_between(xs, mean - half, mean + half, **band)
  43. return line