"""Histogram estimator for the normalized redshift distribution n(z).
Normalizes the weighted counts so that
integral n(z) dz = 1
References
----------
Sánchez et al. (2014), MNRAS 441, 2725.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
from ..catalogue import GalaxyCatalogue
@jax.jit
def _nz_histogram_jax(
redshift: jnp.ndarray,
weight: jnp.ndarray,
z_bins: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""JAX kernel: weighted histogram n(z) and Poisson-like error per bin.
Parameters
----------
redshift : jnp.ndarray, shape (N,)
weight : jnp.ndarray, shape (N,)
z_bins : jnp.ndarray, shape (n_bins+1,)
Returns
-------
nz : jnp.ndarray, shape (n_bins,)
Normalized redshift PDF [dz^-1].
nz_err : jnp.ndarray, shape (n_bins,)
Poisson-like uncertainty [dz^-1]: sqrt(sum w_i^2) / (W_total * dz).
"""
n_bins = z_bins.size - 1
W_total = jnp.sum(weight)
def _bin(i):
lo = z_bins[i]
hi = z_bins[i + 1]
dz = hi - lo
mask = (redshift >= lo) & (redshift < hi)
w_sum = jnp.sum(jnp.where(mask, weight, 0.0))
w2_sum = jnp.sum(jnp.where(mask, weight**2, 0.0))
nz_i = w_sum / (W_total * dz)
nz_err_i = jnp.sqrt(w2_sum) / (W_total * dz)
return nz_i, nz_err_i
nz, nz_err = jax.vmap(_bin)(jnp.arange(n_bins))
return nz, nz_err
[docs]
def nz_histogram(
cat: GalaxyCatalogue,
z_bins: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Weighted histogram estimate of the redshift PDF n(z).
Normalizes the distribution so that sum(nz * dz) = 1.
Parameters
----------
cat : GalaxyCatalogue
Galaxy catalogue with redshift and weight arrays.
z_bins : array_like, shape (n_bins+1,)
Bin edges in redshift.
Returns
-------
z_centres : ndarray, shape (n_bins,)
Redshift bin centres.
nz : ndarray, shape (n_bins,)
Normalized redshift PDF [dz^-1]. Satisfies sum(nz * dz) = 1.
nz_err : ndarray, shape (n_bins,)
Poisson-like uncertainty [dz^-1]: sqrt(sum w_i^2) / (W_total * dz).
References
----------
Sánchez et al. (2014), MNRAS 441, 2725.
"""
z_bins = np.asarray(z_bins, dtype=np.float64)
z_centres = 0.5 * (z_bins[:-1] + z_bins[1:])
nz, nz_err = _nz_histogram_jax(
jnp.array(cat.redshift),
jnp.array(cat.weight),
jnp.array(z_bins),
)
return z_centres, np.array(nz), np.array(nz_err)