"""Redshift-space multipoles of the two-point correlation function.
Computes ξ_ℓ(s) for ℓ = 0 (monopole), 2 (quadrupole), 4 (hexadecapole).
Provides:
- treecorr-based measurement via :func:`xi_multipoles`
- JAX Legendre decomposition kernel via :func:`_legendre_decompose_jax`
"""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
import treecorr
from astropy.cosmology import FlatLambdaCDM
from ..catalogue import GalaxyCatalogue
def _legendre_decompose_jax(
xi_2d: jnp.ndarray,
mu_centres: jnp.ndarray,
ells: tuple[int, ...] = (0, 2, 4),
) -> dict:
"""Decompose ξ(s, μ) into Legendre multipoles via trapezoidal integration.
ξ_ℓ(s) = (2ℓ+1)/2 ∫₀¹ ξ(s, μ) P_ℓ(μ) dμ (symmetry: μ ∈ [0,1])
Legendre polynomials are evaluated with numpy before any JAX tracing so
the function is fully JIT-compatible when called from within jitted code.
The function itself is not decorated with ``@jax.jit`` because it uses a
Python loop over ``ells`` — JIT each call site if needed.
Parameters
----------
xi_2d : jnp.ndarray, shape (n_s, n_mu)
Two-dimensional correlation function on a (s, μ) grid.
mu_centres : jnp.ndarray, shape (n_mu,)
μ = cos(θ) bin centres, should span [0, 1].
ells : tuple of int
Multipole orders to compute.
Returns
-------
dict
Keys are ell values, values are jnp.ndarray of shape (n_s,).
Performance
-----------
~1 ms/call (n_s=20, n_mu=100, CPU x86-64)
"""
# Evaluate Legendre polynomials with numpy (static, outside JAX trace)
mu_np = np.asarray(mu_centres)
dmu = float(mu_np[1] - mu_np[0])
result = {}
for ell in ells:
coeffs = np.zeros(ell + 1)
coeffs[ell] = 1.0
P_ell = jnp.array(np.polynomial.legendre.legval(mu_np, coeffs))
integrand = xi_2d * P_ell[jnp.newaxis, :] # (n_s, n_mu)
xi_l = (2.0 * ell + 1.0) / 2.0 * jnp.trapezoid(integrand, dx=dmu, axis=1)
result[ell] = xi_l
return result
[docs]
def xi_multipoles(
gal: GalaxyCatalogue,
rand: GalaxyCatalogue,
s_bins: np.ndarray,
cosmo: FlatLambdaCDM,
ells: tuple[int, ...] = (0, 2, 4),
n_mu_bins: int = 100,
n_pi_bins: int = 50,
n_threads: int = 4,
) -> tuple[np.ndarray, dict[int, np.ndarray]]:
"""Redshift-space multipoles ξ_ℓ(s).
Pair counts are accumulated on a 2D (rp, π) grid using treecorr's
``FisherRperp`` metric (one call per π slice), then rebinned to (s, μ)
and decomposed into Legendre multipoles by :func:`_legendre_decompose_jax`.
Notes
-----
1. Loop over ``n_pi_bins`` linear π slices from 0 to ``s_bins[-1]``.
2. For each slice run a separate ``NNCorrelation`` with ``FisherRperp``
to obtain ξ(rp, π) in the requested rp bins.
3. Compute s = √(rp² + π²) and μ = π/s for each cell centre and
redistribute ξ into the ``(n_s, n_mu_bins)`` (s, μ) grid by
weighted averaging.
4. Apply Legendre decomposition.
Parameters
----------
gal : GalaxyCatalogue
Galaxy catalogue (ra, dec, redshift required).
rand : GalaxyCatalogue
Random catalogue (ra, dec, redshift required).
s_bins : ndarray, shape (n_bins+1,)
Redshift-space separation bin edges [Mpc].
cosmo : FlatLambdaCDM
Cosmology for distance conversion.
ells : tuple of int
Multipole orders, default (0, 2, 4).
n_mu_bins : int
Number of μ bins in the intermediate (s, μ) grid (default 100).
n_pi_bins : int
Number of linear π slices from 0 to s_max (default 50).
n_threads : int
Number of OpenMP threads for treecorr.
Returns
-------
s_centres : ndarray, shape (n_bins,)
Geometric mean separation per bin [Mpc].
xi_dict : dict[int, ndarray]
Keys are ell, values shape (n_bins,).
Performance
-----------
~60 s/call (N_gal=5000, N_rand=50000, n_s=20, n_pi=50, CPU x86-64)
"""
rand = rand.subsample(5 * gal.n)
s_bins = np.asarray(s_bins, dtype=np.float64)
n_s = len(s_bins) - 1
s_max = float(s_bins[-1])
r_gal = gal.comoving_distance(cosmo)
r_rand = rand.comoving_distance(cosmo)
cat_g = treecorr.Catalog(
ra=gal.ra, dec=gal.dec, r=r_gal, w=gal.weight, ra_units="degrees", dec_units="degrees"
)
cat_r = treecorr.Catalog(
ra=rand.ra, dec=rand.dec, r=r_rand, w=rand.weight, ra_units="degrees", dec_units="degrees"
)
# Geometric rp centres for the FisherRperp grid (same edges as s_bins)
rp_centres = np.sqrt(s_bins[:-1] * s_bins[1:])
# Linear π grid
pi_bins = np.linspace(0.0, s_max, n_pi_bins + 1)
pi_centres = 0.5 * (pi_bins[:-1] + pi_bins[1:])
# ξ(rp, π) grid — built by looping over π slices
xi_rppi = np.zeros((n_s, n_pi_bins))
base_kw = dict(
min_sep=s_bins[0],
max_sep=s_bins[-1],
nbins=n_s,
metric="FisherRperp",
num_threads=n_threads,
)
for i_pi, (pi_lo, pi_hi) in enumerate(zip(pi_bins[:-1], pi_bins[1:])):
kw = dict(**base_kw, min_rpar=pi_lo, max_rpar=pi_hi)
dd = treecorr.NNCorrelation(**kw)
dr = treecorr.NNCorrelation(**kw)
rr = treecorr.NNCorrelation(**kw)
dd.process(cat_g)
dr.process(cat_g, cat_r)
rr.process(cat_r)
xi_slice, _ = dd.calculateXi(rr=rr, dr=dr)
xi_rppi[:, i_pi] = xi_slice
# Convert (rp, π) cell centres to (s, μ) and rebin onto (s_bins, μ_bins)
# s = sqrt(rp² + π²), μ = π/s
rp_2d = rp_centres[:, np.newaxis] # (n_s, 1)
pi_2d = pi_centres[np.newaxis, :] # (1, n_pi)
s_2d = np.sqrt(rp_2d**2 + pi_2d**2) # (n_s, n_pi)
mu_2d = pi_2d / np.maximum(s_2d, 1e-30) # (n_s, n_pi)
# Assign each (rp, π) cell to a (s_bin, μ_bin)
i_s_arr = np.searchsorted(s_bins, s_2d.ravel(), side="right") - 1
i_mu_arr = np.floor(mu_2d.ravel() * n_mu_bins).astype(int)
i_mu_arr = np.clip(i_mu_arr, 0, n_mu_bins - 1)
xi_flat = xi_rppi.ravel()
xi_smu = np.zeros((n_s, n_mu_bins))
cnt_smu = np.zeros((n_s, n_mu_bins), dtype=int)
for k in range(len(xi_flat)):
i_s = i_s_arr[k]
if 0 <= i_s < n_s:
xi_smu[i_s, i_mu_arr[k]] += xi_flat[k]
cnt_smu[i_s, i_mu_arr[k]] += 1
mask = cnt_smu > 0
xi_smu[mask] /= cnt_smu[mask]
mu_centres = np.linspace(0.0, 1.0, n_mu_bins, endpoint=False) + 0.5 / n_mu_bins
xi_dict_jax = _legendre_decompose_jax(
jnp.array(xi_smu),
jnp.array(mu_centres),
ells,
)
xi_dict = {ell: np.array(v) for ell, v in xi_dict_jax.items()}
s_centres = rp_centres # geometric means of s_bins edges
return s_centres, xi_dict