Cosmology utilities
Warning
These utilities are internal helpers used by all estimators. The API is not yet considered stable and may change.
The sum_stat.cosmology module provides JAX-native implementations of
comoving distance, comoving volume, and critical surface density integrals.
These are used internally by all estimators that require a cosmology;
end users typically pass an astropy.cosmology.FLRW object which is
converted automatically via astropy_to_jax_cosmo().
JAX-differentiable cosmological distances for a flat ΛCDM universe.
All functions operate on plain floats / jnp arrays (no astropy Quantities) so they can be used inside jax.jit / jax.grad / jax.vmap contexts.
The integration uses 32-point Gauss-Legendre quadrature, which is exact for smooth integrands and avoids adaptive ODE solvers that are not JAX-traceable.
- sum_stat.cosmology.astropy_to_jax_cosmo(cosmo)[source]
Extract h and omega_m from an astropy FlatLambdaCDM for use in JAX functions.
- Parameters:
cosmo (FlatLambdaCDM) – Astropy cosmology instance.
- Returns:
dict with keys
h(float) andomega_m(float).- Parameters:
cosmo (FlatLambdaCDM)
- Return type:
- sum_stat.cosmology.comoving_distance_jax(z, h, omega_m)[source]
Comoving distance χ(z) [Mpc] for flat ΛCDM.
Uses 32-point Gauss-Legendre quadrature on [0, z] mapped to [-1, 1]. JAX-differentiable with respect to z, h, and omega_m.
- Parameters:
z (jnp.ndarray) – Redshift (scalar or array).
h (float) – Dimensionless Hubble parameter H0 / (100 km/s/Mpc).
omega_m (float) – Matter density parameter Ω_m.
- Returns:
chi (jnp.ndarray) – Comoving distance [Mpc], same shape as z.
- Parameters:
- Return type:
Array
- sum_stat.cosmology.comoving_volume_jax(z, h, omega_m)[source]
Comoving volume V_c(z) [Mpc^3] for flat ΛCDM.
V_c(z) = (4π/3) χ(z)^3 (flat universe). JAX-differentiable.
- sum_stat.cosmology.critical_surface_density_jax(z_l, z_s, h, omega_m, comoving=True)[source]
Critical surface density Σ_crit [M_sun/pc^2].
Σ_crit = (c^2 / 4πG) * D_s / (D_l * D_ls)
where distances are angular diameter distances. For comoving=True, comoving critical surface density is returned.
- Parameters:
z_l (jnp.ndarray) – Lens redshift.
z_s (jnp.ndarray) – Source redshift.
h (float) – Dimensionless Hubble parameter.
omega_m (float) – Matter density parameter.
comoving (bool) – If True, return comoving Σ_crit (multiply physical by (1+z_l)^2).
- Returns:
sigma_crit (jnp.ndarray) – Critical surface density [M_sun/pc^2].
- Parameters:
- Return type:
Array