Installation
Requirements
Python ≥ 3.11
Create and activate the environment
mamba env create -f environment.yml
mamba activate sum_stat
Install the package in editable mode
pip install -e ".[dev,docs]"
Verify the installation
import sum_stat
print(sum_stat.__version__)
import jax
print(jax.devices()) # should show CPU or GPU
Build the documentation
cd docs && make html
# Open docs/_build/html/index.html
Run the tests
pytest tests/ -v
# Benchmarks only (prints timing per function):
pytest tests/test_benchmarks.py -v -s
GPU / JAX configuration
By default JAX uses CPU. To use a GPU, install the GPU version of jaxlib in the environment:
pip install --upgrade "jax[cuda12]" # or cuda11
The package enables jax_enable_x64 automatically on import so that all
floating-point operations use 64-bit precision.
Remote / HPC cluster setup
Use conda-pack to ship the solved environment to a cluster without
running the solver remotely.
Step 1 — pack the environment locally:
mamba activate sum_stat
conda-pack -o sum_stat_env.tar.gz
Step 2 — transfer and unpack on the remote server:
scp sum_stat_env.tar.gz remote-server:$SCRATCH/
ssh remote-server
mkdir -p $SCRATCH/sum_stat_env
tar -xzf $SCRATCH/sum_stat_env.tar.gz -C $SCRATCH/sum_stat_env
source $SCRATCH/sum_stat_env/bin/activate
Step 3 — clone the repository and install in editable mode:
git clone https://github.com/JohanComparat/sum_stat.git $SCRATCH/sum_stat
pip install -e $SCRATCH/sum_stat --no-build-isolation
Step 4 — verify:
python -c "import sum_stat; print(sum_stat.__version__)"
python -c "import jax; print(jax.devices())"
See Running the measurement scripts for the full script reference.