"""JAX-JIT kernels for kNN empirical CDF and Poisson (Erlang) reference CDF."""
from __future__ import annotations
import jax
import jax.numpy as jnp
@jax.jit
def _empirical_cdf_jax(distances_k: jax.Array, r_edges: jax.Array) -> jax.Array:
"""Empirical CDF F_k(r) from a 1-D array of k-th-nearest-neighbor distances.
Parameters
----------
distances_k : jax.Array, shape (N_q,)
Distances from each query point to the k-th nearest galaxy [Mpc].
r_edges : jax.Array, shape (n_r+1,)
Bin edges [Mpc]. Values are evaluated at the right edges ``r_edges[1:]``.
Returns
-------
F_k : jax.Array, shape (n_r,)
Fraction of query points satisfying d_k ≤ r for each bin right-edge.
"""
def _frac(r: jax.Array) -> jax.Array:
return jnp.mean(distances_k <= r)
return jax.vmap(_frac)(r_edges[1:])
[docs]
def knn_poisson_cdf(k: int, r_edges: jax.Array, n_bar: float) -> jax.Array:
"""Poisson (Erlang) reference CDF for the k-th nearest neighbor distance.
For a homogeneous Poisson point process with number density *n_bar*, the
CDF of the distance from a random query point to its k-th nearest neighbor
is the regularised lower incomplete gamma function:
.. math::
F_k^{\\mathrm{Pois}}(r) =
\\frac{\\gamma(k,\\; \\bar{n} \\cdot \\tfrac{4\\pi}{3} r^3)}{\\Gamma(k)}
Parameters
----------
k : int
Neighbor order (≥ 1).
r_edges : jax.Array, shape (n_r+1,)
Bin edges [Mpc]. Evaluated at right edges ``r_edges[1:]``.
n_bar : float
Mean galaxy number density [Mpc⁻³].
Returns
-------
F_k_poisson : jax.Array, shape (n_r,)
Poisson reference CDF at each bin right-edge.
References
----------
Banerjee & Abel (2021), arXiv:2007.13342, eq. (2).
"""
r = r_edges[1:]
lam = n_bar * (4.0 * jnp.pi / 3.0) * r**3
return jax.scipy.special.gammainc(float(k), lam)