Source code for redback.plotting

from __future__ import annotations

from os.path import join
from typing import Any, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import redback
from redback.utils import KwargsAccessorWithDefault

class _FilenameGetter(object):
    def __init__(self, suffix: str) -> None:
        self.suffix = suffix

    def __get__(self, instance: Plotter, owner: object) -> str:
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")

    def __set__(self, instance: Plotter, value: object) -> None:
        pass


class _FilePathGetter(object):

    def __init__(self, directory_property: str, filename_property: str) -> None:
        self.directory_property = directory_property
        self.filename_property = filename_property

    def __get__(self, instance: Plotter, owner: object) -> str:
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))


[docs]class Plotter(object): capsize = KwargsAccessorWithDefault("capsize", 0.) legend_location = KwargsAccessorWithDefault("legend_location", "best") legend_cols = KwargsAccessorWithDefault("legend_cols", 2) band_colors = KwargsAccessorWithDefault("band_colors", None) color = KwargsAccessorWithDefault("color", "k") band_labels = KwargsAccessorWithDefault("band_labels", None) band_scaling = KwargsAccessorWithDefault("band_scaling", {}) dpi = KwargsAccessorWithDefault("dpi", 300) elinewidth = KwargsAccessorWithDefault("elinewidth", 2) errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x") model = KwargsAccessorWithDefault("model", None) ms = KwargsAccessorWithDefault("ms", 1) axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10) max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65) random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05) uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4) max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue") random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red") bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight") linewidth = KwargsAccessorWithDefault("linewidth", 2) zorder = KwargsAccessorWithDefault("zorder", -1) xy = KwargsAccessorWithDefault("xy", (0.95, 0.9)) xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction") horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right") annotation_size = KwargsAccessorWithDefault("annotation_size", 20) fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18) fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30) fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18) fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16) hspace = KwargsAccessorWithDefault("hspace", 0.04) wspace = KwargsAccessorWithDefault("wspace", 0.15) plot_others = KwargsAccessorWithDefault("plot_others", True) random_models = KwargsAccessorWithDefault("random_models", 100) uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models") credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9) plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True) set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True) xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0) xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5) ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0) ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
[docs] def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None: """ :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted. :param kwargs: Additional kwargs the plotter uses. ------- :keyword capsize: Same as matplotlib capsize. :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands. :keyword legend_location: Same as matplotlib legend location. :keyword legend_cols: Same as matplotlib legend columns. :keyword color: Color of the data points. :keyword band_colors: A dictionary with the colors of the bands. :keyword band_labels: List with the names of the bands. :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types. :keyword dpi: Same as matplotlib dpi. :keyword elinewidth: same as matplotlib elinewidth :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`. :keyword model: str or callable, the model to plot. :keyword ms: Same as matplotlib markersize. :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes. :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve. :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves. :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band. :keyword max_likelihood_color: Color of the maximum likelihood curve. :keyword random_sample_color: Color of the random sample curves. :keyword bbox_inches: Setting for saving plots. Default is 'tight'. :keyword linewidth: Same as matplotlib linewidth :keyword zorder: Same as matplotlib zorder :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate. :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction' :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right' :keyword annotation_size: `size` argument of of `ax.annotate`. :keyword fontsize_axes: Font size of the x and y labels. :keyword fontsize_legend: Font size of the legend. :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots. Used on `supxlabel` and `supylabel`. :keyword fontsize_ticks: Font size of the axis ticks. :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels. :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels. :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors :keyword random_models: Number of random draws to use to calculate credible bands or to plot. :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets. 'credible_intervals': Plot a credible interval that is calculated based on the available parameter sets. :keyword reference_mjd_date: Date to use as reference point for the x axis. Default is the first date in the data. :keyword credible_interval_level: 0.9: Plot the 90% credible interval. :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'. :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'. :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values. :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values. :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values. :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values. """ self.transient = transient self.kwargs = kwargs or dict() self._posterior_sorted = False
keyword_docstring = __init__.__doc__.split("-------")[1] def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray: """ :param axes: The axes used in the plotting procedure. :type axes: matplotlib.axes.Axes :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot. :rtype: np.ndarray """ if isinstance(axes, np.ndarray): ax = axes[0] else: ax = axes if ax.get_yscale() == 'linear': times = np.linspace(self._xlim_low, self._xlim_high, 200) else: times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200)) if self.transient.use_phase_model: times = times + self._reference_mjd_date return times @property def _xlim_low(self) -> float: default = self.xlim_low_multiplier * self.transient.x[0] if default == 0: default += 1e-3 return self.kwargs.get("xlim_low", default) @property def _xlim_high(self) -> float: if self._x_err is None: default = self.xlim_high_multiplier * self.transient.x[-1] else: default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1]) return self.kwargs.get("xlim_high", default) @property def _ylim_low(self) -> float: default = self.ylim_low_multiplier * min(self.transient.y) return self.kwargs.get("ylim_low", default) @property def _ylim_high(self) -> float: default = self.ylim_high_multiplier * np.max(self.transient.y) return self.kwargs.get("ylim_high", default) @property def _x_err(self) -> Union[np.ndarray, None]: if self.transient.x_err is not None: return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]]) else: return None @property def _y_err(self) -> np.ndarray: if self.transient.y_err.ndim > 1.: return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]]) else: return np.array([np.abs(self.transient.y_err)]) @property def _lightcurve_plot_outdir(self) -> str: return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__)) @property def _data_plot_outdir(self) -> str: return self._get_outdir(self.transient.directory_structure.directory_path) def _get_outdir(self, default: str) -> str: return self._get_kwarg_with_default(kwarg="outdir", default=default) def get_filename(self, default: str) -> str: return self._get_kwarg_with_default(kwarg="filename", default=default) def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any: return self.kwargs.get(kwarg, default) or default @property def _model_kwargs(self) -> dict: return self._get_kwarg_with_default("model_kwargs", dict()) @property def _posterior(self) -> pd.DataFrame: posterior = self.kwargs.get("posterior", pd.DataFrame()) if not self._posterior_sorted and posterior is not None: posterior.sort_values(by='log_likelihood', inplace=True) self._posterior_sorted = True return posterior @property def _max_like_params(self) -> pd.core.series.Series: return self._posterior.iloc[-1] def _get_random_parameters(self) -> list[pd.core.series.Series]: integers = np.arange(len(self._posterior)) indices = np.random.choice(integers, size=self.random_models) return [self._posterior.iloc[idx] for idx in indices] _data_plot_filename = _FilenameGetter(suffix="data") _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve") _residual_plot_filename = _FilenameGetter(suffix="residual") _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data") _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve") _data_plot_filepath = _FilePathGetter( directory_property="_data_plot_outdir", filename_property="_data_plot_filename") _lightcurve_plot_filepath = _FilePathGetter( directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename") _residual_plot_filepath = _FilePathGetter( directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename") _multiband_data_plot_filepath = _FilePathGetter( directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename") _multiband_lightcurve_plot_filepath = _FilePathGetter( directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename") def _save_and_show(self, filepath: str, save: bool, show: bool) -> None: plt.tight_layout() if save: plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white') if show: plt.show()
[docs]class IntegratedFluxPlotter(Plotter): @property def _xlabel(self) -> str: return r"Time since burst [s]" @property def _ylabel(self) -> str: return self.transient.ylabel
[docs] def plot_data( self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the Integrated flux data and returns Axes. :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ ax = axes or plt.gca() ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err, fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize) ax.set_xscale('log') ax.set_yscale('log') ax.set_xlim(self._xlim_low, self._xlim_high) ax.set_ylim(self._ylim_low, self._ylim_high) ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes) ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes) ax.annotate( self.transient.name, xy=self.xy, xycoords=self.xycoords, horizontalalignment=self.horizontalalignment, size=self.annotation_size) ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks) self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show) return ax
[docs] def plot_lightcurve( self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the Integrated flux data and the lightcurve and returns Axes. :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ axes = axes or plt.gca() axes = self.plot_data(axes=axes, save=False, show=False) times = self._get_times(axes) self._plot_lightcurves(axes, times) self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show) return axes
def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None: if self.plot_max_likelihood: ys = self.model(times, **self._max_like_params, **self._model_kwargs) axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha, lw=self.linewidth) random_ys_list = [self.model(times, **random_params, **self._model_kwargs) for random_params in self._get_random_parameters()] if self.uncertainty_mode == "random_models": for ys in random_ys_list: axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder) elif self.uncertainty_mode == "credible_intervals": lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level) axes.fill_between( times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color) def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None: ys = self.model(times, **params, **self._model_kwargs) axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder)
[docs] def plot_residuals( self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the residual of the Integrated flux data returns Axes. :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot. :param save: Whether to save the plot. (Default value = True) :param show: Whether to show the plot. (Default value = True) :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ if axes is None: fig, axes = plt.subplots( nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1])) axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False) axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes) axes[0].set_xlabel("") ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs) axes[1].errorbar( self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err, fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize) axes[1].set_yscale("log") axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes) axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks) self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show) return axes
[docs]class LuminosityPlotter(IntegratedFluxPlotter): pass
[docs]class MagnitudePlotter(Plotter): xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9) xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1) xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2) ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.8) ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.2) ncols = KwargsAccessorWithDefault("ncols", 2) @property def _colors(self) -> str: return self.kwargs.get("colors", self.transient.get_colors(self._filters)) @property def _xlabel(self) -> str: if self.transient.use_phase_model: default = f"Time since {self._reference_mjd_date} MJD [days]" else: default = self.transient.xlabel return self.kwargs.get("xlabel", default) @property def _ylabel(self) -> str: return self.kwargs.get("ylabel", self.transient.ylabel) @property def _get_bands_to_plot(self) -> list[str]: return self.kwargs.get("bands_to_plot", self.transient.active_bands) @property def _xlim_low(self) -> float: if self.transient.use_phase_model: default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier else: default = self.xlim_low_multiplier * self.transient.x[0] if default == 0: default += 1e-3 return self.kwargs.get("xlim_low", default) @property def _xlim_high(self) -> float: if self.transient.use_phase_model: default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier else: default = self.xlim_high_multiplier * self.transient.x[-1] return self.kwargs.get("xlim_high", default) @property def _ylim_low_magnitude(self) -> float: return self.ylim_low_magnitude_multiplier * min(self.transient.y) @property def _ylim_high_magnitude(self) -> float: return self.ylim_high_magnitude_multiplier * np.max(self.transient.y) def _get_ylim_low_with_indices(self, indices: list) -> float: return self.ylim_low_multiplier * min(self.transient.y[indices]) def _get_ylim_high_with_indices(self, indices: list) -> float: return self.ylim_high_multiplier * np.max(self.transient.y[indices]) def _get_x_err(self, indices: list) -> np.ndarray: return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None: if self.transient.magnitude_data: ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude) ax.invert_yaxis() else: ax.set_ylim(self._ylim_low, self._ylim_high) ax.set_yscale("log") def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None: if self.transient.magnitude_data: ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude) ax.invert_yaxis() else: ax.set_ylim(self._get_ylim_low_with_indices(indices=indices), self._get_ylim_high_with_indices(indices=indices)) ax.set_yscale("log") def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None: if self.transient.use_phase_model: axes.set_xscale("log") axes.set_xlim(self._xlim_low, self._xlim_high) @property def _nrows(self) -> int: default = int(np.ceil(len(self._filters) / 2)) return self._get_kwarg_with_default("nrows", default=default) @property def _npanels(self) -> int: npanels = self._nrows * self.ncols if npanels < len(self._filters): raise ValueError(f"Insufficient number of panels. {npanels} panels were given " f"but {len(self._filters)} panels are needed.") return npanels @property def _figsize(self) -> tuple: default = (4 + 4 * self.ncols, 2 + 2 * self._nrows) return self._get_kwarg_with_default("figsize", default=default) @property def _reference_mjd_date(self) -> int: if self.transient.use_phase_model: return self.kwargs.get("reference_mjd_date", int(self.transient.x[0])) return 0 @property def band_label_generator(self): if self.band_labels is not None: return (bl for bl in self.band_labels)
[docs] def plot_data( self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the Magnitude data and returns Axes. :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ ax = axes or plt.gca() band_label_generator = self.band_label_generator for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands): if band in self._filters: color = self._colors[list(self._filters).index(band)] if band_label_generator is None: if band in self.band_scaling: label = str(self.band_scaling.get(band)) + ' ' + self.band_scaling.get("type") + ' ' + band else: label = band else: label = next(band_label_generator) elif self.plot_others: color = "black" label = None else: continue if isinstance(label, float): label = f"{label:.2e}" if self.band_colors is not None: color = self.band_colors[band] if band in self.band_scaling: if self.band_scaling.get("type") == 'x': ax.errorbar( self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band), xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band), fmt=self.errorbar_fmt, ms=self.ms, color=color, elinewidth=self.elinewidth, capsize=self.capsize, label=label) elif self.band_scaling.get("type") == '+': ax.errorbar( self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band), xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color, elinewidth=self.elinewidth, capsize=self.capsize, label=label) else: ax.errorbar( self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color, elinewidth=self.elinewidth, capsize=self.capsize, label=label) self._set_x_axis(axes=ax) self._set_y_axis_data(ax) ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes) ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes) ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks) ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend) self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show) return ax
[docs] def plot_lightcurve( self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\ -> matplotlib.axes.Axes: """Plots the Magnitude data and returns Axes. :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ axes = axes or plt.gca() axes = self.plot_data(axes=axes, save=False, show=False) axes.set_yscale('log') times = self._get_times(axes) bands_to_plot = self._get_bands_to_plot color_max = self.max_likelihood_color color_sample = self.random_sample_color for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)): if self.set_same_color_per_subplot is True: if self.band_colors is not None: color = self.band_colors[band] color_max = color color_sample = color sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band]) self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))] if isinstance(band, str): frequency = redback.utils.bands_to_frequency([band]) else: frequency = band self._model_kwargs['frequency'] = np.ones(len(times)) * frequency if self.plot_max_likelihood: ys = self.model(times, **self._max_like_params, **self._model_kwargs) if band in self.band_scaling: if self.band_scaling.get("type") == 'x': axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth) elif self.band_scaling.get("type") == '+': axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth) else: axes.plot(times - self._reference_mjd_date, ys, color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth) random_ys_list = [self.model(times, **random_params, **self._model_kwargs) for random_params in self._get_random_parameters()] if self.uncertainty_mode == "random_models": for ys in random_ys_list: if band in self.band_scaling: if self.band_scaling.get("type") == 'x': axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1) elif self.band_scaling.get("type") == '+': axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1) else: axes.plot(times - self._reference_mjd_date, ys, color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1) elif self.uncertainty_mode == "credible_intervals": if band in self.band_scaling: if self.band_scaling.get("type") == 'x': lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level) elif self.band_scaling.get("type") == '+': lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level) else: lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level) axes.fill_between( times - self._reference_mjd_date, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=color_sample) self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show) return axes
def _check_valid_multiband_data_mode(self) -> bool: if self.transient.luminosity_data: redback.utils.logger.warning( f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.") return False return True
[docs] def plot_multiband( self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the Magnitude multiband data and returns Axes. :param figure: Matplotlib figure to plot the data into. :type figure: matplotlib.figure.Figure :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ if not self._check_valid_multiband_data_mode(): return if figure is None or axes is None: figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize) axes = axes.ravel() band_label_generator = self.band_label_generator ii = 0 for indices, band, freq in zip( self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies): if band not in self._filters: continue x_err = self._get_x_err(indices) color = self._colors[list(self._filters).index(band)] if self.band_colors is not None: color = self.band_colors[band] if band_label_generator is None: label = self._get_multiband_plot_label(band, freq) else: label = next(band_label_generator) axes[ii].errorbar( self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=x_err, yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color, elinewidth=self.elinewidth, capsize=self.capsize, label=label) self._set_x_axis(axes[ii]) self._set_y_axis_multiband_data(axes[ii], indices) axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend) axes[ii].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks) ii += 1 figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure) figure.supylabel(self._ylabel, fontsize=self.fontsize_figure) plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace) self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show) return axes
@staticmethod def _get_multiband_plot_label(band: str, freq: float) -> str: if isinstance(band, str): if 1e10 < float(freq) < 1e16: label = band else: label = f"{freq:.2e}" else: label = f"{band:.2e}" return label @property def _filters(self) -> list[str]: filters = self.kwargs.get("filters", self.transient.active_bands) if 'bands_to_plot' in self.kwargs: filters = self.kwargs['bands_to_plot'] if filters is None: return self.transient.active_bands elif str(filters) == 'default': return self.transient.default_filters return filters
[docs] def plot_multiband_lightcurve( self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes: """Plots the Magnitude multiband lightcurve and returns Axes. :param figure: Matplotlib figure to plot the data into. :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot. :type axes: Union[matplotlib.axes.Axes, None], optional :param save: Whether to save the plot. (Default value = True) :type save: bool :param show: Whether to show the plot. (Default value = True) :type show: bool :return: The axes with the plot. :rtype: matplotlib.axes.Axes """ if not self._check_valid_multiband_data_mode(): return if figure is None or axes is None: figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize) axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False) times = self._get_times(axes) ii = 0 color_max = self.max_likelihood_color color_sample = self.random_sample_color for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies): if band not in self._filters: continue new_model_kwargs = self._model_kwargs.copy() new_model_kwargs['frequency'] = freq new_model_kwargs['bands'] = band if self.set_same_color_per_subplot is True: color = self._colors[list(self._filters).index(band)] if self.band_colors is not None: color = self.band_colors[band] color_max = color color_sample = color if self.plot_max_likelihood: ys = self.model(times, **self._max_like_params, **new_model_kwargs) axes[ii].plot( times - self._reference_mjd_date, ys, color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth) random_ys_list = [self.model(times, **random_params, **new_model_kwargs) for random_params in self._get_random_parameters()] if self.uncertainty_mode == "random_models": for random_ys in random_ys_list: axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder) elif self.uncertainty_mode == "credible_intervals": lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level) axes[ii].fill_between( times - self._reference_mjd_date, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=color_sample) ii += 1 self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show) return axes
[docs]class FluxDensityPlotter(MagnitudePlotter): pass
[docs]class IntegratedFluxOpticalPlotter(MagnitudePlotter): pass
[docs]class SpectraPlotter(Plotter): pass