Hacking pymc/jax

`pymc` is great, but when used with large models and hardware with limited RAM, it could do better. The problem is that it requires that all samples fit in the RAM of the GPU. I have improved `pymc` so that it does not require all samples to fit in the RAM of the GPU. This post is to remind me on what I did, and the second aim is to document how to use the resulting code, ie. how to install a locally patched python package.

Create a fork on github

To have your version conveniently accessible, create a fork from the original repostory in your github.

Creating a fork is easy, just click "Fork" and "+ Create a new fork"

Clone the fork to your local computer

I already had a directory for pymc (which was a direct clone of upstream, not recommended), so I used a custom name for the directory holding the local repository "pymc-lowmem". cd ~/src git clone git@github.com:hans-ekbrand/pymc.git pymc-lowmem

Create new branch for your work

cd pymc-lowmem git branch batching_samples git checkout batching_samples

Modify the code

I edited pymc/sampling/jax.py

Testing the modified code

see install a locally hacked pymc

Installing a locally patched python package

Install a locally compiled jaxlib

pip install ~/src/jax/dist/jaxlib-0.4.9-cp310-cp310-manylinux2014_x86_64.whl

install a locally hacked pymc

a locally hacked pymc It is actually rather simple: pip install ~/src/pymc-lomem

If you want to install for a specific version of python run that version with

-m pip

python3.10 -m pip install ~/src/pymc-lowmem

comments powered by Disqus


Back to the index

Blog roll

R-bloggers, Debian Weekly
Valid XHTML 1.0 Strict [Valid RSS] Valid CSS! Emacs Muse Last modified: februari 11, 2024