`pymc` is great, but when used with large models and hardware with limited RAM, it could do better. The problem is that it requires that all samples fit in the RAM of the GPU. I have improved `pymc` so that it does not require all samples to fit in the RAM of the GPU. This post is to remind me on what I did, and the second aim is to document how to use the resulting code, ie. how to install a locally patched python package.
pip install ~/src/jax/dist/jaxlib-0.4.9-cp310-cp310-manylinux2014_x86_64.whl
It is actually rather simple:
pip install ~/src/pymc