Source code for redback.analysis

import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from typing import Union, Optional
from pathlib import Path

import redback.model_library
from redback.utils import logger, find_nearest, bands_to_frequency
from redback.result import RedbackResult
from redback.constants import day_to_s
import matplotlib


def _setup_plotting_result(model, model_kwargs, parameters, transient):
    """
    Helper function to setup the plotting result

    :param model: model string or model function
    :param model_kwargs: keyword arguments passed to the model
    :param parameters: parameters to plot
    :param transient: transient object
    :return: a tuple of model, parameters, and result
    """
    if isinstance(parameters, dict):
        parameters = pd.DataFrame.from_dict(parameters)
    parameters["log_likelihood"] = np.arange(len(parameters))
    if isinstance(model, str):
        model = redback.model_library.all_models_dict[model]
    meta_data = dict(model=model.__name__, transient_type=transient.__class__.__name__.lower())
    transient_kwargs = {k.lstrip("_"): v for k, v in transient.__dict__.items()}
    meta_data.update(transient_kwargs)
    meta_data['model_kwargs'] = model_kwargs or dict()
    res = RedbackResult(label="None", outdir="None",
                        search_parameter_keys=None,
                        fixed_parameter_keys=None,
                        constraint_parameter_keys=None, priors=None,
                        sampler_kwargs=dict(), injection_parameters=None,
                        meta_data=meta_data, posterior=parameters, samples=None,
                        nested_samples=None, log_evidence=0,
                        log_evidence_err=0, information_gain=0,
                        log_noise_evidence=0, log_bayes_factor=0,
                        log_likelihood_evaluations=0,
                        log_prior_evaluations=0, sampling_time=0, nburn=0,
                        num_likelihood_evaluations=0, walkers=0,
                        max_autocorrelation_time=0, use_ratio=False,
                        version=None)
    return model, parameters, res


[docs] def plot_lightcurve(transient, parameters, model, model_kwargs=None, show=True, save=False, **kwargs: None): """ Plot a lightcurve for a given model and parameters :param transient: transient object :param parameters: parameters to plot :param model: model string or model function :param model_kwargs: keyword arguments passed to the model :return: plot_lightcurve """ model, parameters, res = _setup_plotting_result(model, model_kwargs, parameters, transient) return res.plot_lightcurve(model=model, random_models=len(parameters), plot_max_likelihood=False, save=save, show=show, **kwargs)
[docs] def plot_multiband_lightcurve(transient, parameters, model, model_kwargs=None, show=True, save=False, **kwargs: None): """ Plot a multiband lightcurve for a given model and parameters :param transient: transient object :param parameters: parameters to plot :param model: model string or model function :param model_kwargs: keyword arguments passed to the model :return: plot_multiband_lightcurve """ model, parameters, res = _setup_plotting_result(model, model_kwargs, parameters, transient) return res.plot_multiband_lightcurve(model=model, random_models=len(parameters), plot_max_likelihood=False, save=save, show=show, **kwargs)
[docs] def plot_evolution_parameters(result, random_models=100): """ Plot evolution parameters for a given evolving_magnetar result :param result: redback result :param random_models: number of random models to plot :return: fig and axes """ logger.warning("This type of plot is only valid for evolving magnetar models") tmin = np.log10(np.min(result.metadata['time'])) tmax = np.log10(np.max(result.metadata['time'])) time = np.logspace(tmin, tmax, 100) fig, ax = plt.subplots(3, 1, sharex=True, figsize=(5, 10)) for j in range(random_models): s = dict(result.posterior.iloc[np.random.randint(len(result.posterior))]) s["output"] = "namedtuple" model = redback.model_library.all_models_dict["evolving_magnetar_only"] output = model(time, **s) nn = output.nn mu = output.mu alpha = output.alpha ax[0].plot(time, nn, "--", lw=1, color='red', zorder=-1) ax[1].plot(time, np.rad2deg(alpha), "--", lw=1, color='red', zorder=-1) ax[2].plot(time, mu, "--", lw=1, color='red', zorder=-1) ax[0].set_ylabel('braking index') ax[1].set_ylabel('inclination angle') ax[2].set_ylabel('magnetic moment') for x in range(3): ax[x].set_yscale('log') ax[x].set_xscale('log') fig.supxlabel(r"Time since burst [s]") return fig, ax
[docs] def plot_spectrum(model, parameters, time_to_plot, axes=None, **kwargs): """ Plot a spectrum for a given model and parameters :param model: Model string for a redback model :param parameters: dictionary of parameters/alongside model specific keyword arguments. Must be one set of parameters. If you want to plot a posterior prediction of the spectrum, call this function in a loop. :param time_to_plot: Times to plot (in days) the spectrum at. The spectrum plotted will be at the nearest neighbour to this value :param axes: None or matplotlib axes object if you want to plot on an existing set of axes :param kwargs: Additional keyword arguments used by this function. :param colors_list: List of colors to use for each time to plot. Set randomly unless specified. :return: matplotlib axes """ function = redback.model_library.all_models_dict[model] model_kwargs = {} model_kwargs.update(parameters) model_kwargs['output_format'] = 'spectra' model_kwargs['bands'] = 'lsstg' output = function(time_to_plot, **model_kwargs) lambdas = output.lambdas time_of_output = output.time/day_to_s #extract spectrum at the times of interest. spec = {} for tt in time_to_plot: _, idx = find_nearest(time_of_output, tt) spec[tt] = output.spectra[idx] if 'colors_list' in kwargs.keys(): colors_list = kwargs.pop('colors_list') else: colors_list = matplotlib.cm.tab20(range(len(time_to_plot))) ax = axes or plt.gca() for i, tt in enumerate(time_to_plot): ax.semilogx(lambdas, spec[tt], color=colors_list[i], label=f"{tt:.1f} days") ax.set_xlabel(r'Wavelength ($\mathrm{\AA}$)') ax.set_ylabel(r'Flux ($10^{-17}$ erg s$^{-1}$ cm$^{-2}$ $\mathrm{\AA}$)') ax.legend(loc='upper left') return ax
[docs] def plot_gp_lightcurves(transient, gp_output, axes=None, band_colors=None, band_scaling=None): """ Plot the Gaussian Process lightcurves :param transient: A transient object :param gp_output: The output of the fit_gp function :param axes: axes, ideally you should be passing the axes from the plot_data methods :param band_colors: a dictionary of band colors; again ideally you should be passing the band_colors from the plot_data methods :return: axes object with the GP lightcurves plotted """ ax = axes or plt.gca() if transient.use_phase_model: ref_date = transient.x[0] else: ref_date = 0 t_new = np.linspace(transient.x.min() - 10, transient.x.max() + 20, 100) if transient.data_mode in ['flux_density', 'flux', 'magnitude']: if band_colors is None: band_colors = dict(zip(transient.unique_bands, plt.cm.tab20(range(len(transient.unique_bands))))) else: band_colors = band_colors if gp_output.use_frequency: for band in transient.unique_bands: if band_scaling: scaling = band_scaling[band] else: scaling = 0 f_new = np.ones_like(t_new) * bands_to_frequency([band]) X_new = np.column_stack((f_new, t_new)) gp = gp_output.gp y_pred, y_cov = gp.predict(gp_output.scaled_y, X_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_lower = y_pred - 0.5 * y_std y_upper = y_pred + 0.5 * y_std ax.plot(t_new - ref_date, (y_pred * gp_output.y_scaler) + scaling, color=band_colors[band]) ax.fill_between(t_new - ref_date, (y_lower * gp_output.y_scaler) + scaling, (y_upper * gp_output.y_scaler) + scaling, alpha=0.5, color=band_colors[band]) else: for band in transient.unique_bands: if band_scaling: scaling = band_scaling[band] else: scaling = 0 gp = gp_output.gp[band] y_pred, y_cov = gp.predict(gp_output.scaled_y[band], t_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_lower = y_pred - 0.5 * y_std y_upper = y_pred + 0.5 * y_std ax.plot(t_new - ref_date, (y_pred * gp_output.y_scaler) + scaling, color=band_colors[band]) ax.fill_between(t_new - ref_date, (y_lower * gp_output.y_scaler) + scaling, (y_upper * gp_output.y_scaler) + scaling, alpha=0.5, color=band_colors[band]) else: y_pred, y_cov = gp_output.gp.predict(gp_output.scaled_y, t_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_lower = y_pred - 0.5 * y_std y_upper = y_pred + 0.5 * y_std ax.plot(t_new, y_pred * gp_output.y_scaler, color='red') ax.fill_between(t_new, y_lower * gp_output.y_scaler, y_upper * gp_output.y_scaler, alpha=0.5, color='red') return ax
[docs] def fit_temperature_and_radius_gp(data, kernelT, kernelR, plot=False, **kwargs): """ Fit a Gaussian Process to the temperature and radius data :param data: DataFrame containing the temperature and radius data output of the transient.estimate_bb_params method. :param kernelT: george kernel for the temperature :param kernelR: george kernel for the radius :param plot: Whether to make a two-panel plot of the temperature and radius GP evolution and the data :param kwargs: Additional keyword arguments :param inflate_errors: If True, inflate the errors by 20%, default is False :return: Temperature and radius GP objects and plot fig and axes if requested """ import george from scipy.optimize import minimize temperature = data['temperature'] radius = data['radius'] t_data = data['epoch_times'] T_err = data['temp_err'] R_err = data['radius_err'] inflate_errors = kwargs.get('inflate_errors', True) if inflate_errors: error = kwargs.get('error', 1.5) else: error = 1 gp_T_err_raw = T_err * error gp_R_err = R_err * error fit_in_log = kwargs.get("fit_in_log", False) if fit_in_log: # In log space, use: log10(T); propagate errors via: δ(log10T)=δT/(T*ln(10)) temperature_fit = np.log10(temperature) gp_T_err = gp_T_err_raw / (temperature * np.log(10)) else: temperature_fit = temperature gp_T_err = gp_T_err_raw gp_T = george.GP(kernelT) gp_T.compute(t_data, gp_T_err + 1e-8) def neg_ln_like_T(p): gp_T.set_parameter_vector(p) return -gp_T.log_likelihood(temperature_fit) def grad_neg_ln_like_T(p): gp_T.set_parameter_vector(p) return -gp_T.grad_log_likelihood(temperature_fit) p0_T = gp_T.get_parameter_vector() result_T = minimize(neg_ln_like_T, p0_T, jac=grad_neg_ln_like_T) gp_T.set_parameter_vector(result_T.x) logger.info("Finished GP fit for temperature") logger.info(f"GP final parameters: {gp_T.get_parameter_dict()}") gp_R = george.GP(kernelR) gp_R.compute(t_data, gp_R_err + 1e-8) def neg_ln_like_R(p): gp_R.set_parameter_vector(p) return -gp_R.log_likelihood(radius) def grad_neg_ln_like_R(p): gp_R.set_parameter_vector(p) return -gp_R.grad_log_likelihood(radius) p0_R = gp_R.get_parameter_vector() result_R = minimize(neg_ln_like_R, p0_R, jac=grad_neg_ln_like_R) gp_R.set_parameter_vector(result_R.x) logger.info("Finished GP fit for radius") logger.info(f"GP final parameters: {gp_R.get_parameter_dict()}") if plot: sigma_to_plot = kwargs.get('sigma_to_plot', 1) label = r"${}\sigma$ GP uncertainty".format(str(int(sigma_to_plot))) t_pred = np.linspace(t_data.min(), t_data.max(), 100) # Temperature prediction T_pred, T_pred_var = gp_T.predict(temperature_fit, t_pred, return_var=True) T_pred_std = np.sqrt(T_pred_var) # If fitting in log space, convert the prediction back to linear units. if fit_in_log: T_pred_lin = 10**T_pred # Propagate the uncertainty approximately: dT ≈ 10^x * ln(10) * sigma_x. T_pred_std_lin = 10**T_pred * np.log(10) * T_pred_std else: T_pred_lin = T_pred T_pred_std_lin = T_pred_std # Radius prediction R_pred, R_pred_var = gp_R.predict(radius, t_pred, return_var=True) R_pred_std = np.sqrt(R_pred_var) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 3)) ax1.errorbar(t_data, temperature, yerr=T_err, fmt='o', label='Data', color='blue') ax1.plot(t_pred, T_pred_lin, label='GP Prediction', color='red') ax1.fill_between(t_pred, T_pred_lin - sigma_to_plot*T_pred_std_lin, T_pred_lin + sigma_to_plot*T_pred_std_lin, alpha=0.2, color='red', label=label) ax2.errorbar(t_data, radius, yerr=R_err, fmt='o', label='Data', color='blue') ax2.plot(t_pred, R_pred, label='GP Prediction', color='red') ax2.fill_between(t_pred, R_pred - sigma_to_plot*R_pred_std, R_pred + sigma_to_plot*R_pred_std, alpha=0.2, color='red', label=label) ax1.set_xlabel("Time", fontsize=15) ax1.set_ylabel("Temperature [K]", fontsize=15) ax1.set_title("Temperature Evolution", fontsize=15) ax2.set_xlabel("Time", fontsize=15) ax2.set_ylabel("Radius [cm]", fontsize=15) ax2.set_title("Radius Evolution from GP", fontsize=15) ax1.set_yscale('log') ax2.set_yscale('log') ax1.legend() ax2.legend() plt.subplots_adjust(wspace=0.3) return gp_T, gp_R, fig, (ax1, ax2) else: return gp_T, gp_R
[docs] def generate_new_transient_data_from_gp(gp_out, t_new, transient, **kwargs): """ Generates new transient data based on Gaussian Process (GP) predictions for the given time array and transient object. Depending on the data mode of the transient object (e.g., 'flux_density', 'flux', 'magnitude', or 'luminosity'), this function updates the data accordingly, adjusting errors and scaling by frequency if necessary. :param gp_out: The GP output object containing the Gaussian Process model, scaled data, and other related attributes. :type gp_out: object :param t_new: Array of new time values for which GP predictions are to be generated. :type t_new: array-like :param transient: The transient object containing the original observation data and related properties such as data mode and unique frequencies or bands. :type transient: object :param kwargs: Additional parameters to modify behavior, such as: - **inflate_y_err** (bool): Flag to indicate whether to inflate GP errors. - **error** (float): Multiplier for adjusting GP error inflation. :return: A new transient object with data updated using GP predictions. :rtype: object """ data_mode = transient.data_mode logger.info(f"Data mode: {data_mode}") logger.info("Creating new {} data".format(data_mode)) if data_mode not in ['flux_density', 'flux', 'magnitude', 'luminosity']: raise ValueError("Data mode {} not understood".format(data_mode)) if kwargs.get('inflate_y_err', True): error = kwargs.get('error', 10) else: logger.info("Using GP predicted errors, this is likely being too conservative") error = 1. if gp_out.use_frequency: logger.info("GP is a 2D kernel with effective frequency") freqs = transient.unique_frequencies T, F = np.meshgrid(t_new, freqs) try: bands = redback.utils.frequency_to_bandname(F.flatten()) except Exception: bands = F.flatten().astype(str) X_new = np.column_stack((F.flatten(), T.flatten())) y_pred, y_var = gp_out.gp.predict(gp_out.scaled_y, X_new, return_var=True) y_std = np.sqrt(y_var) y_err = y_std * error y_pred = y_pred * gp_out.y_scaler tts = T.flatten() freqs = F.flatten() else: logger.info("GP is a 1D kernel") if data_mode == 'flux_density': logger.warning("Bandnames/frequency attributes for the transient object may be weird, " "Please check for yourself") tts = [] ys = [] yerrs = [] bbs = [] for key in gp_out.gp.keys(): gp = gp_out.gp[key] y_pred, y_cov = gp.predict(gp_out.scaled_y[key], t_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_err = y_std * error y_pred = y_pred * gp_out.y_scaler _bands = np.repeat(key, len(t_new)) bbs.append(key) tts.append(t_new) ys.append(y_pred) yerrs.append(y_err) temp_frame = pd.DataFrame({'time': tts, 'ys': ys, 'yerr': yerrs, 'band': bbs}) temp_frame.sort_values('time', inplace=True) y_pred = temp_frame['ys'] y_err = temp_frame['yerr'] bands = temp_frame['band'] freqs = temp_frame['band'] tts = temp_frame['time'] elif data_mode in ['flux', 'magnitude']: tts = [] ys = [] yerrs = [] bbs = [] for band in transient.unique_bands: gp = gp_out.gp[band] y_pred, y_cov = gp.predict(gp_out.scaled_y[band], t_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_err = y_std * error y_pred = y_pred * gp_out.y_scaler _bands = np.repeat(band, len(t_new)) bbs.append(_bands) tts.append(t_new) ys.append(y_pred) yerrs.append(y_err) temp_frame = pd.DataFrame({'time':tts, 'ys':ys, 'yerr':yerrs, 'band':bbs}) temp_frame.sort_values('time', inplace=True) y_pred = temp_frame['ys'] y_err = temp_frame['yerr'] bands = temp_frame['band'] tts = temp_frame['time'] elif data_mode == 'luminosity': y_pred, y_cov = gp_out.gp.predict(gp_out.scaled_y, t_new, return_cov=True) y_std = np.sqrt(np.diag(y_cov)) y_err = y_std * error y_pred = y_pred * gp_out.y_scaler tts = t_new logger.info(f"Data mode: {data_mode}") logger.info("Creating new transient object with GP data") if data_mode == 'flux_density': new_transient = redback.transient.OpticalTransient(name=transient.name + '_gp', flux_density=y_pred, flux_density_err=y_err, time=tts, bands=bands, frequency=freqs, data_mode=data_mode, redshift=transient.redshift) elif data_mode == 'flux': new_transient = redback.transient.OpticalTransient(name=transient.name + '_gp', flux=y_pred, flux_err=y_err, time=tts, bands=bands, data_mode=data_mode, redshift=transient.redshift) elif data_mode == 'magnitude': new_transient = redback.transient.OpticalTransient(name=transient.name + '_gp', magnitude=y_pred, magnitude_err=y_err, time=tts, bands=bands, data_mode=data_mode, redshift=transient.redshift) elif data_mode == 'luminosity': new_transient = redback.transient.OpticalTransient(name=transient.name + '_gp', Lum50=y_pred, Lum50_err=y_err, time_rest_frame=tts, data_mode=data_mode) return new_transient
[docs] class SpectralVelocityFitter: """ Measure expansion velocities from spectral line profiles Used for: - Photospheric velocity evolution - High-velocity features (HVF) - Velocity gradients (dv/dt) Parameters ---------- wavelength : array Wavelength array in Angstroms flux : array Flux density array flux_err : array, optional Flux density uncertainties Examples -------- >>> fitter = SpectralVelocityFitter(wavelength, flux) >>> v_Si, v_err = fitter.measure_line_velocity(6355) >>> print(f"Si II velocity: {v_Si:.0f} +/- {v_err:.0f} km/s") """
[docs] def __init__(self, wavelength, flux, flux_err=None): """ Initialize SpectralVelocityFitter Parameters ---------- wavelength : array Wavelength array in Angstroms flux : array Flux density array flux_err : array, optional Flux density uncertainties """ self.wavelength = np.asarray(wavelength) self.flux = np.asarray(flux) if flux_err is not None: self.flux_err = np.asarray(flux_err) else: self.flux_err = None
[docs] @classmethod def from_spectrum_object(cls, spectrum): """ Create fitter from a redback Spectrum object Parameters ---------- spectrum : object Object with .angstroms and .flux_density attributes Returns ------- fitter : SpectralVelocityFitter Initialized fitter object """ wavelength = spectrum.angstroms flux = spectrum.flux_density flux_err = getattr(spectrum, 'flux_density_err', None) return cls(wavelength, flux, flux_err)
[docs] def measure_line_velocity(self, line_rest_wavelength, method='min', **kwargs): """ Measure velocity from single absorption line Parameters ---------- line_rest_wavelength : float Rest wavelength in Angstroms (e.g., 6355 for Si II) method : str 'min' - use minimum flux (standard) 'centroid' - use flux-weighted centroid 'fit' - fit P-Cygni profile 'gaussian' - fit Gaussian to absorption trough kwargs : dict Additional parameters: - v_window : float Velocity window for search (km/s, default 5000) - continuum_percentile : float Percentile for continuum estimation (default 90) Returns ------- velocity : float Measured velocity in km/s (negative = blueshift) velocity_err : float Uncertainty in km/s Examples -------- >>> fitter = SpectralVelocityFitter(wavelength, flux) >>> v_Si, verr = fitter.measure_line_velocity(6355, method='min') >>> print(f"Si II velocity: {v_Si:.0f} +/- {verr:.0f} km/s") """ c_kms = 299792.458 v_window = kwargs.get('v_window', 5000) # km/s # Extract region around line lambda_window = line_rest_wavelength * v_window / c_kms mask = ((self.wavelength > line_rest_wavelength - lambda_window) & (self.wavelength < line_rest_wavelength + lambda_window)) if np.sum(mask) < 5: logger.warning(f"Insufficient data points around {line_rest_wavelength} A") return np.nan, np.nan wave_line = self.wavelength[mask] flux_line = self.flux[mask] if method == 'min': # Find minimum flux (absorption trough) imin = np.argmin(flux_line) lambda_min = wave_line[imin] # Convert to velocity velocity = c_kms * (lambda_min - line_rest_wavelength) / line_rest_wavelength # Error estimate from nearby points n_err = min(3, len(wave_line) // 4) if n_err > 0: # Estimate error from wavelength resolution dlambda = np.median(np.diff(wave_line)) velocity_err = c_kms * dlambda / line_rest_wavelength else: velocity_err = 100 # default estimate elif method == 'centroid': # Flux-weighted centroid (inverse for absorption) # Use inverse flux for absorption features continuum_pct = kwargs.get('continuum_percentile', 90) continuum = np.percentile(flux_line, continuum_pct) # Absorption depth absorption = continuum - flux_line absorption[absorption < 0] = 0 if np.sum(absorption) > 0: lambda_centroid = np.sum(wave_line * absorption) / np.sum(absorption) velocity = c_kms * (lambda_centroid - line_rest_wavelength) / line_rest_wavelength # Error from scatter in absorption variance = np.sum(absorption * (wave_line - lambda_centroid)**2) / np.sum(absorption) lambda_err = np.sqrt(variance / np.sum(absorption > 0)) velocity_err = c_kms * lambda_err / line_rest_wavelength else: velocity = 0.0 velocity_err = 500.0 elif method == 'gaussian': # Fit Gaussian to absorption trough from scipy.optimize import curve_fit # Estimate continuum continuum_pct = kwargs.get('continuum_percentile', 90) continuum = np.percentile(flux_line, continuum_pct) def gaussian_absorption(wave, center, depth, sigma): return continuum * (1 - depth * np.exp(-0.5 * ((wave - center) / sigma)**2)) # Initial guess imin = np.argmin(flux_line) center_guess = wave_line[imin] depth_guess = (continuum - flux_line[imin]) / continuum sigma_guess = 10.0 # Angstroms try: popt, pcov = curve_fit( gaussian_absorption, wave_line, flux_line, p0=[center_guess, depth_guess, sigma_guess], bounds=([wave_line.min(), 0.01, 1.0], [wave_line.max(), 1.0, 200.0]) ) lambda_center = popt[0] velocity = c_kms * (lambda_center - line_rest_wavelength) / line_rest_wavelength velocity_err = c_kms * np.sqrt(pcov[0, 0]) / line_rest_wavelength except Exception as e: logger.warning(f"Gaussian fit failed: {e}") return self.measure_line_velocity(line_rest_wavelength, method='min') elif method == 'fit': # Fit P-Cygni profile from scipy.optimize import curve_fit from redback.transient_models.spectral_models import p_cygni_profile # Continuum level continuum_pct = kwargs.get('continuum_percentile', 90) continuum = np.percentile(flux_line, continuum_pct) def pcygni_model(wave, tau, v_phot): return p_cygni_profile( wave, line_rest_wavelength, tau, v_phot, continuum, **kwargs ) # Initial guess from minimum imin = np.argmin(flux_line) lambda_min = wave_line[imin] v_guess = np.abs(c_kms * (lambda_min - line_rest_wavelength) / line_rest_wavelength) try: popt, pcov = curve_fit( pcygni_model, wave_line, flux_line, p0=[3.0, v_guess], bounds=([0.1, 1000], [100, 50000]) ) velocity = -popt[1] # blueshifted, so negative velocity_err = np.sqrt(pcov[1, 1]) except Exception as e: logger.warning(f"P-Cygni fit failed: {e}") return self.measure_line_velocity(line_rest_wavelength, method='min') else: raise ValueError(f"Unknown method: {method}") return velocity, velocity_err
[docs] def measure_multiple_lines(self, line_dict, method='min', **kwargs): """ Measure velocities for multiple lines Parameters ---------- line_dict : dict {'Si II 6355': 6355, 'Fe II 5169': 5169, ...} method : str Method for velocity measurement (default 'min') kwargs : dict Additional parameters passed to measure_line_velocity Returns ------- velocities : dict {'Si II 6355': (v, v_err), ...} Examples -------- >>> lines = { ... 'Si II 6355': 6355, ... 'Ca II H&K': 3934, ... 'Fe II 5169': 5169 ... } >>> velocities = fitter.measure_multiple_lines(lines) >>> for ion, (v, verr) in velocities.items(): ... print(f"{ion}: {v:.0f} +/- {verr:.0f} km/s") """ velocities = {} for ion_name, rest_wave in line_dict.items(): try: v, verr = self.measure_line_velocity(rest_wave, method=method, **kwargs) velocities[ion_name] = (v, verr) except Exception as e: logger.warning(f"Could not measure {ion_name}: {e}") velocities[ion_name] = (np.nan, np.nan) return velocities
[docs] @staticmethod def photospheric_velocity_evolution(wavelength_list, flux_list, times, line_wavelength=6355, method='min', **kwargs): """ Track photospheric velocity evolution over time Parameters ---------- wavelength_list : list of arrays Wavelength arrays for each spectrum flux_list : list of arrays Flux arrays for each spectrum times : array Observation times (days) line_wavelength : float Which line to use (default Si II 6355) method : str Velocity measurement method Returns ------- times : array Observation times velocities : array Measured velocities (km/s) errors : array Velocity uncertainties (km/s) Examples -------- >>> times, vels, errs = SpectralVelocityFitter.photospheric_velocity_evolution( ... wavelength_list, flux_list, obs_times, line_wavelength=6355 ... ) >>> plt.errorbar(times, -vels/1000, yerr=errs/1000) >>> plt.xlabel('Days since explosion') >>> plt.ylabel('Photospheric velocity (1000 km/s)') """ velocities = [] errors = [] for wave, flux in zip(wavelength_list, flux_list): fitter = SpectralVelocityFitter(wave, flux) v, verr = fitter.measure_line_velocity(line_wavelength, method=method, **kwargs) velocities.append(v) errors.append(verr) return np.array(times), np.array(velocities), np.array(errors)
[docs] def identify_high_velocity_features(self, line_rest_wavelength, v_phot_expected, threshold_factor=1.3): """ Identify high-velocity features (HVF) in the spectrum HVFs are absorption features at higher velocities than the photosphere, often associated with circumstellar material or density enhancements. Parameters ---------- line_rest_wavelength : float Rest wavelength of the line in Angstroms v_phot_expected : float Expected photospheric velocity in km/s threshold_factor : float Factor above v_phot to classify as HVF (default 1.3) Returns ------- has_hvf : bool Whether HVF is detected v_hvf : float or None Velocity of HVF if detected (km/s) v_hvf_err : float or None Uncertainty in HVF velocity Examples -------- >>> has_hvf, v_hvf, v_err = fitter.identify_high_velocity_features( ... 6355, v_phot_expected=11000 ... ) >>> if has_hvf: ... print(f"HVF detected at {-v_hvf:.0f} km/s") """ c_kms = 299792.458 # Search for features at higher velocities v_search_max = v_phot_expected * 2.0 # Search up to 2x photospheric velocity v_search_min = v_phot_expected * threshold_factor lambda_min = line_rest_wavelength * (1 - v_search_max / c_kms) lambda_max = line_rest_wavelength * (1 - v_search_min / c_kms) mask = (self.wavelength > lambda_min) & (self.wavelength < lambda_max) if np.sum(mask) < 3: return False, None, None wave_hvf = self.wavelength[mask] flux_hvf = self.flux[mask] # Look for local minimum if len(flux_hvf) > 2: imin = np.argmin(flux_hvf) lambda_min = wave_hvf[imin] # Check if it's a significant absorption continuum = np.percentile(self.flux, 90) absorption_depth = (continuum - flux_hvf[imin]) / continuum if absorption_depth > 0.05: # At least 5% absorption v_hvf = c_kms * (lambda_min - line_rest_wavelength) / line_rest_wavelength dlambda = np.median(np.diff(wave_hvf)) if len(wave_hvf) > 1 else 5.0 v_hvf_err = c_kms * dlambda / line_rest_wavelength return True, v_hvf, v_hvf_err return False, None, None
[docs] def measure_velocity_gradient(self, wavelength_list, flux_list, times, line_wavelength=6355, **kwargs): """ Measure velocity gradient dv/dt from time series of spectra Parameters ---------- wavelength_list : list of arrays Wavelength arrays for each spectrum flux_list : list of arrays Flux arrays for each spectrum times : array Observation times (days) line_wavelength : float Which line to use kwargs : dict Additional parameters passed to measure_line_velocity (e.g., v_window, method) Returns ------- gradient : float Velocity gradient in km/s/day gradient_err : float Uncertainty in gradient Notes ----- The velocity gradient is typically negative (decelerating) for normal SNe Ia (around -50 to -100 km/s/day), but can be different for peculiar objects. """ times, velocities, errors = self.photospheric_velocity_evolution( wavelength_list, flux_list, times, line_wavelength, **kwargs ) # Remove NaN values valid = ~np.isnan(velocities) if np.sum(valid) < 2: return np.nan, np.nan times_valid = times[valid] vel_valid = velocities[valid] err_valid = errors[valid] # Linear fit from numpy.polynomial import polynomial as P # Weighted fit if errors available if np.all(err_valid > 0) and np.all(~np.isnan(err_valid)): weights = 1 / err_valid**2 coeffs = np.polyfit(times_valid, vel_valid, deg=1, w=weights) else: coeffs = np.polyfit(times_valid, vel_valid, deg=1) gradient = coeffs[0] # km/s/day # Error estimate residuals = vel_valid - np.polyval(coeffs, times_valid) if len(times_valid) > 2: gradient_err = np.std(residuals) / np.sqrt(np.sum((times_valid - np.mean(times_valid))**2)) else: gradient_err = np.nan return gradient, gradient_err
[docs] class ClassificationResult(dict): """ Result of spectral or photometric transient classification. Behaves as a plain dict (for backward compatibility) while also providing convenience attributes and methods. The dict contains the keys: ``best_type``, ``best_phase``, ``best_redshift``, ``correlation`` (= rlap), ``type_probabilities``, ``top_matches``, plus ``confidence``, ``best_template_name``, ``best_template_source``, ``method``, ``warnings``. Quality interpretation for rlap (spectral matching): - rlap > 8: high confidence match - rlap 5–8: medium confidence - rlap < 5: low confidence, treat with caution """
[docs] def __init__(self, best_type: str, best_phase: float, best_redshift: float, rlap: float, confidence: str, type_probabilities: dict, top_matches: list, best_template_name: str, best_template_source: Optional[str] = None, method: str = 'rlap', warnings: Optional[list] = None): super().__init__( best_type=best_type, best_phase=best_phase, best_redshift=best_redshift, correlation=rlap, # alias for backward compat rlap=rlap, confidence=confidence, type_probabilities=type_probabilities, top_matches=top_matches, best_template_name=best_template_name, best_template_source=best_template_source, method=method, warnings=warnings if warnings is not None else [], )
# Convenience attribute access via dict keys @property def best_type(self) -> str: return self['best_type'] @property def best_phase(self) -> float: return self['best_phase'] @property def best_redshift(self) -> float: return self['best_redshift'] @property def rlap(self) -> float: return self['rlap'] @property def confidence(self) -> str: return self['confidence'] @property def type_probabilities(self) -> dict: return self['type_probabilities'] @property def top_matches(self) -> list: return self['top_matches'] @property def best_template_name(self) -> str: return self['best_template_name'] @property def best_template_source(self) -> Optional[str]: return self['best_template_source'] @property def method(self) -> str: return self['method'] @property def warnings(self) -> list: return self['warnings']
[docs] def summary(self) -> str: """Return a human-readable classification summary.""" lines = [ f"Classification: Type {self.best_type}", f"Phase: {self.best_phase:+.1f} days from maximum", f"Redshift: {self.best_redshift:.4f}", f"Quality (rlap): {self.rlap:.1f} [{self.confidence} confidence]", "", "Type probabilities:", ] for t, p in sorted(self.type_probabilities.items(), key=lambda x: -x[1]): lines.append(f" {t:10s}: {p * 100:5.1f}%") if self.warnings: lines.append("\nWarnings:") for w in self.warnings: lines.append(f" - {w}") return "\n".join(lines)
[docs] def to_dict(self) -> dict: """Return a plain dict copy (for explicit serialisation).""" return dict(self)
[docs] class SpectralTemplateMatcher(object): """ Match spectra to template library (similar to SNID). Supports Pearson correlation, chi-squared, and the recommended SNID-style rlap cross-correlation metric. Templates can be loaded from SNID .lnw files, CSV/DAT libraries, downloaded from OSC or GitHub, or generated from sncosmo spectral models. The default template library uses sncosmo models (SALT2 for Type Ia, v19-1998bw for Ic-BL, nugent templates for Ib/c / IIP / IIn, and s11-2004hx for generic Type II), providing realistic spectral shapes at multiple phases for each type. The default matching method is 'rlap', which cross-correlates in log-wavelength space (= velocity space) and is shift-invariant — a small redshift error does not degrade the match quality. A good match has rlap > 5; an excellent match has rlap > 10. """ # sncosmo source names and corresponding SN types for the default template library. # Each entry: (sncosmo_source_name, sn_type_label, phases_to_sample) _SNCOSMO_TEMPLATE_SOURCES = [ ('salt2', 'Ia', [-10, -5, 0, 5, 10, 15, 20]), ('v19-1998bw', 'Ic-BL', [-5, 0, 5, 10, 15, 20]), ('nugent-sn1bc', 'Ib/c', [0, 5, 10, 15, 20, 30]), ('nugent-sn2p', 'IIP', [0, 10, 20, 30, 50, 80]), ('nugent-sn2n', 'IIn', [0, 10, 30, 60]), ('s11-2004hx', 'II', [0, 10, 20, 30, 50]), ]
[docs] def __init__(self, template_library_path: Optional[Union[str, Path]] = None, templates: Optional[list] = None) -> None: """ Initialize the SpectralTemplateMatcher with a template library. :param template_library_path: Path to a directory containing template files (CSV/DAT format). If None and templates is None, uses built-in sncosmo templates (SALT2, 1998bw, Nugent templates, etc.). :param templates: List of template dictionaries to use directly. Each template should have keys: 'wavelength', 'flux', 'type', 'phase', and optionally 'name'. """ if templates is not None: self.templates = templates elif template_library_path is not None: self.templates = self._load_templates(template_library_path) else: self.templates = self._load_default_templates() logger.info(f"Loaded {len(self.templates)} templates into the matcher")
def _load_default_templates(self) -> list: """ Load built-in templates from sncosmo spectral models. Uses SALT2 (Type Ia), v19-1998bw (Ic-BL / 1998bw-like), nugent-sn1bc (Ib/c), nugent-sn2p (IIP), nugent-sn2n (IIn), and s11-2004hx (II). :return: List of template dictionaries """ logger.info("Loading default spectral templates from sncosmo") return self.generate_sncosmo_templates()
[docs] @classmethod def generate_sncosmo_templates(cls, sources: Optional[list] = None, wavelength_range: tuple = (3500, 9000), n_wavelength: int = 1000) -> list: """ Generate spectral templates from sncosmo source models. By default uses SALT2, v19-1998bw (SN 1998bw / Ic-BL), nugent-sn1bc, nugent-sn2p, nugent-sn2n, and s11-2004hx. Each source is sampled at a set of representative phases. :param sources: Optional list of ``(source_name, type_label, phases)`` tuples to override the default set (``_SNCOSMO_TEMPLATE_SOURCES``). :param wavelength_range: (min, max) wavelength in Angstroms :param n_wavelength: Number of wavelength points :return: List of template dicts with keys 'wavelength', 'flux', 'type', 'phase', 'name', 'source' """ import sncosmo wavelengths = np.linspace(wavelength_range[0], wavelength_range[1], n_wavelength) source_list = sources if sources is not None else cls._SNCOSMO_TEMPLATE_SOURCES templates = [] for source_name, sn_type, phases in source_list: try: src = sncosmo.get_source(source_name) except Exception as e: logger.warning(f"Could not load sncosmo source '{source_name}': {e}") continue for phase in phases: # Skip phases outside the model's valid range if phase < src.minphase() or phase > src.maxphase(): continue try: # Clip wavelength range to model validity wave_lo = max(wavelength_range[0], src.minwave()) wave_hi = min(wavelength_range[1], src.maxwave()) wave = np.linspace(wave_lo, wave_hi, n_wavelength) flux = src.flux(phase, wave) flux = np.asarray(flux, dtype=float) if not np.all(np.isfinite(flux)): logger.warning( f"Template {source_name} phase {phase} contains non-finite " "values; skipping." ) continue max_flux = np.max(np.abs(flux)) if max_flux <= 0: continue flux = flux / max_flux safe_type = sn_type.replace('/', '-') templates.append({ 'wavelength': wave, 'flux': flux, 'type': sn_type, 'phase': float(phase), 'name': f'{source_name}_{safe_type}_phase{phase:+d}', 'source': source_name, }) except Exception as e: logger.warning( f"Failed to generate template {source_name} phase {phase}: {e}" ) logger.info(f"Generated {len(templates)} sncosmo templates") return templates
[docs] @classmethod def generate_synthetic_templates(cls, sn_types: Optional[list] = None, wavelength_range: tuple = (3500, 9000), n_wavelength: int = 1000, r_photosphere: float = 1e15) -> list: """ Generate spectral templates using sncosmo models (legacy alias for :meth:`generate_sncosmo_templates`). This method is retained for backward compatibility. New code should call :meth:`generate_sncosmo_templates` directly. :param sn_types: Ignored (kept for API compatibility). :param wavelength_range: (min, max) wavelength in Angstroms :param n_wavelength: Number of wavelength points :param r_photosphere: Ignored (kept for API compatibility). :return: List of template dicts """ return cls.generate_sncosmo_templates( wavelength_range=wavelength_range, n_wavelength=n_wavelength, )
@staticmethod def _blackbody_flux(wavelengths: np.ndarray, temperature: float) -> np.ndarray: """ Compute a simple Planck blackbody flux (arbitrary units). :param wavelengths: Wavelength array in Angstroms :param temperature: Temperature in Kelvin :return: Flux array proportional to B_lambda(T), same shape as wavelengths """ # h*c/k_B in Angstrom*K hc_over_k = 1.43878e8 # Angstrom * K wave = np.asarray(wavelengths, dtype=float) exponent = hc_over_k / (wave * temperature) # Clip exponent to avoid overflow exponent = np.clip(exponent, 0, 700) flux = wave ** (-5) / (np.exp(exponent) - 1.0) return flux @staticmethod def _flatten_spectrum(flux: np.ndarray, smooth_sigma: int = 30) -> np.ndarray: """ Remove continuum by dividing by a Gaussian-smoothed version, returning zero-mean fractional deviations. This isolates spectral features from the broad continuum shape — equivalent to the 'flattening' step in SNID. :param flux: Flux array (on a uniform log-wavelength grid) :param smooth_sigma: Gaussian smoothing width in pixels. 30 pixels at dlog_lambda=0.001 corresponds to ~7000 km/s — removes broad continuum but preserves spectral lines. :return: Flattened flux array (zero mean, dimensionless) """ from scipy.ndimage import gaussian_filter1d continuum = gaussian_filter1d(flux.astype(float), sigma=smooth_sigma) continuum = np.where(np.abs(continuum) > 0, continuum, 1e-30) return flux / continuum - 1.0 @staticmethod def _compute_rlap(obs_wave: np.ndarray, obs_flux: np.ndarray, tmpl_wave: np.ndarray, tmpl_flux: np.ndarray, dlog_lambda: float = 0.001, smooth_sigma: int = 30, tmpl_pre_flattened: bool = False) -> tuple: """ Compute the SNID-style rlap quality metric and best-fit redshift via cross-correlation in log-wavelength (= velocity) space. The algorithm: 1. Build a common log-lambda grid over the wavelength overlap. 2. Interpolate both spectra onto the grid. 3. Flatten both (remove continuum) using ``_flatten_spectrum``. 4. Apply a Hanning window to suppress edge ringing. 5. Cross-correlate via FFT; normalise by geometric mean of auto-correlations. 6. The CCF peak position gives the best-fit redshift offset. 7. rlap = |peak_CCF| * n_overlap_pixels (Blondin & Tonry 2007 definition). :param obs_wave: Observed wavelength array (Angstroms, ascending) :param obs_flux: Observed flux array :param tmpl_wave: Template wavelength array (Angstroms, rest frame) :param tmpl_flux: Template flux array :param dlog_lambda: Log-wavelength grid spacing (default 0.001 ≈ 230 km/s/pixel) :param smooth_sigma: Smoothing sigma for continuum removal (pixels) :return: (rlap, z_best, ccf_array, z_lag_array). Returns (0.0, 0.0, None, None) on failure. """ from scipy.interpolate import interp1d log_obs = np.log10(obs_wave) log_tmpl = np.log10(tmpl_wave) log_min = max(log_obs.min(), log_tmpl.min()) log_max = min(log_obs.max(), log_tmpl.max()) if log_max <= log_min: return 0.0, 0.0, None, None n_grid = int((log_max - log_min) / dlog_lambda) if n_grid < 20: return 0.0, 0.0, None, None log_grid = np.linspace(log_min, log_max, n_grid) wave_grid = 10.0 ** log_grid f_obs = interp1d(obs_wave, obs_flux, bounds_error=False, fill_value=0.0) f_tmpl = interp1d(tmpl_wave, tmpl_flux, bounds_error=False, fill_value=0.0) obs_resampled = f_obs(wave_grid) tmpl_resampled = f_tmpl(wave_grid) obs_flat = SpectralTemplateMatcher._flatten_spectrum(obs_resampled, smooth_sigma) tmpl_flat = (tmpl_resampled if tmpl_pre_flattened else SpectralTemplateMatcher._flatten_spectrum(tmpl_resampled, smooth_sigma)) # Hanning taper to suppress edge ringing taper = np.hanning(n_grid) obs_flat = obs_flat * taper tmpl_flat = tmpl_flat * taper # Cross-correlation via FFT fft_obs = np.fft.rfft(obs_flat) fft_tmpl = np.fft.rfft(tmpl_flat) ccf = np.fft.irfft(fft_obs * np.conj(fft_tmpl), n=n_grid) # Normalise by geometric mean of auto-correlations ac_obs = float(np.sum(obs_flat ** 2)) ac_tmpl = float(np.sum(tmpl_flat ** 2)) norm = np.sqrt(ac_obs * ac_tmpl) if (ac_obs > 0 and ac_tmpl > 0) else 1.0 ccf_norm = ccf / norm # Map lags to redshift offsets: lag in pixels → delta(log_lambda) → delta_z lags = np.fft.fftfreq(n_grid, d=1.0 / n_grid) # pixel lags, unshifted log_lags = lags * dlog_lambda z_offsets = 10.0 ** log_lags - 1.0 # Shift to centre (lag=0 in the middle) ccf_shifted = np.fft.fftshift(ccf_norm) z_shifted = np.fft.fftshift(z_offsets) i_peak = int(np.argmax(np.abs(ccf_shifted))) r_peak = float(ccf_shifted[i_peak]) z_best = float(z_shifted[i_peak]) # rlap follows Blondin & Tonry (2007): # rlap = r * lap # where r is the normalised CCF peak (in [-1, 1]) and lap is the # fractional overlap of the log-wavelength range, scaled to [0, 10]. # A good match has rlap > 5; an excellent match has rlap > 8. log_union_min = min(log_obs.min(), log_tmpl.min()) log_union_max = max(log_obs.max(), log_tmpl.max()) n_union = max(int((log_union_max - log_union_min) / dlog_lambda), 1) lap = (n_grid / n_union) * 10.0 # fractional overlap scaled to [0, 10] rlap = abs(r_peak) * lap return rlap, z_best, ccf_shifted, z_shifted def _load_templates(self, library_path: Union[str, Path]) -> list: """ Load templates from a directory of files. Expected file format: CSV or whitespace-separated with columns: wavelength (Angstroms), flux File naming convention: {type}_{phase}.csv or {type}_{phase}.dat e.g., Ia_+5.csv, II_10.dat :param library_path: Path to template library directory :return: List of template dictionaries """ library_path = Path(library_path) if not library_path.exists(): raise FileNotFoundError(f"Template library path not found: {library_path}") templates = [] # Look for CSV and DAT files template_files = list(library_path.glob("*.csv")) + list(library_path.glob("*.dat")) if len(template_files) == 0: raise ValueError(f"No template files found in {library_path}") for file_path in template_files: try: # Try to parse filename for type and phase stem = file_path.stem parts = stem.split('_') if len(parts) >= 2: sn_type = parts[0] phase_str = parts[1].replace('+', '') try: phase = float(phase_str) except ValueError: phase = 0.0 else: sn_type = stem phase = 0.0 # Load data - first check for metadata in comments with open(file_path, 'r') as f: for line in f: line = line.strip() if line.startswith('#'): # Check for metadata in comments if 'Type:' in line or 'type:' in line: try: sn_type = line.split(':')[1].strip() except IndexError: pass if 'Phase:' in line or 'phase:' in line: try: phase = float(line.split(':')[1].strip()) except (IndexError, ValueError): pass if file_path.suffix == '.csv': # Count comment lines and header row to skip skip_count = 0 with open(file_path, 'r') as f: for line in f: if line.strip().startswith('#') or line.strip().startswith('wavelength'): skip_count += 1 else: break data = np.loadtxt(file_path, delimiter=',', skiprows=skip_count) else: data = np.loadtxt(file_path, comments='#') wavelength = data[:, 0] flux = data[:, 1] # Normalize flux flux = flux / np.max(flux) templates.append({ 'wavelength': wavelength, 'flux': flux, 'type': sn_type, 'phase': phase, 'name': stem }) logger.info(f"Loaded template: {stem}") except Exception as e: logger.warning(f"Failed to load template {file_path}: {e}") continue return templates
[docs] def add_template(self, wavelength: np.ndarray, flux: np.ndarray, sn_type: str, phase: float, name: Optional[str] = None) -> None: """ Add a single template to the library. :param wavelength: Wavelength array in Angstroms :param flux: Flux array (will be normalized) :param sn_type: Type classification (e.g., 'Ia', 'II', 'Ib/c') :param phase: Phase in days from maximum light :param name: Optional name for the template """ if name is None: name = f"{sn_type}_phase_{phase}" flux_normalized = flux / np.max(flux) self.templates.append({ 'wavelength': wavelength, 'flux': flux_normalized, 'type': sn_type, 'phase': phase, 'name': name }) logger.info(f"Added template: {name}")
[docs] def match_spectrum(self, spectrum, redshift_range: tuple = (0, 0.5), n_redshift_points: int = 50, method: str = 'rlap', return_all_matches: bool = False, rlap_threshold: float = 0.0) -> Union[dict, list, None]: """ Find the best-matching template for an observed spectrum. :param spectrum: Spectrum object with angstroms and flux_density attributes :param redshift_range: (z_min, z_max) to restrict the redshift search. For method='rlap', the best-fit redshift comes directly from the CCF peak and is clipped to this range. For 'correlation' and 'chi2', a grid of n_redshift_points values is searched. :param n_redshift_points: Grid points for 'correlation'/'chi2' methods. Ignored for method='rlap'. :param method: Matching method: - 'rlap' (default): SNID-style cross-correlation in log-wavelength space. Shift-invariant. Returns rlap quality metric (>5 good, >8 excellent). Recommended for all real use. - 'correlation': Pearson correlation on a redshift grid. Legacy method. - 'chi2': Chi-squared on a redshift grid (requires flux errors for meaningful values). - 'both': Pearson + chi2 with combined normalised score. :param return_all_matches: If True, return the full sorted list of match dicts. :param rlap_threshold: Minimum rlap to include in results (default 0 = no filter). :return: Best match dict (or sorted list if return_all_matches=True), or None. Match dict keys: 'type', 'phase', 'redshift', 'rlap', 'correlation', 'template_name', and (if applicable) 'chi2', 'reduced_chi2', 'scale_factor'. """ from scipy.interpolate import interp1d from scipy.stats import pearsonr if len(self.templates) == 0: raise ValueError("No templates loaded. Add templates before matching.") all_matches = [] obs_wavelength = spectrum.angstroms obs_flux = spectrum.flux_density norm_factor = np.max(np.abs(obs_flux)) obs_flux_norm = obs_flux / norm_factor has_errors = hasattr(spectrum, 'flux_density_err') and spectrum.flux_density_err is not None if has_errors: obs_flux_err_norm = spectrum.flux_density_err / norm_factor # --- rlap path: one CCF per template, no redshift grid needed --- if method == 'rlap': # If a non-trivial redshift range is specified, first check that at # least one template overlaps the spectrum at the requested redshift. # If none do, return None immediately (e.g. z_min=2 pushes all # templates far out of the observed wavelength range). z_lo, z_hi = redshift_range if z_lo > 0: has_overlap_at_z = False for template in self.templates: tmpl_wave_shifted = template['wavelength'] * (1.0 + z_lo) if (np.min(tmpl_wave_shifted) < np.max(obs_wavelength) and np.max(tmpl_wave_shifted) > np.min(obs_wavelength)): has_overlap_at_z = True break if not has_overlap_at_z: logger.warning( f"No template overlaps the observed wavelength range at " f"redshift_range={redshift_range}. Returning None." ) return None from scipy.stats import pearsonr from scipy.interpolate import interp1d as _interp1d for template in self.templates: rlap, z_best, ccf, z_arr = self._compute_rlap( obs_wavelength, obs_flux_norm, template['wavelength'], template['flux'], tmpl_pre_flattened=template.get('pre_flattened', False), ) # ccf is None when there is no wavelength overlap — skip entirely if ccf is None: continue # Clip z_best to the requested range z_best = float(np.clip(z_best, z_lo, z_hi)) if rlap < rlap_threshold: continue # Compute Pearson correlation at the best-fit redshift for the # 'correlation' key (used by downstream tests / backward compat) pearson_corr = rlap # default fallback try: tmpl_wave_z = template['wavelength'] * (1.0 + z_best) f_tmpl = _interp1d(tmpl_wave_z, template['flux'], bounds_error=False, fill_value=np.nan) tmpl_interp = f_tmpl(obs_wavelength) valid = (~np.isnan(tmpl_interp) & np.isfinite(obs_flux_norm)) if np.sum(valid) >= 5: corr_val, _ = pearsonr(obs_flux_norm[valid], tmpl_interp[valid]) pearson_corr = float(corr_val) except Exception: pass all_matches.append({ 'type': template['type'], 'phase': template['phase'], 'redshift': z_best, 'rlap': rlap, 'correlation': pearson_corr, 'template_name': template.get('name', f"{template['type']}_p{template['phase']}"), 'template_source': template.get('source', 'unknown'), 'n_valid_points': len(ccf), }) # Sort by correlation (Pearson) for consistent ordering with other methods all_matches.sort(key=lambda x: -x['correlation']) # --- Pearson / chi2 / both: grid search over redshift --- else: for template in self.templates: for z in np.linspace(redshift_range[0], redshift_range[1], n_redshift_points): template_wave_obs = template['wavelength'] * (1.0 + z) min_overlap = max(np.min(template_wave_obs), np.min(obs_wavelength)) max_overlap = min(np.max(template_wave_obs), np.max(obs_wavelength)) if max_overlap <= min_overlap: continue interp_func = interp1d(template_wave_obs, template['flux'], bounds_error=False, fill_value=np.nan) template_flux_interp = interp_func(obs_wavelength) valid_mask = (~np.isnan(template_flux_interp) & ~np.isnan(obs_flux_norm) & (template_flux_interp != 0)) if np.sum(valid_mask) < 10: continue obs_valid = obs_flux_norm[valid_mask] template_valid = template_flux_interp[valid_mask] match_result = { 'type': template['type'], 'phase': template['phase'], 'redshift': z, 'rlap': 0.0, 'template_name': template.get('name', f"{template['type']}_p{template['phase']}"), 'n_valid_points': int(np.sum(valid_mask)), } if method in ('correlation', 'both'): try: corr, p_value = pearsonr(obs_valid, template_valid) match_result['correlation'] = float(corr) match_result['p_value'] = float(p_value) except Exception: match_result['correlation'] = -1.0 match_result['p_value'] = 1.0 if method in ('chi2', 'both'): if has_errors: err_valid = obs_flux_err_norm[valid_mask] scale = (np.sum(obs_valid * template_valid / err_valid ** 2) / np.sum(template_valid ** 2 / err_valid ** 2)) residuals = obs_valid - scale * template_valid chi2 = float(np.sum((residuals / err_valid) ** 2)) match_result['chi2'] = chi2 match_result['reduced_chi2'] = chi2 / max(len(obs_valid) - 1, 1) match_result['scale_factor'] = float(scale) else: scale = (np.sum(obs_valid * template_valid) / np.sum(template_valid ** 2)) residuals = obs_valid - scale * template_valid chi2 = float(np.sum(residuals ** 2) / max(np.var(obs_valid), 1e-30)) match_result['chi2'] = chi2 match_result['scale_factor'] = float(scale) all_matches.append(match_result) if len(all_matches) == 0: logger.warning("No valid matches found. Check wavelength coverage and templates.") return None if method == 'chi2': all_matches.sort(key=lambda x: x.get('chi2', np.inf)) elif method == 'correlation': all_matches.sort(key=lambda x: -x.get('correlation', -1.0)) else: # both — combined normalised score corr_vals = np.array([m.get('correlation', 0.0) for m in all_matches]) chi2_vals = np.array([m.get('reduced_chi2', np.inf) for m in all_matches]) c_range = np.ptp(corr_vals) q_range = np.ptp(chi2_vals) corr_norm = (corr_vals - corr_vals.min()) / (c_range if c_range > 0 else 1.0) chi2_norm = (chi2_vals - chi2_vals.min()) / (q_range if q_range > 0 else 1.0) for i, m in enumerate(all_matches): m['combined_score'] = float(corr_norm[i] - 0.3 * chi2_norm[i]) all_matches.sort(key=lambda x: -x.get('combined_score', 0.0)) if len(all_matches) == 0: logger.warning("No valid matches found. Check wavelength coverage and templates.") return None return all_matches if return_all_matches else all_matches[0]
[docs] def classify_spectrum(self, spectrum, redshift_range: tuple = (0, 0.5), n_redshift_points: int = 50, top_n: int = 10, rlap_threshold: float = 3.0) -> ClassificationResult: """ Classify a spectrum and return a :class:`ClassificationResult`. Type probabilities are computed via softmax over the mean rlap per type across the top_n matches. Using the mean (rather than sum) ensures that types with more templates in the library do not dominate. :param spectrum: Spectrum object with angstroms and flux_density attributes :param redshift_range: (z_min, z_max) redshift search range :param n_redshift_points: Grid points (only used if method falls back to Pearson when rlap fails) :param top_n: Number of top matches to use for probability estimation :param rlap_threshold: Matches below this rlap are excluded from the probability estimate. Set to 0 to include all matches. :return: :class:`ClassificationResult` instance """ all_matches = self.match_spectrum( spectrum, redshift_range=redshift_range, method='rlap', return_all_matches=True, ) warnings_list = [] if all_matches is None or len(all_matches) == 0: return ClassificationResult( best_type=None, best_phase=0.0, best_redshift=0.0, rlap=0.0, confidence='low', type_probabilities={}, top_matches=[], best_template_name='', method='rlap', warnings=['No valid matches found'], ) # For classification, rank by rlap (regardless of match_spectrum sort order) all_matches_by_rlap = sorted(all_matches, key=lambda x: -x.get('rlap', 0.0)) # Apply rlap threshold; fall back to all matches if nothing passes good_matches = [m for m in all_matches_by_rlap if m.get('rlap', 0) >= rlap_threshold] if len(good_matches) == 0: good_matches = all_matches_by_rlap warnings_list.append( f"No matches exceeded rlap_threshold={rlap_threshold:.1f}. " f"Best rlap was {all_matches_by_rlap[0].get('rlap', 0):.2f}. " "Classification may be unreliable." ) top_matches = good_matches[:min(top_n, len(good_matches))] # Aggregate mean rlap per type from collections import defaultdict type_rlap = defaultdict(list) for m in top_matches: type_rlap[m['type']].append(m.get('rlap', 0.0)) type_mean_rlap = {t: float(np.mean(v)) for t, v in type_rlap.items()} # Softmax normalisation (numerically stable) max_rlap = max(type_mean_rlap.values()) exp_scores = {t: np.exp(s - max_rlap) for t, s in type_mean_rlap.items()} total = sum(exp_scores.values()) type_probabilities = {t: float(v / total) for t, v in exp_scores.items()} best_match = top_matches[0] best_rlap = float(best_match.get('rlap', 0.0)) confidence = 'high' if best_rlap > 8 else 'medium' if best_rlap > 5 else 'low' if best_rlap < 3.0: warnings_list.append( f"Best rlap={best_rlap:.2f} is below 3.0. " "Consider loading a larger or more appropriate template library." ) return ClassificationResult( best_type=best_match['type'], best_phase=float(best_match['phase']), best_redshift=float(best_match['redshift']), rlap=best_rlap, confidence=confidence, type_probabilities=type_probabilities, top_matches=top_matches, best_template_name=best_match.get('template_name', ''), best_template_source=best_match.get('template_source'), method='rlap', warnings=warnings_list, )
[docs] def plot_match(self, spectrum, match_result: dict, axes=None, **kwargs) -> matplotlib.axes.Axes: """ Plot observed spectrum against best-matching template. :param spectrum: Observed Spectrum object :param match_result: Result dictionary from match_spectrum :param axes: Optional matplotlib axes to plot on :param kwargs: Additional plotting arguments :return: Matplotlib axes object """ from scipy.interpolate import interp1d ax = axes or plt.gca() # Get observed spectrum obs_wavelength = spectrum.angstroms obs_flux_raw = spectrum.flux_density # Find matching template template = None for t in self.templates: if t.get('name') == match_result.get('template_name'): template = t break if template is None: # Try to find by type and phase for t in self.templates: if t['type'] == match_result['type'] and t['phase'] == match_result['phase']: template = t break if template is None: raise ValueError("Could not find matching template") pre_flattened = template.get('pre_flattened', False) # If template is already continuum-subtracted, the observed spectrum must # be put on the same scale. We flatten it (remove continuum via Gaussian # division) and then normalise by the RMS so the amplitudes are comparable # regardless of whether the input is raw or already continuum-subtracted. if pre_flattened: norm = np.max(np.abs(obs_flux_raw)) obs_flux_norm = obs_flux_raw / norm if norm > 0 else obs_flux_raw obs_flux_flat = SpectralTemplateMatcher._flatten_spectrum(obs_flux_norm) rms = np.sqrt(np.nanmean(obs_flux_flat ** 2)) obs_flux_plot = obs_flux_flat / rms if rms > 0 else obs_flux_flat ylabel = 'Continuum-subtracted Flux' else: obs_flux_plot = obs_flux_raw / np.max(obs_flux_raw) ylabel = 'Normalized Flux' # Redshift template z = match_result['redshift'] template_wave_obs = template['wavelength'] * (1 + z) # Interpolate template to observed wavelengths for comparison interp_func = interp1d(template_wave_obs, template['flux'], bounds_error=False, fill_value=np.nan) template_flux_interp = interp_func(obs_wavelength) # Scale template to match observed flux if 'scale_factor' in match_result: scale = match_result['scale_factor'] else: denom = np.nansum(template_flux_interp ** 2) scale = (np.nansum(obs_flux_plot * template_flux_interp) / denom if denom > 0 else 1.0) # Plot ax.plot(obs_wavelength, obs_flux_plot, 'k-', label='Observed', alpha=0.8, lw=1.5) ax.plot(obs_wavelength, scale * template_flux_interp, 'r--', label=f"Template: {match_result['type']} (phase={match_result['phase']:.0f}d, z={z:.3f})", alpha=0.8, lw=1.5) ax.set_xlabel(r'Wavelength ($\mathrm{\AA}$)') ax.set_ylabel(ylabel) rlap_val = match_result.get('rlap', match_result.get('correlation', 0)) title = f"Best Match: {match_result['type']}, rlap={rlap_val:.2f}" ax.set_title(title) ax.legend(loc='best') return ax
[docs] @staticmethod def get_available_template_sources() -> dict: """ Get information about available template sources. :return: Dictionary with source names and their descriptions/URLs """ return { 'snid_templates_2.0': { 'description': 'Official SNID templates v2.0 from Blondin & Tonry', 'url': 'https://people.lam.fr/blondin.stephane/software/snid/', 'download_url': 'https://people.lam.fr/blondin.stephane/software/snid/templates-2.0.tgz', 'citation': 'Blondin & Tonry 2007, ApJ, 666, 1024' }, 'super_snid': { 'description': 'Super-SNID expanded templates (841 spectra, 161 objects)', 'url': 'https://github.com/dkjmagill/QUB-SNID-Templates', 'zenodo_doi': '10.5281/zenodo.15167198', 'citation': 'Magill et al. 2025' }, 'sesn_templates': { 'description': 'Stripped-envelope SN templates from METAL collaboration', 'url': 'https://github.com/metal-sn/SESNtemple', 'citation': 'Williamson et al. 2023, Yesmin et al. 2024' }, }
[docs] @staticmethod def parse_snid_template_file(file_path: Union[str, Path]): """ Parse a SNID template file (.lnw format) or two-column ASCII template. For proper SNID .lnw files (Blondin & Tonry 2007), returns a list of dicts, one per epoch. For simple two-column ASCII files, returns a single dict. The SNID .lnw format (Blondin & Tonry 2007, Appendix B): **Line 1 — object header (8 tokens):** ``nwave nspec type_code type_string redshift age_of_max dm15 name`` **Next nwave tokens — log10(wavelength) array** (may span multiple lines). wavelength = 10^token (Angstroms). The grid is log-spaced. **Then nspec epoch blocks, each with:** - One header line: ``phase_days <ignored>`` - nwave flux tokens (may span multiple lines) For two-column ASCII files, metadata can be provided via header comments:: # Type: IIn # Phase: -3.5 Comments are parsed case-insensitively. If a comment key is present but has no valid value, the filename is used as fallback. :param file_path: Path to a SNID .lnw template file or two-column ASCII :return: For .lnw files: list of template dicts, one per epoch. For ASCII files: a single template dict. Each dict has keys: 'wavelength', 'flux', 'type', 'phase', 'name', 'source' """ file_path = Path(file_path) name = file_path.stem # --- Parse comment metadata (case-insensitive) from header lines --- comment_type = None comment_phase = None with open(file_path, 'r') as fh: for line in fh: stripped = line.strip() if not stripped: continue if not stripped.startswith('#'): break # stop at first data line lower = stripped.lower() if 'type:' in lower: try: comment_type = stripped.split(':', 1)[1].strip() except IndexError: comment_type = '' if 'phase:' in lower: try: comment_phase = float(stripped.split(':', 1)[1].strip()) except (IndexError, ValueError): comment_phase = None # mark as failed, fall back to filename # --- Tokenise the file (skip comment lines starting with '#') --- tokens = [] with open(file_path, 'r') as fh: for line in fh: stripped = line.strip() if not stripped or stripped.startswith('#'): continue tokens.extend(stripped.split()) # --- Attempt .lnw header parse (requires ≥8 tokens and valid header) --- if len(tokens) >= 8: try: nwave = int(tokens[0]) nspec = int(tokens[1]) # tokens[2] is integer type code — skip sn_type = tokens[3] source_redshift = float(tokens[4]) age_of_max = float(tokens[5]) # tokens[6] is dm15 — skip obj_name = tokens[7] if nwave < 10 or nspec < 1: raise ValueError("Implausible nwave/nspec") required = 8 + nwave + nspec * (1 + nwave) if len(tokens) < required: raise ValueError( f"Not enough tokens: need {required}, have {len(tokens)}" ) # Read log-wavelength array pos = 8 log_wave = np.array([float(tokens[pos + i]) for i in range(nwave)]) wavelengths = 10.0 ** log_wave # Angstroms pos += nwave templates = [] for epoch_idx in range(nspec): epoch_phase = float(tokens[pos]) - age_of_max pos += 2 # phase + one ignored token flux = np.array([float(tokens[pos + i]) for i in range(nwave)]) pos += nwave max_flux = np.max(np.abs(flux)) if max_flux > 0: flux = flux / max_flux templates.append({ 'wavelength': wavelengths, 'flux': flux, 'type': sn_type, 'phase': float(epoch_phase), 'name': f"{obj_name}_epoch{epoch_idx}", 'source': 'snid', }) return templates except (ValueError, IndexError): pass # Fall through to Super-SNID format attempt # --- Attempt Super-SNID matrix format --- # Header: nspec nwave wmin wmax nfeatures name redshift type ... # Followed by nfeatures rows of feature data, then one epoch-header line # starting with 0 containing all nspec phases, then nwave rows each with # wavelength followed by nspec flux values (already continuum-subtracted). lines_raw = [] with open(file_path, 'r') as fh: for line in fh: stripped = line.strip() if stripped and not stripped.startswith('#'): lines_raw.append(stripped) if len(lines_raw) >= 3: try: header = lines_raw[0].split() nspec = int(header[0]) nwave = int(header[1]) sn_type = header[7] obj_name = header[5] if nspec < 1 or nwave < 10: raise ValueError("Implausible nspec/nwave") # The nfeatures metadata rows follow the header. # After that comes the epoch-header line (starts with '0') and # then nwave data rows. We scan forward to find the epoch line. epoch_line_idx = None for i in range(1, len(lines_raw)): parts = lines_raw[i].split() if parts[0] == '0' and len(parts) == nspec + 1: epoch_line_idx = i break if epoch_line_idx is None: raise ValueError("Could not find epoch header line") phases = [float(p) for p in lines_raw[epoch_line_idx].split()[1:]] # Read nwave data rows (wavelength + nspec flux values) data_start = epoch_line_idx + 1 if len(lines_raw) < data_start + nwave: raise ValueError("Not enough data rows") wavelengths = np.zeros(nwave) flux_matrix = np.zeros((nwave, nspec)) for i in range(nwave): row = lines_raw[data_start + i].split() wavelengths[i] = float(row[0]) for j in range(nspec): flux_matrix[i, j] = float(row[j + 1]) templates = [] for j, phase in enumerate(phases): flux = flux_matrix[:, j] # Flux is already continuum-subtracted; skip zero-only epochs if np.max(np.abs(flux)) == 0: continue templates.append({ 'wavelength': wavelengths, 'flux': flux, 'type': sn_type, 'phase': float(phase), 'name': f"{obj_name}_p{phase:+.1f}", 'source': 'super_snid', 'pre_flattened': True, }) if templates: return templates except (ValueError, IndexError): pass # Fall through to two-column ASCII fallback # --- Fallback: two-column ASCII (wavelength, flux) --- # Manual parser: skips comment lines and malformed rows, uses only # the first two numeric columns (extra columns are ignored). wave_list, flux_list = [], [] with open(file_path, 'r') as fh: for line in fh: stripped = line.strip() if not stripped or stripped.startswith('#'): continue parts = stripped.split() if len(parts) < 2: continue # skip single-value or empty lines try: w = float(parts[0]) f = float(parts[1]) wave_list.append(w) flux_list.append(f) except ValueError: continue # skip non-numeric lines if len(wave_list) < 2: raise ValueError( f"Could not parse {file_path}: fewer than 2 valid data rows found" ) wavelengths = np.array(wave_list) flux = np.array(flux_list) max_flux = np.max(np.abs(flux)) if max_flux > 0: flux = flux / max_flux # Infer type/phase from filename (e.g. sn1999aa_Ia_+5.dat) sn_type = 'Unknown' phase = 0.0 for part in name.split('_')[1:]: if part in ('Ia', 'Ib', 'Ic', 'II', 'IIn', 'IIP', 'IIL', 'Ic-BL', 'Ia-pec'): sn_type = part else: try: phase = float(part) except ValueError: pass # Override with comment metadata if present if comment_type is not None: sn_type = comment_type if comment_phase is not None: phase = comment_phase return { 'wavelength': wavelengths, 'flux': flux, 'type': sn_type, 'phase': phase, 'name': name, 'source': 'ascii', }
[docs] @staticmethod def download_github_templates(repo_url: str, branch: str = "master", subdirectory: str = "", cache_dir: Optional[Union[str, Path]] = None) -> Path: """ Download template files from a GitHub repository. :param repo_url: GitHub repository URL (e.g., 'https://github.com/metal-sn/SESNtemple') :param branch: Branch name (default: 'master') :param subdirectory: Subdirectory within repo containing templates :param cache_dir: Local directory to cache files. If None, uses ~/.redback/spectral_templates/ :return: Path to downloaded template directory """ import urllib.request import zipfile import tempfile if cache_dir is None: cache_dir = Path.home() / '.redback' / 'spectral_templates' else: cache_dir = Path(cache_dir).expanduser() cache_dir.mkdir(parents=True, exist_ok=True) # Parse repo URL to get owner and repo name parts = repo_url.rstrip('/').split('/') repo_name = parts[-1] owner = parts[-2] # Create unique cache directory for this repo repo_cache = cache_dir / f"{owner}_{repo_name}" if repo_cache.exists() and any(repo_cache.iterdir()): logger.info(f"Using cached templates from {repo_cache}") if subdirectory: return repo_cache / subdirectory return repo_cache # Download zip archive zip_url = f"https://github.com/{owner}/{repo_name}/archive/refs/heads/{branch}.zip" logger.info(f"Downloading templates from {zip_url}") try: with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file: urllib.request.urlretrieve(zip_url, tmp_file.name) tmp_path = tmp_file.name # Extract zip with zipfile.ZipFile(tmp_path, 'r') as zip_ref: zip_ref.extractall(cache_dir) # Rename extracted directory extracted_dir = cache_dir / f"{repo_name}-{branch}" if extracted_dir.exists(): if repo_cache.exists(): import shutil shutil.rmtree(repo_cache) extracted_dir.rename(repo_cache) Path(tmp_path).unlink() # Clean up zip file logger.info(f"Templates downloaded to {repo_cache}") if subdirectory: return repo_cache / subdirectory return repo_cache except Exception as e: logger.error(f"Failed to download templates: {e}") raise
[docs] @classmethod def from_sesn_templates(cls, cache_dir: Optional[Union[str, Path]] = None) -> 'SpectralTemplateMatcher': """ Create matcher from METAL/SESNtemple stripped-envelope SN templates. Downloads templates from: https://github.com/metal-sn/SESNtemple :param cache_dir: Local cache directory (default: ~/.redback/spectral_templates/) :return: SpectralTemplateMatcher instance """ template_dir = cls.download_github_templates( 'https://github.com/metal-sn/SESNtemple', subdirectory='SNIDtemplates', cache_dir=cache_dir ) return cls.from_snid_template_directory(template_dir, recursive=True)
[docs] @classmethod def from_super_snid_templates(cls, cache_dir: Optional[Union[str, Path]] = None) -> 'SpectralTemplateMatcher': """ Create matcher from the Super-SNID template library (Magill et al. 2025). Downloads the repository from https://github.com/dkjmagill/QUB-SNID-Templates, extracts the inner templates.zip, and loads all .lnw template files. Please cite: Magill et al. 2025 (Zenodo DOI: 10.5281/zenodo.15167198) :param cache_dir: Local cache directory (default: ~/.redback/spectral_templates/) :return: SpectralTemplateMatcher instance """ import zipfile repo_dir = cls.download_github_templates( 'https://github.com/dkjmagill/QUB-SNID-Templates', branch='main', cache_dir=cache_dir ) templates_dir = repo_dir / 'templates' if not templates_dir.exists(): # Extract the inner templates.zip inner_zip = repo_dir / 'templates.zip' if not inner_zip.exists(): raise FileNotFoundError( f"Could not find templates.zip in {repo_dir}. " "The repository structure may have changed." ) logger.info(f"Extracting {inner_zip} ...") with zipfile.ZipFile(inner_zip, 'r') as zf: zf.extractall(repo_dir) if not templates_dir.exists(): raise FileNotFoundError( f"Expected a 'templates/' directory in {repo_dir} after extraction." ) return cls.from_snid_template_directory(templates_dir)
[docs] @classmethod def from_snid_template_directory(cls, directory: Union[str, Path], file_pattern: str = "*.lnw", recursive: bool = False) -> 'SpectralTemplateMatcher': """ Create a SpectralTemplateMatcher from a directory of SNID template files. :param directory: Path to directory containing SNID template files :param file_pattern: Glob pattern for template files (default: "*.lnw") :param recursive: If True, search subdirectories recursively (default: False) :return: SpectralTemplateMatcher instance """ directory = Path(directory) if not directory.exists(): raise FileNotFoundError(f"Directory not found: {directory}") glob_fn = directory.rglob if recursive else directory.glob template_files = list(glob_fn(file_pattern)) if len(template_files) == 0: # Try other common extensions template_files = (list(glob_fn("*.lnw")) + list(glob_fn("*.dat")) + list(glob_fn("*.txt"))) if len(template_files) == 0: raise ValueError(f"No template files found in {directory}") templates = [] for file_path in template_files: try: result = cls.parse_snid_template_file(file_path) # parse_snid_template_file returns a list for .lnw files # and a single dict for ASCII files if isinstance(result, list): templates.extend(result) logger.info(f"Loaded {len(result)} epoch(s) from {file_path.name}") else: templates.append(result) logger.info(f"Loaded 1 template from {file_path.name}") except Exception as e: logger.warning(f"Failed to load {file_path}: {e}") continue logger.info(f"Loaded {len(templates)} total template epochs from {directory}") return cls(templates=templates)
[docs] def save_templates(self, output_dir: Union[str, Path], format: str = 'csv') -> None: """ Save current templates to disk for later use. :param output_dir: Directory to save templates :param format: Output format ('csv' or 'dat') """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) for template in self.templates: safe_name = template['name'].replace('/', '-').replace('\\', '-') filename = f"{safe_name}.{format}" filepath = output_dir / filename data = np.column_stack([template['wavelength'], template['flux']]) if format == 'csv': # Save with metadata as comments, then header, then data with open(filepath, 'w') as f: f.write(f"# Type: {template['type']}\n") f.write(f"# Phase: {template['phase']}\n") f.write("wavelength,flux\n") for row in data: f.write(f"{row[0]},{row[1]}\n") else: np.savetxt(filepath, data, header=f"Type: {template['type']}\nPhase: {template['phase']}") logger.info(f"Saved {len(self.templates)} templates to {output_dir}")
[docs] def filter_templates(self, types: Optional[list] = None, phase_range: Optional[tuple] = None) -> 'SpectralTemplateMatcher': """ Create a new matcher with filtered templates. :param types: List of SN types to include (e.g., ['Ia', 'Ib']) :param phase_range: Tuple of (min_phase, max_phase) in days :return: New SpectralTemplateMatcher with filtered templates """ filtered = self.templates.copy() if types is not None: filtered = [t for t in filtered if t['type'] in types] if phase_range is not None: min_phase, max_phase = phase_range filtered = [t for t in filtered if min_phase <= t['phase'] <= max_phase] logger.info(f"Filtered to {len(filtered)} templates") return SpectralTemplateMatcher(templates=filtered)
[docs] class PhotometricClassifier: """ Classify transients from light curve shape using redback photometric models. Compares an observed normalised light curve against a set of representative model light curves using dynamic time warping (DTW), which is robust to 10–20 day timing offsets between objects of the same type. Returns a :class:`ClassificationResult` with method='photometric'. """ # Default model templates: (model_name, parameters, label) _DEFAULT_MODEL_PARAMS = [ ('arnett', dict(f_nickel=0.6, mej=1.2, vej=10000, kappa=0.1, kappa_gamma=10.0, temperature_floor=3000, redshift=0.01), 'Ia'), ('arnett', dict(f_nickel=0.05, mej=5.0, vej=5000, kappa=0.07, kappa_gamma=10.0, temperature_floor=3500, redshift=0.01), 'IIP'), ('basic_magnetar_powered', dict(P0=2.0, Bp=1e14, Mns=1.4, chi=90.0, mej=5.0, vej=8000, kappa=0.1, kappa_gamma=10.0, redshift=0.05), 'SLSN-I'), ('arnett', dict(f_nickel=0.2, mej=3.0, vej=15000, kappa=0.08, kappa_gamma=10.0, temperature_floor=3000, redshift=0.02), 'Ic-BL'), ]
[docs] def __init__(self, model_templates: Optional[list] = None) -> None: """ :param model_templates: List of (model_name, parameters_dict, label) tuples. If None, uses built-in defaults. """ self.model_templates = model_templates or self._DEFAULT_MODEL_PARAMS self._lc_cache = {}
@staticmethod def _dtw_distance(a: np.ndarray, b: np.ndarray) -> float: """ Compute Dynamic Time Warping distance between two 1-D sequences. Uses a simple O(N*M) cumulative distance matrix without Sakoe-Chiba band. Both sequences should already be normalised (peak = 1). :param a: First sequence :param b: Second sequence :return: DTW distance (lower = more similar) """ n, m = len(a), len(b) dtw = np.full((n + 1, m + 1), np.inf) dtw[0, 0] = 0.0 for i in range(1, n + 1): for j in range(1, m + 1): cost = abs(a[i - 1] - b[j - 1]) dtw[i, j] = cost + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1]) return float(dtw[n, m]) def _evaluate_model_lc(self, model_name: str, params: dict, time_grid: np.ndarray) -> Optional[np.ndarray]: """Evaluate a redback model on time_grid; returns normalised flux or None.""" key = (model_name, tuple(sorted(params.items()))) if key in self._lc_cache: return self._lc_cache[key] try: from redback.model_library import all_models_dict func = all_models_dict[model_name] flux = func(time_grid, **params) flux = np.asarray(flux, dtype=float) peak = np.max(flux) if peak > 0: flux = flux / peak self._lc_cache[key] = flux return flux except Exception as e: logger.warning(f"PhotometricClassifier: failed to evaluate model '{model_name}': {e}") self._lc_cache[key] = None return None
[docs] def classify_from_lightcurve(self, transient, time_grid: Optional[np.ndarray] = None, top_n: int = 5) -> ClassificationResult: """ Classify a transient from its bolometric or single-band light curve shape. :param transient: A redback transient object with ``time`` and a flux/ luminosity attribute, or any object with ``time`` and ``flux_density`` arrays. :param time_grid: Time grid (days) on which to evaluate models. If None, uses the transient's own time array. :param top_n: Number of top matches to use for probability estimation. :return: :class:`ClassificationResult` with method='photometric'. """ # Extract observed LC obs_time = np.asarray(getattr(transient, 'time', None) or getattr(transient, 'time_days', None)) # Try common flux attributes for attr in ('flux_density', 'Lum50', 'magnitude', 'counts'): obs_flux = getattr(transient, attr, None) if obs_flux is not None: obs_flux = np.asarray(obs_flux, dtype=float) break if obs_time is None or obs_flux is None: return ClassificationResult( best_type='Unknown', best_phase=0.0, best_redshift=0.0, rlap=0.0, confidence='low', type_probabilities={}, top_matches=[], best_template_name='', method='photometric', warnings=['Could not extract time/flux from transient'], ) # Normalise observed flux peak = np.max(np.abs(obs_flux)) if peak > 0: obs_norm = obs_flux / peak else: obs_norm = obs_flux if time_grid is None: time_grid = obs_time all_matches = [] for model_name, params, label in self.model_templates: model_lc = self._evaluate_model_lc(model_name, params, time_grid) if model_lc is None: continue # Interpolate model onto observed time points from scipy.interpolate import interp1d f_interp = interp1d(time_grid, model_lc, bounds_error=False, fill_value=(model_lc[0], model_lc[-1])) model_at_obs = f_interp(obs_time) dist = self._dtw_distance(obs_norm, model_at_obs) all_matches.append({ 'type': label, 'phase': 0.0, 'redshift': 0.0, 'rlap': float(1.0 / (dist + 1e-6)), # invert distance to rlap-like score 'correlation': float(1.0 / (dist + 1e-6)), 'template_name': f'{model_name}_{label}', 'dtw_distance': dist, }) if len(all_matches) == 0: return ClassificationResult( best_type='Unknown', best_phase=0.0, best_redshift=0.0, rlap=0.0, confidence='low', type_probabilities={}, top_matches=[], best_template_name='', method='photometric', warnings=['No model templates could be evaluated'], ) all_matches.sort(key=lambda x: x['dtw_distance']) top_matches = all_matches[:min(top_n, len(all_matches))] # Softmax over negative DTW distances (lower dist = better) from collections import defaultdict type_dists = defaultdict(list) for m in top_matches: type_dists[m['type']].append(m['dtw_distance']) type_mean_dist = {t: float(np.mean(v)) for t, v in type_dists.items()} min_dist = min(type_mean_dist.values()) exp_scores = {t: np.exp(-(d - min_dist)) for t, d in type_mean_dist.items()} total = sum(exp_scores.values()) type_probabilities = {t: float(v / total) for t, v in exp_scores.items()} best = top_matches[0] best_score = best['rlap'] confidence = 'high' if best['dtw_distance'] < 0.5 else \ 'medium' if best['dtw_distance'] < 2.0 else 'low' return ClassificationResult( best_type=best['type'], best_phase=best['phase'], best_redshift=best['redshift'], rlap=best_score, confidence=confidence, type_probabilities=type_probabilities, top_matches=top_matches, best_template_name=best['template_name'], method='photometric', )
[docs] def combine_classifications(spectral_result: ClassificationResult, photometric_result: ClassificationResult, spectral_weight: float = 0.7) -> ClassificationResult: """ Combine spectral and photometric classification results into a single estimate. Type probabilities are computed as a weighted average: ``p_combined = spectral_weight * p_spectral + (1 - spectral_weight) * p_photometric`` :param spectral_result: :class:`ClassificationResult` from :meth:`SpectralTemplateMatcher.classify_spectrum` :param photometric_result: :class:`ClassificationResult` from :meth:`PhotometricClassifier.classify_from_lightcurve` :param spectral_weight: Weight given to spectral classification (0–1). Default 0.7 reflects that spectral features are more discriminating. :return: Combined :class:`ClassificationResult` with method='combined' """ photo_weight = 1.0 - spectral_weight # Merge type sets all_types = set(spectral_result.type_probabilities) | set(photometric_result.type_probabilities) combined_probs = {} for t in all_types: p_spec = spectral_result.type_probabilities.get(t, 0.0) p_phot = photometric_result.type_probabilities.get(t, 0.0) combined_probs[t] = spectral_weight * p_spec + photo_weight * p_phot # Normalise (in case the two probability dicts don't cover the same types) total = sum(combined_probs.values()) if total > 0: combined_probs = {t: v / total for t, v in combined_probs.items()} best_type = max(combined_probs, key=combined_probs.get) # Take the best-redshift and best-phase from the spectral result (more precise) combined_rlap = (spectral_weight * spectral_result.rlap + photo_weight * photometric_result.rlap) confidence = 'high' if combined_rlap > 8 else 'medium' if combined_rlap > 5 else 'low' warnings = spectral_result.warnings + photometric_result.warnings return ClassificationResult( best_type=best_type, best_phase=spectral_result.best_phase, best_redshift=spectral_result.best_redshift, rlap=combined_rlap, confidence=confidence, type_probabilities=combined_probs, top_matches=spectral_result.top_matches, best_template_name=spectral_result.best_template_name, best_template_source=spectral_result.best_template_source, method='combined', warnings=warnings, )