"""kNN distance computation with pyfnntw (primary) and scipy.cKDTree (fallback).
The primary backend is ``pyfnntw.Treef64`` — a Rust-backed kd-tree that builds
and queries in parallel across all available CPUs. If pyfnntw is not installed,
``scipy.spatial.cKDTree`` is used transparently.
References
----------
fnntw crate https://docs.rs/fnntw/0.4.1/fnntw/
pyfnntw https://pypi.org/project/pyfnntw/
"""
from __future__ import annotations
import numpy as np
from astropy.cosmology import FlatLambdaCDM
from ..catalogue import GalaxyCatalogue
try:
import pyfnntw as _pyfnntw
_HAS_PYFNNTW = True
except ImportError: # pragma: no cover
_HAS_PYFNNTW = False
[docs]
def comoving_xyz(cat: GalaxyCatalogue, cosmo: FlatLambdaCDM) -> np.ndarray:
"""Convert RA/Dec/z to Cartesian comoving coordinates.
Parameters
----------
cat : GalaxyCatalogue
Catalogue with ``ra``, ``dec``, ``redshift`` attributes.
cosmo : FlatLambdaCDM
Cosmology for redshift → comoving distance.
Returns
-------
xyz : ndarray, shape (N, 3)
Cartesian comoving coordinates [Mpc].
"""
phi = np.deg2rad(cat.ra)
theta = np.deg2rad(90.0 - cat.dec) # colatitude
rr = cat.comoving_distance(cosmo)
x = rr * np.sin(theta) * np.cos(phi)
y = rr * np.sin(theta) * np.sin(phi)
z = rr * np.cos(theta)
return np.stack([x, y, z], axis=1)
def knn_query(
gal_xyz: np.ndarray,
query_xyz: np.ndarray,
k_max: int,
workers: int = -1,
) -> np.ndarray:
"""Distances from each query point to the k_max nearest galaxies.
Uses ``pyfnntw.Treef64`` when available (parallel Rust kd-tree), otherwise
falls back to ``scipy.spatial.cKDTree``.
Parameters
----------
gal_xyz : ndarray, shape (N_gal, 3)
Cartesian comoving coordinates of galaxies [Mpc], dtype float64.
query_xyz : ndarray, shape (N_q, 3)
Cartesian comoving coordinates of query points [Mpc], dtype float64.
k_max : int
Neighbour order. Distances to the 1st through k_max-th nearest
galaxies are returned.
workers : int
Parallel workers for the scipy fallback only (``-1`` = all CPUs).
pyfnntw uses all CPUs automatically.
Returns
-------
distances : ndarray, shape (N_q, k_max)
Sorted Euclidean distances to the k_max nearest galaxies [Mpc].
Notes
-----
The pyfnntw backend builds the tree with ``leafsize=32`` which is near-
optimal for random 3-D data and typical cosmological sample sizes.
See https://docs.rs/fnntw/0.4.1/fnntw/ for the ``par_split_level``
discussion of parallelism.
"""
gal_xyz = np.ascontiguousarray(gal_xyz, dtype=np.float64)
query_xyz = np.ascontiguousarray(query_xyz, dtype=np.float64)
if _HAS_PYFNNTW:
tree = _pyfnntw.Treef64(gal_xyz, leafsize=32)
distances, _ = tree.query(query_xyz, k_max)
if distances.ndim == 1: # k_max == 1 edge case
distances = distances[:, np.newaxis]
return distances
# scipy fallback
from scipy.spatial import cKDTree # noqa: PLC0415
tree = cKDTree(gal_xyz)
distances, _ = tree.query(query_xyz, k=k_max, workers=workers)
if distances.ndim == 1:
distances = distances[:, np.newaxis]
return distances