Skip to content

ENH: Delegating from JAX to CuPy under the JIT #602

@steppi

Description

@steppi

lazy_apply works very well for delegating to the NumPy backend on CPU within JAX JIT-ed functions. I've had it on my wish-list for a while to be able to do the same thing with CuPy on GPU. lazy_apply uses pure_callback which inherently brings data to host, so that's not a viable approach for what I'd like to do. I've started tinkering with the JAX FFI which allows calling out to external code within the JAX JIT and have come up with a working prototype which allows for such delegation. It can be found here, https://github.com/steppi/jax_cupy_bridge/tree/main. It offers jax_cupy_bridge.cupy_lazy_apply which works much like lazy_apply with a few limitations:

  • func must have an out argument for specifying the output array or arrays.
  • func cannot do any dynamic cuda allocations, if it needs workspace buffers, it needs to take them as inputs. (This is actually the reason behind the first limitation).
  • It can only have array arguments. I don't think this is an inherent limitation though.

There's still a lot that can be done within these limitations though, for instance, all CuPy ufuncs backed by a kernel from xsf can be supported in JAX under the JIT this way. I've created a discussion on the jax-ml discussion board here, jax-ml/jax#34732, to try to get a sense of whether what I'm doing here is kosher, or if there may be better ways to accomplish this objective.

For now, I'm just making people aware of this. My hope, if what I've done is not a cursed abomination, is that this sort of thing could be integrated into lazy_apply so it can be made to seamlessly handle JAX to CuPy delegation if requested. Or perhaps, it would be better to be more explicit about things. I'm not too settled on any particular API, I just want to be able to delegate from JAX to CuPy under the JIT.

cc @rgommers @mdhaber @ogrisel who I briefly mentioned this idea to in a project meeting earlier this week.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions