Source code for sum_stat.lensing.shear_calib

"""Shear calibration corrections for galaxy-galaxy lensing.

All correction functions are JAX-differentiable where inputs are arrays.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np

from ..catalogue import PhotoZCalibTable


[docs] @jax.jit def apply_multiplicative_bias( delta_sigma: jnp.ndarray, mean_m: float, ) -> jnp.ndarray: """Correct ΔΣ for multiplicative shear bias. ΔΣ_corr = ΔΣ / (1 + m̄) Parameters ---------- delta_sigma : jnp.ndarray, shape (n_bins,) Raw excess surface density [M_sun/pc^2]. mean_m : float Weighted mean multiplicative bias ⟨m⟩. Returns ------- delta_sigma_corr : jnp.ndarray, shape (n_bins,) Corrected ΔΣ [M_sun/pc^2]. """ return delta_sigma / (1.0 + mean_m)
[docs] @jax.jit def apply_shear_response( e1: jnp.ndarray, e2: jnp.ndarray, R11: jnp.ndarray, R22: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Apply diagonal shear response correction to ellipticity components. γ_i = e_i / R_ii Parameters ---------- e1, e2 : jnp.ndarray, shape (N,) Measured ellipticity components. R11, R22 : jnp.ndarray, shape (N,) Diagonal elements of the per-galaxy shear response matrix. Returns ------- gamma1, gamma2 : jnp.ndarray, shape (N,) Shear estimates corrected for response. """ return e1 / jnp.where(R11 != 0, R11, 1.0), e2 / jnp.where(R22 != 0, R22, 1.0)
[docs] def nz_from_calibration_table( calib: PhotoZCalibTable, z_bins: np.ndarray, lens_z: float | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Estimate the source n(z) from a photo-z calibration table. Optionally selects only source galaxies behind a given lens redshift (z_true > lens_z) to avoid source contamination. Parameters ---------- calib : PhotoZCalibTable Photo-z calibration table. z_bins : ndarray, shape (n_bins+1,) Redshift bin edges for the output n(z). lens_z : float, optional If provided, only include calibration objects with z_true > lens_z. Returns ------- z_centres : ndarray, shape (n_bins,) Redshift bin centres. n_z : ndarray, shape (n_bins,) Normalised n(z) (integrates to 1 over z_bins). """ mask = np.ones(calib.n, dtype=bool) if lens_z is not None: mask &= calib.z_true > lens_z z_bins = np.asarray(z_bins, dtype=np.float64) counts, _ = np.histogram(calib.z_true[mask], bins=z_bins, weights=(calib.w * calib.w_sys)[mask]) z_centres = 0.5 * (z_bins[:-1] + z_bins[1:]) dz = z_bins[1:] - z_bins[:-1] norm = np.sum(counts * dz) if norm > 0: n_z = counts / norm else: n_z = counts return z_centres, n_z
[docs] def effective_sigma_crit_inv( z_l: float, z_s_grid: np.ndarray, n_z: np.ndarray, h: float, omega_m: float, comoving: bool = True, ) -> float: """Effective inverse critical surface density ⟨Σ_crit^{-1}⟩. ⟨Σ_crit^{-1}⟩ = ∫ Σ_crit^{-1}(z_l, z_s) n(z_s) dz_s Parameters ---------- z_l : float Lens redshift. z_s_grid : ndarray, shape (n_z,) Source redshift grid. n_z : ndarray, shape (n_z,) Source n(z) (normalised, units sr^-1 dz^-1). h, omega_m : float Cosmological parameters for distance computation. comoving : bool Use comoving Σ_crit if True. Returns ------- sigma_crit_inv_eff : float ⟨Σ_crit^{-1}⟩ [pc^2/M_sun]. """ import jax.numpy as jnp from ..cosmology import critical_surface_density_jax z_s = jnp.array(z_s_grid[z_s_grid > z_l + 0.01]) n_z_s = jnp.array(n_z[z_s_grid > z_l + 0.01]) if z_s.size == 0: return 0.0 sigma_crit = critical_surface_density_jax(jnp.full_like(z_s, z_l), z_s, h, omega_m, comoving) sigma_crit_inv = 1.0 / jnp.where(sigma_crit > 0, sigma_crit, jnp.inf) dz = jnp.gradient(jnp.array(z_s_grid[z_s_grid > z_l + 0.01])) return float(jnp.sum(sigma_crit_inv * n_z_s * dz))