"""3D power spectrum P(k) and multipoles P_ℓ(k).
Uses a direct FFT estimator with FKP weighting when a random catalogue is
available, or a simple Fourier transform for periodic-box catalogues.
"""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
from astropy.cosmology import FlatLambdaCDM
from ..catalogue import GalaxyCatalogue
from ..cosmology import astropy_to_jax_cosmo
def _build_density_grid(
cat: GalaxyCatalogue,
cosmo: FlatLambdaCDM,
n_grid: int,
box_size: float,
) -> tuple[np.ndarray, float]:
"""Paint galaxies onto a 3D density grid using CIC assignment.
Returns the overdensity grid δ(r) and mean number density n̄.
"""
chi = cat.comoving_distance(cosmo) # Mpc
# Cartesian coordinates (assuming small angle / flat sky approximation)
ra_r = np.deg2rad(cat.ra)
dec_r = np.deg2rad(cat.dec)
x = chi * np.cos(dec_r) * np.cos(ra_r)
y = chi * np.cos(dec_r) * np.sin(ra_r)
z = chi * np.sin(dec_r)
# Shift to box coordinates [0, box_size]
x -= x.min()
y -= y.min()
z -= z.min()
grid = np.zeros((n_grid, n_grid, n_grid), dtype=np.float64)
cell = box_size / n_grid
# NGP (nearest grid point) assignment for simplicity
ix = np.clip((x / cell).astype(int), 0, n_grid - 1)
iy = np.clip((y / cell).astype(int), 0, n_grid - 1)
iz = np.clip((z / cell).astype(int), 0, n_grid - 1)
for i, j, k, w in zip(ix, iy, iz, cat.weight):
grid[i, j, k] += w
n_bar = np.sum(grid) / box_size**3
delta = grid / (n_bar * cell**3) - 1.0
return delta, n_bar
[docs]
def pk3d(
gal: GalaxyCatalogue,
cosmo: FlatLambdaCDM,
k_bins: np.ndarray,
rand: GalaxyCatalogue | None = None,
n_grid: int = 256,
box_size: float | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""3D power spectrum P(k) from a galaxy catalogue.
Uses a direct FFT approach. When ``rand`` is provided, FKP optimal
weights are approximated. For periodic boxes set ``rand=None`` and
provide ``box_size``.
Parameters
----------
gal : GalaxyCatalogue
Galaxy catalogue.
cosmo : FlatLambdaCDM
Cosmology for redshift-distance conversion.
k_bins : ndarray, shape (n_k+1,)
Wavenumber bin edges [h/Mpc].
rand : GalaxyCatalogue, optional
Random catalogue for shot-noise subtraction. If None, uses
Poisson shot noise (1/n̄).
n_grid : int
FFT grid size per dimension.
box_size : float, optional
Box side length [Mpc]. Estimated from the catalogue extent if None.
Returns
-------
k_centres : ndarray, shape (n_k,)
Wavenumber bin centres [h/Mpc].
pk : ndarray, shape (n_k,)
Power spectrum [Mpc/h]^3.
n_modes : ndarray, shape (n_k,)
Number of Fourier modes per bin.
Performance
-----------
~2 s/call (N_gal=1e4, n_grid=128, n_k=20, CPU x86-64)
"""
k_bins = np.asarray(k_bins, dtype=np.float64)
jax_cosmo = astropy_to_jax_cosmo(cosmo)
h = jax_cosmo["h"]
# Estimate box size from data extent
if box_size is None:
chi = gal.comoving_distance(cosmo)
box_size = (
float(
max(
chi.max() - chi.min(),
np.deg2rad(gal.ra.max() - gal.ra.min()) * chi.mean(),
np.deg2rad(gal.dec.max() - gal.dec.min()) * chi.mean(),
)
)
* 1.1
)
delta, n_bar = _build_density_grid(gal, cosmo, n_grid, box_size)
cell = box_size / n_grid
# FFT
delta_k = np.fft.rfftn(delta) * cell**3 # (n_grid, n_grid, n_grid//2+1)
pk_3d = np.abs(delta_k) ** 2 / box_size**3 # raw power
# Shot noise subtraction: 1/n̄
shot_noise = 1.0 / n_bar if n_bar > 0 else 0.0
# Build k-grid
kx = np.fft.fftfreq(n_grid, d=cell) * 2.0 * np.pi
ky = np.fft.fftfreq(n_grid, d=cell) * 2.0 * np.pi
kz = np.fft.rfftfreq(n_grid, d=cell) * 2.0 * np.pi
KX, KY, KZ = np.meshgrid(kx, ky, kz, indexing="ij")
k_abs = np.sqrt(KX**2 + KY**2 + KZ**2)
# Bin the power spectrum
n_k = len(k_bins) - 1
k_centres = np.zeros(n_k)
pk_binned = np.zeros(n_k)
n_modes = np.zeros(n_k, dtype=int)
for i in range(n_k):
sel = (k_abs >= k_bins[i]) & (k_abs < k_bins[i + 1])
n_modes[i] = int(np.sum(sel))
if n_modes[i] > 0:
k_centres[i] = np.mean(k_abs[sel])
pk_binned[i] = np.mean(pk_3d[sel]) - shot_noise
# Convert k to h/Mpc
k_centres_h = k_centres * h
return k_centres_h, pk_binned, n_modes
[docs]
def pk_multipoles(
gal: GalaxyCatalogue,
cosmo: FlatLambdaCDM,
k_bins: np.ndarray,
rand: GalaxyCatalogue | None = None,
ells: tuple[int, ...] = (0, 2, 4),
n_grid: int = 256,
box_size: float | None = None,
n_mu_bins: int = 100,
) -> tuple[np.ndarray, dict[int, np.ndarray]]:
"""Power spectrum multipoles P_ℓ(k) via Legendre decomposition.
Computes P(k, μ) on a 2D grid and decomposes into Legendre multipoles
using the same JAX kernel as the 2PCF multipoles.
Parameters
----------
gal : GalaxyCatalogue
Galaxy catalogue.
cosmo : FlatLambdaCDM
Cosmology for redshift-distance conversion.
k_bins : ndarray, shape (n_k+1,)
Wavenumber bin edges [h/Mpc].
rand : GalaxyCatalogue, optional
Random catalogue for shot-noise subtraction.
ells : tuple of int
Multipole orders, default (0, 2, 4).
n_grid : int
FFT grid size per dimension.
box_size : float, optional
Box side length [Mpc].
n_mu_bins : int
Number of μ = k_∥/k bins.
Returns
-------
k_centres : ndarray, shape (n_k,)
Wavenumber bin centres [h/Mpc].
pk_dict : dict[int, ndarray]
Keys are ell; values shape (n_k,) [Mpc/h]^3.
Performance
-----------
~5 s/call (N_gal=1e4, n_grid=128, n_k=20, n_mu=50, CPU x86-64)
"""
from ..twopcf.multipoles import _legendre_decompose_jax
k_bins = np.asarray(k_bins, dtype=np.float64)
jax_cosmo = astropy_to_jax_cosmo(cosmo)
h = jax_cosmo["h"]
if box_size is None:
chi = gal.comoving_distance(cosmo)
box_size = (
float(
max(
chi.max() - chi.min(),
np.deg2rad(gal.ra.max() - gal.ra.min()) * chi.mean(),
np.deg2rad(gal.dec.max() - gal.dec.min()) * chi.mean(),
)
)
* 1.1
)
delta, n_bar = _build_density_grid(gal, cosmo, n_grid, box_size)
cell = box_size / n_grid
shot_noise = 1.0 / n_bar if n_bar > 0 else 0.0
delta_k = np.fft.rfftn(delta) * cell**3
pk_3d = np.abs(delta_k) ** 2 / box_size**3 - shot_noise
kx = np.fft.fftfreq(n_grid, d=cell) * 2.0 * np.pi
ky = np.fft.fftfreq(n_grid, d=cell) * 2.0 * np.pi
kz = np.fft.rfftfreq(n_grid, d=cell) * 2.0 * np.pi
KX, KY, KZ = np.meshgrid(kx, ky, kz, indexing="ij")
k_abs = np.sqrt(KX**2 + KY**2 + KZ**2)
mu_vals = np.where(k_abs > 0, np.abs(KZ) / k_abs, 0.0) # μ = k_∥ / k
# Build P(k, μ) grid
n_k = len(k_bins) - 1
mu_bins = np.linspace(0.0, 1.0, n_mu_bins + 1)
mu_centres = 0.5 * (mu_bins[:-1] + mu_bins[1:])
pk_2d = np.zeros((n_k, n_mu_bins))
k_centres = np.zeros(n_k)
for i in range(n_k):
k_sel = (k_abs >= k_bins[i]) & (k_abs < k_bins[i + 1])
if k_sel.sum() == 0:
continue
k_centres[i] = np.mean(k_abs[k_sel])
for j in range(n_mu_bins):
mu_sel = k_sel & (mu_vals >= mu_bins[j]) & (mu_vals < mu_bins[j + 1])
if mu_sel.sum() > 0:
pk_2d[i, j] = np.mean(pk_3d[mu_sel])
# Legendre decomposition via the shared JAX kernel
xi_like_dict = _legendre_decompose_jax(jnp.array(pk_2d), jnp.array(mu_centres), ells)
pk_dict = {ell: np.array(v) for ell, v in xi_like_dict.items()}
k_centres_h = k_centres * h
return k_centres_h, pk_dict