Source code for sum_stat.nz.kde

"""Kernel density estimate of the normalized redshift distribution n(z).

Uses a Gaussian kernel.  The default bandwidth follows the weighted
Silverman rule-of-thumb.

References
----------
Silverman (1986), Density Estimation for Statistics and Data Analysis, §3.4.2.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np

from ..catalogue import GalaxyCatalogue


@jax.jit
def _nz_kde_jax(
    z_grid: jnp.ndarray,
    redshift: jnp.ndarray,
    weight: jnp.ndarray,
    bandwidth: float,
) -> jnp.ndarray:
    """JAX kernel: weighted Gaussian KDE evaluated on z_grid.

    Parameters
    ----------
    z_grid : jnp.ndarray, shape (M,)
    redshift : jnp.ndarray, shape (N,)
    weight : jnp.ndarray, shape (N,)
    bandwidth : float
        Gaussian kernel bandwidth.

    Returns
    -------
    nz : jnp.ndarray, shape (M,)
        n(z) evaluated at z_grid [dz^-1].
    """
    W_total = jnp.sum(weight)
    inv_Wh = 1.0 / (W_total * bandwidth)
    norm = 1.0 / jnp.sqrt(2.0 * jnp.pi)

    def _eval(z_j: jnp.ndarray) -> jnp.ndarray:
        u = (z_j - redshift) / bandwidth
        k = norm * jnp.exp(-0.5 * u**2)
        return jnp.sum(weight * k) * inv_Wh

    return jax.vmap(_eval)(z_grid)


def _silverman_bandwidth(redshift: np.ndarray, weight: np.ndarray) -> float:
    """Weighted Silverman rule-of-thumb bandwidth.

    h = 1.06 * sigma_w * N_eff^{-1/5}

    where N_eff = (sum w)^2 / sum(w^2).
    """
    W = weight.sum()
    z_mean = np.average(redshift, weights=weight)
    sigma_w = np.sqrt(np.average((redshift - z_mean) ** 2, weights=weight))
    n_eff = W**2 / (weight**2).sum()
    return 1.06 * sigma_w * n_eff ** (-0.2)


[docs] def nz_kde( cat: GalaxyCatalogue, z_grid: np.ndarray, bandwidth: float | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Weighted Gaussian KDE estimate of the redshift PDF n(z). Parameters ---------- cat : GalaxyCatalogue Galaxy catalogue with redshift and weight arrays. z_grid : array_like, shape (M,) Redshift values at which to evaluate n(z). bandwidth : float, optional Gaussian kernel bandwidth in redshift units. Defaults to the weighted Silverman rule-of-thumb: h = 1.06 * sigma_w * N_eff^{-1/5}. Returns ------- z_grid : ndarray, shape (M,) Same redshift grid as input. nz : ndarray, shape (M,) KDE estimate of n(z) [dz^-1]. Integrates approximately to 1 over the full redshift range of the catalogue. References ---------- Silverman (1986), Density Estimation for Statistics and Data Analysis, §3.4.2. """ z_grid = np.asarray(z_grid, dtype=np.float64) if bandwidth is None: bandwidth = _silverman_bandwidth(cat.redshift, cat.weight) nz = _nz_kde_jax( jnp.array(z_grid), jnp.array(cat.redshift), jnp.array(cat.weight), float(bandwidth), ) return z_grid, np.array(nz)