import matplotlib.pyplot as plt
import os
from pathlib import Path
from typing import Union
import bilby
import redback.get_data
from redback.likelihoods import GaussianLikelihood, GaussianLikelihoodWithUpperLimits, PoissonLikelihood, PoissonSpectralLikelihood, \
WStatSpectralLikelihood, ChiSquareSpectralLikelihood
from redback.model_library import all_models_dict
from redback.result import RedbackResult
from redback.utils import logger
from redback.transient.afterglow import Afterglow
from redback.transient.prompt import PromptTimeSeries
from redback.transient.transient import OpticalTransient, Transient, Spectrum
from redback.spectral.dataset import SpectralDataset
import numpy as np
dirname = os.path.dirname(__file__)
[docs]
def fit_model(
transient: redback.transient.transient.Transient, model: Union[callable, str], outdir: str = None,
label: str = None, sampler: str = "dynesty", nlive: int = 2000, prior: dict = None, walks: int = 200,
truncate: bool = True, use_photon_index_prior: bool = False, truncate_method: str = "prompt_time_error",
resume: bool = True, save_format: str = "json", model_kwargs: dict = None, plot=True, **kwargs)\
-> redback.result.RedbackResult:
"""
:param transient: The transient to be fitted
:param model: Name of the model to fit to data or a function.
:param outdir: Output directory. Will default to a sensible structure if not given.
:param label: Result file labels. Will use the model name if not given.
:param sampler: The sampling backend. Nested samplers are encouraged to allow evidence calculation.
(Default value = 'dynesty')
:param nlive: Number of live points.
:param prior: Priors to use during sampling. If not given, we use the default priors for the given model.
:param walks: Number of `dynesty` random walks.
:param truncate: Flag to confirm whether to truncate the prompt emission data
:param use_photon_index_prior: flag to turn off/on photon index prior and fits according to the curvature effect
:param truncate_method: method of truncation
:param resume: Whether to resume the run from a checkpoint if available.
:param save_format: The format to save the result in. (Default value = 'json'_
:param model_kwargs: Additional keyword arguments for the model.
:param clean: If True, rerun the fitting, if false try to use previous results in the output directory.
:param plot: If True, create corner and lightcurve plot
:param kwargs: Additional parameters that will be passed to the sampler via bilby
:return: Redback result object, transient specific data object
"""
if isinstance(model, str):
modelname = model
model = all_models_dict[model]
else:
modelname = getattr(model, "__name__", "custom_model")
if transient.data_mode in ["flux_density", "magnitude", "flux"]:
if model_kwargs is None:
logger.warning("No model_kwargs given, assuming model works correctly for this transient by default.")
if model_kwargs is not None:
if model_kwargs["output_format"] != transient.data_mode:
raise ValueError(
f"Transient data mode {transient.data_mode} is inconsistent with "
f"output format {model_kwargs['output_format']}. These should be the same.")
if model_kwargs["output_format"] in ['magnitude', 'flux']:
if model_kwargs['bands'] is None:
raise ValueError("For magnitude or flux data, model_kwargs must specify the bands corresponding to the data.")
if model_kwargs["output_format"] == 'flux_density':
if model_kwargs['frequency'] is None:
raise ValueError("For flux density data, model_kwargs must specify the frequency corresponding to the data.")
if prior is None:
prior = bilby.prior.PriorDict(filename=f"{dirname}/priors/{modelname}.prior")
logger.warning(f"No prior given. Using default priors for {modelname}")
else:
prior = prior
if isinstance(transient, SpectralDataset):
outdir = outdir or f"high_energy_spectra/{model.__name__}"
else:
outdir = outdir or f"{transient.directory_structure.directory_path}/{model.__name__}"
Path(outdir).mkdir(parents=True, exist_ok=True)
label = label or transient.name
if isinstance(transient, SpectralDataset):
return _fit_spectral_dataset(
transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive,
prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs,
plot=plot, **kwargs)
try:
from redback.transient.spectral import CountsSpectrumTransient
except Exception:
CountsSpectrumTransient = None
if CountsSpectrumTransient is not None and isinstance(transient, CountsSpectrumTransient):
return _fit_spectral_dataset(
transient=transient.dataset, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive,
prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs,
plot=plot, **kwargs)
if isinstance(transient, Spectrum):
return _fit_spectrum(transient=transient, model=model, outdir=outdir, label=label, sampler=sampler,
nlive=nlive, prior=prior, walks=walks,
resume=resume, save_format=save_format, model_kwargs=model_kwargs,
plot=plot, **kwargs)
elif isinstance(transient, Afterglow):
return _fit_grb(
transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior,
walks=walks, use_photon_index_prior=use_photon_index_prior, resume=resume, save_format=save_format,
model_kwargs=model_kwargs, truncate=truncate, truncate_method=truncate_method, plot=plot, **kwargs)
elif isinstance(transient, PromptTimeSeries):
return _fit_prompt(
transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior,
walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs)
elif isinstance(transient, OpticalTransient):
return _fit_optical_transient(
transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior,
walks=walks, truncate=truncate, use_photon_index_prior=use_photon_index_prior,
truncate_method=truncate_method, resume=resume, save_format=save_format, model_kwargs=model_kwargs,
plot=plot, **kwargs)
elif isinstance(transient, Transient):
return _fit_optical_transient(
transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior,
walks=walks, truncate=truncate, use_photon_index_prior=use_photon_index_prior,
truncate_method=truncate_method, resume=resume, save_format=save_format, model_kwargs=model_kwargs,
plot=plot, **kwargs)
else:
raise ValueError(f'Source type {transient.__class__.__name__} not known')
def _fit_spectrum(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000,
resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs):
x, y, y_err = transient.angstroms, transient.flux_density, transient.flux_density_err
if likelihood is None:
likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs)
logger.info("No likelihood provided, using standard GaussianLikelihood")
else:
logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__))
likelihood = likelihood
meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items() if k not in ("rmf", "arf")}
meta_data.update(transient_kwargs)
model_kwargs = redback.utils.check_kwargs_validity(model_kwargs)
meta_data['model_kwargs'] = model_kwargs
result = None
if not kwargs.get("clean", False):
try:
result = redback.result.read_in_result(
outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False))
plt.close('all')
return result
except Exception:
pass
try:
result = result or bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
except ValueError as exc:
if sampler.lower() == "pymultinest" and "dead_points" in str(exc) and "live_points" in str(exc):
logger.warning(
"Pymultinest failed to assemble nested samples (%s). "
"Rerunning with dynesty.",
exc,
)
result = bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler="dynesty", nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
else:
raise
plt.close('all')
if plot:
result.plot_spectrum(model=model)
return result
def _fit_spectral_dataset(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None,
walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs):
try:
import inspect
sig = inspect.signature(model)
param_names = list(sig.parameters.keys())
has_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig.parameters.values()
)
except Exception:
param_names = []
has_var_keyword = True
if "energies_keV" not in param_names and "energy_keV" not in param_names and not has_var_keyword:
raise ValueError(
"Spectral models must accept `energies_keV` (or `energy_keV`) as the first argument. "
"Update the model signature to be compatible with spectral fitting."
)
statistic = kwargs.pop("statistic", None)
if statistic is None or str(statistic).lower() == "auto":
statistic = "wstat" if transient.counts_bkg is not None else "cstat"
logger.info("Spectral fit using statistic=%s", statistic)
if likelihood is None:
if statistic.lower() in ("wstat", "w-stat"):
likelihood = WStatSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs)
elif statistic.lower() in ("cstat", "c-stat", "cash"):
likelihood = PoissonSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs)
elif statistic.lower() in ("chi2", "chi-square", "chisq"):
likelihood = ChiSquareSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs)
else:
raise ValueError(f"Unknown statistic '{statistic}' for spectral fitting")
logger.info("No likelihood provided, using spectral likelihood %s", likelihood.__class__.__name__)
else:
logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__))
meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()}
meta_data.update(transient_kwargs)
model_kwargs = redback.utils.check_kwargs_validity(model_kwargs)
meta_data['model_kwargs'] = model_kwargs
# Spectral datasets contain numpy arrays in response objects that cannot JSON-serialise
if save_format == "json":
logger.warning("JSON save not supported for spectral datasets with response objects. Using pkl instead.")
save_format = "pkl"
result = None
if not kwargs.get("clean", False):
for ext in [kwargs.get("extension", save_format), "pkl", "json"]:
try:
result = redback.result.read_in_result(
outdir=outdir, label=label, extension=ext, gzip=kwargs.get("gzip", False))
plt.close('all')
return result
except Exception:
continue
if prior is not None:
likelihood.parameters = dict.fromkeys(prior.keys())
try:
samples = [prior.sample() for _ in range(5)]
finite = 0
for s in samples:
likelihood.parameters.update(s)
ll = likelihood.log_likelihood()
if np.isfinite(ll):
finite += 1
logger.info("Spectral preflight: %d/%d finite logL samples", finite, len(samples))
if finite == 0:
raise ValueError("Spectral likelihood preflight failed: all sampled logL are non-finite")
except Exception as exc:
logger.warning("Spectral preflight failed: %s", exc)
try:
result = result or bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
except ValueError as exc:
if sampler.lower() == "pymultinest" and "dead_points" in str(exc) and "live_points" in str(exc):
logger.warning(
"Pymultinest failed to assemble nested samples (%s). Rerunning with dynesty.",
exc,
)
result = bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler="dynesty", nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
else:
raise
plt.close('all')
if plot:
filename = f"{label}_spectrum_counts.png"
transient.plot_spectrum_fit(model=model, posterior=result.posterior, model_kwargs=model_kwargs,
filename=filename, outdir=outdir, show=False, save=True)
return result
def _fit_grb(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000,
use_photon_index_prior=False, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs):
if use_photon_index_prior:
label += '_photon_index'
if transient.photon_index < 0.:
logger.info('photon index for GRB', transient.name, 'is negative. Using default prior on alpha_1')
prior['alpha_1'] = bilby.prior.Uniform(-10, -0.5, 'alpha_1', latex_label=r'$\alpha_{1}$')
else:
prior['alpha_1'] = bilby.prior.Gaussian(mu=-(transient.photon_index + 1), sigma=0.1,
latex_label=r'$\alpha_{1}$')
if any([transient.flux_data, transient.magnitude_data, transient.flux_density_data]):
x, x_err, y, y_err = transient.get_filtered_data()
else:
x, x_err, y, y_err = transient.x, transient.x_err, transient.y, transient.y_err
if likelihood is None:
likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs)
logger.info("No likelihood provided, using standard GaussianLikelihood")
else:
logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__))
likelihood = likelihood
meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()}
meta_data.update(transient_kwargs)
model_kwargs = redback.utils.check_kwargs_validity(model_kwargs)
meta_data['model_kwargs'] = model_kwargs
result = None
if not kwargs.get("clean", False):
try:
result = redback.result.read_in_result(
outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False))
plt.close('all')
return result
except Exception:
pass
result = result or bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
plt.close('all')
if plot:
result.plot_lightcurve(model=model)
return result
def _get_filtered_upper_limit_sigma(transient):
"""Return upper-limit sigma values matching the transient's active-band filtered data."""
upper_limit_sigma = transient.upper_limit_sigma
if np.isscalar(upper_limit_sigma):
return upper_limit_sigma
upper_limit_sigma = np.asarray(upper_limit_sigma)
filtered_indices = transient.filtered_indices
if len(upper_limit_sigma) == len(transient.x):
return upper_limit_sigma[filtered_indices]
if transient.detections is not None and len(upper_limit_sigma) == np.sum(transient.upper_limits):
upper_limit_positions = np.cumsum(transient.upper_limits) - 1
filtered_upper_limits = transient.upper_limits[filtered_indices]
return upper_limit_sigma[upper_limit_positions[filtered_indices][filtered_upper_limits]]
return upper_limit_sigma
def _fit_optical_transient(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None,
walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs):
if any([transient.flux_data, transient.magnitude_data, transient.flux_density_data]):
x, x_err, y, y_err, detections = transient.get_filtered_data_with_limits()
else:
x, x_err, y, y_err = transient.x, transient.x_err, transient.y, transient.y_err
detections = None
if likelihood is None:
if transient.has_upper_limits and detections is not None:
# Determine data_mode for likelihood
if transient.magnitude_data:
ul_data_mode = 'magnitude'
else:
ul_data_mode = 'flux'
n_ul = int(np.sum(~detections))
# Check that upper limit y-values are finite
ul_y = y[~detections]
n_nan_ul = int(np.sum(np.isnan(ul_y)))
if n_nan_ul > 0:
logger.warning(
f"{n_nan_ul} upper limit(s) have NaN y-values, which cannot be used in "
f"GaussianLikelihoodWithUpperLimits. Falling back to standard GaussianLikelihood "
f"with detection data only. Replace NaN values with the upper limit value "
f"(e.g. limiting magnitude or flux) to use upper limit likelihood."
)
# Filter to detection-only data for standard likelihood
det_mask = detections
x_fit, y_fit, y_err_fit = x[det_mask], y[det_mask], y_err[det_mask]
likelihood = GaussianLikelihood(x=x_fit, y=y_fit, sigma=y_err_fit, function=model, kwargs=model_kwargs)
else:
logger.info(f"Auto-detected {n_ul} upper limits, using GaussianLikelihoodWithUpperLimits "
f"with data_mode='{ul_data_mode}'")
likelihood = GaussianLikelihoodWithUpperLimits(
x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs,
detections=detections, upper_limit_sigma=_get_filtered_upper_limit_sigma(transient),
data_mode=ul_data_mode)
else:
likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs)
logger.info("No likelihood provided, using standard GaussianLikelihood")
else:
logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__))
likelihood = likelihood
meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()}
meta_data.update(transient_kwargs)
model_kwargs = redback.utils.check_kwargs_validity(model_kwargs)
meta_data['model_kwargs'] = model_kwargs
result = None
if not kwargs.get("clean", False):
try:
result = redback.result.read_in_result(
outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False))
plt.close('all')
return result
except Exception:
pass
result = result or bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive,
outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
plt.close('all')
if plot:
result.plot_lightcurve(model=model)
return result
def _fit_prompt(transient, model, outdir, label, likelihood=None, integrated_rate_function=True, sampler='dynesty', nlive=3000,
prior=None, walks=1000, resume=True, save_format='json',
model_kwargs=None, plot=True, **kwargs):
likelihood = likelihood or PoissonLikelihood(
time=transient.x, counts=transient.y, dt=transient.bin_size, function=model,
integrated_rate_function=integrated_rate_function, kwargs=model_kwargs)
meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()}
meta_data.update(transient_kwargs)
model_kwargs = redback.utils.check_kwargs_validity(model_kwargs)
meta_data['model_kwargs'] = model_kwargs
result = None
if not kwargs.get("clean", False):
try:
result = redback.result.read_in_result(
outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False))
plt.close('all')
return result
except Exception:
pass
result = result or bilby.run_sampler(
likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive,
outdir=outdir, plot=False, use_ratio=False, walks=walks, resume=resume,
maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data,
save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs)
plt.close('all')
if plot:
result.plot_lightcurve(model=model)
return result