"""k-nearest-neighbor (kNN) summary statistics.
Three estimators are provided:
knn_cdf
CDF of distances from random query points to the k-th nearest galaxy,
compared to a Poisson (Erlang) reference. Primary statistic from
Banerjee & Abel (2021).
cross_knn_cdf
Joint kNN-CDFs from query points to two distinct galaxy populations A and B.
Captures cross-correlations between the populations.
knn_volume_map
Per-query-point kNN sphere volumes V_k = 4π/3 · r_k³, for 2-D and 3-D
density maps. Reproduces ``VolumekNN`` from the external ``kNN_CDFs``
library used in the GAMA analysis scripts.
The public helper :func:`comoving_xyz` converts a
:class:`~sum_stat.GalaxyCatalogue` to Cartesian comoving coordinates, which
is needed when constructing custom query-point grids for :func:`knn_volume_map`.
References
----------
Banerjee & Abel (2021) https://arxiv.org/abs/2007.13342
Banerjee et al. (2021) https://ui.adsabs.harvard.edu/abs/2021MNRAS.504.2911B
Banerjee et al. (2023) https://ui.adsabs.harvard.edu/abs/2023MNRAS.519.4856B
Yuan et al. (2023) https://ui.adsabs.harvard.edu/abs/2023MNRAS.522.3935Y
Gao et al. (2025) https://ui.adsabs.harvard.edu/abs/2025MNRAS.543.3409G
Obreschkow et al. (2025) https://ui.adsabs.harvard.edu/abs/2025arXiv250209709O
"""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
from astropy.cosmology import FlatLambdaCDM
from ..catalogue import GalaxyCatalogue
from .cdf import _empirical_cdf_jax, knn_poisson_cdf
from .distances import comoving_xyz, knn_query
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _survey_volume_bbox(xyz: np.ndarray) -> float:
"""Bounding-box comoving volume [Mpc³] of a Cartesian point set."""
extents = xyz.max(axis=0) - xyz.min(axis=0)
return float(np.prod(extents))
def _make_query_xyz(
ref_xyz: np.ndarray,
rand: GalaxyCatalogue | None,
cosmo: FlatLambdaCDM,
n_query: int,
seed: int,
) -> np.ndarray:
"""Return (N_query, 3) Cartesian comoving query points [Mpc].
If ``rand`` is provided, query points are drawn (with replacement when
n_query > rand.n) from the randoms catalogue, which traces the survey
geometry. Otherwise, points are drawn uniformly from the bounding box of
``ref_xyz``, which is only appropriate for simulation boxes.
"""
rng = np.random.default_rng(seed)
if rand is not None:
replace = n_query > rand.n
idx = rng.choice(rand.n, size=n_query, replace=replace)
q = GalaxyCatalogue(ra=rand.ra[idx], dec=rand.dec[idx], redshift=rand.redshift[idx])
return comoving_xyz(q, cosmo)
lo = ref_xyz.min(axis=0)
hi = ref_xyz.max(axis=0)
return rng.uniform(lo, hi, size=(n_query, 3))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def knn_cdf(
gal: GalaxyCatalogue,
cosmo: FlatLambdaCDM,
k_values: np.ndarray,
r_bins: np.ndarray,
rand: GalaxyCatalogue | None = None,
n_query: int = 100_000,
n_bar: float | None = None,
seed: int = 42,
workers: int = -1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""k-nearest-neighbor CDF statistic F_k(r).
For each of *n_query* random query points drawn from the survey volume,
the distance d_k to the k-th nearest galaxy is recorded. The empirical
CDF F_k(r) = fraction of query points with d_k ≤ r is returned together
with the Poisson (Erlang) reference F_k^Pois(r).
Parameters
----------
gal : GalaxyCatalogue
Galaxy catalogue (``ra``, ``dec``, ``redshift`` required).
cosmo : FlatLambdaCDM
Cosmology for comoving distance conversion.
k_values : array_like of int
Neighbor orders to compute, e.g. ``np.arange(1, 6)``.
r_bins : ndarray, shape (n_r+1,)
Separation bin edges [Mpc].
rand : GalaxyCatalogue, optional
Random catalogue tracing the survey geometry. Query points are drawn
from this catalogue. If ``None``, uniform points inside the bounding
box of galaxy positions are used — suitable only for simulation boxes.
n_query : int
Number of query points.
n_bar : float, optional
Effective galaxy number density [Mpc⁻³] used for the Poisson reference.
If ``None``, estimated as ``sum(weights) / V_bounding_box``. Pass this
explicitly when the bounding box is a poor proxy for the survey volume.
seed : int
RNG seed for query-point sampling.
workers : int
Parallel workers for the scipy fallback kd-tree (``-1`` = all CPUs).
Ignored when pyfnntw is installed.
Returns
-------
r_centres : ndarray, shape (n_r,)
Geometric-mean separation bin centres [Mpc].
F_k : ndarray, shape (n_k, n_r)
Empirical kNN-CDF for each k in ``k_values``.
F_k_poisson : ndarray, shape (n_k, n_r)
Erlang (Poisson) reference CDF for each k.
References
----------
Banerjee & Abel (2021), arXiv:2007.13342.
Banerjee et al. (2021), MNRAS 504, 2911.
Banerjee et al. (2023), MNRAS 519, 4856.
Yuan et al. (2023), MNRAS 522, 3935.
Gao et al. (2025), MNRAS 543, 3409.
Obreschkow et al. (2025), arXiv:2502.09709.
Performance
-----------
~2 s/call (N_gal = 5 000, N_query = 10 000, k_max = 5, pyfnntw, 16 CPUs)
"""
k_values = np.asarray(k_values, dtype=int)
r_bins = np.asarray(r_bins, dtype=np.float64)
k_max = int(k_values.max())
gal_xyz = comoving_xyz(gal, cosmo)
query_xyz = _make_query_xyz(gal_xyz, rand, cosmo, n_query, seed)
distances = knn_query(gal_xyz, query_xyz, k_max, workers=workers) # (N_q, k_max)
if n_bar is None:
n_bar = float(gal.weight.sum()) / _survey_volume_bbox(gal_xyz)
r_bins_jax = jnp.asarray(r_bins)
n_k = len(k_values)
n_r = len(r_bins) - 1
F_k = np.zeros((n_k, n_r))
F_k_poisson = np.zeros((n_k, n_r))
for i, k in enumerate(k_values):
d_k = jnp.asarray(distances[:, int(k) - 1])
F_k[i] = np.array(_empirical_cdf_jax(d_k, r_bins_jax))
F_k_poisson[i] = np.array(knn_poisson_cdf(int(k), r_bins_jax, float(n_bar)))
r_centres = np.sqrt(r_bins[:-1] * r_bins[1:])
return r_centres, F_k, F_k_poisson
[docs]
def cross_knn_cdf(
gal_a: GalaxyCatalogue,
gal_b: GalaxyCatalogue,
cosmo: FlatLambdaCDM,
k_values: np.ndarray,
r_bins: np.ndarray,
rand: GalaxyCatalogue | None = None,
n_query: int = 100_000,
seed: int = 42,
workers: int = -1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Cross-kNN CDFs from shared query points to two galaxy populations.
Computes F_k^A(r) and F_k^B(r) from the same set of query points: the
fraction of query points whose k-th nearest neighbour in population A (or
B) lies within distance r. The joint distribution captures
cross-correlations between the two populations.
Parameters
----------
gal_a : GalaxyCatalogue
First galaxy population.
gal_b : GalaxyCatalogue
Second galaxy population.
cosmo : FlatLambdaCDM
Cosmology for comoving distance conversion.
k_values : array_like of int
Neighbor orders to compute.
r_bins : ndarray, shape (n_r+1,)
Separation bin edges [Mpc].
rand : GalaxyCatalogue, optional
Random catalogue tracing the survey geometry.
n_query : int
Number of query points.
seed : int
RNG seed.
workers : int
Parallel workers for the scipy fallback kd-tree.
Returns
-------
r_centres : ndarray, shape (n_r,)
Geometric-mean bin centres [Mpc].
F_k_a : ndarray, shape (n_k, n_r)
kNN-CDF towards population A.
F_k_b : ndarray, shape (n_k, n_r)
kNN-CDF towards population B.
References
----------
Banerjee & Abel (2021), arXiv:2007.13342, §4 (joint kNN).
Banerjee et al. (2021), MNRAS 504, 2911.
"""
k_values = np.asarray(k_values, dtype=int)
r_bins = np.asarray(r_bins, dtype=np.float64)
k_max = int(k_values.max())
xyz_a = comoving_xyz(gal_a, cosmo)
xyz_b = comoving_xyz(gal_b, cosmo)
ref_xyz = np.concatenate([xyz_a, xyz_b], axis=0)
query_xyz = _make_query_xyz(ref_xyz, rand, cosmo, n_query, seed)
dist_a = knn_query(xyz_a, query_xyz, k_max, workers=workers)
dist_b = knn_query(xyz_b, query_xyz, k_max, workers=workers)
r_bins_jax = jnp.asarray(r_bins)
n_k = len(k_values)
n_r = len(r_bins) - 1
F_k_a = np.zeros((n_k, n_r))
F_k_b = np.zeros((n_k, n_r))
for i, k in enumerate(k_values):
ki = int(k) - 1
F_k_a[i] = np.array(_empirical_cdf_jax(jnp.asarray(dist_a[:, ki]), r_bins_jax))
F_k_b[i] = np.array(_empirical_cdf_jax(jnp.asarray(dist_b[:, ki]), r_bins_jax))
r_centres = np.sqrt(r_bins[:-1] * r_bins[1:])
return r_centres, F_k_a, F_k_b
[docs]
def knn_volume_map(
gal: GalaxyCatalogue,
query_xyz: np.ndarray,
k_values: list[int],
cosmo: FlatLambdaCDM,
workers: int = -1,
) -> np.ndarray:
"""kNN sphere volumes V_k(x) = 4π/3 · r_k(x)³ at each query point.
Reproduces the ``VolumekNN`` function from the external ``kNN_CDFs``
library used in the GAMA analysis scripts. The volumes encode the local
galaxy density field and are useful for 2-D and 3-D density maps.
Parameters
----------
gal : GalaxyCatalogue
Galaxy catalogue (``ra``, ``dec``, ``redshift`` required).
query_xyz : ndarray, shape (N_q, 3)
Cartesian comoving coordinates of the query points [Mpc]. Build
these from a regular grid using :func:`comoving_xyz` or a custom
lattice (see GAMA_gal_all_volumekNN.py for an example).
k_values : list of int
Neighbor orders, e.g. ``[1, 2, 3, 4]``.
cosmo : FlatLambdaCDM
Cosmology for comoving distance conversion.
workers : int
Parallel workers for the scipy fallback kd-tree.
Returns
-------
volumes : ndarray, shape (N_q, len(k_values))
kNN sphere volume at each query point for each k [Mpc³].
Examples
--------
>>> from sum_stat.knn import comoving_xyz, knn_volume_map
>>> gal_xyz = comoving_xyz(gal, cosmo) # build a lattice grid
>>> vols = knn_volume_map(gal, gal_xyz, [1, 2, 3, 4], cosmo)
>>> print(vols.shape) # (N_gal, 4)
References
----------
Banerjee & Abel (2021), arXiv:2007.13342.
"""
k_values = list(k_values)
k_max = max(k_values)
gal_xyz = comoving_xyz(gal, cosmo)
distances = knn_query(gal_xyz, query_xyz, k_max, workers=workers) # (N_q, k_max)
volumes = np.zeros((query_xyz.shape[0], len(k_values)))
for j, k in enumerate(k_values):
r_k = distances[:, int(k) - 1]
volumes[:, j] = (4.0 * np.pi / 3.0) * r_k**3
return volumes
__all__ = [
"knn_cdf",
"cross_knn_cdf",
"knn_volume_map",
"knn_poisson_cdf",
"comoving_xyz",
]