feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel#2124
feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel#2124juanitorduz wants to merge 4 commits intopyro-ppl:masterfrom
Conversation
Add MCLMC inference algorithm as a new MCMCKernel that wraps blackjax's MCLMC implementation. This provides an alternative gradient-based MCMC method to NUTS/HMC. Features: - MCLMC kernel with automatic step size and trajectory length tuning - Optional blackjax dependency with informative error message - postprocess_fn for constrained/unconstrained transformations - Diagnostics string for progress bar - Comprehensive test suite References: - Microcanonical Hamiltonian Monte Carlo (arXiv:2212.08549)
|
Hey @reubenharry I tried using your branch, but I made a git mess 🙈 so I decided to open another one to get feedback (I will add you as a coauthor). Could you please check this one out and see if the implementation (and the tests) are as expected? |
fehiepsi
left a comment
There was a problem hiding this comment.
I worry that we might introduce technical debt by depending on other libraries. Could we turn this into an example/tutorial instead? It's not clear to me the benefit of using numpyro here.
|
Maybe adding a section for https://num.pyro.ai/en/stable/tutorials/other_samplers.html would be better? edit: #2035 has good discussion about the above points |
Sure! This was a first attempt at trying to see how it would fit. I agree this "optional" dependencies can be hard to maintain. So what about adding another section with this code in https://num.pyro.ai/en/stable/tutorials/other_samplers.html ? Or do we want an additional notebook with just a brief explanation of MCLMC? |
|
on the other hand the optimal outcome would be that its very easy for numpyro users to use the sampler. if it's hidden in some tutorial... |
True 😄 . I do not have any strong opinion. I just wanted to bring this one alive. That being said, I have seen blackjax making breaking changes and this would be a pain to maintain indeed (in the notebook we would put a disclaimer about these potential changes) |
|
how entangled is the blackjax implementation with the rest of blackjax? can the sampler be ripped out with minimal changes? |
|
Awesome, thanks for doing this! It looks good to me, although for peace of mind, we should take a non-trivial example and check that the efficiency is good (e.g. Stochastic Volatility or something). Perhaps I can do that by running the Numpyro version in my https://github.com/reubenharry/sampler-benchmarks repo and checking the results, when I have more time. @martinjankowiak Re entanglement, I have previously ported the implementation outside Blackjax before (even out of Jax), and it isn't insanely hard to do so (blackjax itself isn't really a very complex codebase). More a question of time. I agree that the main draw of adding the code to Numpyro is discoverability for users. (I have been a bit busy this month, but will try to be more responsive going forward) |
|
This implementation works for me. I'd say that if it was totally necessary, the relevant parts of blackjax could be pulled out and the implementation could be self-contained. But I think it would be more natural to keep it in - I think changes that break backwards compatibility of this sampler are unlikely to be frequent. |
Trying to support #2039