Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions pykeops/torch/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def forward(
rec_multVar_highdim,
nx,
ny,
solver_callback,
*args
):

Expand Down Expand Up @@ -92,8 +93,15 @@ def linop(var):
res += alpha * var
return res

def default_solver_callback(_ctx, _args, _varinvpos, _tagCPUGPU, _tag1D2D, _tagHostDevice):
return ConjugateGradientSolver("torch", linop, varinv.data, eps)

global copy
result = ConjugateGradientSolver("torch", linop, varinv.data, eps)
ctx.solver_callback = solver_callback

if solver_callback is None:
solver_callback = default_solver_callback
result = solver_callback(ctx, args, varinvpos, tagCPUGPU, tag1D2D, tagHostDevice)

# relying on the 'ctx.saved_variables' attribute is necessary if you want to be able to differentiate the output
# of the backward once again. It helps pytorch to keep track of 'who is who'.
Expand Down Expand Up @@ -161,6 +169,7 @@ def backward(ctx, G):
rec_multVar_highdim,
nx,
ny,
ctx.solver_callback,
*newargs
)

Expand All @@ -169,8 +178,8 @@ def backward(ctx, G):
for (var_ind, sig) in enumerate(aliases): # Run through the arguments
# If the current gradient is to be discarded immediatly...
if not ctx.needs_input_grad[
var_ind + 13
]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, nx, ny)
var_ind + 14
]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, nx, ny, solver_callback)
grads.append(None) # Don't waste time computing it.

else: # Otherwise, the current gradient is really needed by the user:
Expand Down Expand Up @@ -248,7 +257,7 @@ def backward(ctx, G):
)
grads.append(grad)

# Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, *args
# Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, optional_flags, rec_multVar_highdim, solver_callback, *args
return (
None,
None,
Expand All @@ -263,6 +272,7 @@ def backward(ctx, G):
None,
None,
None,
None,
*grads,
)

Expand Down Expand Up @@ -323,6 +333,7 @@ def __init__(
sum_scheme="auto",
enable_chunks=True,
rec_multVar_highdim=None,
solver_callback=None,
):
r"""
Instantiate a new KernelSolve operation.
Expand Down Expand Up @@ -397,6 +408,7 @@ def __init__(

enable_chunks (bool, default True): enable automatic selection of special "chunked" computation mode for accelerating reductions
with formulas involving large dimension variables.
solver_callback (function): custom linear solver

"""
if cuda_type:
Expand Down Expand Up @@ -426,6 +438,7 @@ def __init__(
self.varinvpos = varinvpos
self.dtype = dtype
self.rec_multVar_highdim = rec_multVar_highdim
self.solver_callback = solver_callback

def __call__(
self, *args, backend="auto", device_id=-1, alpha=1e-10, eps=1e-6, ranges=None
Expand Down Expand Up @@ -497,5 +510,6 @@ def __call__(
self.rec_multVar_highdim,
nx,
ny,
self.solver_callback,
*args
)