diff --git a/doc/conf.py b/doc/conf.py index 29b5fd19b..72b41e7a3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -114,6 +114,11 @@ def skip(app, what, name, obj, would_skip, options): def setup(app): app.connect("autodoc-skip-member", skip) +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, + message='Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.') + # Include the example source for plots in API docs # plot_include_source = True # plot_formats = [("png", 90)] diff --git a/pykeops/benchmarks/plot_benchmark_cg.py b/pykeops/benchmarks/plot_benchmark_cg.py new file mode 100644 index 000000000..d11544a63 --- /dev/null +++ b/pykeops/benchmarks/plot_benchmark_cg.py @@ -0,0 +1,523 @@ +""" +Comparison of conjugate gradient methods +========================================== + +Different implementations of the conjugate gradient (CG) exist. Here, we compare the CG implemented in scipy which uses Fortran +against it's pythonized version and the older version of the algorithm available in pykeops. + +We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b, x \in \mathbb R^N`. + +Let the Gaussian RBF kernel be defined as + +.. math:: + + K_{x,x}=\left[\exp(-\gamma \|x_i - x_j\|^2)\\right]_{i,j=1}^N. + + +Choosing :math:`x` such that :math:`x_i = i/N,\ i=1,\dots, N` makes :math:`K_{x,x}` be a highly unwell-conditioned matrix for :math:`N\geq 10`. + +""" + +############################# +# Setup +# ---------- +# Imports needed + +import importlib +import os +import time +import inspect + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from scipy.sparse import diags +from scipy.sparse.linalg import aslinearoperator, cg + +from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor +from pykeops.torch import KernelSolve +from pykeops.torch.utils import squared_distances +from pykeops.torch import Genred as Genred_tch +from pykeops.numpy import Vi, Vj, Pm +from pykeops.numpy import Genred as Genred_np + +use_cuda = torch.cuda.is_available() + +device = torch.device("cuda") if use_cuda else torch.device("cpu") +print("The device used is {}.".format(device)) + +######################################## +# Gaussian radial basis function kernel +######################################## + +formula = 'Exp(- g * SqDist(x,y)) * a' # linear w.r.t a +aliases = ['x = Vi(1)', # First arg: i-variable of size 1 + 'y = Vj(1)', # Second arg: j-variable of size 1 + 'a = Vj(1)', # Third arg: j-variable of size 1 + 'g = Pm(1)'] + + +############################ +# Functions to benchmark +########################### +# +# All systems are regularized with a ridge parameter ``alpha``. +# +# The originals : +# + + +def keops_tch(x, b, gamma, alpha): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha) + return res + + +def keops_np(x, b, gamma, alpha, callback=None): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha, callback=callback) + return res + + +#################################### +# Scipy : +# +# + + +def scipy_cg(x, b, gamma, alpha, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + res = cg(A, b, callback=callback) + return res + + +#################################### +# Pythonized scipy : +# + + +def dic_cg_np(x, b, gamma, alpha, callback=None, check_cond=False): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, + callback=callback, check_cond=check_cond) + return ans + + +def dic_cg_tch(x, b, gamma, alpha, check_cond=False): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, check_cond=check_cond) + return ans + + +######################### +# Benchmarking +######################### + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +sizes = [50, 100, 500, 1000, 5000, 20000, 40000] +reps = [50 , 50 , 50, 10, 10, 5, 5] + + +def compute_error(func, pack, result, errors, x, b, alpha, gamma): + if str(func)[10:15] == "keops": + code = "a = func(x, b, gamma, alpha).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, size, rep): + global use_cuda + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + for i in range(rep): + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes)): + try: + time, err = to_bench(one_to_bench, sizes[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + + +######################################### +# Plot the results of the benchmarking +######################################### + +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.tight_layout() +plt.show() + + +############################################## +# Stability +# ------------ +# +# Stability of the errors and norm of the iterated approximations of the answer + + +def norm_stability(size, funcpack): + errk_scipy, iter_scipy, x_scipy = [], [], [] + errk_dic, iter_dic, x_dic = [], [], [] + errk_keops, iter_keops, x_keops = [], [], [] + + def callback_sci(xk): + env = inspect.currentframe().f_back + iter_scipy.append(env.f_locals['iter_']) + x_scipy.append(env.f_locals['x']) + err = ( ( alpha * xk.reshape(-1, 1) + K(x, x, xk.reshape(-1, 1), gamma) - b) ** 2).sum() + errk_scipy.append(err) + + def callback_kinv_keops(xk): + env = inspect.currentframe().f_back + err = ( ( alpha * xk + K(x, x, xk, gamma) - b) ** 2).sum() + errk_keops.append(err) + iter_keops.append(env.f_locals['k']) + x_keops.append(env.f_locals['a']) + + def callback_dic(xk): + env = inspect.currentframe().f_back + err = ( ( alpha * xk + K(x, x, xk, gamma) - b) ** 2).sum() + errk_dic.append(err) + iter_dic.append(env.f_locals['iter_']) + x_dic.append(env.f_locals['x']) + + callback_list = [callback_sci, callback_kinv_keops, callback_dic] + + for i, funcpack in enumerate(funcpack): + fun, pack = funcpack + + global x, b, gamma, alpha, K + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + fun(x, b, gamma, alpha, callback=callback_list[i]) + del x, b, gamma, alpha, K + return errk_scipy, iter_scipy, x_scipy, errk_dic, iter_dic, x_keops, errk_keops, iter_keops, x_dic + + +######################################### +# Plot the results of the stability +######################################### + +onlynum = [(scipy_cg, "numpy"), (keops_np, "numpy"), (dic_cg_np, "numpy")] +errk_scipy, iter_scipy, x_scipy, errk_dic, iter_dic,\ + x_keops, errk_keops, iter_keops, x_dic = norm_stability( + 1000, onlynum) + +scal_dic, scal_keops, scal_scipy = [], [], [] +for i in range(1,len(iter_dic)): + scal_dic.append((x_dic[i-1].T @ x_dic[i]).flatten()) +for i in range(1, len(iter_keops)): + scal_keops.append((x_keops[i-1].T @ x_keops[i]).flatten()) +for i in range(1, len(iter_scipy)): + scal_scipy.append((x_scipy[i-1].T @ x_scipy[i]).flatten()) + +plt.figure(figsize=(20,10)) +plt.subplot(121) +plt.plot(iter_keops, errk_keops, 'o-', label=labels[1]) +plt.plot(iter_scipy, errk_scipy, '^-', label=labels[0]) +plt.plot(iter_dic, errk_dic, 'x-', label=labels[3]) +plt.yscale('log') +plt.xlabel(r"Iteration k") +plt.ylabel(r"$||(\alpha\ Id + K_{x,x})x_k - b||^2$") +plt.legend() + +plt.subplot(122) +plt.plot(iter_keops[1:], scal_keops, 'o-', label=labels[1]) +plt.plot(iter_scipy[1:], scal_scipy, '^-', label=labels[0]) +plt.plot(iter_dic[1:], scal_dic, 'x-', label=labels[3]) +plt.yscale('log') +plt.xlabel(r"Iteration k") +plt.ylabel(r"$\langle x_{k-1}|x_k\rangle $") +plt.legend() + +plt.tight_layout() +plt.show() + +####################################################### +# Condition number check +# ------------------------------------- +# +# Scipy's algorithm can't be used practically for large kernels in this case. The condition number can be why. +# +# +# The argument ``check_cond`` in Keops lets the user have an idea of the conditioning number of the matrix :math:`A=(K_{x,x} + \alpha Id)`. A warning appears +# if :math:`\mathrm{cond}(A)>500`. The user is also warned if the CG algorithm reached its maximum number of iterations *ie* did not converge. The idea here +# is not to estimate the condition number and let the user have another sanity check at disposal. +# +# To test the condition number :math:`\mathrm{cond}(A)=\frac{\lambda_{\max}}{\lambda_{\min}}`, we first use the +# power iteration to have a good estimation of :math:`\lambda_{\max}`. Then, wee apply the inverse power iteration +# to obtain the iterations :math:`\mu_k` of the estimated :math:`\lambda_{\min}` using the Rayleigh's quotient after having the iterations :math:`u_k` +# of the estimated eigen vector :math:`u_1`. The distance between the vectors :math:`v_k` and :math:`u_1` decreasing over the iterations at a rate of +# :math:`\mathcal{O}\left(\left|\frac{\lambda_{\min}}{\lambda_{submin}}\right|^k\right)`, if we don't want +# :math:`\frac{\lambda_{\max}}{\lambda_{\min}}>500` then :math:`\mu_k` must not be below the threshold :math:`\frac{\lambda_{\max}}{500}` +# If so, the system warns the user that the condition number might be too high. +# +# In practice only a few iterations are necessary to go below this threshold. Thus we fixed a maximum number of iterations for the inverse +# power method to ``50`` so that for large matrices it doesn't take too much time. + + +def test_cond(device, size, pack, alpha): + if device == 'cuda': + torch.cuda.manual_seed_all(1234) + else: + torch.manual_seed(1234) + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones(1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + alpha = torch.ones(1, device=device, dtype=torch.float32) * alpha # regularization + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + ans = dic_cg_np(x, b, gamma, alpha, check_cond=True) + else: + ans = dic_cg_tch(x, b, gamma, alpha, check_cond=True) + return ans + + +print("Condition number warnings tests") +print("Small matrix well conditioned (nothing should appear)") +ans = test_cond(device, 20, 'numpy', alpha=1) +print("Large matrix unwell conditioned (a warning should appear)") +ans2 = test_cond(device, 1000, 'numpy', alpha=1e-6) +print("Large matrix unwell conditioned but with a large regularization (nothing should appear)") +ans3 = test_cond(device, 1000, 'numpy', alpha=100) + + +########################## +# Zoom in on Keops times +############################ +# +# Let's consider the Keops conjugate gradients for large kernels. Scipy's algorithm explodes in time for +# :math:`n\geq 50000` so we only consider the keops implementations here. +# + + +functions = functions[1:] +sizes = [30000, 50000, 100000, 200000] +reps = [5, 5, 5, 2] +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.show() + +########################### +# Random points +########################### +# +# Let's now use random values for :math:`x_i`. + +def to_bench(funcpack, size, rep): + global use_cuda + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + for i in range(rep): + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma) + + return sum(times)/rep, sum(errors)/rep + + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + + +sizes = [50, 100, 500, 1000, 5000, 20000, 40000] +reps = [50 , 50 , 50, 10, 10, 5, 5] + +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/pykeops/benchmarks/plot_benchmark_cg_dimensions.py b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py new file mode 100644 index 000000000..aeb609346 --- /dev/null +++ b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py @@ -0,0 +1,400 @@ +""" +Conjugate gradient method in arbitrary dimension +==================================================== + +Different implementations of the conjugate gradient (CG) exist. Here, we compare the CG implemented in scipy which uses Fortran +against it's pythonized version and the older version of the algorithm available in pykeops. + +We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b\in \mathbb R^N` and :math:`x\in\mathbb R^{N\\times d}`. +We will use :math:`N=100000` points. + +Let the Gaussian RBF kernel be defined as + +.. math:: + + K_{x,x}=\left[\exp(-\gamma \|x_i - x_j\|^2)\\right]_{i,j=1}^N. + + +The case where :math:`d=1` is already benchmarked in a very ill-conditioned situation, now let's compare when :math:`d` increases. +""" + +############################# +# Setup +# ---------- +# Imports needed + +import importlib +import os +import time +import inspect + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from scipy.sparse import diags +from scipy.sparse.linalg import aslinearoperator, cg + +from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor +from pykeops.torch import KernelSolve +from pykeops.torch.utils import squared_distances +from pykeops.torch import Genred as Genred_tch +from pykeops.numpy import Vi, Vj, Pm +from pykeops.numpy import Genred as Genred_np + +use_cuda = torch.cuda.is_available() + +device = torch.device("cuda") if use_cuda else torch.device("cpu") +print("The device used is {}.".format(device)) + +######################################## +# Gaussian radial basis function kernel +######################################## + +n = 100000 +dv = 1 # number of systems to solve +formula = 'Exp(- g * SqDist(x,y)) * a' # linear w.r.t a + + +############################ +# Functions to benchmark +########################### +# +# All systems are regularized with a ridge parameter ``alpha``. +# +# The originals : +# + + +def keops_tch(x, b, gamma, alpha, aliases, callback=None): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha) + return res + + +def keops_np(x, b, gamma, alpha, aliases, callback=None): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha, callback=callback) + return res + + +#################################### +# Scipy : +# +# + + +def scipy_cg(x, b, gamma, alpha, aliases, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + res = cg(A, b, callback=callback) + return res + + +#################################### +# Pythonized scipy : +# + + +def dic_cg_np(x, b, gamma, alpha, aliases, callback=None, check_cond=False): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, + callback=callback, check_cond=check_cond) + return ans + + +def dic_cg_tch(x, b, gamma, alpha, aliases, check_cond=False): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, check_cond=check_cond) + return ans + + +######################### +# Benchmarking +######################### + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +sizes_d = [10, 50, 75, 100, 150] # dimension of each point +reps = [5, 5, 5, 5, 5] + + +def compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases): + if str(func)[10:15] == "keops": + code = "a = func(x, b, gamma, alpha, aliases).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha, aliases)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, d, rep): + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha, aliases)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + + aliases = ['x = Vi(' + str(d) + ')', # First arg: i-variable of size d + 'y = Vj(' + str(d) + ')', # Second arg: j-variable of size d + 'a = Vj(' + str(dv) + ')', # Third arg: j-variable of size dv + 'g = Pm(1)'] + + for i in range(rep): + + x = torch.rand(n, d, device=device, dtype=torch.float32) + b = torch.randn(n, dv, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha, aliases) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes_d, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes_d)): + try: + time, err = to_bench(one_to_bench, sizes_d[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes_d[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + + +######################################### +# Plot the results of the benchmarking +######################################### + +list_times, list_errors = global_bench(functions, sizes_d, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes_d, list_times[i], label=labels[i]) +plt.xscale('log') +plt.ylim((1e-1, 1e2)) +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} system.".format(n, 'd', dv)) +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes_d, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.ylim((1e-10, 1e-7)) +plt.xlabel(r"Kernel made from {} points of size {} solving {} system.".format(n, 'd', dv)) +plt.ylabel(r"Error $||Ax_{k_{end}} - b||^2$") +plt.legend() +plt.tight_layout() +plt.show() + + +########################################## +# Changing the number of systems to solve +########################################## +# Let's consider the case where :math:`b\in\mathbb R^{dv}`. Then we need to solve multiple systems at once. + +n = 100000 +d = 4 # an image in RGBA for example + + +############################# +# Prepare the benchmarking : +# We need to modify the code for scipy's solver because scipy only solves one system at a time. + + +def scipy_cg_multi(x, b, gamma, alpha, aliases, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + ans = np.zeros(b.shape).astype('float32') + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + for i in range(b.shape[1]): + res = cg(A, b[:, i], callback=callback) + ans[:, i] = res[0].flatten() + return ans + +def compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases): + if str(func)[10:15] == "keops" or str(func)[10:15] == "scipy": + code = "a = func(x, b, gamma, alpha, aliases).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha, aliases)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, dv, rep): + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha, aliases)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + + aliases = ['x = Vi(' + str(d) + ')', # First arg: i-variable of size d + 'y = Vj(' + str(d) + ')', # Second arg: j-variable of size d + 'a = Vj(' + str(dv) + ')', # Third arg: j-variable of size dv + 'g = Pm(1)'] + + for i in range(rep): + + x = torch.rand(n, d, device=device, dtype=torch.float32) + b = torch.randn(n, dv, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha, aliases) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes_dv, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes_d)): + try: + time, err = to_bench(one_to_bench, sizes_dv[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes_dv[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + +################################################################ +# Plot the results of the benchmarking for multi-system solver +################################################################## + +sizes_dv = [1, 5, 10, 50, 100, 150] +reps = [5, 5, 5, 5, 5, 5] + +functions = [(scipy_cg_multi, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +list_times, list_errors = global_bench(functions, sizes_dv, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes_d, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} systems.".format(n, d, 'dv')) +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes_d, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} systems.".format(n, d, 'dv')) +plt.ylabel(r"Error $\sum_{i,j}\left((Ax_{k_{end}} - b)_{i,j}\right)^2$") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/pykeops/benchmarks/plot_benchmark_invkernel.py b/pykeops/benchmarks/plot_benchmark_invkernel.py index f7a6edca4..f2ceed0fd 100644 --- a/pykeops/benchmarks/plot_benchmark_invkernel.py +++ b/pykeops/benchmarks/plot_benchmark_invkernel.py @@ -30,11 +30,12 @@ from scipy.sparse import diags from scipy.sparse.linalg import aslinearoperator, cg -from scipy.sparse.linalg.interface import IdentityOperator from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor from pykeops.torch import KernelSolve from pykeops.torch.utils import squared_distances +from pykeops.numpy import Vi, Vj, Pm + use_cuda = torch.cuda.is_available() @@ -71,7 +72,7 @@ def generate_samples(N, device, lang): x = torch.rand(N, D, device=device) b = torch.randn(N, Dv, device=device) gamma = torch.ones(1, device=device) * .5 / .01 ** 2 # kernel bandwidth - alpha = torch.ones(1, device=device) * 0.8 # regularization + alpha = torch.ones(1, device=device) * 2 # regularization else: np.random.seed(1234) @@ -116,9 +117,9 @@ def Kinv_keops_numpy(x, b, gamma, alpha): return res def Kinv_scipy(x, b, gamma, alpha): - x_i, y_j = LazyTensor( gamma * x[:, None, :]), LazyTensor( gamma * x[None, :, :]) - K_ij = (- ((x_i - y_j) ** 2).sum(2)).exp() - A = aslinearoperator(diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) A.dtype = np.dtype('float32') res = cg(A, b) return res diff --git a/pykeops/common/cg.py b/pykeops/common/cg.py new file mode 100644 index 000000000..0fc99ba2e --- /dev/null +++ b/pykeops/common/cg.py @@ -0,0 +1,148 @@ +import torch +from pykeops.common.utils import get_tools +from math import sqrt +import warnings + + +############################################# +# CG_revcom with Python dictionnary +############################################# + +def cg(linop, b, binding, x=None, eps=None, maxiter=None, callback=None, check_cond=False): + if binding not in ("torch", "numpy", "pytorch"): + raise ValueError( + "Language not supported, please use numpy, torch or pytorch.") + + tools = get_tools(binding) + + # we don't need cuda with numpy (at least i think so) + is_cuda = True if (binding == 'torch' or binding == + 'pytorch') and torch.cuda.is_available() else False + device = torch.device("cuda") if is_cuda else torch.device('cpu') + + b, x, replaced = check_dims(b, x, tools, is_cuda) + n, m = b.shape + + if eps == None: + eps = 1e-6 * sqrt((b ** 2).sum()) + + if maxiter == None: + maxiter = 10 * n + + if check_cond: + from pykeops.common.power_iteration import bootleg_inv_power_cond_big as cond_big + cond_too_big = cond_big(linop, n, binding, device) + if cond_too_big: + warnings.warn( + "Warning ----------- Condition number might be too large.") + + # define the functions needed along the iterations + if binding == "numpy": + p, q, r = tools.zeros((n, m), dtype=b.dtype), tools.zeros( + (n, m), dtype=b.dtype), tools.zeros((n, m), dtype=b.dtype) + scal1, scal2 = tools.zeros(1, dtype=b.dtype), tools.zeros( + 1, dtype=b.dtype) # init the scala values + + else: + p, q, r = tools.zeros((n, m), dtype=b.dtype, device=device), tools.zeros( + (n, m), dtype=b.dtype, device=device), tools.zeros((n, m), dtype=b.dtype, device=device) + scal1, scal2 = tools.zeros(1, dtype=b.dtype, device=device), tools.zeros( + 1, dtype=b.dtype, device=device) # init the scala values + + def init_iter(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + r = tools.copy(b) if replaced else (b - linop(x)) + scal1 = (r ** 2).sum() + job_cg = "check" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def check_resid(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # cg -> revc + if scal1 <= eps**2 or scal1 != scal1: + job_rev = "stop" + else: + iter_ += 1 + job_rev = "direction_next" if iter_ > 1 else "direction_first" + return job_rev, x, r, p, q, scal1, scal2, iter_ + + def first_direct(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + p = tools.copy(r) + job_cg = "matvec_p" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def matvec_p(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # cg -> revc + q = linop(p) + job_rev = "update" + return job_rev, x, r, p, q, scal1, scal2, iter_ + + def update(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + alpha = scal1 / (p * q).sum() + x += alpha * p + r -= alpha * q + scal2 = scal1 + job_cg = "check" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def next_direct(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + scal1 = (r ** 2).sum() + p = r + (scal1 / scal2) * p + job_cg = "matvec_p" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + jobs_cg = {"matvec_p": matvec_p, + "check": check_resid + } + + jobs_revcom = { + "init": init_iter, + "update": update, + "direction_first": first_direct, + "direction_next": next_direct + } + + iter_ = 0 + job_rev = "init" + job_cg = None + + while iter_ <=maxiter: + if job_cg == "check" and callback is not None: + if iter_ > 1: + callback(x) + job_cg, x, r, p, q, scal1, scal2, iter_ = jobs_revcom[job_rev]( + linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_) + job_rev, x, r, p, q, scal1, scal2, iter_ = jobs_cg[job_cg]( + linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_) + + if job_rev == "stop": + break + if (iter_ - 1) == maxiter: + warnings.warn("Warning ----------- Conjugate gradient reached maximum iteration !") + + return x, iter_ + + +#################################### +# Sanity checks +#################################### + +def check_dims(b, x, tools, cuda_avlb): # x is always of b's shape + try: + nrow, ncol = b.shape + except ValueError: + b = b.reshape(-1, 1) + nrow, ncol = b.shape + + x_replaced = False + + if x is None: # check x shape and initiate it if needed + x = tools.zeros((nrow, ncol), dtype=b.dtype, device=torch.device('cuda')) if cuda_avlb \ + else tools.zeros((nrow, ncol), dtype=b.dtype) + x_replaced = True + elif (nrow, ncol) != x.shape: # add sth to check if x is on the same device as b if torch is used! + if x.shape == (nrow,): + x = x.reshape((nrow, ncol)) + else: + raise ValueError("Mismatch between shapes of b {} and shape of x {}.".format( + (nrow, nrow), x.shape)) + if x.dtype != b.dtype: + raise ValueError("Type of given x {} is not compatible with type of b {}.".format(x.dtype, b.dtype)) + + return b, x, x_replaced diff --git a/pykeops/common/operations.py b/pykeops/common/operations.py index 95a302921..42b44af86 100644 --- a/pykeops/common/operations.py +++ b/pykeops/common/operations.py @@ -1,4 +1,5 @@ import numpy as np +import warnings from pykeops.common.utils import get_tools @@ -72,20 +73,25 @@ def postprocess(out, binding, reduction_op, nout, opt_arg, dtype): return out -def ConjugateGradientSolver(binding, linop, b, eps=1e-6): +def ConjugateGradientSolver(binding, linop, b, eps=1e-6, callback=None, maxiter=None): # Conjugate gradient algorithm to solve linear system of the form # Ma=b where linop is a linear operation corresponding # to a symmetric and positive definite matrix + if binding not in ("torch", "numpy", "pytorch"): + raise ValueError( + "Language not supported, please use numpy, torch or pytorch.") tools = get_tools(binding) delta = tools.size(b) * eps ** 2 + if maxiter == None: + maxiter = 10 * tools.size(b) a = 0 r = tools.copy(b) nr2 = (r ** 2).sum() if nr2 < delta: return 0 * r p = tools.copy(r) - k = 0 - while True: + k = 1 + while k <= maxiter: Mp = linop(p) alp = nr2 / (p * Mp).sum() a += alp * p @@ -96,6 +102,10 @@ def ConjugateGradientSolver(binding, linop, b, eps=1e-6): p = r + (nr2new / nr2) * p nr2 = nr2new k += 1 + if callback is not None: + callback(a) + if k == maxiter: + warnings.warn("Warning ----------- Conjugate gradient reached maximum iteration !") return a diff --git a/pykeops/common/power_iteration.py b/pykeops/common/power_iteration.py new file mode 100644 index 000000000..dea59be5c --- /dev/null +++ b/pykeops/common/power_iteration.py @@ -0,0 +1,80 @@ +import torch +import numpy as np +import warnings +from math import sqrt + +from pykeops.common.utils import get_tools +from pykeops.common.cg import cg + +##################### +# Power iteration +##################### + + +def random_draw_np(size, device, dtype='float32'): + return np.random.rand(size, 1).astype(dtype) + + +def random_draw_torch(size, device, dtype=torch.float32): + return torch.rand(size, 1, device=device, dtype=dtype) + + +def power_it_ray(linop, size, binding, device, eps=1e-6): + r""" Compute the eigenvalue of maximum magnitude for a linear operator. + + Args: + linop: a linear operator that, when called, computes the matrix-vector product. + size (int): dimension of the linear operator. + binding (string): torch, pytorch or numpy. + device (torch.device): use GPU or CPU, ``torch.device('cuda')`` or ``torch.device("cpu")``* + for example. + + Keyword Args: + eps (float, default=1e-6): precision for the acceptable distance between two iterates of + the eigenvalue. + + Returns: + lambd_ (float): the eigenvalue of maximum magnitude for ``linop``. + A warning is displayed if the algorithm didn't converge. + """ + random = random_draw_np if binding == "numpy" else random_draw_torch + x = random(size, device) + x = x / sqrt((x ** 2).sum()) + maxiter = 10 * size + k = 0 + while k <= maxiter: + y = linop(x) + norm_y = sqrt((y ** 2).sum()) + z = y / norm_y + lambd_ = (z.T @ linop(z)) + if k > 0 and (old_lambd - lambd_) ** 2 <= eps ** 2: + break + old_lambd = lambd_ + x = z + k += 1 + if (k - 1) == maxiter: + warnings.warn( + "Warning ----------- Power iteration method did not converge !") + return lambd_ + + +def bootleg_inv_power_cond_big(linop, size, binding, device, maxcond=500, maxiter=50): + lambda_max = power_it_ray(linop, size, binding, device) + thresh = lambda_max / maxcond + k = 0 + vp = [maxcond] + random = random_draw_np if binding == "numpy" else random_draw_torch + x = random(size, device) + while k <= maxiter: + x = cg(linop, x, binding)[0] + x = x / sqrt((x ** 2).sum()) + vp.append(x.T @ linop(x)) + if vp[k] <= thresh: + cond_too_big = True + break + if k >=1 and (vp[k]-vp[k-1]) ** 2 <= 1e-10: #cv + k = maxiter #exit + k += 1 + if (k - 1) == maxiter: + cond_too_big = False + return cond_too_big diff --git a/pykeops/numpy/operations.py b/pykeops/numpy/operations.py index da68a27f7..c0b3c2d18 100644 --- a/pykeops/numpy/operations.py +++ b/pykeops/numpy/operations.py @@ -7,6 +7,7 @@ from pykeops.common.utils import axis2cat from pykeops.numpy import default_dtype +from pykeops.common.cg import cg class KernelSolve: r""" @@ -135,7 +136,7 @@ def __init__(self, formula, aliases, varinvalias, axis=0, dtype=default_dtype, o varinvpos = tmp.index(varinvalias) self.varinvpos = varinvpos - def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, ranges=None): + def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, ranges=None, callback=None): r""" To apply the routine on arbitrary NumPy arrays. @@ -181,6 +182,9 @@ def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, r as we loop over all indices :math:`i\in[0,M)` and :math:`j\in[0,N)`. + callback (function, default=None): function of x called at the end of + each iteration of the conjugate gradient. + Returns: (M,D) or (N,D) array: @@ -203,4 +207,33 @@ def linop(var): res += alpha * var return res - return ConjugateGradientSolver('numpy', linop, varinv, eps=eps) + return ConjugateGradientSolver('numpy', linop, varinv, eps=eps, callback=callback) + + def cg(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=None, ranges=None, check_cond=False, callback=None): + r""" + Another version of the conjugate gradient. Args and keywords args are the same + as calling ``KernelSolve``. Only one keyword is added. + + Keyword Args: + check_cond (boolean, False by default): Indicates if the condition number + might be greater than 500. *Warning: setting it to True will + result in a more time-consuming method.* + + Returns: + A tuple containing the (M,D) or (N,D) array being the approximated + solution of the problem and the iteration number the algorithm stopped. + + """ + tagCpuGpu, tag1D2D, _ = get_tag_backend(backend, args) + varinv = args[self.varinvpos] + + if ranges is None: ranges = () # ranges should be encoded as a tuple + + def linop(var): + newargs = args[:self.varinvpos] + (var,) + args[self.varinvpos + 1:] + res = self.myconv.genred_numpy(tagCpuGpu, tag1D2D, 0, device_id, ranges, *newargs) + if alpha: + res += alpha * var + return res + + return cg(linop, varinv, 'numpy', eps=eps, callback=callback, check_cond=check_cond) diff --git a/pykeops/test/unit_tests_numpy.py b/pykeops/test/unit_tests_numpy.py index 8db03f8aa..d88a90910 100644 --- a/pykeops/test/unit_tests_numpy.py +++ b/pykeops/test/unit_tests_numpy.py @@ -302,6 +302,46 @@ def test_LazyTensor_sum(self): self.assertTrue(res_keops.shape == res_numpy.shape) self.assertTrue(np.allclose(res_keops, res_numpy, atol=1e-3)) + ############################################################ + def test_cg_dic(self): + ############################################################ + from pykeops.numpy import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + + K = Genred(formula, aliases, axis=1, dtype=self.type_to_test[1]) + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype=self.type_to_test[1]) + + ans = Kinv.cg(self.x, self.x, self.f, + self.sigma, alpha=self.sigma)[0] + err = ((self.sigma * ans + K(self.x, self.x, + ans, self.sigma) - self.f) ** 2).sum() + self.assertTrue(np.allclose(err, np.zeros(err.shape))) + + ############################################################ + def test_cg_call(self): + ############################################################ + from pykeops.numpy import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + + K = Genred(formula, aliases, axis=1, dtype=self.type_to_test[1]) + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype=self.type_to_test[1]) + + ans = Kinv(self.x, self.x, self.f, + self.sigma, alpha=self.sigma) + err = ((self.sigma * ans + K(self.x, self.x, + ans, self.sigma) - self.f) ** 2).sum() + self.assertTrue(np.allclose(err, np.zeros(err.shape))) + if __name__ == '__main__': unittest.main() diff --git a/pykeops/test/unit_tests_pytorch.py b/pykeops/test/unit_tests_pytorch.py index 3fc669f06..23fb14ed7 100644 --- a/pykeops/test/unit_tests_pytorch.py +++ b/pykeops/test/unit_tests_pytorch.py @@ -556,7 +556,41 @@ def invert_permutation_numpy(permutation): grad_torch = torch.autograd.grad(sum_f_torch2, y, e)[0] self.assertTrue(torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4)) - + ############################################################ + def test_cg_dic(self): + ############################################################ + from pykeops.torch import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + K = Genred(formula, aliases, axis=1, dtype="float32") + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype="float32") + ans = Kinv.cg(self.xc, self.xc, self.fc, + self.sigmac, alpha=self.sigmac)[0] + err = ((self.sigmac * ans + K(self.xc, self.xc, + ans, self.sigmac) - self.fc) ** 2).sum() + self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) + + ############################################################# + def test_cg(self): + ############################################################ + from pykeops.torch import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + K = Genred(formula, aliases, axis=1, dtype="float32") + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype="float32") + ans = Kinv(self.xc, self.xc, self.fc, + self.sigmac, alpha=self.sigmac) + err = ((self.sigmac * ans + K(self.xc, self.xc, + ans, self.sigmac) - self.fc) ** 2).sum() + self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) if __name__ == '__main__': """ diff --git a/pykeops/torch/__init__.py b/pykeops/torch/__init__.py index 506671ce4..4a760a5a6 100644 --- a/pykeops/torch/__init__.py +++ b/pykeops/torch/__init__.py @@ -33,7 +33,6 @@ from .generic.generic_ops import generic_sum, generic_logsumexp, generic_argmin, generic_argkmin from .kernel_product.formula import Formula from pykeops.common.lazy_tensor import LazyTensor, Vi, Vj, Pm - # N.B.: If "from pykeops.numpy import LazyTensor" has already been run, # the line above will *not* import "torchtools" and we'll end up with an error... # So even though it may be a bit ugly, we re-load the lazy_tensor file @@ -43,5 +42,5 @@ importlib.reload(pykeops.common.lazy_tensor) __all__ = sorted( - ["Genred", "generic_sum", "generic_logsumexp", "generic_argmin", "generic_argkmin", "Kernel", "kernel_product", + ["cg", "Genred", "generic_sum", "generic_logsumexp", "generic_argmin", "generic_argkmin", "Kernel", "kernel_product", "KernelSolve", "kernel_formulas", "Formula", "LazyTensor", "Vi", "Vj", "Pm"]) diff --git a/pykeops/torch/operations.py b/pykeops/torch/operations.py index b5931552c..a9acf36b6 100644 --- a/pykeops/torch/operations.py +++ b/pykeops/torch/operations.py @@ -8,6 +8,7 @@ from pykeops.torch import default_dtype from pykeops.torch import include_dirs from pykeops.torch.generic.generic_red import GenredAutograd +from pykeops.common.cg import cg class KernelSolveAutograd(torch.autograd.Function): @@ -47,7 +48,6 @@ def forward(ctx, formula, aliases, varinvpos, alpha, backend, dtype, device_id, for i in range(1,len(args)): if args[i].device.index != device_id: raise ValueError("[KeOps] Input arrays must be all located on the same device.") - def linop(var): newargs = args[:varinvpos] + (var,) + args[varinvpos+1:] res = myconv.genred_pytorch(tagCPUGPU, tag1D2D, tagHostDevice, device_id, ranges, *newargs) @@ -61,11 +61,11 @@ def linop(var): # relying on the 'ctx.saved_variables' attribute is necessary if you want to be able to differentiate the output # of the backward once again. It helps pytorch to keep track of 'who is who'. ctx.save_for_backward(*args, result) - return result @staticmethod def backward(ctx, G): + formula = ctx.formula aliases = ctx.aliases varinvpos = ctx.varinvpos @@ -77,7 +77,6 @@ def backward(ctx, G): myconv = ctx.myconv ranges = ctx.ranges accuracy_flags = ctx.accuracy_flags - args = ctx.saved_tensors[:-1] # Unwrap the saved variables nargs = len(args) result = ctx.saved_tensors[-1] @@ -105,7 +104,7 @@ def backward(ctx, G): if var_ind == varinvpos: grads.append(KinvG) else: - # adding new aliases is way too dangerous if we want to compute + #adding new aliases is way too dangerous if we want to compute # second derivatives, etc. So we make explicit references to Var instead. # New here (Joan) : we still add the new variables to the list of "aliases" (without giving new aliases for them) # these will not be used in the C++ code, @@ -136,8 +135,6 @@ def backward(ctx, G): # Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, *args return (None, None, None, None, None, None, None, None, None, None, *grads) - - class KernelSolve(): r""" Creates a new conjugate gradient solver. @@ -325,8 +322,120 @@ def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, r that is inferred from the **formula**. """ - return KernelSolveAutograd.apply(self.formula, self.aliases, self.varinvpos, alpha, backend, self.dtype, device_id, eps, ranges, self.accuracy_flags, *args) + def cg(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=None, check_cond=False, callback=None, ranges=None): + r""" + Same as calling ``KernelSolve``. The keyword argument `check_cond` being added. + + Keyword Args: + check_cond (boolean, default=False): Indicates if the condition number + **might be** greater than 500. *Warning: setting it to True will + result in a more time-consuming method.* + + Returns: + A tuple of tensors containing the (M,D) or (N,D) tensor being the approximated + solution of the problem and the iteration number the algorithm stopped. + + """ + return dic_KernelSolveAutograd.apply(self.formula, self.aliases, self.varinvpos, alpha, backend, self.dtype, device_id, eps, ranges, self.accuracy_flags, check_cond, callback, *args) + +class dic_KernelSolveAutograd(torch.autograd.Function): + + @staticmethod + def forward(ctx, formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *args): + + optional_flags = include_dirs + accuracy_flags + + myconv = LoadKeOps(formula, aliases, dtype, 'torch', + optional_flags).import_module() + + # Context variables: save everything to compute the gradient: + ctx.formula = formula + ctx.aliases = aliases + ctx.varinvpos = varinvpos + ctx.alpha = alpha + ctx.backend = backend + ctx.dtype = dtype + ctx.device_id = device_id + ctx.check_cond = check_cond + ctx.eps = eps + ctx.myconv = myconv + ctx.ranges = ranges + ctx.callback = callback + ctx.accuracy_flags = accuracy_flags + if ranges is None: ranges = () # To keep the same type + + varinv = args[varinvpos] + ctx.varinvpos = varinvpos + + tagCPUGPU, tag1D2D, tagHostDevice = get_tag_backend(backend, args) + + if tagCPUGPU==1 & tagHostDevice==1: + device_id = args[0].device.index + for i in range(1,len(args)): + if args[i].device.index != device_id: + raise ValueError("[KeOps] Input arrays must be all located on the same device.") + + def linop(var): + newargs = args[:varinvpos] + (var,) + args[varinvpos+1:] + res = myconv.genred_pytorch(tagCPUGPU, tag1D2D, tagHostDevice, device_id, ranges, *newargs) + if alpha: + res += alpha*var + return res + global copy + + result, iter_ = cg(linop, varinv.data, 'torch', eps=eps, check_cond=check_cond, callback=callback) + ctx.save_for_backward(*args, result) + + return result, torch.as_tensor(iter_) + + @staticmethod + def backward(ctx, G, G2): + + formula = ctx.formula + aliases = ctx.aliases + varinvpos = ctx.varinvpos + backend = ctx.backend + alpha = ctx.alpha + dtype = ctx.dtype + device_id = ctx.device_id + eps = ctx.eps + myconv = ctx.myconv + ranges = ctx.ranges + accuracy_flags = ctx.accuracy_flags + check_cond = ctx.check_cond + callback = ctx.callback + args = ctx.saved_tensors[:-1] # Unwrap the saved variables + nargs = len(args) + result = ctx.saved_tensors[-1] + eta = 'Var(' + str(nargs) + ',' + str(myconv.dimout) + ',' + str(myconv.tagIJ) + ')' + + # there is also a new variable for the formula's output + resvar = 'Var(' + str(nargs+1) + ',' + str(myconv.dimout) + ',' + str(myconv.tagIJ) + ')' + newargs = args[:varinvpos] + (G,) + args[varinvpos+1:] + KinvG = dic_KernelSolveAutograd.apply(formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *newargs) + grads = [] # list of gradients wrt. args; + for (var_ind, sig) in enumerate(aliases): + if not ctx.needs_input_grad[var_ind + 12]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags) + grads.append(None) # Don't waste time computing it. + else: + if var_ind == varinvpos: + grads.append(KinvG) + else: + _, cat, dim, pos = get_type(sig, position_in_list=var_ind) + var = 'Var(' + str(pos) + ',' + str(dim) + ',' + str(cat) + ')' # V + formula_g = 'Grad_WithSavedForward(' + formula + ', ' + var + ', ' + eta + ', ' + resvar + ')' # Grad + aliases_g = aliases + [eta, resvar] + args_g = args[:varinvpos] + (result,) + args[varinvpos+1:] + (-KinvG[0],) + (result,) + genconv = GenredAutograd().apply + if cat == 2: + grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) + grad = torch.ones(1, grad.shape[0]).type_as(grad.data) @ grad + grad = grad.view(-1) + else: + grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) + grads.append(grad) + return (None, None, None, None, None, None, None, None, None, None, None, None, *grads) \ No newline at end of file