Cosmology utilities

Warning

These utilities are internal helpers used by all estimators. The API is not yet considered stable and may change.

The sum_stat.cosmology module provides JAX-native implementations of comoving distance, comoving volume, and critical surface density integrals. These are used internally by all estimators that require a cosmology; end users typically pass an astropy.cosmology.FLRW object which is converted automatically via astropy_to_jax_cosmo().

JAX-differentiable cosmological distances for a flat ΛCDM universe.

All functions operate on plain floats / jnp arrays (no astropy Quantities) so they can be used inside jax.jit / jax.grad / jax.vmap contexts.

The integration uses 32-point Gauss-Legendre quadrature, which is exact for smooth integrands and avoids adaptive ODE solvers that are not JAX-traceable.

sum_stat.cosmology.astropy_to_jax_cosmo(cosmo)[source]

Extract h and omega_m from an astropy FlatLambdaCDM for use in JAX functions.

Parameters:

cosmo (FlatLambdaCDM) – Astropy cosmology instance.

Returns:

dict with keys h (float) and omega_m (float).

Parameters:

cosmo (FlatLambdaCDM)

Return type:

dict[str, float]

sum_stat.cosmology.comoving_distance_jax(z, h, omega_m)[source]

Comoving distance χ(z) [Mpc] for flat ΛCDM.

Uses 32-point Gauss-Legendre quadrature on [0, z] mapped to [-1, 1]. JAX-differentiable with respect to z, h, and omega_m.

Parameters:
  • z (jnp.ndarray) – Redshift (scalar or array).

  • h (float) – Dimensionless Hubble parameter H0 / (100 km/s/Mpc).

  • omega_m (float) – Matter density parameter Ω_m.

Returns:

chi (jnp.ndarray) – Comoving distance [Mpc], same shape as z.

Parameters:
Return type:

Array

sum_stat.cosmology.comoving_volume_jax(z, h, omega_m)[source]

Comoving volume V_c(z) [Mpc^3] for flat ΛCDM.

V_c(z) = (4π/3) χ(z)^3 (flat universe). JAX-differentiable.

Parameters:
  • z (jnp.ndarray) – Redshift (scalar or array).

  • h (float) – Dimensionless Hubble parameter.

  • omega_m (float) – Matter density parameter Ω_m.

Returns:

vc (jnp.ndarray) – Comoving volume [Mpc^3], same shape as z.

Parameters:
Return type:

Array

sum_stat.cosmology.critical_surface_density_jax(z_l, z_s, h, omega_m, comoving=True)[source]

Critical surface density Σ_crit [M_sun/pc^2].

Σ_crit = (c^2 / 4πG) * D_s / (D_l * D_ls)

where distances are angular diameter distances. For comoving=True, comoving critical surface density is returned.

Parameters:
  • z_l (jnp.ndarray) – Lens redshift.

  • z_s (jnp.ndarray) – Source redshift.

  • h (float) – Dimensionless Hubble parameter.

  • omega_m (float) – Matter density parameter.

  • comoving (bool) – If True, return comoving Σ_crit (multiply physical by (1+z_l)^2).

Returns:

sigma_crit (jnp.ndarray) – Critical surface density [M_sun/pc^2].

Parameters:
Return type:

Array