`pymc` is great, but when used with large models and hardware with limited RAM, it could do better. The problem is that it requires that all samples fit in the RAM of the GPU. I have improved `pymc` so that it does not require all samples to fit in the RAM of the GPU. This post is to remind me on what I did, and the second aim is to document how to use the resulting code, ie. how to install a locally patched python package.
To have your version conveniently accessible, create a fork from the original repostory in your github.
Creating a fork is easy, just click "Fork" and "+ Create a new fork"
I already had a directory for pymc (which was a direct clone of upstream, not recommended), so I used a custom name for the directory holding the local repository "pymc-lowmem".
cd ~/src
git clone git@github.com:hans-ekbrand/pymc.git pymc-lowmem
cd pymc-lowmem
git branch batching_samples
git checkout batching_samples
I edited pymc/sampling/jax.py
see install a locally hacked pymc
pip install ~/src/jax/dist/jaxlib-0.4.9-cp310-cp310-manylinux2014_x86_64.whl
a locally hacked pymc
It is actually rather simple:
pip install ~/src/pymc-lomem
If you want to install for a specific version of python run that version with
-m pip
python3.10 -m pip install ~/src/pymc-lowmem