Hacking pymc/jax

`pymc` is great, but when used with large models and hardware with limited RAM, it could do better. The problem is that it requires that all samples fit in the RAM of the GPU. I have improved `pymc` so that it does not require all samples to fit in the RAM of the GPU. This post is to remind me on what I did, and the second aim is to document how to use the resulting code, ie. how to install a locally patched python package.

Installing a locally patched python package

Install a locally compiled jaxlib

pip install ~/src/jax/dist/jaxlib-0.4.9-cp310-cp310-manylinux2014_x86_64.whl

install a locally hacked pymc

It is actually rather simple: pip install ~/src/pymc

comments powered by Disqus


Back to the index

Blog roll

R-bloggers, Debian Weekly
Valid XHTML 1.0 Strict [Valid RSS] Valid CSS! Emacs Muse Last modified: maj 9, 2023