Source code for redback.sampler

import matplotlib.pyplot as plt
import os
from pathlib import Path
from typing import Union

import bilby

import redback.get_data
from redback.likelihoods import GaussianLikelihood, GaussianLikelihoodWithUpperLimits, PoissonLikelihood, PoissonSpectralLikelihood, \
    WStatSpectralLikelihood, ChiSquareSpectralLikelihood
from redback.model_library import all_models_dict
from redback.result import RedbackResult
from redback.utils import logger
from redback.transient.afterglow import Afterglow
from redback.transient.prompt import PromptTimeSeries
from redback.transient.transient import OpticalTransient, Transient, Spectrum
from redback.spectral.dataset import SpectralDataset
import numpy as np


dirname = os.path.dirname(__file__)


[docs] def fit_model( transient: redback.transient.transient.Transient, model: Union[callable, str], outdir: str = None, label: str = None, sampler: str = "dynesty", nlive: int = 2000, prior: dict = None, walks: int = 200, truncate: bool = True, use_photon_index_prior: bool = False, truncate_method: str = "prompt_time_error", resume: bool = True, save_format: str = "json", model_kwargs: dict = None, plot=True, **kwargs)\ -> redback.result.RedbackResult: """ :param transient: The transient to be fitted :param model: Name of the model to fit to data or a function. :param outdir: Output directory. Will default to a sensible structure if not given. :param label: Result file labels. Will use the model name if not given. :param sampler: The sampling backend. Nested samplers are encouraged to allow evidence calculation. (Default value = 'dynesty') :param nlive: Number of live points. :param prior: Priors to use during sampling. If not given, we use the default priors for the given model. :param walks: Number of `dynesty` random walks. :param truncate: Flag to confirm whether to truncate the prompt emission data :param use_photon_index_prior: flag to turn off/on photon index prior and fits according to the curvature effect :param truncate_method: method of truncation :param resume: Whether to resume the run from a checkpoint if available. :param save_format: The format to save the result in. (Default value = 'json'_ :param model_kwargs: Additional keyword arguments for the model. :param clean: If True, rerun the fitting, if false try to use previous results in the output directory. :param plot: If True, create corner and lightcurve plot :param kwargs: Additional parameters that will be passed to the sampler via bilby :return: Redback result object, transient specific data object """ if isinstance(model, str): modelname = model model = all_models_dict[model] else: modelname = getattr(model, "__name__", "custom_model") if transient.data_mode in ["flux_density", "magnitude", "flux"]: if model_kwargs is None: logger.warning("No model_kwargs given, assuming model works correctly for this transient by default.") if model_kwargs is not None: if model_kwargs["output_format"] != transient.data_mode: raise ValueError( f"Transient data mode {transient.data_mode} is inconsistent with " f"output format {model_kwargs['output_format']}. These should be the same.") if model_kwargs["output_format"] in ['magnitude', 'flux']: if model_kwargs['bands'] is None: raise ValueError("For magnitude or flux data, model_kwargs must specify the bands corresponding to the data.") if model_kwargs["output_format"] == 'flux_density': if model_kwargs['frequency'] is None: raise ValueError("For flux density data, model_kwargs must specify the frequency corresponding to the data.") if prior is None: prior = bilby.prior.PriorDict(filename=f"{dirname}/priors/{modelname}.prior") logger.warning(f"No prior given. Using default priors for {modelname}") else: prior = prior if isinstance(transient, SpectralDataset): outdir = outdir or f"high_energy_spectra/{model.__name__}" else: outdir = outdir or f"{transient.directory_structure.directory_path}/{model.__name__}" Path(outdir).mkdir(parents=True, exist_ok=True) label = label or transient.name if isinstance(transient, SpectralDataset): return _fit_spectral_dataset( transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) try: from redback.transient.spectral import CountsSpectrumTransient except Exception: CountsSpectrumTransient = None if CountsSpectrumTransient is not None and isinstance(transient, CountsSpectrumTransient): return _fit_spectral_dataset( transient=transient.dataset, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) if isinstance(transient, Spectrum): return _fit_spectrum(transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) elif isinstance(transient, Afterglow): return _fit_grb( transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, use_photon_index_prior=use_photon_index_prior, resume=resume, save_format=save_format, model_kwargs=model_kwargs, truncate=truncate, truncate_method=truncate_method, plot=plot, **kwargs) elif isinstance(transient, PromptTimeSeries): return _fit_prompt( transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) elif isinstance(transient, OpticalTransient): return _fit_optical_transient( transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, truncate=truncate, use_photon_index_prior=use_photon_index_prior, truncate_method=truncate_method, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) elif isinstance(transient, Transient): return _fit_optical_transient( transient=transient, model=model, outdir=outdir, label=label, sampler=sampler, nlive=nlive, prior=prior, walks=walks, truncate=truncate, use_photon_index_prior=use_photon_index_prior, truncate_method=truncate_method, resume=resume, save_format=save_format, model_kwargs=model_kwargs, plot=plot, **kwargs) else: raise ValueError(f'Source type {transient.__class__.__name__} not known')
def _fit_spectrum(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs): x, y, y_err = transient.angstroms, transient.flux_density, transient.flux_density_err if likelihood is None: likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs) logger.info("No likelihood provided, using standard GaussianLikelihood") else: logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__)) likelihood = likelihood meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower()) transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items() if k not in ("rmf", "arf")} meta_data.update(transient_kwargs) model_kwargs = redback.utils.check_kwargs_validity(model_kwargs) meta_data['model_kwargs'] = model_kwargs result = None if not kwargs.get("clean", False): try: result = redback.result.read_in_result( outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False)) plt.close('all') return result except Exception: pass try: result = result or bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) except ValueError as exc: if sampler.lower() == "pymultinest" and "dead_points" in str(exc) and "live_points" in str(exc): logger.warning( "Pymultinest failed to assemble nested samples (%s). " "Rerunning with dynesty.", exc, ) result = bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler="dynesty", nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) else: raise plt.close('all') if plot: result.plot_spectrum(model=model) return result def _fit_spectral_dataset(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs): try: import inspect sig = inspect.signature(model) param_names = list(sig.parameters.keys()) has_var_keyword = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) except Exception: param_names = [] has_var_keyword = True if "energies_keV" not in param_names and "energy_keV" not in param_names and not has_var_keyword: raise ValueError( "Spectral models must accept `energies_keV` (or `energy_keV`) as the first argument. " "Update the model signature to be compatible with spectral fitting." ) statistic = kwargs.pop("statistic", None) if statistic is None or str(statistic).lower() == "auto": statistic = "wstat" if transient.counts_bkg is not None else "cstat" logger.info("Spectral fit using statistic=%s", statistic) if likelihood is None: if statistic.lower() in ("wstat", "w-stat"): likelihood = WStatSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs) elif statistic.lower() in ("cstat", "c-stat", "cash"): likelihood = PoissonSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs) elif statistic.lower() in ("chi2", "chi-square", "chisq"): likelihood = ChiSquareSpectralLikelihood(dataset=transient, function=model, kwargs=model_kwargs) else: raise ValueError(f"Unknown statistic '{statistic}' for spectral fitting") logger.info("No likelihood provided, using spectral likelihood %s", likelihood.__class__.__name__) else: logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__)) meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower()) transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()} meta_data.update(transient_kwargs) model_kwargs = redback.utils.check_kwargs_validity(model_kwargs) meta_data['model_kwargs'] = model_kwargs # Spectral datasets contain numpy arrays in response objects that cannot JSON-serialise if save_format == "json": logger.warning("JSON save not supported for spectral datasets with response objects. Using pkl instead.") save_format = "pkl" result = None if not kwargs.get("clean", False): for ext in [kwargs.get("extension", save_format), "pkl", "json"]: try: result = redback.result.read_in_result( outdir=outdir, label=label, extension=ext, gzip=kwargs.get("gzip", False)) plt.close('all') return result except Exception: continue if prior is not None: likelihood.parameters = dict.fromkeys(prior.keys()) try: samples = [prior.sample() for _ in range(5)] finite = 0 for s in samples: likelihood.parameters.update(s) ll = likelihood.log_likelihood() if np.isfinite(ll): finite += 1 logger.info("Spectral preflight: %d/%d finite logL samples", finite, len(samples)) if finite == 0: raise ValueError("Spectral likelihood preflight failed: all sampled logL are non-finite") except Exception as exc: logger.warning("Spectral preflight failed: %s", exc) try: result = result or bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) except ValueError as exc: if sampler.lower() == "pymultinest" and "dead_points" in str(exc) and "live_points" in str(exc): logger.warning( "Pymultinest failed to assemble nested samples (%s). Rerunning with dynesty.", exc, ) result = bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler="dynesty", nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) else: raise plt.close('all') if plot: filename = f"{label}_spectrum_counts.png" transient.plot_spectrum_fit(model=model, posterior=result.posterior, model_kwargs=model_kwargs, filename=filename, outdir=outdir, show=False, save=True) return result def _fit_grb(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000, use_photon_index_prior=False, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs): if use_photon_index_prior: label += '_photon_index' if transient.photon_index < 0.: logger.info('photon index for GRB', transient.name, 'is negative. Using default prior on alpha_1') prior['alpha_1'] = bilby.prior.Uniform(-10, -0.5, 'alpha_1', latex_label=r'$\alpha_{1}$') else: prior['alpha_1'] = bilby.prior.Gaussian(mu=-(transient.photon_index + 1), sigma=0.1, latex_label=r'$\alpha_{1}$') if any([transient.flux_data, transient.magnitude_data, transient.flux_density_data]): x, x_err, y, y_err = transient.get_filtered_data() else: x, x_err, y, y_err = transient.x, transient.x_err, transient.y, transient.y_err if likelihood is None: likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs) logger.info("No likelihood provided, using standard GaussianLikelihood") else: logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__)) likelihood = likelihood meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower()) transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()} meta_data.update(transient_kwargs) model_kwargs = redback.utils.check_kwargs_validity(model_kwargs) meta_data['model_kwargs'] = model_kwargs result = None if not kwargs.get("clean", False): try: result = redback.result.read_in_result( outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False)) plt.close('all') return result except Exception: pass result = result or bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) plt.close('all') if plot: result.plot_lightcurve(model=model) return result def _get_filtered_upper_limit_sigma(transient): """Return upper-limit sigma values matching the transient's active-band filtered data.""" upper_limit_sigma = transient.upper_limit_sigma if np.isscalar(upper_limit_sigma): return upper_limit_sigma upper_limit_sigma = np.asarray(upper_limit_sigma) filtered_indices = transient.filtered_indices if len(upper_limit_sigma) == len(transient.x): return upper_limit_sigma[filtered_indices] if transient.detections is not None and len(upper_limit_sigma) == np.sum(transient.upper_limits): upper_limit_positions = np.cumsum(transient.upper_limits) - 1 filtered_upper_limits = transient.upper_limits[filtered_indices] return upper_limit_sigma[upper_limit_positions[filtered_indices][filtered_upper_limits]] return upper_limit_sigma def _fit_optical_transient(transient, model, outdir, label, likelihood=None, sampler='dynesty', nlive=3000, prior=None, walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs): if any([transient.flux_data, transient.magnitude_data, transient.flux_density_data]): x, x_err, y, y_err, detections = transient.get_filtered_data_with_limits() else: x, x_err, y, y_err = transient.x, transient.x_err, transient.y, transient.y_err detections = None if likelihood is None: if transient.has_upper_limits and detections is not None: # Determine data_mode for likelihood if transient.magnitude_data: ul_data_mode = 'magnitude' else: ul_data_mode = 'flux' n_ul = int(np.sum(~detections)) # Check that upper limit y-values are finite ul_y = y[~detections] n_nan_ul = int(np.sum(np.isnan(ul_y))) if n_nan_ul > 0: logger.warning( f"{n_nan_ul} upper limit(s) have NaN y-values, which cannot be used in " f"GaussianLikelihoodWithUpperLimits. Falling back to standard GaussianLikelihood " f"with detection data only. Replace NaN values with the upper limit value " f"(e.g. limiting magnitude or flux) to use upper limit likelihood." ) # Filter to detection-only data for standard likelihood det_mask = detections x_fit, y_fit, y_err_fit = x[det_mask], y[det_mask], y_err[det_mask] likelihood = GaussianLikelihood(x=x_fit, y=y_fit, sigma=y_err_fit, function=model, kwargs=model_kwargs) else: logger.info(f"Auto-detected {n_ul} upper limits, using GaussianLikelihoodWithUpperLimits " f"with data_mode='{ul_data_mode}'") likelihood = GaussianLikelihoodWithUpperLimits( x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs, detections=detections, upper_limit_sigma=_get_filtered_upper_limit_sigma(transient), data_mode=ul_data_mode) else: likelihood = GaussianLikelihood(x=x, y=y, sigma=y_err, function=model, kwargs=model_kwargs) logger.info("No likelihood provided, using standard GaussianLikelihood") else: logger.info("Likelihood provided, using custom likelihood {}".format(likelihood.__class__.__name__)) likelihood = likelihood meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower()) transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()} meta_data.update(transient_kwargs) model_kwargs = redback.utils.check_kwargs_validity(model_kwargs) meta_data['model_kwargs'] = model_kwargs result = None if not kwargs.get("clean", False): try: result = redback.result.read_in_result( outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False)) plt.close('all') return result except Exception: pass result = result or bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive, outdir=outdir, plot=plot, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) plt.close('all') if plot: result.plot_lightcurve(model=model) return result def _fit_prompt(transient, model, outdir, label, likelihood=None, integrated_rate_function=True, sampler='dynesty', nlive=3000, prior=None, walks=1000, resume=True, save_format='json', model_kwargs=None, plot=True, **kwargs): likelihood = likelihood or PoissonLikelihood( time=transient.x, counts=transient.y, dt=transient.bin_size, function=model, integrated_rate_function=integrated_rate_function, kwargs=model_kwargs) meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower()) transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()} meta_data.update(transient_kwargs) model_kwargs = redback.utils.check_kwargs_validity(model_kwargs) meta_data['model_kwargs'] = model_kwargs result = None if not kwargs.get("clean", False): try: result = redback.result.read_in_result( outdir=outdir, label=label, extension=kwargs.get("extension", "json"), gzip=kwargs.get("gzip", False)) plt.close('all') return result except Exception: pass result = result or bilby.run_sampler( likelihood=likelihood, priors=prior, label=label, sampler=sampler, nlive=nlive, outdir=outdir, plot=False, use_ratio=False, walks=walks, resume=resume, maxmcmc=10 * walks, result_class=RedbackResult, meta_data=meta_data, save_bounds=False, nsteps=nlive, nwalkers=walks, save=save_format, **kwargs) plt.close('all') if plot: result.plot_lightcurve(model=model) return result