plots.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # viz/plots.py — Python 3.8 & older Matplotlib compatible
  2. from __future__ import annotations
  3. import math
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from cycler import cycler
  7. # ---- Safe color cycle (hex codes; no 'C0' refs) ----
  8. COLORS = [
  9. "#1f77b4", # blue
  10. "#ff7f0e", # orange
  11. "#2ca02c", # green
  12. "#d62728", # red
  13. "#9467bd", # purple
  14. "#8c564b", # brown
  15. ]
  16. plt.rc("axes", prop_cycle=cycler("color", COLORS))
  17. def tcrit_95(n: int) -> float:
  18. if n <= 1:
  19. return float("inf")
  20. if n < 30:
  21. return 2.262
  22. return 1.96
  23. def mean_ci95(vals):
  24. arr = np.array(list(vals), dtype=float)
  25. n = len(arr)
  26. if n == 0:
  27. return 0.0, 0.0
  28. if n == 1:
  29. return float(arr[0]), 0.0
  30. m = float(arr.mean())
  31. s = float(arr.std(ddof=1))
  32. half = tcrit_95(n) * (s / math.sqrt(n))
  33. return m, half
  34. def plot_with_ci_band(ax, xs, mean, half, *, label, line_kwargs=None, band_kwargs=None):
  35. """Plot mean line and shaded CI band (Py3.8 safe, old-mpl safe)."""
  36. line_kwargs = {} if line_kwargs is None else dict(line_kwargs)
  37. band = {"alpha": 0.25}
  38. if band_kwargs is not None:
  39. band.update(dict(band_kwargs))
  40. line, = ax.plot(xs, mean, label=label, **line_kwargs)
  41. ax.fill_between(xs, mean - half, mean + half, **band)
  42. return line