Source code for sum_stat.cosmology

"""JAX-differentiable cosmological distances for a flat ΛCDM universe.

All functions operate on plain floats / jnp arrays (no astropy Quantities)
so they can be used inside jax.jit / jax.grad / jax.vmap contexts.

The integration uses 32-point Gauss-Legendre quadrature, which is exact for
smooth integrands and avoids adaptive ODE solvers that are not JAX-traceable.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np
from astropy.cosmology import FlatLambdaCDM

# Physical constants
_C_KM_S = 299792.458  # speed of light [km/s]

# 32-point Gauss-Legendre nodes and weights on [-1, 1]
_GL_NODES, _GL_WEIGHTS = np.polynomial.legendre.leggauss(32)
_GL_NODES = jnp.array(_GL_NODES)
_GL_WEIGHTS = jnp.array(_GL_WEIGHTS)


@jax.jit
def _E_inv(z: jnp.ndarray, omega_m: float) -> jnp.ndarray:
    """1 / E(z) where E(z) = sqrt(Om * (1+z)^3 + 1 - Om) for flat ΛCDM."""
    return 1.0 / jnp.sqrt(omega_m * (1.0 + z) ** 3 + (1.0 - omega_m))


[docs] @jax.jit def comoving_distance_jax( z: jnp.ndarray, h: float, omega_m: float, ) -> jnp.ndarray: """Comoving distance χ(z) [Mpc] for flat ΛCDM. Uses 32-point Gauss-Legendre quadrature on [0, z] mapped to [-1, 1]. JAX-differentiable with respect to z, h, and omega_m. Parameters ---------- z : jnp.ndarray Redshift (scalar or array). h : float Dimensionless Hubble parameter H0 / (100 km/s/Mpc). omega_m : float Matter density parameter Ω_m. Returns ------- chi : jnp.ndarray Comoving distance [Mpc], same shape as z. """ dh = _C_KM_S / (100.0 * h) # Hubble distance [Mpc] def _chi_scalar(z_s): # GL quadrature: ∫_0^z f(t) dt via substitution t = z/2*(1+x) t = 0.5 * z_s * (_GL_NODES + 1.0) # nodes on [0, z_s] integrand = _E_inv(t, omega_m) return dh * 0.5 * z_s * jnp.dot(_GL_WEIGHTS, integrand) return jax.vmap(_chi_scalar)(jnp.atleast_1d(jnp.asarray(z, dtype=jnp.float64))).squeeze()
[docs] @jax.jit def comoving_volume_jax( z: jnp.ndarray, h: float, omega_m: float, ) -> jnp.ndarray: """Comoving volume V_c(z) [Mpc^3] for flat ΛCDM. V_c(z) = (4π/3) χ(z)^3 (flat universe). JAX-differentiable. Parameters ---------- z : jnp.ndarray Redshift (scalar or array). h : float Dimensionless Hubble parameter. omega_m : float Matter density parameter Ω_m. Returns ------- vc : jnp.ndarray Comoving volume [Mpc^3], same shape as z. """ chi = comoving_distance_jax(z, h, omega_m) return (4.0 * jnp.pi / 3.0) * chi**3
[docs] @jax.jit def critical_surface_density_jax( z_l: jnp.ndarray, z_s: jnp.ndarray, h: float, omega_m: float, comoving: bool = True, ) -> jnp.ndarray: """Critical surface density Σ_crit [M_sun/pc^2]. Σ_crit = (c^2 / 4πG) * D_s / (D_l * D_ls) where distances are angular diameter distances. For comoving=True, comoving critical surface density is returned. Parameters ---------- z_l : jnp.ndarray Lens redshift. z_s : jnp.ndarray Source redshift. h : float Dimensionless Hubble parameter. omega_m : float Matter density parameter. comoving : bool If True, return comoving Σ_crit (multiply physical by (1+z_l)^2). Returns ------- sigma_crit : jnp.ndarray Critical surface density [M_sun/pc^2]. """ # Physical constants in SI-compatible units for M_sun/pc^2 # c^2 / (4πG) in M_sun/Mpc units = 1.6625e18 M_sun/Mpc _C2_4PIG = 1.6625e18 # M_sun / Mpc (= c^2 / (4πG) in these units) chi_l = comoving_distance_jax(z_l, h, omega_m) chi_s = comoving_distance_jax(z_s, h, omega_m) chi_ls = chi_s - chi_l # flat universe # Angular diameter distances D_l = chi_l / (1.0 + z_l) D_s = chi_s / (1.0 + z_s) D_ls = chi_ls / (1.0 + z_s) sigma_crit_phys = _C2_4PIG * D_s / (D_l * D_ls) # M_sun/Mpc _ = sigma_crit_phys / (1.0e6) ** 2 * 1.0e6 # convert Mpc^-2 -> pc^-2 (superseded below) # Simpler: Σ_crit [M_sun/pc^2] = (c^2/4πG) * D_s / (D_l * D_ls) * (1 Mpc/1e6 pc)^2 # _C2_4PIG [M_sun/Mpc] -> [M_sun/pc^2]: divide by (1e6)^2 * Mpc_to_pc = 1e6 (already pc) # Actually: 1 Mpc = 3.0857e22 m, 1 pc = 3.0857e16 m, so 1 Mpc = 1e6 pc # Σ_crit [M_sun/pc^2] = Σ_crit_Mpc [M_sun/Mpc^2] / (1e6 pc/Mpc)^2 = Σ_crit_Mpc * 1e-12 # _C2_4PIG [M_sun/Mpc] / (Mpc^2) -> need D_s/(D_l*D_ls) in Mpc^-1 sigma_crit_mpc2 = _C2_4PIG * D_s / (D_l * D_ls) # M_sun / Mpc^2 sigma_crit = sigma_crit_mpc2 * 1e-12 # M_sun / pc^2 (1 Mpc = 1e6 pc) sigma_crit = jnp.where(comoving, sigma_crit * (1.0 + z_l) ** 2, sigma_crit) return sigma_crit
[docs] def astropy_to_jax_cosmo(cosmo: FlatLambdaCDM) -> dict[str, float]: """Extract h and omega_m from an astropy FlatLambdaCDM for use in JAX functions. Parameters ---------- cosmo : FlatLambdaCDM Astropy cosmology instance. Returns ------- dict with keys ``h`` (float) and ``omega_m`` (float). """ return { "h": float(cosmo.H0.value / 100.0), "omega_m": float(cosmo.Om0), }
DEFAULT_COSMO = FlatLambdaCDM(H0=67.36, Om0=0.3111, name="Planck18_approx")