Installation

Requirements

Create and activate the environment

mamba env create -f environment.yml
mamba activate sum_stat

Install the package in editable mode

pip install -e ".[dev,docs]"

Verify the installation

import sum_stat
print(sum_stat.__version__)

import jax
print(jax.devices())  # should show CPU or GPU

Build the documentation

cd docs && make html
# Open docs/_build/html/index.html

Run the tests

pytest tests/ -v

# Benchmarks only (prints timing per function):
pytest tests/test_benchmarks.py -v -s

GPU / JAX configuration

By default JAX uses CPU. To use a GPU, install the GPU version of jaxlib in the environment:

pip install --upgrade "jax[cuda12]"   # or cuda11

The package enables jax_enable_x64 automatically on import so that all floating-point operations use 64-bit precision.

Remote / HPC cluster setup

Use conda-pack to ship the solved environment to a cluster without running the solver remotely.

Step 1 — pack the environment locally:

mamba activate sum_stat
conda-pack -o sum_stat_env.tar.gz

Step 2 — transfer and unpack on the remote server:

scp sum_stat_env.tar.gz remote-server:$SCRATCH/
ssh remote-server
mkdir -p $SCRATCH/sum_stat_env
tar -xzf $SCRATCH/sum_stat_env.tar.gz -C $SCRATCH/sum_stat_env
source $SCRATCH/sum_stat_env/bin/activate

Step 3 — clone the repository and install in editable mode:

git clone https://github.com/JohanComparat/sum_stat.git $SCRATCH/sum_stat
pip install -e $SCRATCH/sum_stat --no-build-isolation

Step 4 — verify:

python -c "import sum_stat; print(sum_stat.__version__)"
python -c "import jax; print(jax.devices())"

See Running the measurement scripts for the full script reference.