Source code for sum_stat.twopcf.multipoles

"""Redshift-space multipoles of the two-point correlation function.

Computes ξ_ℓ(s) for ℓ = 0 (monopole), 2 (quadrupole), 4 (hexadecapole).

Provides:
- treecorr-based measurement via :func:`xi_multipoles`
- JAX Legendre decomposition kernel via :func:`_legendre_decompose_jax`
"""

from __future__ import annotations

import jax.numpy as jnp
import numpy as np
import treecorr
from astropy.cosmology import FlatLambdaCDM

from ..catalogue import GalaxyCatalogue


def _legendre_decompose_jax(
    xi_2d: jnp.ndarray,
    mu_centres: jnp.ndarray,
    ells: tuple[int, ...] = (0, 2, 4),
) -> dict:
    """Decompose ξ(s, μ) into Legendre multipoles via trapezoidal integration.

    ξ_ℓ(s) = (2ℓ+1)/2 ∫₀¹ ξ(s, μ) P_ℓ(μ) dμ  (symmetry: μ ∈ [0,1])

    Legendre polynomials are evaluated with numpy before any JAX tracing so
    the function is fully JIT-compatible when called from within jitted code.
    The function itself is not decorated with ``@jax.jit`` because it uses a
    Python loop over ``ells`` — JIT each call site if needed.

    Parameters
    ----------
    xi_2d : jnp.ndarray, shape (n_s, n_mu)
        Two-dimensional correlation function on a (s, μ) grid.
    mu_centres : jnp.ndarray, shape (n_mu,)
        μ = cos(θ) bin centres, should span [0, 1].
    ells : tuple of int
        Multipole orders to compute.

    Returns
    -------
    dict
        Keys are ell values, values are jnp.ndarray of shape (n_s,).

    Performance
    -----------
    ~1 ms/call  (n_s=20, n_mu=100, CPU x86-64)
    """
    # Evaluate Legendre polynomials with numpy (static, outside JAX trace)
    mu_np = np.asarray(mu_centres)
    dmu = float(mu_np[1] - mu_np[0])

    result = {}
    for ell in ells:
        coeffs = np.zeros(ell + 1)
        coeffs[ell] = 1.0
        P_ell = jnp.array(np.polynomial.legendre.legval(mu_np, coeffs))
        integrand = xi_2d * P_ell[jnp.newaxis, :]  # (n_s, n_mu)
        xi_l = (2.0 * ell + 1.0) / 2.0 * jnp.trapezoid(integrand, dx=dmu, axis=1)
        result[ell] = xi_l

    return result


[docs] def xi_multipoles( gal: GalaxyCatalogue, rand: GalaxyCatalogue, s_bins: np.ndarray, cosmo: FlatLambdaCDM, ells: tuple[int, ...] = (0, 2, 4), n_mu_bins: int = 100, n_pi_bins: int = 50, n_threads: int = 4, ) -> tuple[np.ndarray, dict[int, np.ndarray]]: """Redshift-space multipoles ξ_ℓ(s). Pair counts are accumulated on a 2D (rp, π) grid using treecorr's ``FisherRperp`` metric (one call per π slice), then rebinned to (s, μ) and decomposed into Legendre multipoles by :func:`_legendre_decompose_jax`. Notes ----- 1. Loop over ``n_pi_bins`` linear π slices from 0 to ``s_bins[-1]``. 2. For each slice run a separate ``NNCorrelation`` with ``FisherRperp`` to obtain ξ(rp, π) in the requested rp bins. 3. Compute s = √(rp² + π²) and μ = π/s for each cell centre and redistribute ξ into the ``(n_s, n_mu_bins)`` (s, μ) grid by weighted averaging. 4. Apply Legendre decomposition. Parameters ---------- gal : GalaxyCatalogue Galaxy catalogue (ra, dec, redshift required). rand : GalaxyCatalogue Random catalogue (ra, dec, redshift required). s_bins : ndarray, shape (n_bins+1,) Redshift-space separation bin edges [Mpc]. cosmo : FlatLambdaCDM Cosmology for distance conversion. ells : tuple of int Multipole orders, default (0, 2, 4). n_mu_bins : int Number of μ bins in the intermediate (s, μ) grid (default 100). n_pi_bins : int Number of linear π slices from 0 to s_max (default 50). n_threads : int Number of OpenMP threads for treecorr. Returns ------- s_centres : ndarray, shape (n_bins,) Geometric mean separation per bin [Mpc]. xi_dict : dict[int, ndarray] Keys are ell, values shape (n_bins,). Performance ----------- ~60 s/call (N_gal=5000, N_rand=50000, n_s=20, n_pi=50, CPU x86-64) """ rand = rand.subsample(5 * gal.n) s_bins = np.asarray(s_bins, dtype=np.float64) n_s = len(s_bins) - 1 s_max = float(s_bins[-1]) r_gal = gal.comoving_distance(cosmo) r_rand = rand.comoving_distance(cosmo) cat_g = treecorr.Catalog( ra=gal.ra, dec=gal.dec, r=r_gal, w=gal.weight, ra_units="degrees", dec_units="degrees" ) cat_r = treecorr.Catalog( ra=rand.ra, dec=rand.dec, r=r_rand, w=rand.weight, ra_units="degrees", dec_units="degrees" ) # Geometric rp centres for the FisherRperp grid (same edges as s_bins) rp_centres = np.sqrt(s_bins[:-1] * s_bins[1:]) # Linear π grid pi_bins = np.linspace(0.0, s_max, n_pi_bins + 1) pi_centres = 0.5 * (pi_bins[:-1] + pi_bins[1:]) # ξ(rp, π) grid — built by looping over π slices xi_rppi = np.zeros((n_s, n_pi_bins)) base_kw = dict( min_sep=s_bins[0], max_sep=s_bins[-1], nbins=n_s, metric="FisherRperp", num_threads=n_threads, ) for i_pi, (pi_lo, pi_hi) in enumerate(zip(pi_bins[:-1], pi_bins[1:])): kw = dict(**base_kw, min_rpar=pi_lo, max_rpar=pi_hi) dd = treecorr.NNCorrelation(**kw) dr = treecorr.NNCorrelation(**kw) rr = treecorr.NNCorrelation(**kw) dd.process(cat_g) dr.process(cat_g, cat_r) rr.process(cat_r) xi_slice, _ = dd.calculateXi(rr=rr, dr=dr) xi_rppi[:, i_pi] = xi_slice # Convert (rp, π) cell centres to (s, μ) and rebin onto (s_bins, μ_bins) # s = sqrt(rp² + π²), μ = π/s rp_2d = rp_centres[:, np.newaxis] # (n_s, 1) pi_2d = pi_centres[np.newaxis, :] # (1, n_pi) s_2d = np.sqrt(rp_2d**2 + pi_2d**2) # (n_s, n_pi) mu_2d = pi_2d / np.maximum(s_2d, 1e-30) # (n_s, n_pi) # Assign each (rp, π) cell to a (s_bin, μ_bin) i_s_arr = np.searchsorted(s_bins, s_2d.ravel(), side="right") - 1 i_mu_arr = np.floor(mu_2d.ravel() * n_mu_bins).astype(int) i_mu_arr = np.clip(i_mu_arr, 0, n_mu_bins - 1) xi_flat = xi_rppi.ravel() xi_smu = np.zeros((n_s, n_mu_bins)) cnt_smu = np.zeros((n_s, n_mu_bins), dtype=int) for k in range(len(xi_flat)): i_s = i_s_arr[k] if 0 <= i_s < n_s: xi_smu[i_s, i_mu_arr[k]] += xi_flat[k] cnt_smu[i_s, i_mu_arr[k]] += 1 mask = cnt_smu > 0 xi_smu[mask] /= cnt_smu[mask] mu_centres = np.linspace(0.0, 1.0, n_mu_bins, endpoint=False) + 0.5 / n_mu_bins xi_dict_jax = _legendre_decompose_jax( jnp.array(xi_smu), jnp.array(mu_centres), ells, ) xi_dict = {ell: np.array(v) for ell, v in xi_dict_jax.items()} s_centres = rp_centres # geometric means of s_bins edges return s_centres, xi_dict