Source code for chaotic_pfc.plotting.figures

"""
figures.py
==========
Publication-quality SVG figures with LaTeX-style labels.

Typography is configured via :func:`setup_rc`, which auto-detects
whether a full LaTeX installation is available on the system:

* If ``latex`` is found on ``$PATH``, real LaTeX rendering is used
  (``usetex=True``, Computer Modern Roman).
* Otherwise, mathtext with STIX fonts is used as a fallback that
  visually matches LaTeX output without any external dependency.

The detection can be overridden via the environment variable
``CHAOTIC_PFC_FORCE_LATEX`` (``0`` forces fallback, ``1`` forces
real LaTeX).  This is useful in CI environments or for debugging.

All text is converted to vector paths in SVG output so that figures
render identically on any system.
"""

from __future__ import annotations

import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray
from scipy.signal import freqz

from chaotic_pfc._i18n import t

# ── LaTeX availability (cached at import time) ──────────────────────────────

_force_var = os.environ.get("CHAOTIC_PFC_FORCE_LATEX")
if _force_var == "1":
    _LATEX_AVAILABLE = True
elif _force_var == "0":
    _LATEX_AVAILABLE = False
else:
    _LATEX_AVAILABLE = shutil.which("latex") is not None


[docs] def latex_available() -> bool: """Return ``True`` if LaTeX is available for figure rendering. The return value is determined at import time by checking ``$PATH`` for the ``latex`` executable. Set the environment variable ``CHAOTIC_PFC_FORCE_LATEX=0`` or ``1`` to override. """ return _LATEX_AVAILABLE
# ── Global RC params for LaTeX-like rendering ───────────────────────────────
[docs] def setup_rc(): """Configure matplotlib for publication-quality LaTeX-style SVG output. Auto-detects whether a full LaTeX installation is available: * If ``latex`` is found on ``$PATH``, uses real LaTeX rendering (``usetex=True``, Computer Modern Roman, ``amsmath`` + ``amssymb`` preamble). * Otherwise, falls back to mathtext with STIX fonts, which visually match LaTeX without external dependencies. The detection can be overridden via ``CHAOTIC_PFC_FORCE_LATEX=0`` or ``=1``. All text is converted to vector paths in SVG output. """ if _LATEX_AVAILABLE: plt.rcParams.update( { "text.usetex": True, "text.latex.preamble": r"\usepackage{amsmath}\usepackage{amssymb}", "font.family": "serif", "font.serif": ["Computer Modern Roman"], "svg.fonttype": "path", "axes.unicode_minus": False, "axes.formatter.use_mathtext": False, "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14, "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 11, "figure.dpi": 150, "savefig.dpi": 150, "savefig.bbox": "tight", "savefig.pad_inches": 0.05, "axes.linewidth": 1.2, "xtick.major.width": 1.0, "ytick.major.width": 1.0, "lines.linewidth": 1.5, } ) else: plt.rcParams.update( { "text.usetex": False, "mathtext.fontset": "stix", "font.family": "STIXGeneral", "svg.fonttype": "path", "axes.unicode_minus": False, "axes.formatter.use_mathtext": True, "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14, "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 11, "figure.dpi": 150, "savefig.dpi": 150, "savefig.bbox": "tight", "savefig.pad_inches": 0.05, "axes.linewidth": 1.2, "xtick.major.width": 1.0, "ytick.major.width": 1.0, "lines.linewidth": 1.5, } )
# ── Colour palette ────────────────────────────────────────────────────────── C = { "msg_t": "#0073BD", # blue (time-domain message) "msg_f": "#804000", # brown (freq-domain message) "sig_t": "#E00000", # red (time-domain signal) "sig_f": "#660066", # purple (freq-domain signal) "traj": "#000000", # black (attractor) "traj2": "#E00000", # red (second trajectory) } def _style(ax: Axes, ts: int = 12) -> None: """Apply uniform grid/tick/spine styling to ``ax``.""" ax.grid(True, alpha=0.25, linewidth=0.5) ax.tick_params(labelsize=ts, width=1.2, direction="in") for sp in ax.spines.values(): sp.set_linewidth(1.2) def _save(fig: Figure, path: str | Path | None, **savefig_kwargs) -> None: """Write ``fig`` to ``path`` if not ``None``, creating parent directories. Extra keyword arguments are forwarded to ``fig.savefig()``. """ if path is not None: Path(path).parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, **savefig_kwargs) # ── 1. Attractor ────────────────────────────────────────────────────────────
[docs] def plot_attractor( X: NDArray, Y: NDArray, title: str = "", xlabel: str = r"$x_1[n]$", ylabel: str = r"$x_2[n]$", save_path: str | None = None, ) -> Figure: """Plot a phase-space portrait of a 2-D trajectory. Parameters ---------- X, Y State-variable trajectories, same length. title Optional figure title. If empty, no title is drawn. xlabel, ylabel Axis labels. Defaults use LaTeX-style math for ``x_1`` and ``x_2``. save_path If given, the figure is written to this path. The extension selects the format (``.svg``, ``.png``, etc.). Returns ------- Figure The matplotlib ``Figure`` object. Returned so callers can inspect or further annotate it before showing / closing. """ fig, ax = plt.subplots(figsize=(6, 5)) ax.plot(X, Y, ",", color=C["traj"], alpha=0.8, markersize=0.3) if title: ax.set_title(title, fontsize=13) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) _style(ax) fig.tight_layout() _save(fig, save_path) return fig
# ── 2. Sensitivity (SDIC) ──────────────────────────────────────────────────
[docs] def plot_sensitivity( n: NDArray, X1: NDArray, X2: NDArray, save_path: str | None = None, lang: str = "pt", ) -> Figure: """Overlay two Hénon trajectories to illustrate sensitivity to ICs. Plots two state trajectories that start from infinitesimally different initial conditions, making visually obvious how they diverge exponentially — the classic demonstration of chaos. Parameters ---------- n Sample index axis, shape ``(N,)``. X1, X2 Two state trajectories evaluated on ``n``. Typically differing only by ``x0_2 = x0_1 + 1e-4``. save_path If given, the figure is written to this path. lang Language code for the figure title (``"pt"`` or ``"en"``). Returns ------- Figure The matplotlib ``Figure`` object. """ fig, ax = plt.subplots(figsize=(9, 4)) ax.plot(n, X1, label=r"$x^{(1)}[n],\; x_0=0$", color="steelblue", linewidth=1.2) ax.plot(n, X2, label=r"$x^{(2)}[n],\; x_0=10^{-4}$", color="tomato", linewidth=1.2, alpha=0.85) ax.set_xlabel(r"$n$", fontsize=12) ax.set_ylabel(r"$x[n]$", fontsize=12) ax.set_title( t("sensitivity.title", lang=lang), fontsize=13, ) ax.legend(fontsize=11, framealpha=0.9) _style(ax) fig.tight_layout() _save(fig, save_path) return fig
# ── 3. Communication 4×2 grid (time + PSD) ─────────────────────────────────
[docs] @dataclass class PlotGridOptions: """Optional styling parameters for :func:`plot_comm_grid`. All fields have sensible defaults; only override what you need. """ time_window: slice = field(default_factory=lambda: slice(0, 300)) suptitle: str = "" y_lim_msg: tuple[float, float] = (-1.5, 1.5) y_lim_sig: tuple[float, float] = (-2.5, 2.5) y_lim_mhat: tuple[float, float] | None = None h_channel: NDArray | None = None save_path: str | None = None
_UNSET = object()
[docs] def plot_comm_grid( n: NDArray, m: NDArray, s: NDArray, r: NDArray, m_hat: NDArray, omega: NDArray, psd_m: NDArray, psd_s: NDArray, psd_r: NDArray, psd_mhat: NDArray, *, opts: PlotGridOptions | None = None, time_window: slice = slice(0, 300), suptitle: str = "", y_lim_msg: tuple[float, float] = _UNSET, # type: ignore[assignment] y_lim_sig: tuple[float, float] = _UNSET, # type: ignore[assignment] y_lim_mhat: tuple[float, float] | None = _UNSET, # type: ignore[assignment] h_channel: NDArray | None = None, save_path: str | None = None, lang: str = "pt", ) -> Figure: """ 4×2 grid: left = time domain, right = PSD. Rows: m[n], s[n], r[n], m̂[n]. If h_channel is provided, its frequency response is overlaid on PSD_s. The *opts* dataclass overrides individual keyword arguments when both are provided. See :class:`PlotGridOptions` for grouped defaults. Parameters ---------- n Sample indices, shape ``(N,)``. m Original message, shape ``(N,)``. s Transmitted (modulated) signal, shape ``(N,)``. r Received signal after the channel, shape ``(N,)``. m_hat Recovered message estimate, shape ``(N,)``. omega Normalised frequency axis from :func:`psd_normalised`. psd_m, psd_s, psd_r, psd_mhat PSD of each signal (same length as *omega*). opts :class:`PlotGridOptions` dataclass with default axis limits, time window, and title. Individual keyword arguments override the corresponding *opts* fields. time_window Python ``slice`` selecting which sample range to plot in the time-domain panels. suptitle Figure-level title. Overrides ``opts.suptitle`` unless empty. y_lim_msg, y_lim_sig, y_lim_mhat Y-axis limits for the time-domain panels. *y_lim_mhat* defaults to *y_lim_msg* unless explicitly set. h_channel FIR coefficients for overlaying the channel frequency response on PSD_s. save_path If provided, the figure is saved to this path via ``_save()``. lang Language for axis labels (``"pt"`` or ``"en"``). Returns ------- matplotlib.figure.Figure """ if opts is not None: if time_window is None or time_window == slice(0, 300): time_window = opts.time_window if not suptitle: suptitle = opts.suptitle if y_lim_msg is _UNSET: y_lim_msg = opts.y_lim_msg if y_lim_sig is _UNSET: y_lim_sig = opts.y_lim_sig if y_lim_mhat is _UNSET: y_lim_mhat = opts.y_lim_mhat if h_channel is None: h_channel = opts.h_channel if save_path is None: save_path = opts.save_path if y_lim_msg is _UNSET: y_lim_msg = (-1.5, 1.5) if y_lim_sig is _UNSET: y_lim_sig = (-2.5, 2.5) if y_lim_mhat is _UNSET or y_lim_mhat is None: y_lim_mhat = y_lim_msg fig, axes = plt.subplots(4, 2, figsize=(14, 12)) if suptitle: fig.suptitle(suptitle, fontsize=14, y=0.995) nn = n[time_window] # Row configs: (signal, ylabel_t, ylabel_f, color_t, color_f, dots?, ylim_t) rows = [ ( m, r"$(a)\; m[n]$", r"$(e)\; \mathcal{M}(\omega)$", C["msg_t"], C["msg_f"], True, y_lim_msg, ), (s, r"$(b)\; s[n]$", r"$(f)\; S(\omega)$", C["sig_t"], C["sig_f"], False, y_lim_sig), (r, r"$(c)\; r[n]$", r"$(g)\; R(\omega)$", C["sig_t"], C["sig_f"], False, y_lim_sig), ( m_hat, r"$(d)\; \hat{m}[n]$", r"$(h)\; \hat{\mathcal{M}}(\omega)$", C["msg_t"], C["msg_f"], True, y_lim_mhat, ), ] psds = [psd_m, psd_s, psd_r, psd_mhat] for i, ((sig, lbl_t, lbl_f, ct, cf, dots, ylim), psd) in enumerate( zip(rows, psds, strict=False) ): # ---- Time domain (left column) ---- ax_t = axes[i, 0] if dots: ax_t.plot(nn, sig[time_window], ".", markersize=5, color=ct) else: ax_t.plot(nn, sig[time_window], color=ct, linewidth=1.2) ax_t.set_ylabel(lbl_t, fontsize=12) ax_t.set_xlim([nn[0], nn[-1]]) ax_t.set_ylim(ylim) if i < 3: ax_t.set_xticklabels([]) else: ax_t.set_xlabel(r"$n$", fontsize=12) _style(ax_t) # ---- PSD (right column) ---- ax_f = axes[i, 1] ax_f.plot(omega, psd, color=cf, linewidth=1.2) ax_f.set_ylabel(lbl_f, fontsize=12) ax_f.set_ylim([-0.05, 1.08]) ax_f.set_xlim([0, 1.0]) if i < 3: ax_f.set_xticklabels([]) else: ax_f.set_xlabel(r"$\omega / \pi$", fontsize=12) _style(ax_f) # Overlay channel response on PSD_s panel if i == 1 and h_channel is not None: w_h, H = freqz(h_channel, worN=1024, whole=False) ax_f.plot( w_h / np.pi, np.abs(H), "k--", linewidth=1.5, label=r"$|H_{ch}(e^{j\omega})|$", alpha=0.7, ) ax_f.legend(fontsize=9, loc="upper right") # Column titles axes[0, 0].set_title(t("comm.time_domain", lang=lang), fontsize=12) axes[0, 1].set_title(t("comm.psd", lang=lang), fontsize=12) fig.tight_layout() _save(fig, save_path) return fig