"""JAX-differentiable cosmological distances for a flat ΛCDM universe.
All functions operate on plain floats / jnp arrays (no astropy Quantities)
so they can be used inside jax.jit / jax.grad / jax.vmap contexts.
The integration uses 32-point Gauss-Legendre quadrature, which is exact for
smooth integrands and avoids adaptive ODE solvers that are not JAX-traceable.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
from astropy.cosmology import FlatLambdaCDM
# Physical constants
_C_KM_S = 299792.458 # speed of light [km/s]
# 32-point Gauss-Legendre nodes and weights on [-1, 1]
_GL_NODES, _GL_WEIGHTS = np.polynomial.legendre.leggauss(32)
_GL_NODES = jnp.array(_GL_NODES)
_GL_WEIGHTS = jnp.array(_GL_WEIGHTS)
@jax.jit
def _E_inv(z: jnp.ndarray, omega_m: float) -> jnp.ndarray:
"""1 / E(z) where E(z) = sqrt(Om * (1+z)^3 + 1 - Om) for flat ΛCDM."""
return 1.0 / jnp.sqrt(omega_m * (1.0 + z) ** 3 + (1.0 - omega_m))
[docs]
@jax.jit
def comoving_distance_jax(
z: jnp.ndarray,
h: float,
omega_m: float,
) -> jnp.ndarray:
"""Comoving distance χ(z) [Mpc] for flat ΛCDM.
Uses 32-point Gauss-Legendre quadrature on [0, z] mapped to [-1, 1].
JAX-differentiable with respect to z, h, and omega_m.
Parameters
----------
z : jnp.ndarray
Redshift (scalar or array).
h : float
Dimensionless Hubble parameter H0 / (100 km/s/Mpc).
omega_m : float
Matter density parameter Ω_m.
Returns
-------
chi : jnp.ndarray
Comoving distance [Mpc], same shape as z.
"""
dh = _C_KM_S / (100.0 * h) # Hubble distance [Mpc]
def _chi_scalar(z_s):
# GL quadrature: ∫_0^z f(t) dt via substitution t = z/2*(1+x)
t = 0.5 * z_s * (_GL_NODES + 1.0) # nodes on [0, z_s]
integrand = _E_inv(t, omega_m)
return dh * 0.5 * z_s * jnp.dot(_GL_WEIGHTS, integrand)
return jax.vmap(_chi_scalar)(jnp.atleast_1d(jnp.asarray(z, dtype=jnp.float64))).squeeze()
[docs]
@jax.jit
def comoving_volume_jax(
z: jnp.ndarray,
h: float,
omega_m: float,
) -> jnp.ndarray:
"""Comoving volume V_c(z) [Mpc^3] for flat ΛCDM.
V_c(z) = (4π/3) χ(z)^3 (flat universe).
JAX-differentiable.
Parameters
----------
z : jnp.ndarray
Redshift (scalar or array).
h : float
Dimensionless Hubble parameter.
omega_m : float
Matter density parameter Ω_m.
Returns
-------
vc : jnp.ndarray
Comoving volume [Mpc^3], same shape as z.
"""
chi = comoving_distance_jax(z, h, omega_m)
return (4.0 * jnp.pi / 3.0) * chi**3
[docs]
@jax.jit
def critical_surface_density_jax(
z_l: jnp.ndarray,
z_s: jnp.ndarray,
h: float,
omega_m: float,
comoving: bool = True,
) -> jnp.ndarray:
"""Critical surface density Σ_crit [M_sun/pc^2].
Σ_crit = (c^2 / 4πG) * D_s / (D_l * D_ls)
where distances are angular diameter distances.
For comoving=True, comoving critical surface density is returned.
Parameters
----------
z_l : jnp.ndarray
Lens redshift.
z_s : jnp.ndarray
Source redshift.
h : float
Dimensionless Hubble parameter.
omega_m : float
Matter density parameter.
comoving : bool
If True, return comoving Σ_crit (multiply physical by (1+z_l)^2).
Returns
-------
sigma_crit : jnp.ndarray
Critical surface density [M_sun/pc^2].
"""
# Physical constants in SI-compatible units for M_sun/pc^2
# c^2 / (4πG) in M_sun/Mpc units = 1.6625e18 M_sun/Mpc
_C2_4PIG = 1.6625e18 # M_sun / Mpc (= c^2 / (4πG) in these units)
chi_l = comoving_distance_jax(z_l, h, omega_m)
chi_s = comoving_distance_jax(z_s, h, omega_m)
chi_ls = chi_s - chi_l # flat universe
# Angular diameter distances
D_l = chi_l / (1.0 + z_l)
D_s = chi_s / (1.0 + z_s)
D_ls = chi_ls / (1.0 + z_s)
sigma_crit_phys = _C2_4PIG * D_s / (D_l * D_ls) # M_sun/Mpc
_ = sigma_crit_phys / (1.0e6) ** 2 * 1.0e6 # convert Mpc^-2 -> pc^-2 (superseded below)
# Simpler: Σ_crit [M_sun/pc^2] = (c^2/4πG) * D_s / (D_l * D_ls) * (1 Mpc/1e6 pc)^2
# _C2_4PIG [M_sun/Mpc] -> [M_sun/pc^2]: divide by (1e6)^2 * Mpc_to_pc = 1e6 (already pc)
# Actually: 1 Mpc = 3.0857e22 m, 1 pc = 3.0857e16 m, so 1 Mpc = 1e6 pc
# Σ_crit [M_sun/pc^2] = Σ_crit_Mpc [M_sun/Mpc^2] / (1e6 pc/Mpc)^2 = Σ_crit_Mpc * 1e-12
# _C2_4PIG [M_sun/Mpc] / (Mpc^2) -> need D_s/(D_l*D_ls) in Mpc^-1
sigma_crit_mpc2 = _C2_4PIG * D_s / (D_l * D_ls) # M_sun / Mpc^2
sigma_crit = sigma_crit_mpc2 * 1e-12 # M_sun / pc^2 (1 Mpc = 1e6 pc)
sigma_crit = jnp.where(comoving, sigma_crit * (1.0 + z_l) ** 2, sigma_crit)
return sigma_crit
[docs]
def astropy_to_jax_cosmo(cosmo: FlatLambdaCDM) -> dict[str, float]:
"""Extract h and omega_m from an astropy FlatLambdaCDM for use in JAX functions.
Parameters
----------
cosmo : FlatLambdaCDM
Astropy cosmology instance.
Returns
-------
dict with keys ``h`` (float) and ``omega_m`` (float).
"""
return {
"h": float(cosmo.H0.value / 100.0),
"omega_m": float(cosmo.Om0),
}
DEFAULT_COSMO = FlatLambdaCDM(H0=67.36, Om0=0.3111, name="Planck18_approx")