Source code for redback.multimessenger

"""
Multi-messenger analysis framework for joint fitting of transient data across multiple messengers.

This module provides infrastructure for jointly analyzing transients observed through different messengers
(optical, X-ray, radio, gravitational waves, neutrinos, etc.) with shared physical parameters.
"""

import numpy as np
from typing import Dict, List, Union, Optional, Any
from pathlib import Path
import functools
import inspect
import bilby

import redback
from redback.likelihoods import (
    GaussianLikelihood, GaussianLikelihoodQuadratureNoise, GaussianLikelihoodUniformXErrors,
    GaussianLikelihoodWithUpperLimits, _RedbackParameterStore
)
from redback.model_library import all_models_dict
from redback.result import MultiMessengerResult
from redback.utils import logger
from redback.transient.transient import Transient


def _get_model_function(model: Union[str, callable]) -> callable:
    if isinstance(model, str):
        if model not in all_models_dict:
            raise ValueError(f"Model '{model}' not found in redback model library")
        return all_models_dict[model]
    return model


def _get_model_parameter_names(function: callable) -> List[str]:
    return bilby.core.utils.introspection.infer_parameters_from_function(func=function)


def _get_independent_variable_name(function: callable) -> str:
    try:
        return next(iter(inspect.signature(function).parameters))
    except (StopIteration, TypeError, ValueError):
        return "x"


def _make_parameter_mapped_model(model_func: callable, parameter_mapping: Optional[Dict[str, str]] = None) -> callable:
    """
    Wrap a model so joint-analysis parameter names can differ from the model's native names.

    parameter_mapping maps joint parameter names to native model parameter names, e.g.
    {'viewing_angle': 'thv'} exposes viewing_angle to the sampler and passes it as thv.
    """
    parameter_mapping = parameter_mapping or {}
    model_parameters = _get_model_parameter_names(model_func)
    unknown_native_parameters = sorted(set(parameter_mapping.values()) - set(model_parameters))
    if unknown_native_parameters:
        raise ValueError(
            "Parameter mapping refers to model parameter(s) not present in the model "
            f"signature: {unknown_native_parameters}"
        )
    native_to_joint = {native: joint for joint, native in parameter_mapping.items()}
    exposed_parameters = [native_to_joint.get(parameter, parameter) for parameter in model_parameters]

    duplicates = {parameter for parameter in exposed_parameters if exposed_parameters.count(parameter) > 1}
    if duplicates:
        raise ValueError(f"Parameter mapping creates duplicate sampled parameters: {sorted(duplicates)}")

    @functools.wraps(model_func)
    def mapped_model(x, **parameters):
        native_parameters = parameters.copy()
        for joint_name, native_name in parameter_mapping.items():
            if joint_name in parameters:
                native_parameters[native_name] = parameters[joint_name]
                native_parameters.pop(joint_name, None)
        return model_func(x, **native_parameters)

    independent_variable = inspect.Parameter(
        _get_independent_variable_name(model_func), inspect.Parameter.POSITIONAL_OR_KEYWORD)
    signature_parameters = [independent_variable] + [
        inspect.Parameter(parameter, inspect.Parameter.POSITIONAL_OR_KEYWORD)
        for parameter in exposed_parameters
    ]
    mapped_model.__signature__ = inspect.Signature(parameters=signature_parameters)
    return mapped_model


def _get_transient_data_for_likelihood(transient: Transient) -> tuple:
    """
    Return transient data in likelihood-ready form.

    If no active band filter is configured, use the full x/y arrays directly.
    This keeps flux-density transients with frequency arrays but no photometric
    band labels from failing inside Transient.get_filtered_data().
    """
    if getattr(transient, "active_bands", None) is None:
        return transient.x, transient.x_err, transient.y, transient.y_err
    return transient.get_filtered_data()


def _get_transient_data_with_limits_for_likelihood(transient: Transient) -> tuple:
    """
    Return transient data and detection flags in likelihood-ready form.

    This mirrors :meth:`Transient.get_filtered_data_with_limits` but preserves
    the no-active-band path used by flux-density transients without band labels.
    """
    if getattr(transient, "active_bands", None) is None:
        return transient.x, transient.x_err, transient.y, transient.y_err, transient.detections
    return transient.get_filtered_data_with_limits()


def _get_filtered_upper_limit_sigma(transient: Transient) -> Union[float, np.ndarray]:
    """Return upper-limit sigma values matching the transient's likelihood data."""
    upper_limit_sigma = transient.upper_limit_sigma
    if np.isscalar(upper_limit_sigma):
        return upper_limit_sigma

    upper_limit_sigma = np.asarray(upper_limit_sigma)
    if getattr(transient, "active_bands", None) is None:
        return upper_limit_sigma

    filtered_indices = transient.filtered_indices
    if len(upper_limit_sigma) == len(transient.x):
        return upper_limit_sigma[filtered_indices]

    if transient.detections is not None and len(upper_limit_sigma) == np.sum(transient.upper_limits):
        upper_limit_positions = np.cumsum(transient.upper_limits) - 1
        filtered_upper_limits = transient.upper_limits[filtered_indices]
        return upper_limit_sigma[upper_limit_positions[filtered_indices][filtered_upper_limits]]

    return upper_limit_sigma


def _get_upper_limit_data_mode(transient: Transient) -> str:
    if getattr(transient, "magnitude_data", False):
        return "magnitude"
    if getattr(transient, "flux_density_data", False):
        return "flux_density"
    return "flux"


def _has_positive_x_errors(x_err: Optional[np.ndarray]) -> bool:
    return x_err is not None and np.any(np.asarray(x_err) > 0)


def _get_x_error_bin_size(x_err: np.ndarray) -> np.ndarray:
    x_err = np.asarray(x_err)
    if x_err.ndim == 2 and x_err.shape[0] == 2:
        return np.sum(np.abs(x_err), axis=0)
    return x_err


def _validate_shared_parameters(likelihoods: List[bilby.Likelihood], shared_params: Optional[List[str]]) -> None:
    if not shared_params:
        return

    missing_params = []
    singly_present_params = []
    for parameter in shared_params:
        count = sum(parameter in _get_likelihood_parameters(likelihood) for likelihood in likelihoods)
        if count == 0:
            missing_params.append(parameter)
        elif len(likelihoods) > 1 and count == 1:
            singly_present_params.append(parameter)

    if missing_params:
        raise ValueError(
            "Shared parameter(s) are not present in any likelihood after applying "
            f"parameter mappings: {missing_params}"
        )
    if singly_present_params:
        raise ValueError(
            "Shared parameter(s) are present in only one likelihood and therefore are not "
            f"actually shared: {singly_present_params}. Use parameter_mappings or model wrappers "
            "so each relevant likelihood exposes the same sampled parameter name."
        )


def _log_likelihood_accepts_parameters(likelihood: bilby.Likelihood) -> bool:
    try:
        signature = inspect.signature(likelihood.log_likelihood)
    except (TypeError, ValueError):
        return False
    return "parameters" in signature.parameters


def _get_likelihood_parameters(likelihood: bilby.Likelihood) -> Dict[str, Any]:
    """Return a likelihood's parameter dictionary without triggering Bilby state warnings."""
    if hasattr(likelihood, "_parameters"):
        return likelihood._parameters
    return likelihood.parameters


def _update_likelihood_parameters(likelihood: bilby.Likelihood, parameters: Dict[str, Any]) -> None:
    """Update child likelihood parameters without relying on deprecated Bilby state access."""
    if hasattr(likelihood, "_parameters"):
        likelihood._parameters.update(parameters)
    else:
        likelihood.parameters.update(parameters)


[docs] class MultiMessengerLikelihood(_RedbackParameterStore, bilby.Likelihood): """A sampler-compatible likelihood product for redback and bilby likelihoods."""
[docs] def __init__(self, *likelihoods: bilby.Likelihood): if len(likelihoods) == 0: raise ValueError("At least one likelihood is required.") self.likelihoods = likelihoods self._parameter_names_by_likelihood = [ (likelihood, set(_get_likelihood_parameters(likelihood))) for likelihood in likelihoods ] parameters = {} for likelihood in likelihoods: parameters.update(_get_likelihood_parameters(likelihood)) super().__init__(parameters=parameters)
def _get_child_parameters(self, likelihood: bilby.Likelihood) -> Dict[str, Any]: child_parameter_names = None for child_likelihood, parameter_names in self._parameter_names_by_likelihood: if child_likelihood is likelihood: child_parameter_names = parameter_names break if child_parameter_names is None: raise ValueError("Likelihood is not part of this MultiMessengerLikelihood.") return { parameter: self.parameters[parameter] for parameter in child_parameter_names if parameter in self.parameters } def _update_parameters(self, parameters: Optional[Dict[str, Any]] = None) -> None: if parameters is not None: self.parameters.update(parameters) def log_likelihood(self, parameters: Optional[Dict[str, Any]] = None) -> float: self._update_parameters(parameters=parameters) log_likelihood = 0.0 for likelihood in self.likelihoods: child_parameters = self._get_child_parameters(likelihood=likelihood) _update_likelihood_parameters(likelihood=likelihood, parameters=child_parameters) if _log_likelihood_accepts_parameters(likelihood): log_likelihood += likelihood.log_likelihood(parameters=child_parameters) else: log_likelihood += likelihood.log_likelihood() return log_likelihood def noise_log_likelihood(self) -> float: return sum(likelihood.noise_log_likelihood() for likelihood in self.likelihoods)
[docs] class MultiMessengerTransient: """ Joint analysis of multiple messengers for transient events. This class enables multi-messenger analysis by combining data from different observational channels (electromagnetic, gravitational wave, neutrino) and performing joint parameter estimation with shared physical parameters. Examples -------- Basic usage for a kilonova + GRB afterglow analysis: >>> import redback >>> mm_transient = MultiMessengerTransient( ... optical_transient=kilonova_transient, ... xray_transient=xray_transient, ... radio_transient=radio_transient ... ) >>> result = mm_transient.fit_joint( ... models={'optical': 'two_component_kilonova_model', ... 'xray': 'tophat', ... 'radio': 'tophat'}, ... shared_params=['viewing_angle', 'luminosity_distance'], ... model_kwargs={'optical': {'output_format': 'magnitude'}, ... 'xray': {'output_format': 'flux_density'}, ... 'radio': {'output_format': 'flux_density'}}, ... priors=priors ... ) Advanced usage with custom likelihoods and GW data: >>> mm_transient = MultiMessengerTransient( ... optical_transient=optical_lc, ... gw_likelihood=gw_likelihood # Pre-constructed bilby GW likelihood ... ) >>> result = mm_transient.fit_joint( ... models={'optical': 'two_component_kilonova_model'}, ... shared_params=['viewing_angle', 'luminosity_distance'], ... priors=priors ... ) """
[docs] def __init__( self, optical_transient: Optional[Transient] = None, xray_transient: Optional[Transient] = None, radio_transient: Optional[Transient] = None, uv_transient: Optional[Transient] = None, infrared_transient: Optional[Transient] = None, gw_likelihood: Optional[bilby.Likelihood] = None, neutrino_likelihood: Optional[bilby.Likelihood] = None, custom_likelihoods: Optional[Dict[str, bilby.Likelihood]] = None, name: str = 'multimessenger_transient' ): """ Initialize a MultiMessengerTransient object. Parameters ---------- optical_transient : redback.transient.Transient, optional Optical/NIR data as a Redback transient object xray_transient : redback.transient.Transient, optional X-ray data as a Redback transient object radio_transient : redback.transient.Transient, optional Radio data as a Redback transient object uv_transient : redback.transient.Transient, optional UV data as a Redback transient object infrared_transient : redback.transient.Transient, optional Infrared data as a Redback transient object gw_likelihood : bilby.Likelihood, optional Pre-constructed gravitational wave likelihood (e.g., from bilby.gw) neutrino_likelihood : bilby.Likelihood, optional Pre-constructed neutrino likelihood custom_likelihoods : dict, optional Dictionary of custom likelihood objects with messenger names as keys name : str, optional Name for this multi-messenger transient (default: 'multimessenger_transient') """ self.name = name # Store transient data objects self.messengers = { 'optical': optical_transient, 'xray': xray_transient, 'radio': radio_transient, 'uv': uv_transient, 'infrared': infrared_transient } # Remove None entries self.messengers = {k: v for k, v in self.messengers.items() if v is not None} # Store pre-constructed likelihoods (e.g., for GW or neutrinos) self.external_likelihoods = {} if gw_likelihood is not None: self.external_likelihoods['gw'] = gw_likelihood if neutrino_likelihood is not None: self.external_likelihoods['neutrino'] = neutrino_likelihood if custom_likelihoods is not None: self.external_likelihoods.update(custom_likelihoods) logger.info(f"Initialized MultiMessengerTransient '{name}' with {len(self.messengers)} " f"transient data objects and {len(self.external_likelihoods)} external likelihoods")
def _build_likelihood_for_messenger( self, messenger: str, transient: Transient, model: Union[str, callable], model_kwargs: Optional[Dict] = None, likelihood_type: str = 'GaussianLikelihood', parameter_mapping: Optional[Dict[str, str]] = None ) -> bilby.Likelihood: """ Build a likelihood for a single messenger. Parameters ---------- messenger : str Name of the messenger (e.g., 'optical', 'xray', 'radio') transient : redback.transient.Transient Transient data object model : str or callable Model name (string) or callable function model_kwargs : dict, optional Additional keyword arguments for the model likelihood_type : str, optional Type of likelihood to use (default: 'GaussianLikelihood') Options: 'GaussianLikelihood', 'GaussianLikelihoodWithUpperLimits', 'GaussianLikelihoodQuadratureNoise' parameter_mapping : dict, optional Mapping from joint sampled parameter names to native model parameter names for this messenger. For example {'viewing_angle': 'thv'} shares the sampled viewing_angle parameter with a model that expects thv. Returns ------- bilby.Likelihood Constructed likelihood object """ if model_kwargs is None: model_kwargs = {} model_func = _make_parameter_mapped_model( _get_model_function(model), parameter_mapping=parameter_mapping) # Get data from transient detections = None if getattr(transient, "has_upper_limits", False) is True: x, x_err, y, y_err, detections = _get_transient_data_with_limits_for_likelihood(transient) else: x, x_err, y, y_err = _get_transient_data_for_likelihood(transient) # Select likelihood class has_upper_limits = detections is not None and np.any(~detections) if has_upper_limits: if likelihood_type not in ('GaussianLikelihood', 'GaussianLikelihoodWithUpperLimits'): raise ValueError( f"{likelihood_type} does not support upper limits in MultiMessengerTransient. " "Use GaussianLikelihood/GaussianLikelihoodWithUpperLimits or provide a custom likelihood." ) if _has_positive_x_errors(x_err): raise ValueError( "Upper-limit likelihoods with x/time errors are not supported in " "MultiMessengerTransient. Provide a custom likelihood for this case." ) n_upper_limits = int(np.sum(~detections)) upper_limit_y = y[~detections] n_nan_upper_limits = int(np.sum(np.isnan(upper_limit_y))) if n_nan_upper_limits > 0: logger.warning( f"{n_nan_upper_limits} upper limit(s) for {messenger} have NaN y-values and " "cannot be used in GaussianLikelihoodWithUpperLimits. Falling back to a " "GaussianLikelihood using detection data only." ) detection_mask = detections likelihood_class = GaussianLikelihood x, y, y_err = x[detection_mask], y[detection_mask], y_err[detection_mask] if x_err is not None: x_err = x_err[detection_mask] else: logger.info( f"Building GaussianLikelihoodWithUpperLimits for {messenger} with " f"{n_upper_limits} upper limits" ) return GaussianLikelihoodWithUpperLimits( x=x, y=y, sigma=y_err, function=model_func, kwargs=model_kwargs, detections=detections, upper_limit_sigma=_get_filtered_upper_limit_sigma(transient), data_mode=_get_upper_limit_data_mode(transient) ) elif likelihood_type in ('GaussianLikelihood', 'GaussianLikelihoodWithUpperLimits'): likelihood_class = GaussianLikelihoodUniformXErrors if _has_positive_x_errors(x_err) else GaussianLikelihood elif likelihood_type == 'GaussianLikelihoodQuadratureNoise': if _has_positive_x_errors(x_err): raise ValueError( "GaussianLikelihoodQuadratureNoise does not support x/time errors in " "MultiMessengerTransient. Use GaussianLikelihood or provide a custom likelihood." ) likelihood_class = GaussianLikelihoodQuadratureNoise else: raise ValueError(f"Unsupported likelihood type: {likelihood_type}") # Construct likelihood if likelihood_class is GaussianLikelihoodUniformXErrors: logger.info(f"Building {likelihood_type} for {messenger} with time errors") likelihood = likelihood_class( x=x, y=y, sigma=y_err, bin_size=_get_x_error_bin_size(x_err), function=model_func, kwargs=model_kwargs ) else: if likelihood_type == 'GaussianLikelihoodQuadratureNoise': likelihood = likelihood_class( x=x, y=y, sigma_i=y_err, function=model_func, kwargs=model_kwargs ) else: likelihood = likelihood_class( x=x, y=y, sigma=y_err, function=model_func, kwargs=model_kwargs ) logger.info(f"Built likelihood for {messenger} messenger with model {model_func.__name__}") return likelihood
[docs] def fit_joint( self, models: Dict[str, Union[str, callable]], priors: Union[bilby.core.prior.PriorDict, dict], shared_params: Optional[List[str]] = None, model_kwargs: Optional[Dict[str, Dict]] = None, likelihood_types: Optional[Dict[str, str]] = None, parameter_mappings: Optional[Dict[str, Dict[str, str]]] = None, sampler: str = 'dynesty', nlive: int = 2000, walks: int = 200, outdir: Optional[str] = None, label: Optional[str] = None, resume: bool = True, plot: bool = True, save_format: str = 'json', **kwargs ) -> bilby.core.result.Result: """ Perform joint multi-messenger analysis. This method builds individual likelihoods for each messenger, combines them into a joint likelihood, and runs parameter estimation with the specified sampler. Parameters ---------- models : dict Dictionary mapping messenger names to model names/functions. Example: {'optical': 'two_component_kilonova_model', 'xray': 'tophat'} priors : bilby.core.prior.PriorDict or dict Prior distributions for all parameters. For shared parameters, the same prior will be used across all messengers. shared_params : list of str, optional List of parameter names that are shared across messengers. Example: ['viewing_angle', 'luminosity_distance', 'time_of_merger'] If None, parameters are assumed independent unless they have the same name. model_kwargs : dict of dict, optional Dictionary mapping messenger names to their model keyword arguments. Example: {'optical': {'output_format': 'magnitude'}, 'xray': {'output_format': 'flux_density', 'frequency': freq_array}} likelihood_types : dict of str, optional Dictionary mapping messenger names to likelihood types. Example: {'optical': 'GaussianLikelihood', 'xray': 'GaussianLikelihoodQuadratureNoise'} Default: 'GaussianLikelihood' for all messengers parameter_mappings : dict of dict, optional Dictionary mapping messenger names to parameter maps. Each map should map joint sampled parameter names to that messenger model's native parameter names. Example: {'xray': {'viewing_angle': 'thv'}}. sampler : str, optional Sampler to use (default: 'dynesty'). See bilby documentation for options. nlive : int, optional Number of live points for nested sampling (default: 2000) walks : int, optional Number of random walks for dynesty (default: 200) outdir : str, optional Output directory for results (default: './outdir_multimessenger') label : str, optional Label for output files (default: self.name) resume : bool, optional Whether to resume from checkpoint if available (default: True) plot : bool, optional Whether to create corner plots (default: True) save_format : str, optional Format for saving results (default: 'json') **kwargs Additional keyword arguments passed to bilby.run_sampler Returns ------- bilby.core.result.Result Result object containing posterior samples and evidence Notes ----- The joint likelihood is constructed as the product of individual messenger likelihoods: L_joint = L_optical × L_xray × L_radio × ... For shared parameters, the same parameter value is used across all relevant models, allowing the data from different messengers to jointly constrain these parameters. Examples -------- >>> result = mm_transient.fit_joint( ... models={'optical': 'two_component_kilonova_model', ... 'xray': 'tophat', ... 'radio': 'tophat'}, ... shared_params=['viewing_angle', 'luminosity_distance'], ... priors=my_priors, ... nlive=2000 ... ) """ if model_kwargs is None: model_kwargs = {} if likelihood_types is None: likelihood_types = {} if parameter_mappings is None: parameter_mappings = {} # Set default output directory and label outdir = outdir or './outdir_multimessenger' label = label or self.name Path(outdir).mkdir(parents=True, exist_ok=True) # Build likelihoods for each messenger likelihoods = [] # Build EM likelihoods from transient objects for messenger, transient in self.messengers.items(): if messenger in models: model = models[messenger] mkwargs = model_kwargs.get(messenger, {}) ltype = likelihood_types.get(messenger, 'GaussianLikelihood') parameter_mapping = parameter_mappings.get(messenger, {}) likelihood = self._build_likelihood_for_messenger( messenger, transient, model, mkwargs, ltype, parameter_mapping ) likelihoods.append(likelihood) else: logger.warning(f"No model specified for messenger '{messenger}', skipping") # Add external likelihoods (GW, neutrino, etc.) for messenger, likelihood in self.external_likelihoods.items(): logger.info(f"Adding external likelihood for {messenger}") likelihoods.append(likelihood) if len(likelihoods) == 0: raise ValueError("No likelihoods were constructed. Please provide models or external likelihoods.") # Construct joint likelihood if len(likelihoods) == 1: logger.warning("Only one likelihood present. Joint analysis reduces to single-messenger analysis.") else: logger.info(f"Combining {len(likelihoods)} likelihoods into joint likelihood") joint_likelihood = MultiMessengerLikelihood(*likelihoods) # Ensure priors is a PriorDict if not isinstance(priors, bilby.core.prior.PriorDict): priors = bilby.core.prior.PriorDict(priors) # Log shared parameters if shared_params: logger.info(f"Shared parameters across messengers: {', '.join(shared_params)}") _validate_shared_parameters(likelihoods=likelihoods, shared_params=shared_params) # Prepare metadata meta_data = { 'multimessenger': True, 'messengers': list(self.messengers.keys()) + list(self.external_likelihoods.keys()), 'models': {k: v if isinstance(v, str) else v.__name__ for k, v in models.items()}, 'shared_params': shared_params or [], 'parameter_mappings': parameter_mappings, 'name': self.name } # Run sampler logger.info(f"Starting joint analysis with {sampler} sampler") result = bilby.run_sampler( likelihood=joint_likelihood, priors=priors, sampler=sampler, nlive=nlive, walks=walks, outdir=outdir, label=label, resume=resume, use_ratio=False, maxmcmc=kwargs.pop('maxmcmc', 10 * walks), result_class=MultiMessengerResult, meta_data=meta_data, save=save_format, plot=plot, **kwargs ) logger.info("Joint analysis complete") return result
[docs] def fit_individual( self, models: Dict[str, Union[str, callable]], priors: Dict[str, Union[bilby.core.prior.PriorDict, dict]], model_kwargs: Optional[Dict[str, Dict]] = None, parameter_mappings: Optional[Dict[str, Dict[str, str]]] = None, sampler: str = 'dynesty', nlive: int = 2000, walks: int = 200, outdir: Optional[str] = None, resume: bool = True, plot: bool = True, **kwargs ) -> Dict[str, redback.result.RedbackResult]: """ Fit each messenger independently (for comparison with joint analysis). Parameters ---------- models : dict Dictionary mapping messenger names to model names/functions priors : dict Dictionary mapping messenger names to their prior distributions model_kwargs : dict of dict, optional Dictionary mapping messenger names to their model keyword arguments parameter_mappings : dict of dict, optional Dictionary mapping messenger names to parameter maps. Each map should map sampled parameter names to that messenger model's native parameter names, matching :meth:`fit_joint`. sampler : str, optional Sampler to use (default: 'dynesty') nlive : int, optional Number of live points (default: 2000) walks : int, optional Number of random walks (default: 200) outdir : str, optional Output directory (default: './outdir_individual') resume : bool, optional Whether to resume from checkpoint (default: True) plot : bool, optional Whether to create plots (default: True) **kwargs Additional arguments for bilby.run_sampler Returns ------- dict Dictionary mapping messenger names to their individual fit results Examples -------- >>> individual_results = mm_transient.fit_individual( ... models={'optical': 'two_component_kilonova_model', 'xray': 'tophat'}, ... priors={'optical': optical_priors, 'xray': xray_priors} ... ) >>> optical_result = individual_results['optical'] """ if model_kwargs is None: model_kwargs = {} if parameter_mappings is None: parameter_mappings = {} outdir = outdir or './outdir_individual' Path(outdir).mkdir(parents=True, exist_ok=True) results = {} for messenger, transient in self.messengers.items(): if messenger not in models: logger.warning(f"No model specified for messenger '{messenger}', skipping") continue if messenger not in priors: logger.warning(f"No prior specified for messenger '{messenger}', skipping") continue model = models[messenger] prior = priors[messenger] mkwargs = model_kwargs.get(messenger, {}) parameter_mapping = parameter_mappings.get(messenger, {}) if parameter_mapping: model = _make_parameter_mapped_model( _get_model_function(model), parameter_mapping=parameter_mapping) logger.info(f"Fitting {messenger} messenger independently") messenger_outdir = f"{outdir}/{messenger}" result = redback.fit_model( transient=transient, model=model, prior=prior, model_kwargs=mkwargs, sampler=sampler, nlive=nlive, walks=walks, outdir=messenger_outdir, label=f"{self.name}_{messenger}", resume=resume, plot=plot, **kwargs ) results[messenger] = result logger.info(f"Completed fit for {messenger}") return results
[docs] def add_messenger(self, messenger_name: str, transient: Optional[Transient] = None, likelihood: Optional[bilby.Likelihood] = None): """ Add a new messenger to the analysis. Parameters ---------- messenger_name : str Name for the messenger transient : redback.transient.Transient, optional Transient data object likelihood : bilby.Likelihood, optional Pre-constructed likelihood object Notes ----- Either transient or likelihood must be provided, but not both. """ if transient is not None and likelihood is not None: raise ValueError("Provide either transient or likelihood, not both") if transient is None and likelihood is None: raise ValueError("Must provide either transient or likelihood") if transient is not None: self.messengers[messenger_name] = transient logger.info(f"Added transient data for {messenger_name}") else: self.external_likelihoods[messenger_name] = likelihood logger.info(f"Added external likelihood for {messenger_name}")
[docs] def remove_messenger(self, messenger_name: str): """ Remove a messenger from the analysis. Parameters ---------- messenger_name : str Name of the messenger to remove """ if messenger_name in self.messengers: del self.messengers[messenger_name] logger.info(f"Removed {messenger_name} from messengers") elif messenger_name in self.external_likelihoods: del self.external_likelihoods[messenger_name] logger.info(f"Removed {messenger_name} from external likelihoods") else: logger.warning(f"Messenger '{messenger_name}' not found")
def __repr__(self): transient_messengers = list(self.messengers.keys()) external_messengers = list(self.external_likelihoods.keys()) return (f"MultiMessengerTransient(name='{self.name}', " f"transients={transient_messengers}, " f"external_likelihoods={external_messengers})")
[docs] def create_joint_prior( individual_priors: Dict[str, bilby.core.prior.PriorDict], shared_params: List[str], shared_param_priors: Optional[Dict[str, bilby.core.prior.Prior]] = None ) -> bilby.core.prior.PriorDict: """ Create a joint prior dictionary from individual messenger priors. This utility function helps construct a prior dictionary for joint multi-messenger analysis by combining individual priors and handling shared parameters. Parameters ---------- individual_priors : dict Dictionary mapping messenger names to their PriorDict objects shared_params : list of str List of parameter names that are shared across messengers shared_param_priors : dict, optional Dictionary of prior objects for shared parameters. If not provided, the prior from the first messenger will be used. Returns ------- bilby.core.prior.PriorDict Combined prior dictionary for joint analysis Examples -------- >>> optical_priors = bilby.core.prior.PriorDict({ ... 'viewing_angle': bilby.core.prior.Uniform(0, np.pi/2), ... 'kappa': bilby.core.prior.Uniform(0.1, 10) ... }) >>> xray_priors = bilby.core.prior.PriorDict({ ... 'viewing_angle': bilby.core.prior.Uniform(0, np.pi/2), ... 'log_n0': bilby.core.prior.Uniform(-5, 2) ... }) >>> joint_priors = create_joint_prior( ... {'optical': optical_priors, 'xray': xray_priors}, ... shared_params=['viewing_angle'] ... ) """ joint_prior = bilby.core.prior.PriorDict() # Add priors for shared parameters for param in shared_params: if shared_param_priors and param in shared_param_priors: joint_prior[param] = shared_param_priors[param] else: # Use the prior from the first messenger that has this parameter found_shared_param = False for messenger, prior_dict in individual_priors.items(): if param in prior_dict: joint_prior[param] = prior_dict[param] logger.info(f"Using {messenger} prior for shared parameter '{param}'") found_shared_param = True break if not found_shared_param: raise ValueError( f"Shared parameter '{param}' is not present in any individual prior. " "Add it to at least one prior dictionary or pass shared_param_priors." ) # Add messenger-specific priors. Parameter names are left unchanged so they # match the likelihood parameters built by MultiMessengerTransient. for messenger, prior_dict in individual_priors.items(): for param, prior in prior_dict.items(): if param not in shared_params: if param in joint_prior: raise ValueError( f"Parameter '{param}' appears in multiple priors but is not marked shared. " "Use distinct model parameter names, a custom model wrapper, or include it " "in shared_params if it should be common." ) joint_prior[param] = prior # Shared params are already added, so skip them return joint_prior