Source code for sum_stat.knn

"""k-nearest-neighbor (kNN) summary statistics.

Three estimators are provided:

knn_cdf
    CDF of distances from random query points to the k-th nearest galaxy,
    compared to a Poisson (Erlang) reference.  Primary statistic from
    Banerjee & Abel (2021).

cross_knn_cdf
    Joint kNN-CDFs from query points to two distinct galaxy populations A and B.
    Captures cross-correlations between the populations.

knn_volume_map
    Per-query-point kNN sphere volumes V_k = 4π/3 · r_k³, for 2-D and 3-D
    density maps.  Reproduces ``VolumekNN`` from the external ``kNN_CDFs``
    library used in the GAMA analysis scripts.

The public helper :func:`comoving_xyz` converts a
:class:`~sum_stat.GalaxyCatalogue` to Cartesian comoving coordinates, which
is needed when constructing custom query-point grids for :func:`knn_volume_map`.

References
----------
Banerjee & Abel (2021)     https://arxiv.org/abs/2007.13342
Banerjee et al. (2021)     https://ui.adsabs.harvard.edu/abs/2021MNRAS.504.2911B
Banerjee et al. (2023)     https://ui.adsabs.harvard.edu/abs/2023MNRAS.519.4856B
Yuan et al. (2023)         https://ui.adsabs.harvard.edu/abs/2023MNRAS.522.3935Y
Gao et al. (2025)          https://ui.adsabs.harvard.edu/abs/2025MNRAS.543.3409G
Obreschkow et al. (2025)   https://ui.adsabs.harvard.edu/abs/2025arXiv250209709O
"""

from __future__ import annotations

import jax.numpy as jnp
import numpy as np
from astropy.cosmology import FlatLambdaCDM

from ..catalogue import GalaxyCatalogue
from .cdf import _empirical_cdf_jax, knn_poisson_cdf
from .distances import comoving_xyz, knn_query

# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _survey_volume_bbox(xyz: np.ndarray) -> float:
    """Bounding-box comoving volume [Mpc³] of a Cartesian point set."""
    extents = xyz.max(axis=0) - xyz.min(axis=0)
    return float(np.prod(extents))


def _make_query_xyz(
    ref_xyz: np.ndarray,
    rand: GalaxyCatalogue | None,
    cosmo: FlatLambdaCDM,
    n_query: int,
    seed: int,
) -> np.ndarray:
    """Return (N_query, 3) Cartesian comoving query points [Mpc].

    If ``rand`` is provided, query points are drawn (with replacement when
    n_query > rand.n) from the randoms catalogue, which traces the survey
    geometry.  Otherwise, points are drawn uniformly from the bounding box of
    ``ref_xyz``, which is only appropriate for simulation boxes.
    """
    rng = np.random.default_rng(seed)
    if rand is not None:
        replace = n_query > rand.n
        idx = rng.choice(rand.n, size=n_query, replace=replace)
        q = GalaxyCatalogue(ra=rand.ra[idx], dec=rand.dec[idx], redshift=rand.redshift[idx])
        return comoving_xyz(q, cosmo)
    lo = ref_xyz.min(axis=0)
    hi = ref_xyz.max(axis=0)
    return rng.uniform(lo, hi, size=(n_query, 3))


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def knn_cdf( gal: GalaxyCatalogue, cosmo: FlatLambdaCDM, k_values: np.ndarray, r_bins: np.ndarray, rand: GalaxyCatalogue | None = None, n_query: int = 100_000, n_bar: float | None = None, seed: int = 42, workers: int = -1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """k-nearest-neighbor CDF statistic F_k(r). For each of *n_query* random query points drawn from the survey volume, the distance d_k to the k-th nearest galaxy is recorded. The empirical CDF F_k(r) = fraction of query points with d_k ≤ r is returned together with the Poisson (Erlang) reference F_k^Pois(r). Parameters ---------- gal : GalaxyCatalogue Galaxy catalogue (``ra``, ``dec``, ``redshift`` required). cosmo : FlatLambdaCDM Cosmology for comoving distance conversion. k_values : array_like of int Neighbor orders to compute, e.g. ``np.arange(1, 6)``. r_bins : ndarray, shape (n_r+1,) Separation bin edges [Mpc]. rand : GalaxyCatalogue, optional Random catalogue tracing the survey geometry. Query points are drawn from this catalogue. If ``None``, uniform points inside the bounding box of galaxy positions are used — suitable only for simulation boxes. n_query : int Number of query points. n_bar : float, optional Effective galaxy number density [Mpc⁻³] used for the Poisson reference. If ``None``, estimated as ``sum(weights) / V_bounding_box``. Pass this explicitly when the bounding box is a poor proxy for the survey volume. seed : int RNG seed for query-point sampling. workers : int Parallel workers for the scipy fallback kd-tree (``-1`` = all CPUs). Ignored when pyfnntw is installed. Returns ------- r_centres : ndarray, shape (n_r,) Geometric-mean separation bin centres [Mpc]. F_k : ndarray, shape (n_k, n_r) Empirical kNN-CDF for each k in ``k_values``. F_k_poisson : ndarray, shape (n_k, n_r) Erlang (Poisson) reference CDF for each k. References ---------- Banerjee & Abel (2021), arXiv:2007.13342. Banerjee et al. (2021), MNRAS 504, 2911. Banerjee et al. (2023), MNRAS 519, 4856. Yuan et al. (2023), MNRAS 522, 3935. Gao et al. (2025), MNRAS 543, 3409. Obreschkow et al. (2025), arXiv:2502.09709. Performance ----------- ~2 s/call (N_gal = 5 000, N_query = 10 000, k_max = 5, pyfnntw, 16 CPUs) """ k_values = np.asarray(k_values, dtype=int) r_bins = np.asarray(r_bins, dtype=np.float64) k_max = int(k_values.max()) gal_xyz = comoving_xyz(gal, cosmo) query_xyz = _make_query_xyz(gal_xyz, rand, cosmo, n_query, seed) distances = knn_query(gal_xyz, query_xyz, k_max, workers=workers) # (N_q, k_max) if n_bar is None: n_bar = float(gal.weight.sum()) / _survey_volume_bbox(gal_xyz) r_bins_jax = jnp.asarray(r_bins) n_k = len(k_values) n_r = len(r_bins) - 1 F_k = np.zeros((n_k, n_r)) F_k_poisson = np.zeros((n_k, n_r)) for i, k in enumerate(k_values): d_k = jnp.asarray(distances[:, int(k) - 1]) F_k[i] = np.array(_empirical_cdf_jax(d_k, r_bins_jax)) F_k_poisson[i] = np.array(knn_poisson_cdf(int(k), r_bins_jax, float(n_bar))) r_centres = np.sqrt(r_bins[:-1] * r_bins[1:]) return r_centres, F_k, F_k_poisson
[docs] def cross_knn_cdf( gal_a: GalaxyCatalogue, gal_b: GalaxyCatalogue, cosmo: FlatLambdaCDM, k_values: np.ndarray, r_bins: np.ndarray, rand: GalaxyCatalogue | None = None, n_query: int = 100_000, seed: int = 42, workers: int = -1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Cross-kNN CDFs from shared query points to two galaxy populations. Computes F_k^A(r) and F_k^B(r) from the same set of query points: the fraction of query points whose k-th nearest neighbour in population A (or B) lies within distance r. The joint distribution captures cross-correlations between the two populations. Parameters ---------- gal_a : GalaxyCatalogue First galaxy population. gal_b : GalaxyCatalogue Second galaxy population. cosmo : FlatLambdaCDM Cosmology for comoving distance conversion. k_values : array_like of int Neighbor orders to compute. r_bins : ndarray, shape (n_r+1,) Separation bin edges [Mpc]. rand : GalaxyCatalogue, optional Random catalogue tracing the survey geometry. n_query : int Number of query points. seed : int RNG seed. workers : int Parallel workers for the scipy fallback kd-tree. Returns ------- r_centres : ndarray, shape (n_r,) Geometric-mean bin centres [Mpc]. F_k_a : ndarray, shape (n_k, n_r) kNN-CDF towards population A. F_k_b : ndarray, shape (n_k, n_r) kNN-CDF towards population B. References ---------- Banerjee & Abel (2021), arXiv:2007.13342, §4 (joint kNN). Banerjee et al. (2021), MNRAS 504, 2911. """ k_values = np.asarray(k_values, dtype=int) r_bins = np.asarray(r_bins, dtype=np.float64) k_max = int(k_values.max()) xyz_a = comoving_xyz(gal_a, cosmo) xyz_b = comoving_xyz(gal_b, cosmo) ref_xyz = np.concatenate([xyz_a, xyz_b], axis=0) query_xyz = _make_query_xyz(ref_xyz, rand, cosmo, n_query, seed) dist_a = knn_query(xyz_a, query_xyz, k_max, workers=workers) dist_b = knn_query(xyz_b, query_xyz, k_max, workers=workers) r_bins_jax = jnp.asarray(r_bins) n_k = len(k_values) n_r = len(r_bins) - 1 F_k_a = np.zeros((n_k, n_r)) F_k_b = np.zeros((n_k, n_r)) for i, k in enumerate(k_values): ki = int(k) - 1 F_k_a[i] = np.array(_empirical_cdf_jax(jnp.asarray(dist_a[:, ki]), r_bins_jax)) F_k_b[i] = np.array(_empirical_cdf_jax(jnp.asarray(dist_b[:, ki]), r_bins_jax)) r_centres = np.sqrt(r_bins[:-1] * r_bins[1:]) return r_centres, F_k_a, F_k_b
[docs] def knn_volume_map( gal: GalaxyCatalogue, query_xyz: np.ndarray, k_values: list[int], cosmo: FlatLambdaCDM, workers: int = -1, ) -> np.ndarray: """kNN sphere volumes V_k(x) = 4π/3 · r_k(x)³ at each query point. Reproduces the ``VolumekNN`` function from the external ``kNN_CDFs`` library used in the GAMA analysis scripts. The volumes encode the local galaxy density field and are useful for 2-D and 3-D density maps. Parameters ---------- gal : GalaxyCatalogue Galaxy catalogue (``ra``, ``dec``, ``redshift`` required). query_xyz : ndarray, shape (N_q, 3) Cartesian comoving coordinates of the query points [Mpc]. Build these from a regular grid using :func:`comoving_xyz` or a custom lattice (see GAMA_gal_all_volumekNN.py for an example). k_values : list of int Neighbor orders, e.g. ``[1, 2, 3, 4]``. cosmo : FlatLambdaCDM Cosmology for comoving distance conversion. workers : int Parallel workers for the scipy fallback kd-tree. Returns ------- volumes : ndarray, shape (N_q, len(k_values)) kNN sphere volume at each query point for each k [Mpc³]. Examples -------- >>> from sum_stat.knn import comoving_xyz, knn_volume_map >>> gal_xyz = comoving_xyz(gal, cosmo) # build a lattice grid >>> vols = knn_volume_map(gal, gal_xyz, [1, 2, 3, 4], cosmo) >>> print(vols.shape) # (N_gal, 4) References ---------- Banerjee & Abel (2021), arXiv:2007.13342. """ k_values = list(k_values) k_max = max(k_values) gal_xyz = comoving_xyz(gal, cosmo) distances = knn_query(gal_xyz, query_xyz, k_max, workers=workers) # (N_q, k_max) volumes = np.zeros((query_xyz.shape[0], len(k_values))) for j, k in enumerate(k_values): r_k = distances[:, int(k) - 1] volumes[:, j] = (4.0 * np.pi / 3.0) * r_k**3 return volumes
__all__ = [ "knn_cdf", "cross_knn_cdf", "knn_volume_map", "knn_poisson_cdf", "comoving_xyz", ]