Conversation
|
|
||
| def icdf(self, q): | ||
| # https://github.com/pyro-ppl/numpyro/issues/969 | ||
| from numpyro.distributions.util import gammaincinv |
There was a problem hiding this comment.
I think you can move this import to the top.
| return cls(batch_shape=aux_data) | ||
|
|
||
|
|
||
| def TruncatedGamma(base_gamma, low=None, high=None, validate_args=None): |
There was a problem hiding this comment.
I think it is better to expose the parameters of Gamma here (TruncatedGamma(concentration, rate, low=..., high=...), rather than using a nested pattern. There are a couple of benefits with that:
- parameters of the distribution is defined probably in
args_constraints - it is easier to test
- no need to have flatten/unflatten logic
| base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) | ||
| return cls(base_gamma, low=low) | ||
|
|
||
| @validate_sample |
There was a problem hiding this comment.
Unfortunately, currently validate_sample logic does not work with cdf :(
| # until jax/lax has direct implementation we'll need to rely on tfp | ||
| # https://github.com/pyro-ppl/numpyro/issues/969 | ||
| try: | ||
| import tensorflow_probability as tfpm |
There was a problem hiding this comment.
I think you can import tensorflow_probability.substrates.jax directly, to make sure that jax backend is installed.
| return lprob - jnp.log(1.0 - lscale) | ||
|
|
||
| def _scale_moment(self, t): | ||
| assert t > -self.base_gamma.concentration |
There was a problem hiding this comment.
This won't work for jax arrays (which might have abstract values under jit compiling). You can use jnp.where to mask out the invalid cases like this.
| def log_prob(self, value): | ||
| lprob = self.base_gamma.log_prob(value) | ||
| lscale = self.base_gamma.cdf(self.low) | ||
| return lprob - jnp.log(1.0 - lscale) |
There was a problem hiding this comment.
You can use log1p(-lscale) for a better numerical result
|
@quattro Looking the the PR is is the good shape - just have small comments above. Any chance we can have this in the next numpyro release? |
|
Will try my best. Should have some time closer to Thanksgiving holidays, does that fall before next release schedule? |
|
Absolutely, there is no plan for the release date yet. Thank you! |
|
Will we have this feature in the future? |
PR for issue #969 . Contains initial implementation that performs uniform sampling + inverse CDF of Left/Right/Doubly truncated Gamma. Relies on tensorflow functionality for igammainv function, which is not yet implemented at the lax/jax level (see jax-ml/jax#5350).
There is a test that fails, but it is not clear to me if this is purely a numerical issue with the uniform + iCDF sampling, or a larger issue that I missed at the time I implemented things.