"""Shear calibration corrections for galaxy-galaxy lensing.
All correction functions are JAX-differentiable where inputs are arrays.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
from ..catalogue import PhotoZCalibTable
[docs]
@jax.jit
def apply_multiplicative_bias(
delta_sigma: jnp.ndarray,
mean_m: float,
) -> jnp.ndarray:
"""Correct ΔΣ for multiplicative shear bias.
ΔΣ_corr = ΔΣ / (1 + m̄)
Parameters
----------
delta_sigma : jnp.ndarray, shape (n_bins,)
Raw excess surface density [M_sun/pc^2].
mean_m : float
Weighted mean multiplicative bias ⟨m⟩.
Returns
-------
delta_sigma_corr : jnp.ndarray, shape (n_bins,)
Corrected ΔΣ [M_sun/pc^2].
"""
return delta_sigma / (1.0 + mean_m)
[docs]
@jax.jit
def apply_shear_response(
e1: jnp.ndarray,
e2: jnp.ndarray,
R11: jnp.ndarray,
R22: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Apply diagonal shear response correction to ellipticity components.
γ_i = e_i / R_ii
Parameters
----------
e1, e2 : jnp.ndarray, shape (N,)
Measured ellipticity components.
R11, R22 : jnp.ndarray, shape (N,)
Diagonal elements of the per-galaxy shear response matrix.
Returns
-------
gamma1, gamma2 : jnp.ndarray, shape (N,)
Shear estimates corrected for response.
"""
return e1 / jnp.where(R11 != 0, R11, 1.0), e2 / jnp.where(R22 != 0, R22, 1.0)
[docs]
def nz_from_calibration_table(
calib: PhotoZCalibTable,
z_bins: np.ndarray,
lens_z: float | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Estimate the source n(z) from a photo-z calibration table.
Optionally selects only source galaxies behind a given lens redshift
(z_true > lens_z) to avoid source contamination.
Parameters
----------
calib : PhotoZCalibTable
Photo-z calibration table.
z_bins : ndarray, shape (n_bins+1,)
Redshift bin edges for the output n(z).
lens_z : float, optional
If provided, only include calibration objects with z_true > lens_z.
Returns
-------
z_centres : ndarray, shape (n_bins,)
Redshift bin centres.
n_z : ndarray, shape (n_bins,)
Normalised n(z) (integrates to 1 over z_bins).
"""
mask = np.ones(calib.n, dtype=bool)
if lens_z is not None:
mask &= calib.z_true > lens_z
z_bins = np.asarray(z_bins, dtype=np.float64)
counts, _ = np.histogram(calib.z_true[mask], bins=z_bins, weights=(calib.w * calib.w_sys)[mask])
z_centres = 0.5 * (z_bins[:-1] + z_bins[1:])
dz = z_bins[1:] - z_bins[:-1]
norm = np.sum(counts * dz)
if norm > 0:
n_z = counts / norm
else:
n_z = counts
return z_centres, n_z
[docs]
def effective_sigma_crit_inv(
z_l: float,
z_s_grid: np.ndarray,
n_z: np.ndarray,
h: float,
omega_m: float,
comoving: bool = True,
) -> float:
"""Effective inverse critical surface density ⟨Σ_crit^{-1}⟩.
⟨Σ_crit^{-1}⟩ = ∫ Σ_crit^{-1}(z_l, z_s) n(z_s) dz_s
Parameters
----------
z_l : float
Lens redshift.
z_s_grid : ndarray, shape (n_z,)
Source redshift grid.
n_z : ndarray, shape (n_z,)
Source n(z) (normalised, units sr^-1 dz^-1).
h, omega_m : float
Cosmological parameters for distance computation.
comoving : bool
Use comoving Σ_crit if True.
Returns
-------
sigma_crit_inv_eff : float
⟨Σ_crit^{-1}⟩ [pc^2/M_sun].
"""
import jax.numpy as jnp
from ..cosmology import critical_surface_density_jax
z_s = jnp.array(z_s_grid[z_s_grid > z_l + 0.01])
n_z_s = jnp.array(n_z[z_s_grid > z_l + 0.01])
if z_s.size == 0:
return 0.0
sigma_crit = critical_surface_density_jax(jnp.full_like(z_s, z_l), z_s, h, omega_m, comoving)
sigma_crit_inv = 1.0 / jnp.where(sigma_crit > 0, sigma_crit, jnp.inf)
dz = jnp.gradient(jnp.array(z_s_grid[z_s_grid > z_l + 0.01]))
return float(jnp.sum(sigma_crit_inv * n_z_s * dz))