Source code for demregpy.plotting

"""
Plotting helpers for demregpy.
"""

import importlib

import numpy as np

__all__ = [
    "plot_dem",
    "plot_loci_curves",
]


def _resolve_axes(plt, *, ax=None, fig=None):
    if ax is not None:
        return ax
    if fig is None:
        _, ax = plt.subplots()
        return ax
    if len(fig.axes) == 0:
        return fig.add_subplot(111)
    if len(fig.axes) == 1:
        return fig.axes[0]
    raise ValueError("fig must have exactly one axes unless ax is provided")


def _logt_bin_widths(logt):
    logt = np.asarray(logt, dtype=float)
    if logt.ndim != 1:
        raise ValueError("logt must be one-dimensional")
    if logt.size < 2:
        raise ValueError("logt must contain at least two temperature points")
    if not np.all(np.diff(logt) > 0):
        raise ValueError("logt must be strictly increasing")

    edges = np.empty(logt.size + 1, dtype=float)
    edges[1:-1] = 0.5 * (logt[:-1] + logt[1:])
    edges[0] = logt[0] - 0.5 * (logt[1] - logt[0])
    edges[-1] = logt[-1] + 0.5 * (logt[-1] - logt[-2])
    return np.diff(edges)


def _default_loci_ylabel(dem_space):
    if dem_space:
        return r"Loci Curve [$\mathrm{cm}^{-5}\,\mathrm{K}^{-1}$]"
    return r"EM Loci Curve [$\mathrm{cm}^{-5}$]"


def _axis_positive_ydata(ax):
    values = []
    for line in ax.lines:
        ydata = np.asarray(line.get_ydata(), dtype=float)
        mask = np.isfinite(ydata) & (ydata > 0)
        if np.any(mask):
            values.append(ydata[mask])
    if not values:
        return np.array([], dtype=float)
    return np.concatenate(values)


def _loci_envelope(loci):
    valid = np.isfinite(loci) & (loci > 0)
    envelope = np.full(loci.shape[0], np.nan, dtype=float)
    if np.any(valid):
        envelope[valid.any(axis=1)] = np.min(np.where(valid, loci, np.inf), axis=1)[valid.any(axis=1)]
    return envelope


[docs] def plot_dem( logt, dem, *, elogt=None, edem=None, ax=None, label=None, color=None, ecolor=None, fmt="o", capsize=0, elinewidth=2, xlabel=r"$\log_{10} T$", ylabel=r"DEM [$\mathrm{cm}^{-5}\,\mathrm{K}^{-1}$]", yscale="log", **kwargs, ): """ Plot a one-dimensional DEM with optional horizontal and vertical error bars. Parameters ---------- logt : array_like Temperature-bin centres in log10(T). dem : array_like DEM values for each temperature bin. elogt : array_like, optional Horizontal uncertainty in log10(T). edem : array_like, optional Vertical uncertainty on the DEM. ax : `matplotlib.axes.Axes`, optional Axes to draw on. If not given, a new figure and axes are created. label : str, optional Label for the plotted series. color : str, optional Matplotlib colour for markers and line. ecolor : str, optional Matplotlib colour for the error bars. Defaults to ``color`` if given. fmt : str, optional Errorbar marker and line format. capsize : float, optional Errorbar cap size. elinewidth : float, optional Errorbar line width. xlabel : str, optional X-axis label. ylabel : str, optional Y-axis label. yscale : str or None, optional Y-axis scale. Defaults to ``"log"``. **kwargs Additional keyword arguments passed to ``Axes.errorbar``. Returns ------- ax : `matplotlib.axes.Axes` Axes used for the plot. container : `matplotlib.container.ErrorbarContainer` Matplotlib container returned by ``Axes.errorbar``. """ try: plt = importlib.import_module("matplotlib.pyplot") except ImportError as exc: raise ImportError("plot_dem requires matplotlib to be installed.") from exc logt = np.asarray(logt) dem = np.asarray(dem) if logt.ndim != 1 or dem.ndim != 1: raise ValueError("plot_dem expects one-dimensional logt and dem arrays") if logt.shape != dem.shape: raise ValueError("logt and dem must have the same shape") if elogt is not None: elogt = np.asarray(elogt) if elogt.shape != logt.shape: raise ValueError("elogt must have the same shape as logt") if edem is not None: edem = np.asarray(edem) if edem.shape != dem.shape: raise ValueError("edem must have the same shape as dem") ax = _resolve_axes(plt, ax=ax) if ecolor is None and color is not None: ecolor = color container = ax.errorbar( logt, dem, xerr=elogt, yerr=edem, fmt=fmt, label=label, color=color, ecolor=ecolor, capsize=capsize, elinewidth=elinewidth, **kwargs, ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if yscale is not None: ax.set_yscale(yscale) return ax, container
[docs] def plot_loci_curves( logt, dn_in, tresp, *, channels=None, ax=None, fig=None, dem_space=True, show_minimum=True, minimum_kwargs=None, xlabel=r"$\log_{10} T$", ylabel=None, yscale="log", ylim=None, **kwargs, ): """ Plot filter loci curves for a single observation vector. By default the response matrix is scaled in the same bin-aware way used by ``dn2dem``, so the loci curves can be overplotted directly on a DEM axis. Set ``dem_space=False`` to plot the raw ``DN / R(T)`` EM loci curves instead. Parameters ---------- logt : array_like Temperature-bin centres in log10(T). dn_in : array_like Input counts for one spectrum, with shape ``(nf,)``. tresp : array_like Temperature response matrix with shape ``(nt, nf)``. channels : sequence of str, optional Channel labels for the individual loci curves. ax : `matplotlib.axes.Axes`, optional Axes to draw on. fig : `matplotlib.figure.Figure`, optional Figure to draw on if ``ax`` is not given. If the figure already has an axes, it must have exactly one. dem_space : bool, optional Apply the same bin-width scaling used by ``dn2dem`` so the curves are in DEM-like units and can be overplotted directly on a DEM axis. show_minimum : bool, optional Plot the pointwise minimum loci envelope. minimum_kwargs : dict, optional Keyword arguments passed to the minimum-envelope line. xlabel : str, optional X-axis label. ylabel : str, optional Y-axis label. If omitted, a default is chosen from ``dem_space``. yscale : str or None, optional Y-axis scale. Defaults to ``"log"``. ylim : tuple[float, float], optional Explicit y limits. If omitted, robust limits are inferred from the loci envelope and any existing lines already on the axes. **kwargs Additional keyword arguments passed to ``Axes.plot`` for each loci curve. Returns ------- ax : `matplotlib.axes.Axes` Axes used for the plot. lines : list[`matplotlib.lines.Line2D`] Line objects for the individual loci curves, followed by the minimum envelope if ``show_minimum=True``. """ try: plt = importlib.import_module("matplotlib.pyplot") except ImportError as exc: raise ImportError("plot_loci_curves requires matplotlib to be installed.") from exc logt = np.asarray(logt, dtype=float) dn_in = np.asarray(dn_in, dtype=float) tresp = np.asarray(tresp, dtype=float) if logt.ndim != 1 or dn_in.ndim != 1: raise ValueError("plot_loci_curves expects one-dimensional logt and dn_in arrays") if tresp.ndim != 2: raise ValueError("tresp must have shape (nt, nf)") if tresp.shape[0] != logt.size: raise ValueError("The first axis of tresp must match the length of logt") if tresp.shape[1] != dn_in.size: raise ValueError("The second axis of tresp must match the length of dn_in") if channels is not None and len(channels) != dn_in.size: raise ValueError("channels must have the same length as dn_in") ax = _resolve_axes(plt, ax=ax, fig=fig) existing_positive = _axis_positive_ydata(ax) response = tresp if dem_space: dlogt = _logt_bin_widths(logt) response = tresp * (10.0 ** logt * np.log(10.0 ** dlogt))[:, np.newaxis] with np.errstate(divide="ignore", invalid="ignore"): loci = dn_in[np.newaxis, :] / response loci[~np.isfinite(loci)] = np.nan loci[loci <= 0] = np.nan labels = channels if channels is not None else [None] * dn_in.size lines = [] for idx, label in enumerate(labels): (line,) = ax.plot(logt, loci[:, idx], label=label, **kwargs) lines.append(line) envelope = _loci_envelope(loci) if show_minimum: minimum_style = { "color": "k", "linestyle": "--", "linewidth": 2, "label": "Minimum loci", } if minimum_kwargs is not None: minimum_style.update(minimum_kwargs) (line,) = ax.plot(logt, envelope, **minimum_style) lines.append(line) ax.set_xlabel(xlabel) ax.set_ylabel(_default_loci_ylabel(dem_space) if ylabel is None else ylabel) if yscale is not None: ax.set_yscale(yscale) if ylim is not None: ax.set_ylim(*ylim) else: envelope_positive = envelope[np.isfinite(envelope) & (envelope > 0)] if envelope_positive.size or existing_positive.size: y_reference = np.concatenate([arr for arr in (existing_positive, envelope_positive) if arr.size > 0]) else: y_reference = loci[np.isfinite(loci) & (loci > 0)] if y_reference.size: ymin = 0.5 * np.min(y_reference) ymax = 5.0 * np.max(y_reference) if np.isfinite(ymin) and np.isfinite(ymax) and 0 < ymin < ymax: ax.set_ylim(ymin, ymax) return ax, lines