import bilby.core.prior
import numpy as np
import os
from typing import Union
import warnings
import matplotlib
import pandas as pd
from bilby.core.result import Result
from bilby.core.result import _determine_file_name # noqa
import redback.transient.transient
from redback import model_library
from redback.transient import TRANSIENT_DICT
from redback.utils import MetaDataAccessor, logger
warnings.simplefilter(action='ignore')
def _smart_corner_title(median, minus, plus):
"""Format a corner plot title with automatically chosen precision.
Uses scientific notation when the values span many orders of magnitude
(e.g. ek ~ 1e51), and picks enough decimal places so that uncertainties
are not rendered as 0.00.
:param median: Median value.
:param minus: Lower uncertainty (positive number).
:param plus: Upper uncertainty (positive number).
:return: LaTeX-formatted title string.
"""
# Determine whether scientific notation is appropriate
abs_median = abs(median) if median != 0 else max(abs(plus), abs(minus))
use_sci = abs_median != 0 and (abs_median >= 1e4 or abs_median < 1e-2)
if use_sci:
exponent = int(np.floor(np.log10(abs_median)))
scale = 10 ** exponent
m = median / scale
p = plus / scale
mn = minus / scale
# Enough decimal places so the smaller uncertainty shows at least 2 sig figs
smallest = max(min(p, mn), 1e-10 * abs(m))
if smallest > 0:
dp = max(0, int(np.ceil(-np.log10(smallest))) + 1)
else:
dp = 2
dp = min(dp, 4)
fmt = f".{dp}f"
f = "{{0:{0}}}".format(fmt).format
mantissa = r"${{{0}}}_{{-{1}}}^{{+{2}}}$".format(f(m), f(mn), f(p))
return r"${} \times 10^{{{}}}$".format(mantissa.strip('$'), exponent)
# Linear scale: pick decimal places so uncertainties are not 0.00
smallest = max(min(plus, minus), 1e-10 * max(abs(median), 1))
if smallest > 0:
dp = max(1, int(np.ceil(-np.log10(smallest))) + 1)
else:
dp = 2
dp = min(dp, 4)
fmt = f".{dp}f"
f = "{{0:{0}}}".format(fmt).format
return r"${{{0}}}_{{-{1}}}^{{+{2}}}$".format(f(median), f(minus), f(plus))
[docs]
class RedbackResult(Result):
model = MetaDataAccessor('model')
transient_type = MetaDataAccessor('transient_type')
model_kwargs = MetaDataAccessor('model_kwargs')
name = MetaDataAccessor('name')
path = MetaDataAccessor('path')
[docs]
def __init__(
self, label: str = 'no_label', outdir: str = '.', sampler: str = None, search_parameter_keys: list = None,
fixed_parameter_keys: list = None, constraint_parameter_keys: list = None,
priors: Union[dict, bilby.core.prior.PriorDict] = None, sampler_kwargs: dict = None,
injection_parameters: dict = None, meta_data: dict = None, posterior: pd.DataFrame = None,
samples: pd.DataFrame = None, nested_samples: pd.DataFrame = None, log_evidence: float = np.nan,
log_evidence_err: float = np.nan, information_gain: float = np.nan, log_noise_evidence: float = np.nan,
log_bayes_factor: float = np.nan, log_likelihood_evaluations: np.ndarray = None,
log_prior_evaluations: int = None, sampling_time: float = None, nburn: int = None,
num_likelihood_evaluations: int = None, walkers: int = None, max_autocorrelation_time: float = None,
use_ratio: bool = None, parameter_labels: list = None, parameter_labels_with_unit: list = None,
version: str = None) -> None:
"""Constructor for an extension of the regular bilby `Result`. This result adds the capability of utilising the
plotting methods of the `Transient` such as `plot_lightcurve`. The class does this by reconstructing the
`Transient` object that was used during the run by saving the required information in `meta_data`.
:param label: Labels of files produced by this class.
:type label: str, optional
:param outdir: Output directory of the result. Default is the current directory.
:type outdir: str, optional
:param sampler: The sampler used during the run.
:type sampler: str, optional
:param search_parameter_keys: The parameters that were sampled in.
:type search_parameter_keys: list, optional
:param fixed_parameter_keys: Parameters that had a `DeltaFunction` prior
:type fixed_parameter_keys: list, optional
:param constraint_parameter_keys: Parameters that had a `Constraint` prior
:type constraint_parameter_keys: list, optional
:param priors: Dictionary of priors.
:type priors: Union[dict, bilby.core.prior.PriorDict]
:param sampler_kwargs: Any keyword arguments passed to the sampling package.
:type sampler_kwargs: dict, optional
:param injection_parameters: True parameters if the dataset is simulated.
:type injection_parameters: dict, optional
:param meta_data: Additional dictionary. Contains the data used during the run and
is used to reconstruct the `Transient` object used during the run.
:type meta_data: dict, optional
:param posterior: Posterior samples with log likelihood and log prior values.
:type posterior: pd.Dataframe, optional
:param samples: An array of the output posterior samples.
:type samples: np.ndarray, optional
:param nested_samples: An array of the unweighted samples
:type nested_samples: np.ndarray, optional
:param log_evidence: The log evidence value if provided.
:type log_evidence: float, optional
:param log_evidence_err: The log evidence error value if provided
:type log_evidence_err: float, optional
:param information_gain: The information gain calculated.
:type information_gain: float, optional
:param log_noise_evidence: The log noise evidence.
:type log_noise_evidence: float, optional
:param log_bayes_factor: The log Bayes factor if we sampled using the likelihood ratio.
:type log_bayes_factor: float, optional
:param log_likelihood_evaluations: The evaluations of the likelihood for each sample point
:type log_likelihood_evaluations: np.ndarray, optional
:param log_prior_evaluations: Number of log prior evaluations.
:type log_prior_evaluations: int, optional
:param sampling_time: The time taken to complete the sampling in seconds.
:type sampling_time: float, optional
:param nburn: The number of burn-in steps discarded for MCMC samplers
:type nburn: int, optional
:param num_likelihood_evaluations: Number of total likelihood evaluations.
:type num_likelihood_evaluations: int, optional
:param walkers: The samplers taken by an ensemble MCMC samplers.
:type walkers: array_like, optional
:param max_autocorrelation_time: The estimated maximum autocorrelation time for MCMC samplers.
:type max_autocorrelation_time: float, optional
:param use_ratio:
A boolean stating whether the likelihood ratio, as opposed to the
likelihood was used during sampling.
:type use_ratio: bool, optional
:param parameter_labels: List of the latex-formatted parameter labels.
:type parameter_labels: list, optional
:param parameter_labels_with_unit: List of the latex-formatted parameter labels with units.
:type parameter_labels_with_unit: list, optional
:param version: Version information for software used to generate the result. Note,
this information is generated when the result object is initialized.
:type version: str
"""
super(RedbackResult, self).__init__(
label=label, outdir=outdir, sampler=sampler,
search_parameter_keys=search_parameter_keys, fixed_parameter_keys=fixed_parameter_keys,
constraint_parameter_keys=constraint_parameter_keys, priors=priors,
sampler_kwargs=sampler_kwargs, injection_parameters=injection_parameters,
meta_data=meta_data, posterior=posterior, samples=samples,
nested_samples=nested_samples, log_evidence=log_evidence,
log_evidence_err=log_evidence_err, information_gain=information_gain,
log_noise_evidence=log_noise_evidence, log_bayes_factor=log_bayes_factor,
log_likelihood_evaluations=log_likelihood_evaluations,
log_prior_evaluations=log_prior_evaluations, sampling_time=sampling_time, nburn=nburn,
num_likelihood_evaluations=num_likelihood_evaluations, walkers=walkers,
max_autocorrelation_time=max_autocorrelation_time, use_ratio=use_ratio,
parameter_labels=parameter_labels, parameter_labels_with_unit=parameter_labels_with_unit,
version=version)
@property
def transient(self) -> redback.transient.transient.Transient:
"""Reconstruct the transient used during sampling time using the metadata information.
:return: The reconstructed Transient.
:rtype: redback.transient.transient.Transient
"""
logger.debug(f"Reconstructing transient of type '{self.transient_type}' from metadata")
try:
transient_obj = TRANSIENT_DICT[self.transient_type](**self.meta_data)
logger.debug(f"Successfully reconstructed transient '{self.name}'")
return transient_obj
except KeyError as e:
logger.error(f"Unknown transient type '{self.transient_type}'. Available types: {list(TRANSIENT_DICT.keys())}")
raise
except Exception as e:
logger.error(f"Failed to reconstruct transient '{self.transient_type}': {e}")
raise
[docs]
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
"""Wrapper around bilby's plot_corner that applies smart title formatting.
Titles are formatted in scientific notation when the median or uncertainties
span many orders of magnitude (e.g. ek ~ 1e51 erg), and pick enough decimal
places so that uncertainties are never displayed as 0.00.
All extra keyword arguments are forwarded to corner.corner via bilby. Useful ones:
**Selecting and labelling parameters**
:param parameters: List of parameter names to plot, or a dict mapping name -> label.
e.g. ``parameters=['mej', 'vej']`` or
``parameters={'mej': r'$M_{\\rm ej}~(M_\\odot)$', 'vej': r'$v_{\\rm ej}$'}``
:param labels: List of LaTeX labels, one per parameter (overrides names on axes).
e.g. ``labels=[r'$M_{\\rm ej}~(M_\\odot)$', r'$f_{\\rm Ni}$', ...]``
:param priors: bilby PriorDict to overplot prior distributions on the 1-D marginals.
**Font sizes**
:param title_kwargs: Dict of kwargs passed to ``ax.set_title``.
e.g. ``title_kwargs={'fontsize': 20}`` (default fontsize is 16).
:param label_kwargs: Dict of kwargs passed to the axis label setters.
e.g. ``label_kwargs={'fontsize': 20}``
**Smoothing and appearance**
:param smooth: Gaussian smoothing sigma applied to the 2-D histograms.
e.g. ``smooth=1.8`` (no smoothing by default).
:param smooth1d: Gaussian smoothing sigma for the 1-D marginals.
:param bins: Number of histogram bins (default 50).
:param color: Colour of the contours and histograms. e.g. ``color='steelblue'``
:param quantiles: Quantiles to mark on 1-D marginals, default ``[0.16, 0.84]``.
Pass ``quantiles=None`` to suppress vertical quantile lines and titles.
:param levels: Contour levels for 2-D panels, e.g. ``levels=[0.5, 0.9]``.
:param fill_contours: Whether to fill the 2-D contours (default True).
:param plot_datapoints: Whether to scatter raw samples (default False).
:param show_titles: Passed to corner; redback overrides this to apply smart formatting.
**Saving**
:param save: Whether to save the figure to disk (default True).
:param filename: Output filename. Defaults to ``<outdir>/<label>_corner.png``.
:param dpi: Figure resolution (default 300).
**Example**::
result.plot_corner(
parameters=['mej', 'f_nickel', 'kappa', 'vej', 'av_host'],
labels=[r'$M_{\\rm ej}~(M_\\odot)$', r'$f_{\\rm Ni}$',
r'$\\kappa$ (cm$^2$/g)', r'$v_{\\rm ej}$ (km/s)',
r'$A_{\\rm v, host}$'],
filename='my_corner.png',
smooth=1.8,
title_kwargs={'fontsize': 20},
label_kwargs={'fontsize': 20},
)
"""
fig = super().plot_corner(parameters=parameters, priors=priors, titles=False,
save=False, filename=filename, dpi=dpi, **kwargs)
if fig is None:
return fig
if not titles:
if save:
import matplotlib.pyplot as plt
if filename is None:
outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
filename = '{}/{}_corner.png'.format(outdir, self.label)
from bilby.core.result import safe_save_figure
safe_save_figure(fig=fig, filename=filename, dpi=dpi)
plt.close(fig)
return fig
# Determine which parameters were plotted
if isinstance(parameters, dict):
plot_parameter_keys = list(parameters.keys())
elif parameters is None:
plot_parameter_keys = self.search_parameter_keys
else:
plot_parameter_keys = list(parameters)
quantiles = kwargs.get('quantiles', [0.16, 0.84])
if quantiles is None:
# No titles requested via quantiles=None
if save:
import matplotlib.pyplot as plt
if filename is None:
outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
filename = '{}/{}_corner.png'.format(outdir, self.label)
from bilby.core.result import safe_save_figure
safe_save_figure(fig=fig, filename=filename, dpi=dpi)
plt.close(fig)
return fig
title_kwargs = kwargs.get('title_kwargs', dict(fontsize=16))
axes = fig.get_axes()
for i, par in enumerate(plot_parameter_keys):
ax = axes[i + i * len(plot_parameter_keys)]
if ax.title.get_text() != '':
continue
summary = self.get_one_dimensional_median_and_error_bar(
par, quantiles=quantiles)
title_str = _smart_corner_title(summary.median, summary.minus, summary.plus)
ax.set_title(title_str, **title_kwargs)
if save:
import matplotlib.pyplot as plt
if filename is None:
outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
filename = '{}/{}_corner.png'.format(outdir, self.label)
from bilby.core.result import safe_save_figure
safe_save_figure(fig=fig, filename=filename, dpi=dpi)
plt.close(fig)
return fig
[docs]
def plot_lightcurve(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes:
""" Reconstructs the transient and calls the specific `plot_lightcurve` method.
Detailed documentation appears below by running `print(plot_lightcurve.__doc__)` """
if model is None:
model = model_library.all_models_dict[self.model]
logger.debug(f"Using stored model '{self.model}' for lightcurve plot")
else:
logger.debug(f"Using provided model for lightcurve plot")
return self.transient.plot_lightcurve(model=model, posterior=self.posterior,
model_kwargs=self.model_kwargs, **kwargs)
[docs]
def plot_spectrum(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes:
""" Reconstructs the transient and calls the specific `plot_spectrum` method.
Detailed documentation appears below by running `print(plot_spectrum.__doc__)` """
if model is None:
model = model_library.all_models_dict[self.model]
return self.transient.plot_spectrum(model=model, posterior=self.posterior,
model_kwargs=self.model_kwargs, **kwargs)
[docs]
def plot_residual(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes:
"""Reconstructs the transient and calls the specific `plot_residual` method.
Detailed documentation appears below by running `print(plot_residual.__doc__)` """
if model is None:
model = model_library.all_models_dict[self.model]
return self.transient.plot_residual(model=model, posterior=self.posterior,
model_kwargs=self.model_kwargs, **kwargs)
[docs]
def plot_multiband_lightcurve(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes:
"""Reconstructs the transient and calls the specific `plot_multiband_lightcurve` method.
Detailed documentation appears below by running `print(plot_multiband_lightcurve.__doc__)` """
if model is None:
model = model_library.all_models_dict[self.model]
return self.transient.plot_multiband_lightcurve(
model=model, posterior=self.posterior, model_kwargs=self.model_kwargs, **kwargs)
[docs]
def plot_data(self, **kwargs: None) -> matplotlib.axes.Axes:
"""Reconstructs the transient and calls the specific `plot_data` method.
Detailed documentation appears below by running `print(plot_data.__doc__)` """
return self.transient.plot_data(**kwargs)
[docs]
def plot_multiband(self, **kwargs: None) -> matplotlib.axes.Axes:
"""Reconstructs the transient and calls the specific `plot_multiband` method.
Detailed documentation appears below by running `print(plot_multiband.__doc__)` """
return self.transient.plot_multiband(**kwargs)
plot_data.__doc__ = plot_data.__doc__ + redback.transient.Transient.plot_data.__doc__
plot_lightcurve.__doc__ = plot_lightcurve.__doc__ + redback.transient.Transient.plot_lightcurve.__doc__
plot_residual.__doc__ = plot_residual.__doc__ + redback.transient.Transient.plot_residual.__doc__
plot_multiband.__doc__ = plot_multiband.__doc__ + redback.transient.Transient.plot_multiband.__doc__
plot_multiband_lightcurve.__doc__ = \
plot_multiband_lightcurve.__doc__ + redback.transient.Transient.plot_multiband_lightcurve.__doc__
plot_spectrum.__doc__ = plot_spectrum.__doc__ + redback.transient.Spectrum.plot_spectrum.__doc__
[docs]
class MultiMessengerResult(RedbackResult):
"""
Result class for joint multi-messenger analyses.
This preserves the standard bilby/RedbackResult posterior, evidence, and
corner-plot behaviour while intentionally disabling helpers that require a
single reconstructable transient.
"""
messengers = MetaDataAccessor('messengers')
models = MetaDataAccessor('models')
shared_params = MetaDataAccessor('shared_params')
parameter_mappings = MetaDataAccessor('parameter_mappings')
@property
def transient(self):
raise NotImplementedError(
"MultiMessengerResult does not reconstruct a single transient. "
"Use the relevant transient object with redback.analysis plotting "
"helpers and this result's posterior samples."
)
@staticmethod
def _raise_single_transient_plot_error():
raise NotImplementedError(
"MultiMessengerResult cannot call single-transient plotting helpers "
"automatically because joint analyses can contain multiple transients "
"and external likelihoods. Use redback.analysis.plot_lightcurve, "
"redback.analysis.plot_spectrum, or the transient plotting methods "
"with the relevant transient, model, posterior samples, and model_kwargs."
)
[docs]
def plot_lightcurve(self, *args, **kwargs):
self._raise_single_transient_plot_error()
[docs]
def plot_spectrum(self, *args, **kwargs):
self._raise_single_transient_plot_error()
[docs]
def plot_residual(self, *args, **kwargs):
self._raise_single_transient_plot_error()
[docs]
def plot_multiband_lightcurve(self, *args, **kwargs):
self._raise_single_transient_plot_error()
[docs]
def plot_data(self, *args, **kwargs):
self._raise_single_transient_plot_error()
[docs]
def plot_multiband(self, *args, **kwargs):
self._raise_single_transient_plot_error()
def _is_multimessenger_result(result: Result) -> bool:
return isinstance(getattr(result, "meta_data", None), dict) and \
result.meta_data.get("multimessenger", False)
[docs]
def read_in_result(
filename: str = None, outdir: str = None, label: str = None,
extension: str = 'json', gzip: bool = False) -> RedbackResult:
"""
:param filename: Filename with entire path of result to open.
:type filename: str, optional
:param outdir: If filename is not given, directory of the result.
:type outdir: str, optional
:param label: If filename is not given, label of the result.
:type label: str, optional
:param extension: If filename is not given, filename extension.
Must be in ('json', 'hdf5', 'h5', 'pkl', 'pickle', 'gz').
(Default value = 'json')
:type extension: str, optional
:param gzip: If the file is compressed with gzip. Default is False.
:type gzip: bool, optional
:return: The loaded redback result.
:rtype: RedbackResult
"""
filename = _determine_file_name(filename, outdir, label, extension, gzip)
logger.info(f"Loading result from file: {filename}")
# Check if file exists
if not os.path.exists(filename):
logger.error(f"Result file not found: {filename}")
raise FileNotFoundError(f"Result file not found: {filename}")
# Get the actual extension (may differ from the default extension if the filename is given)
extension = os.path.splitext(filename)[1].lstrip('.')
if extension == 'gz': # gzipped file
extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip('.')
logger.debug(f"Reading result file with extension: {extension}")
try:
if 'json' in extension:
result = RedbackResult.from_json(filename=filename)
if _is_multimessenger_result(result):
result = MultiMessengerResult.from_json(filename=filename)
elif ('hdf5' in extension) or ('h5' in extension):
result = RedbackResult.from_hdf5(filename=filename)
if _is_multimessenger_result(result):
result = MultiMessengerResult.from_hdf5(filename=filename)
elif ("pkl" in extension) or ("pickle" in extension):
result = RedbackResult.from_pickle(filename=filename)
if _is_multimessenger_result(result) and not isinstance(result, MultiMessengerResult):
result = MultiMessengerResult.from_pickle(filename=filename)
else:
logger.error(f"Unsupported filetype: {extension}. Supported types: json, hdf5, h5, pkl, pickle")
raise ValueError("Filetype {} not understood".format(extension))
logger.info(f"Successfully loaded result for '{result.label}' (model: {result.model})")
return result
except Exception as e:
logger.error(f"Failed to load result from {filename}: {e}")
raise