Source code for redback.result

import bilby.core.prior
import numpy as np
import os
from typing import Union
import warnings
import matplotlib

import pandas as pd
from bilby.core.result import Result
from bilby.core.result import _determine_file_name # noqa

import redback.transient.transient
from redback import model_library
from redback.transient import TRANSIENT_DICT
from redback.utils import MetaDataAccessor

warnings.simplefilter(action='ignore')


[docs]class RedbackResult(Result): model = MetaDataAccessor('model') transient_type = MetaDataAccessor('transient_type') model_kwargs = MetaDataAccessor('model_kwargs') name = MetaDataAccessor('name') path = MetaDataAccessor('path')
[docs] def __init__( self, label: str = 'no_label', outdir: str = '.', sampler: str = None, search_parameter_keys: list = None, fixed_parameter_keys: list = None, constraint_parameter_keys: list = None, priors: Union[dict, bilby.core.prior.PriorDict] = None, sampler_kwargs: dict = None, injection_parameters: dict = None, meta_data: dict = None, posterior: pd.DataFrame = None, samples: pd.DataFrame = None, nested_samples: pd.DataFrame = None, log_evidence: float = np.nan, log_evidence_err: float = np.nan, information_gain: float = np.nan, log_noise_evidence: float = np.nan, log_bayes_factor: float = np.nan, log_likelihood_evaluations: np.ndarray = None, log_prior_evaluations: int = None, sampling_time: float = None, nburn: int = None, num_likelihood_evaluations: int = None, walkers: int = None, max_autocorrelation_time: float = None, use_ratio: bool = None, parameter_labels: list = None, parameter_labels_with_unit: list = None, version: str = None) -> None: """Constructor for an extension of the regular bilby `Result`. This result adds the capability of utilising the plotting methods of the `Transient` such as `plot_lightcurve`. The class does this by reconstructing the `Transient` object that was used during the run by saving the required information in `meta_data`. :param label: Labels of files produced by this class. :type label: str, optional :param outdir: Output directory of the result. Default is the current directory. :type outdir: str, optional :param sampler: The sampler used during the run. :type sampler: str, optional :param search_parameter_keys: The parameters that were sampled in. :type search_parameter_keys: list, optional :param fixed_parameter_keys: Parameters that had a `DeltaFunction` prior :type fixed_parameter_keys: list, optional :param constraint_parameter_keys: Parameters that had a `Constraint` prior :type constraint_parameter_keys: list, optional :param priors: Dictionary of priors. :type priors: Union[dict, bilby.core.prior.PriorDict] :param sampler_kwargs: Any keyword arguments passed to the sampling package. :type sampler_kwargs: dict, optional :param injection_parameters: True parameters if the dataset is simulated. :type injection_parameters: dict, optional :type meta_data: dict, optional :param meta_data: Additional dictionary. Contains the data used during the run and is used to reconstruct the `Transient` object used during the run. :param posterior: Posterior samples with log likelihood and log prior values. :type posterior: pd.Dataframe, optional :param samples: An array of the output posterior samples. :type samples: np.ndarray, optional :param nested_samples: An array of the unweighted samples :type nested_samples: np.ndarray, optional :param log_evidence: The log evidence value if provided. :type log_evidence: float, optional :param log_evidence_err: The log evidence error value if provided :type log_evidence_err: float, optional :param information_gain: The information gain calculated. :type information_gain: float, optional :param log_noise_evidence:The log noise evidence. :type log_noise_evidence: float, optional :param log_bayes_factor:The log Bayes factor if we sampled using the likelihood ratio. :type log_bayes_factor: float, optional :param log_likelihood_evaluations: The evaluations of the likelihood for each sample point :type log_likelihood_evaluations: np.ndarray, optional :param log_prior_evaluations: Number of log prior evaluations. :type log_prior_evaluations: int, optional :param sampling_time: The time taken to complete the sampling in seconds. :type sampling_time: float, optional :param nburn: The number of burn-in steps discarded for MCMC samplers :type nburn: int, optional :param num_likelihood_evaluations: Number of total likelihood evaluations. :type num_likelihood_evaluations: int, optional :param walkers: The samplers taken by an ensemble MCMC samplers. :type walkers: array_like, optional :param max_autocorrelation_time: The estimated maximum autocorrelation time for MCMC samplers. :type max_autocorrelation_time: float, optional :param use_ratio: A boolean stating whether the likelihood ratio, as opposed to the likelihood was used during sampling. :type use_ratio: bool, optional :param parameter_labels: List of the latex-formatted parameter labels. :type parameter_labels: list, optional :param parameter_labels_with_unit: List of the latex-formatted parameter labels with units. :type parameter_labels_with_unit: list, optional :param version: Version information for software used to generate the result. Note, this information is generated when the result object is initialized. :type version: str, """ super(RedbackResult, self).__init__( label=label, outdir=outdir, sampler=sampler, search_parameter_keys=search_parameter_keys, fixed_parameter_keys=fixed_parameter_keys, constraint_parameter_keys=constraint_parameter_keys, priors=priors, sampler_kwargs=sampler_kwargs, injection_parameters=injection_parameters, meta_data=meta_data, posterior=posterior, samples=samples, nested_samples=nested_samples, log_evidence=log_evidence, log_evidence_err=log_evidence_err, information_gain=information_gain, log_noise_evidence=log_noise_evidence, log_bayes_factor=log_bayes_factor, log_likelihood_evaluations=log_likelihood_evaluations, log_prior_evaluations=log_prior_evaluations, sampling_time=sampling_time, nburn=nburn, num_likelihood_evaluations=num_likelihood_evaluations, walkers=walkers, max_autocorrelation_time=max_autocorrelation_time, use_ratio=use_ratio, parameter_labels=parameter_labels, parameter_labels_with_unit=parameter_labels_with_unit, version=version)
@property def transient(self) -> redback.transient.transient.Transient: """Reconstruct the transient used during sampling time using the metadata information. :return: The reconstructed Transient. :rtype: redback.transient.transient.Transient """ return TRANSIENT_DICT[self.transient_type](**self.meta_data)
[docs] def plot_lightcurve(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes: """ Reconstructs the transient and calls the specific `plot_lightcurve` method. Detailed documentation appears below by running `print(plot_lightcurve.__doc__)` """ if model is None: model = model_library.all_models_dict[self.model] return self.transient.plot_lightcurve(model=model, posterior=self.posterior, model_kwargs=self.model_kwargs, **kwargs)
[docs] def plot_residual(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes: """Reconstructs the transient and calls the specific `plot_residual` method. Detailed documentation appears below by running `print(plot_residual.__doc__)` """ if model is None: model = model_library.all_models_dict[self.model] return self.transient.plot_residual(model=model, posterior=self.posterior, model_kwargs=self.model_kwargs, **kwargs)
[docs] def plot_multiband_lightcurve(self, model: Union[callable, str] = None, **kwargs: None) -> matplotlib.axes.Axes: """Reconstructs the transient and calls the specific `plot_multiband_lightcurve` method. Detailed documentation appears below by running `print(plot_multiband_lightcurve.__doc__)` """ if model is None: model = model_library.all_models_dict[self.model] return self.transient.plot_multiband_lightcurve( model=model, posterior=self.posterior, model_kwargs=self.model_kwargs, **kwargs)
[docs] def plot_data(self, **kwargs: None) -> matplotlib.axes.Axes: """Reconstructs the transient and calls the specific `plot_data` method. Detailed documentation appears below by running `print(plot_data.__doc__)` """ return self.transient.plot_data(**kwargs)
[docs] def plot_multiband(self, **kwargs: None) -> matplotlib.axes.Axes: """Reconstructs the transient and calls the specific `plot_multiband` method. Detailed documentation appears below by running `print(plot_multiband.__doc__)` """ return self.transient.plot_multiband(**kwargs)
plot_data.__doc__ = plot_data.__doc__ + redback.transient.Transient.plot_data.__doc__ plot_lightcurve.__doc__ = plot_lightcurve.__doc__ + redback.transient.Transient.plot_lightcurve.__doc__ plot_residual.__doc__ = plot_residual.__doc__ + redback.transient.Transient.plot_residual.__doc__ plot_multiband.__doc__ = plot_multiband.__doc__ + redback.transient.Transient.plot_multiband.__doc__ plot_multiband_lightcurve.__doc__ = \ plot_multiband_lightcurve.__doc__ + redback.transient.Transient.plot_multiband_lightcurve.__doc__
[docs]def read_in_result( filename: str = None, outdir: str = None, label: str = None, extension: str = 'json', gzip: bool = False) -> RedbackResult: """ :param filename: Filename with entire path of result to open. :type filename: str, optional :param outdir: If filename is not given, directory of the result. :type outdir: str, optional :param label: If filename is not given, label of the result. :type label: str, optional :type extension: str, optional :param extension: If filename is not given, filename extension. Must be in ('json', 'hdf5', 'h5', 'pkl', 'pickle', 'gz'). (Default value = 'json') :param gzip: If the file is compressed with gzip. Default is False. :type gzip: bool, optional :return: The loaded redback result. :rtype: RedbackResult """ filename = _determine_file_name(filename, outdir, label, extension, gzip) # Get the actual extension (may differ from the default extension if the filename is given) extension = os.path.splitext(filename)[1].lstrip('.') if extension == 'gz': # gzipped file extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip('.') if 'json' in extension: result = RedbackResult.from_json(filename=filename) elif ('hdf5' in extension) or ('h5' in extension): result = RedbackResult.from_hdf5(filename=filename) elif ("pkl" in extension) or ("pickle" in extension): result = RedbackResult.from_pickle(filename=filename) elif extension is None: raise ValueError("No filetype extension provided") else: raise ValueError("Filetype {} not understood".format(extension)) return result