Source code for chaotic_pfc.plotting.figures

"""
figures.py
==========
Publication-quality SVG figures with LaTeX-style labels.

All text uses matplotlib's mathtext engine (no external LaTeX needed).
Figures are saved as .svg by default for vector-quality output.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray
from scipy.signal import freqz

from chaotic_pfc._i18n import t

# ── Global RC params for LaTeX-like rendering ───────────────────────────────


[docs] def setup_rc(): """Configure matplotlib for publication-quality LaTeX-style SVG output. Uses STIX fonts (the standard for scientific publishing, very close to Computer Modern) and converts all text to vector paths so that SVGs render identically on any system without requiring font installation. """ plt.rcParams.update( { "text.usetex": False, "mathtext.fontset": "stix", # STIX ≈ Computer Modern "font.family": "STIXGeneral", # matching text font "svg.fonttype": "path", # text → vector paths (portable) "axes.unicode_minus": False, "axes.formatter.use_mathtext": True, "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14, "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 11, "figure.dpi": 150, "savefig.dpi": 150, "savefig.bbox": "tight", "savefig.pad_inches": 0.05, "axes.linewidth": 1.2, "xtick.major.width": 1.0, "ytick.major.width": 1.0, "lines.linewidth": 1.5, } )
# ── Colour palette ────────────────────────────────────────────────────────── C = { "msg_t": "#0073BD", # blue (time-domain message) "msg_f": "#804000", # brown (freq-domain message) "sig_t": "#E00000", # red (time-domain signal) "sig_f": "#660066", # purple (freq-domain signal) "traj": "#000000", # black (attractor) "traj2": "#E00000", # red (second trajectory) } def _style(ax: Axes, ts: int = 12) -> None: """Apply uniform grid/tick/spine styling to ``ax``.""" ax.grid(True, alpha=0.25, linewidth=0.5) ax.tick_params(labelsize=ts, width=1.2, direction="in") for sp in ax.spines.values(): sp.set_linewidth(1.2) def _save(fig: Figure, path: str | Path | None, **savefig_kwargs) -> None: """Write ``fig`` to ``path`` if not ``None``, creating parent directories. Extra keyword arguments are forwarded to ``fig.savefig()``. """ if path is not None: Path(path).parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, **savefig_kwargs) # ── 1. Attractor ────────────────────────────────────────────────────────────
[docs] def plot_attractor( X: NDArray, Y: NDArray, title: str = "", xlabel: str = r"$x_1[n]$", ylabel: str = r"$x_2[n]$", save_path: str | None = None, ) -> Figure: """Plot a phase-space portrait of a 2-D trajectory. Parameters ---------- X, Y State-variable trajectories, same length. title Optional figure title. If empty, no title is drawn. xlabel, ylabel Axis labels. Defaults use LaTeX-style math for ``x_1`` and ``x_2``. save_path If given, the figure is written to this path. The extension selects the format (``.svg``, ``.png``, etc.). Returns ------- Figure The matplotlib ``Figure`` object. Returned so callers can inspect or further annotate it before showing / closing. """ fig, ax = plt.subplots(figsize=(6, 5)) ax.plot(X, Y, ",", color=C["traj"], alpha=0.8, markersize=0.3) if title: ax.set_title(title, fontsize=13) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) _style(ax) fig.tight_layout() _save(fig, save_path) return fig
# ── 2. Sensitivity (SDIC) ──────────────────────────────────────────────────
[docs] def plot_sensitivity( n: NDArray, X1: NDArray, X2: NDArray, save_path: str | None = None, lang: str = "pt", ) -> Figure: """Overlay two Hénon trajectories to illustrate sensitivity to ICs. Plots two state trajectories that start from infinitesimally different initial conditions, making visually obvious how they diverge exponentially — the classic demonstration of chaos. Parameters ---------- n Sample index axis, shape ``(N,)``. X1, X2 Two state trajectories evaluated on ``n``. Typically differing only by ``x0_2 = x0_1 + 1e-4``. save_path If given, the figure is written to this path. lang Language code for the figure title (``"pt"`` or ``"en"``). Returns ------- Figure The matplotlib ``Figure`` object. """ fig, ax = plt.subplots(figsize=(9, 4)) ax.plot(n, X1, label=r"$x^{(1)}[n],\; x_0=0$", color="steelblue", linewidth=1.2) ax.plot(n, X2, label=r"$x^{(2)}[n],\; x_0=10^{-4}$", color="tomato", linewidth=1.2, alpha=0.85) ax.set_xlabel(r"$n$", fontsize=12) ax.set_ylabel(r"$x[n]$", fontsize=12) ax.set_title( t("sensitivity.title", lang=lang), fontsize=13, ) ax.legend(fontsize=11, framealpha=0.9) _style(ax) fig.tight_layout() _save(fig, save_path) return fig
# ── 3. Communication 4×2 grid (time + PSD) ─────────────────────────────────
[docs] @dataclass class PlotGridOptions: """Optional styling parameters for :func:`plot_comm_grid`. All fields have sensible defaults; only override what you need. """ time_window: slice = field(default_factory=lambda: slice(0, 300)) suptitle: str = "" y_lim_msg: tuple[float, float] = (-1.5, 1.5) y_lim_sig: tuple[float, float] = (-2.5, 2.5) y_lim_mhat: tuple[float, float] | None = None h_channel: NDArray | None = None save_path: str | None = None
[docs] def plot_comm_grid( n: NDArray, m: NDArray, s: NDArray, r: NDArray, m_hat: NDArray, omega: NDArray, psd_m: NDArray, psd_s: NDArray, psd_r: NDArray, psd_mhat: NDArray, *, opts: PlotGridOptions | None = None, time_window: slice = slice(0, 300), suptitle: str = "", y_lim_msg: tuple = (-1.5, 1.5), y_lim_sig: tuple = (-2.5, 2.5), y_lim_mhat: tuple | None = None, h_channel: NDArray | None = None, save_path: str | None = None, lang: str = "pt", ) -> Figure: """ 4×2 grid: left = time domain, right = PSD. Rows: m[n], s[n], r[n], m̂[n]. If h_channel is provided, its frequency response is overlaid on PSD_s. The *opts* dataclass overrides individual keyword arguments when both are provided. """ if opts is not None: if time_window is None or time_window == slice(0, 300): time_window = opts.time_window if not suptitle: suptitle = opts.suptitle y_lim_msg = opts.y_lim_msg if y_lim_msg == (-1.5, 1.5) else y_lim_msg y_lim_sig = opts.y_lim_sig if y_lim_sig == (-2.5, 2.5) else y_lim_sig if y_lim_mhat is None: y_lim_mhat = opts.y_lim_mhat if h_channel is None: h_channel = opts.h_channel if save_path is None: save_path = opts.save_path if y_lim_mhat is None: y_lim_mhat = y_lim_msg fig, axes = plt.subplots(4, 2, figsize=(14, 12)) if suptitle: fig.suptitle(suptitle, fontsize=14, y=0.995) nn = n[time_window] # Row configs: (signal, ylabel_t, ylabel_f, color_t, color_f, dots?, ylim_t) rows = [ ( m, r"$(a)\; m[n]$", r"$(e)\; \mathcal{M}(\omega)$", C["msg_t"], C["msg_f"], True, y_lim_msg, ), (s, r"$(b)\; s[n]$", r"$(f)\; S(\omega)$", C["sig_t"], C["sig_f"], False, y_lim_sig), (r, r"$(c)\; r[n]$", r"$(g)\; R(\omega)$", C["sig_t"], C["sig_f"], False, y_lim_sig), ( m_hat, r"$(d)\; \hat{m}[n]$", r"$(h)\; \hat{\mathcal{M}}(\omega)$", C["msg_t"], C["msg_f"], True, y_lim_mhat, ), ] psds = [psd_m, psd_s, psd_r, psd_mhat] for i, ((sig, lbl_t, lbl_f, ct, cf, dots, ylim), psd) in enumerate( zip(rows, psds, strict=False) ): # ---- Time domain (left column) ---- ax_t = axes[i, 0] if dots: ax_t.plot(nn, sig[time_window], ".", markersize=5, color=ct) else: ax_t.plot(nn, sig[time_window], color=ct, linewidth=1.2) ax_t.set_ylabel(lbl_t, fontsize=12) ax_t.set_xlim([nn[0], nn[-1]]) ax_t.set_ylim(ylim) if i < 3: ax_t.set_xticklabels([]) else: ax_t.set_xlabel(r"$n$", fontsize=12) _style(ax_t) # ---- PSD (right column) ---- ax_f = axes[i, 1] ax_f.plot(omega, psd, color=cf, linewidth=1.2) ax_f.set_ylabel(lbl_f, fontsize=12) ax_f.set_ylim([-0.05, 1.08]) ax_f.set_xlim([0, 1.0]) if i < 3: ax_f.set_xticklabels([]) else: ax_f.set_xlabel(r"$\omega / \pi$", fontsize=12) _style(ax_f) # Overlay channel response on PSD_s panel if i == 1 and h_channel is not None: w_h, H = freqz(h_channel, worN=1024, whole=False) ax_f.plot( w_h / np.pi, np.abs(H), "k--", linewidth=1.5, label=r"$|H_{ch}(e^{j\omega})|$", alpha=0.7, ) ax_f.legend(fontsize=9, loc="upper right") # Column titles axes[0, 0].set_title(t("comm.time_domain", lang=lang), fontsize=12) axes[0, 1].set_title(t("comm.psd", lang=lang), fontsize=12) fig.tight_layout() _save(fig, save_path) return fig