Source code for sum_stat.knn.distances

"""kNN distance computation with pyfnntw (primary) and scipy.cKDTree (fallback).

The primary backend is ``pyfnntw.Treef64`` — a Rust-backed kd-tree that builds
and queries in parallel across all available CPUs.  If pyfnntw is not installed,
``scipy.spatial.cKDTree`` is used transparently.

References
----------
fnntw crate  https://docs.rs/fnntw/0.4.1/fnntw/
pyfnntw      https://pypi.org/project/pyfnntw/
"""

from __future__ import annotations

import numpy as np
from astropy.cosmology import FlatLambdaCDM

from ..catalogue import GalaxyCatalogue

try:
    import pyfnntw as _pyfnntw

    _HAS_PYFNNTW = True
except ImportError:  # pragma: no cover
    _HAS_PYFNNTW = False


[docs] def comoving_xyz(cat: GalaxyCatalogue, cosmo: FlatLambdaCDM) -> np.ndarray: """Convert RA/Dec/z to Cartesian comoving coordinates. Parameters ---------- cat : GalaxyCatalogue Catalogue with ``ra``, ``dec``, ``redshift`` attributes. cosmo : FlatLambdaCDM Cosmology for redshift → comoving distance. Returns ------- xyz : ndarray, shape (N, 3) Cartesian comoving coordinates [Mpc]. """ phi = np.deg2rad(cat.ra) theta = np.deg2rad(90.0 - cat.dec) # colatitude rr = cat.comoving_distance(cosmo) x = rr * np.sin(theta) * np.cos(phi) y = rr * np.sin(theta) * np.sin(phi) z = rr * np.cos(theta) return np.stack([x, y, z], axis=1)
def knn_query( gal_xyz: np.ndarray, query_xyz: np.ndarray, k_max: int, workers: int = -1, ) -> np.ndarray: """Distances from each query point to the k_max nearest galaxies. Uses ``pyfnntw.Treef64`` when available (parallel Rust kd-tree), otherwise falls back to ``scipy.spatial.cKDTree``. Parameters ---------- gal_xyz : ndarray, shape (N_gal, 3) Cartesian comoving coordinates of galaxies [Mpc], dtype float64. query_xyz : ndarray, shape (N_q, 3) Cartesian comoving coordinates of query points [Mpc], dtype float64. k_max : int Neighbour order. Distances to the 1st through k_max-th nearest galaxies are returned. workers : int Parallel workers for the scipy fallback only (``-1`` = all CPUs). pyfnntw uses all CPUs automatically. Returns ------- distances : ndarray, shape (N_q, k_max) Sorted Euclidean distances to the k_max nearest galaxies [Mpc]. Notes ----- The pyfnntw backend builds the tree with ``leafsize=32`` which is near- optimal for random 3-D data and typical cosmological sample sizes. See https://docs.rs/fnntw/0.4.1/fnntw/ for the ``par_split_level`` discussion of parallelism. """ gal_xyz = np.ascontiguousarray(gal_xyz, dtype=np.float64) query_xyz = np.ascontiguousarray(query_xyz, dtype=np.float64) if _HAS_PYFNNTW: tree = _pyfnntw.Treef64(gal_xyz, leafsize=32) distances, _ = tree.query(query_xyz, k_max) if distances.ndim == 1: # k_max == 1 edge case distances = distances[:, np.newaxis] return distances # scipy fallback from scipy.spatial import cKDTree # noqa: PLC0415 tree = cKDTree(gal_xyz) distances, _ = tree.query(query_xyz, k=k_max, workers=workers) if distances.ndim == 1: distances = distances[:, np.newaxis] return distances