Source code for sum_stat.nz.histogram

"""Histogram estimator for the normalized redshift distribution n(z).

Normalizes the weighted counts so that

    integral n(z) dz = 1

References
----------
Sánchez et al. (2014), MNRAS 441, 2725.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np

from ..catalogue import GalaxyCatalogue


@jax.jit
def _nz_histogram_jax(
    redshift: jnp.ndarray,
    weight: jnp.ndarray,
    z_bins: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """JAX kernel: weighted histogram n(z) and Poisson-like error per bin.

    Parameters
    ----------
    redshift : jnp.ndarray, shape (N,)
    weight : jnp.ndarray, shape (N,)
    z_bins : jnp.ndarray, shape (n_bins+1,)

    Returns
    -------
    nz : jnp.ndarray, shape (n_bins,)
        Normalized redshift PDF [dz^-1].
    nz_err : jnp.ndarray, shape (n_bins,)
        Poisson-like uncertainty [dz^-1]: sqrt(sum w_i^2) / (W_total * dz).
    """
    n_bins = z_bins.size - 1
    W_total = jnp.sum(weight)

    def _bin(i):
        lo = z_bins[i]
        hi = z_bins[i + 1]
        dz = hi - lo
        mask = (redshift >= lo) & (redshift < hi)
        w_sum = jnp.sum(jnp.where(mask, weight, 0.0))
        w2_sum = jnp.sum(jnp.where(mask, weight**2, 0.0))
        nz_i = w_sum / (W_total * dz)
        nz_err_i = jnp.sqrt(w2_sum) / (W_total * dz)
        return nz_i, nz_err_i

    nz, nz_err = jax.vmap(_bin)(jnp.arange(n_bins))
    return nz, nz_err


[docs] def nz_histogram( cat: GalaxyCatalogue, z_bins: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Weighted histogram estimate of the redshift PDF n(z). Normalizes the distribution so that sum(nz * dz) = 1. Parameters ---------- cat : GalaxyCatalogue Galaxy catalogue with redshift and weight arrays. z_bins : array_like, shape (n_bins+1,) Bin edges in redshift. Returns ------- z_centres : ndarray, shape (n_bins,) Redshift bin centres. nz : ndarray, shape (n_bins,) Normalized redshift PDF [dz^-1]. Satisfies sum(nz * dz) = 1. nz_err : ndarray, shape (n_bins,) Poisson-like uncertainty [dz^-1]: sqrt(sum w_i^2) / (W_total * dz). References ---------- Sánchez et al. (2014), MNRAS 441, 2725. """ z_bins = np.asarray(z_bins, dtype=np.float64) z_centres = 0.5 * (z_bins[:-1] + z_bins[1:]) nz, nz_err = _nz_histogram_jax( jnp.array(cat.redshift), jnp.array(cat.weight), jnp.array(z_bins), ) return z_centres, np.array(nz), np.array(nz_err)