"""Kernel density estimate of the normalized redshift distribution n(z).
Uses a Gaussian kernel. The default bandwidth follows the weighted
Silverman rule-of-thumb.
References
----------
Silverman (1986), Density Estimation for Statistics and Data Analysis, §3.4.2.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
from ..catalogue import GalaxyCatalogue
@jax.jit
def _nz_kde_jax(
z_grid: jnp.ndarray,
redshift: jnp.ndarray,
weight: jnp.ndarray,
bandwidth: float,
) -> jnp.ndarray:
"""JAX kernel: weighted Gaussian KDE evaluated on z_grid.
Parameters
----------
z_grid : jnp.ndarray, shape (M,)
redshift : jnp.ndarray, shape (N,)
weight : jnp.ndarray, shape (N,)
bandwidth : float
Gaussian kernel bandwidth.
Returns
-------
nz : jnp.ndarray, shape (M,)
n(z) evaluated at z_grid [dz^-1].
"""
W_total = jnp.sum(weight)
inv_Wh = 1.0 / (W_total * bandwidth)
norm = 1.0 / jnp.sqrt(2.0 * jnp.pi)
def _eval(z_j: jnp.ndarray) -> jnp.ndarray:
u = (z_j - redshift) / bandwidth
k = norm * jnp.exp(-0.5 * u**2)
return jnp.sum(weight * k) * inv_Wh
return jax.vmap(_eval)(z_grid)
def _silverman_bandwidth(redshift: np.ndarray, weight: np.ndarray) -> float:
"""Weighted Silverman rule-of-thumb bandwidth.
h = 1.06 * sigma_w * N_eff^{-1/5}
where N_eff = (sum w)^2 / sum(w^2).
"""
W = weight.sum()
z_mean = np.average(redshift, weights=weight)
sigma_w = np.sqrt(np.average((redshift - z_mean) ** 2, weights=weight))
n_eff = W**2 / (weight**2).sum()
return 1.06 * sigma_w * n_eff ** (-0.2)
[docs]
def nz_kde(
cat: GalaxyCatalogue,
z_grid: np.ndarray,
bandwidth: float | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Weighted Gaussian KDE estimate of the redshift PDF n(z).
Parameters
----------
cat : GalaxyCatalogue
Galaxy catalogue with redshift and weight arrays.
z_grid : array_like, shape (M,)
Redshift values at which to evaluate n(z).
bandwidth : float, optional
Gaussian kernel bandwidth in redshift units. Defaults to the
weighted Silverman rule-of-thumb: h = 1.06 * sigma_w * N_eff^{-1/5}.
Returns
-------
z_grid : ndarray, shape (M,)
Same redshift grid as input.
nz : ndarray, shape (M,)
KDE estimate of n(z) [dz^-1]. Integrates approximately to 1 over
the full redshift range of the catalogue.
References
----------
Silverman (1986), Density Estimation for Statistics and Data Analysis, §3.4.2.
"""
z_grid = np.asarray(z_grid, dtype=np.float64)
if bandwidth is None:
bandwidth = _silverman_bandwidth(cat.redshift, cat.weight)
nz = _nz_kde_jax(
jnp.array(z_grid),
jnp.array(cat.redshift),
jnp.array(cat.weight),
float(bandwidth),
)
return z_grid, np.array(nz)