Source code for redback.transient_models.learned_models

"""This file holds the functions to call the LearnedSurrogateModel models.

LearnedSurrogateModel are models from the redback_surrogates package that
have been trained to emulate more complex transient models and saved in ONNX format.

The workflow for these models is to load the model from an ONNX file using
the LearnedSurrogateModel.from_onnx_file() method, then use the
make_learned_model_callable() function to create a callable function that can be used
to evaluate the model given time and parameters.
"""
import astropy.units as uu
import numpy as np
import re

from astropy.cosmology import Planck18 as cosmo  # noqa
from collections import namedtuple
from scipy.interpolate import RegularGridInterpolator

import redback.sed as sed
from redback.utils import calc_kcorrected_properties, lambda_to_nu


[docs] def make_learned_model_callable(model): """ This function takes in a LearnedSurrogateModel instance and returns a callable function that can be used to evaluate the model given time and parameters. The function's signature will match the expected format for redback with time as the first argument, followed by each of the model parameters and then any additional keyword arguments. :param model: LearnedSurrogateModel instance :return: Callable function to evaluate the model """ # Make sure all the model's parameter names are safe to use as function arguments. identifier_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") for name in model.param_names: if not isinstance(name, str) or not identifier_re.match(name): raise ValueError( f"Parameter name '{name}' is invalid. Parameter names can " "only contain alphanumeric characters and underscores." ) # Redshift is a required parameter for cosmological calculations. if "redshift" not in model.param_names: raise ValueError("Model parameter names must include 'redshift'.") # Build the complete function string. We have already checked that the parameter names are safe. param_str = ", ".join(model.param_names) param_dict_str = ( "{" + ", ".join([f"'{name}': {name}" for name in model.param_names]) + "}" ) function_code = ( f"def _dynamic_predict_grid(time, {param_str}, *, model=model, **kwargs):\n" f" param_dict = {param_dict_str}\n" f" return _eval_learned_surrogate(model, time, param_dict, **kwargs)\n" ) # Execute the function definition and bind it to this instance. Note that we can only do exec # safely here only because we checked the parameter names earlier to ensure they are safe. local_namespace = {"model": model} exec(function_code, globals(), local_namespace) # Use partial to bind the model to the function so the user doesn't have to pass it in. return local_namespace["_dynamic_predict_grid"]
def _eval_learned_surrogate(model, time, params, **kwargs): """ This is a common evaluation function for LearnedSurrogateModel models that can be called from each model's wrapper function. :param model: LearnedSurrogateModel instance :param time: Time in days in observer frame :param params: Dictionary of model parameters. Must include 'redshift' key. :param kwargs: Additional parameters for the model, such as: :param frequency: Required if output_format is 'flux_density'. frequency to calculate - Must be same length as time array or a single number). :param bands: Required if output_format is 'magnitude' or 'flux'. :param output_format: 'flux_density', 'magnitude', 'spectra', 'flux', 'sncosmo_source' :param lambda_array: Optional argument to set your desired wavelength array (in Angstroms) to evaluate the SED on. :param cosmology: Cosmology to use for luminosity distance calculation. Defaults to Planck18. Must be a astropy.cosmology object. :return: set by output format - 'flux_density', 'magnitude', 'spectra', 'flux', 'sncosmo_source' """ cosmology = kwargs.get('cosmology', cosmo) redshift = params.get('redshift', 0.0) if redshift <= 0.0: raise ValueError(f"Redshift must be positive and non-zero. Got {redshift}.") dl = cosmology.luminosity_distance(redshift).cgs # Get the rest-frame spectrum from the model. # These will always be f_lambda in erg/s/Hz luminosity_density = model.predict_spectra_grid(**params) if not hasattr(luminosity_density, 'unit'): luminosity_density = luminosity_density * uu.erg / uu.s / uu.Hz lambda_rest = model.wavelengths # Angstrom in rest frame time_rest = model.times # days in rest frame # Apply cosmological dimming: L_nu / (4*pi*d_L^2) gives flux that # still needs (1+z) correction. Units are now erg/s/Hz/cm^2 flux_density = luminosity_density / (4 * np.pi * dl ** 2) # Handle different output formats if kwargs.get('output_format') == 'flux_density': # Use redback's K-correction utilities frequency = kwargs['frequency'] frequency, time = calc_kcorrected_properties(frequency=frequency, time=time, redshift=redshift) # Convert rest-frame wavelengths to rest-frame frequencies for interpolation nu_rest = lambda_to_nu(lambda_rest) # Convert flux density to mJy conversion_factor = (1.0 * uu.erg / uu.s / uu.Hz / (uu.cm ** 2)).to(uu.mJy).value fmjy = conversion_factor * flux_density # Create interpolator on rest-frame grid flux_interpolator = RegularGridInterpolator( (time_rest, nu_rest), fmjy, bounds_error=False, fill_value=0.0 ) # Prepare points for interpolation if isinstance(frequency, (int, float)): frequency = np.ones_like(time) * frequency # Create points for evaluation points = np.column_stack((time, frequency)) # Return interpolated flux density with (1+z) correction for observer frame return flux_interpolator(points) * (1 + redshift) else: # Create denser grid for output (in rest frame) time_rest_dense = np.geomspace(np.min(time_rest), np.max(time_rest), 200) lambda_rest_dense = np.geomspace(np.min(lambda_rest), np.max(lambda_rest), 200) # Create interpolator for the flux density in rest frame flux_interpolator = RegularGridInterpolator( (time_rest, lambda_rest), flux_density.value, bounds_error=False, fill_value=0.0 ) # Create meshgrid for new grid points tt_mesh, ll_mesh = np.meshgrid(time_rest_dense, lambda_rest_dense, indexing='ij') points_to_evaluate = np.column_stack((tt_mesh.ravel(), ll_mesh.ravel())) # Interpolate flux density onto denser grid interpolated_values = flux_interpolator(points_to_evaluate) interpolated_flux = interpolated_values.reshape(tt_mesh.shape) * flux_density.unit # Convert to observer frame: times and wavelengths time_observer_frame = time_rest_dense * (1 + redshift) lambda_observer_frame = lambda_rest_dense * (1 + redshift) # Move to the observer frame and switch to spectra units. spectra = sed.flux_density_to_spectrum(interpolated_flux, redshift, lambda_observer_frame) # Create output structure if kwargs.get('output_format') == 'spectra': return namedtuple('output', ['time', 'lambdas', 'spectra'])( time=time_observer_frame, lambdas=lambda_observer_frame, spectra=spectra ) else: # Get correct output format using redback utility return sed.get_correct_output_format_from_spectra( time=time, # Original observer frame time for evaluation time_eval=time_observer_frame, spectra=spectra, lambda_array=lambda_observer_frame, time_spline_degree=1, **kwargs )