diff --git a/pykeops/torch/operations.py b/pykeops/torch/operations.py index 823c7abd5..0724b9a55 100644 --- a/pykeops/torch/operations.py +++ b/pykeops/torch/operations.py @@ -36,6 +36,7 @@ def forward( rec_multVar_highdim, nx, ny, + solver_callback, *args ): @@ -92,8 +93,15 @@ def linop(var): res += alpha * var return res + def default_solver_callback(_ctx, _args, _varinvpos, _tagCPUGPU, _tag1D2D, _tagHostDevice): + return ConjugateGradientSolver("torch", linop, varinv.data, eps) + global copy - result = ConjugateGradientSolver("torch", linop, varinv.data, eps) + ctx.solver_callback = solver_callback + + if solver_callback is None: + solver_callback = default_solver_callback + result = solver_callback(ctx, args, varinvpos, tagCPUGPU, tag1D2D, tagHostDevice) # relying on the 'ctx.saved_variables' attribute is necessary if you want to be able to differentiate the output # of the backward once again. It helps pytorch to keep track of 'who is who'. @@ -161,6 +169,7 @@ def backward(ctx, G): rec_multVar_highdim, nx, ny, + ctx.solver_callback, *newargs ) @@ -169,8 +178,8 @@ def backward(ctx, G): for (var_ind, sig) in enumerate(aliases): # Run through the arguments # If the current gradient is to be discarded immediatly... if not ctx.needs_input_grad[ - var_ind + 13 - ]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, nx, ny) + var_ind + 14 + ]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, nx, ny, solver_callback) grads.append(None) # Don't waste time computing it. else: # Otherwise, the current gradient is really needed by the user: @@ -248,7 +257,7 @@ def backward(ctx, G): ) grads.append(grad) - # Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, *args + # Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, solver_callback, *args return ( None, None, @@ -263,6 +272,7 @@ def backward(ctx, G): None, None, None, + None, *grads, ) @@ -323,6 +333,7 @@ def __init__( sum_scheme="auto", enable_chunks=True, rec_multVar_highdim=None, + solver_callback=None, ): r""" Instantiate a new KernelSolve operation. @@ -397,6 +408,7 @@ def __init__( enable_chunks (bool, default True): enable automatic selection of special "chunked" computation mode for accelerating reductions with formulas involving large dimension variables. + solver_callback (function): custom linear solver """ if cuda_type: @@ -426,6 +438,7 @@ def __init__( self.varinvpos = varinvpos self.dtype = dtype self.rec_multVar_highdim = rec_multVar_highdim + self.solver_callback = solver_callback def __call__( self, *args, backend="auto", device_id=-1, alpha=1e-10, eps=1e-6, ranges=None @@ -497,5 +510,6 @@ def __call__( self.rec_multVar_highdim, nx, ny, + self.solver_callback, *args )