-
Notifications
You must be signed in to change notification settings - Fork 17
Description
lazy_apply works very well for delegating to the NumPy backend on CPU within JAX JIT-ed functions. I've had it on my wish-list for a while to be able to do the same thing with CuPy on GPU. lazy_apply uses pure_callback which inherently brings data to host, so that's not a viable approach for what I'd like to do. I've started tinkering with the JAX FFI which allows calling out to external code within the JAX JIT and have come up with a working prototype which allows for such delegation. It can be found here, https://github.com/steppi/jax_cupy_bridge/tree/main. It offers jax_cupy_bridge.cupy_lazy_apply which works much like lazy_apply with a few limitations:
- func must have an
outargument for specifying the output array or arrays. - func cannot do any dynamic cuda allocations, if it needs workspace buffers, it needs to take them as inputs. (This is actually the reason behind the first limitation).
- It can only have array arguments. I don't think this is an inherent limitation though.
There's still a lot that can be done within these limitations though, for instance, all CuPy ufuncs backed by a kernel from xsf can be supported in JAX under the JIT this way. I've created a discussion on the jax-ml discussion board here, jax-ml/jax#34732, to try to get a sense of whether what I'm doing here is kosher, or if there may be better ways to accomplish this objective.
For now, I'm just making people aware of this. My hope, if what I've done is not a cursed abomination, is that this sort of thing could be integrated into lazy_apply so it can be made to seamlessly handle JAX to CuPy delegation if requested. Or perhaps, it would be better to be more explicit about things. I'm not too settled on any particular API, I just want to be able to delegate from JAX to CuPy under the JIT.
cc @rgommers @mdhaber @ogrisel who I briefly mentioned this idea to in a project meeting earlier this week.