From 3dfa020627d81bdf493b5baaccfffd08d156ac6e Mon Sep 17 00:00:00 2001 From: Tanguy Lefort Date: Tue, 28 Jul 2020 14:56:06 +0200 Subject: [PATCH 1/4] Conjugate Gradient Features --- doc/conf.py | 5 + pykeops/benchmarks/plot_benchmark_cg.py | 523 ++++++++++++++++++ .../plot_benchmark_cg_dimensions.py | 400 ++++++++++++++ .../benchmarks/plot_benchmark_invkernel.py | 11 +- pykeops/common/cg.py | 148 +++++ pykeops/common/operations.py | 16 +- pykeops/common/power_iteration.py | 80 +++ pykeops/numpy/operations.py | 37 +- pykeops/test/unit_tests_numpy.py | 40 ++ pykeops/test/unit_tests_pytorch.py | 36 +- pykeops/torch/__init__.py | 3 +- pykeops/torch/operations.py | 126 ++++- 12 files changed, 1405 insertions(+), 20 deletions(-) create mode 100644 pykeops/benchmarks/plot_benchmark_cg.py create mode 100644 pykeops/benchmarks/plot_benchmark_cg_dimensions.py create mode 100644 pykeops/common/cg.py create mode 100644 pykeops/common/power_iteration.py diff --git a/doc/conf.py b/doc/conf.py index 29b5fd19b..72b41e7a3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -114,6 +114,11 @@ def skip(app, what, name, obj, would_skip, options): def setup(app): app.connect("autodoc-skip-member", skip) +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, + message='Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.') + # Include the example source for plots in API docs # plot_include_source = True # plot_formats = [("png", 90)] diff --git a/pykeops/benchmarks/plot_benchmark_cg.py b/pykeops/benchmarks/plot_benchmark_cg.py new file mode 100644 index 000000000..d11544a63 --- /dev/null +++ b/pykeops/benchmarks/plot_benchmark_cg.py @@ -0,0 +1,523 @@ +""" +Comparison of conjugate gradient methods +========================================== + +Different implementations of the conjugate gradient (CG) exist. Here, we compare the CG implemented in scipy which uses Fortran +against it's pythonized version and the older version of the algorithm available in pykeops. + +We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b, x \in \mathbb R^N`. + +Let the Gaussian RBF kernel be defined as + +.. math:: + + K_{x,x}=\left[\exp(-\gamma \|x_i - x_j\|^2)\\right]_{i,j=1}^N. + + +Choosing :math:`x` such that :math:`x_i = i/N,\ i=1,\dots, N` makes :math:`K_{x,x}` be a highly unwell-conditioned matrix for :math:`N\geq 10`. + +""" + +############################# +# Setup +# ---------- +# Imports needed + +import importlib +import os +import time +import inspect + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from scipy.sparse import diags +from scipy.sparse.linalg import aslinearoperator, cg + +from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor +from pykeops.torch import KernelSolve +from pykeops.torch.utils import squared_distances +from pykeops.torch import Genred as Genred_tch +from pykeops.numpy import Vi, Vj, Pm +from pykeops.numpy import Genred as Genred_np + +use_cuda = torch.cuda.is_available() + +device = torch.device("cuda") if use_cuda else torch.device("cpu") +print("The device used is {}.".format(device)) + +######################################## +# Gaussian radial basis function kernel +######################################## + +formula = 'Exp(- g * SqDist(x,y)) * a' # linear w.r.t a +aliases = ['x = Vi(1)', # First arg: i-variable of size 1 + 'y = Vj(1)', # Second arg: j-variable of size 1 + 'a = Vj(1)', # Third arg: j-variable of size 1 + 'g = Pm(1)'] + + +############################ +# Functions to benchmark +########################### +# +# All systems are regularized with a ridge parameter ``alpha``. +# +# The originals : +# + + +def keops_tch(x, b, gamma, alpha): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha) + return res + + +def keops_np(x, b, gamma, alpha, callback=None): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha, callback=callback) + return res + + +#################################### +# Scipy : +# +# + + +def scipy_cg(x, b, gamma, alpha, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + res = cg(A, b, callback=callback) + return res + + +#################################### +# Pythonized scipy : +# + + +def dic_cg_np(x, b, gamma, alpha, callback=None, check_cond=False): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, + callback=callback, check_cond=check_cond) + return ans + + +def dic_cg_tch(x, b, gamma, alpha, check_cond=False): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, check_cond=check_cond) + return ans + + +######################### +# Benchmarking +######################### + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +sizes = [50, 100, 500, 1000, 5000, 20000, 40000] +reps = [50 , 50 , 50, 10, 10, 5, 5] + + +def compute_error(func, pack, result, errors, x, b, alpha, gamma): + if str(func)[10:15] == "keops": + code = "a = func(x, b, gamma, alpha).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, size, rep): + global use_cuda + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + for i in range(rep): + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes)): + try: + time, err = to_bench(one_to_bench, sizes[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + + +######################################### +# Plot the results of the benchmarking +######################################### + +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.tight_layout() +plt.show() + + +############################################## +# Stability +# ------------ +# +# Stability of the errors and norm of the iterated approximations of the answer + + +def norm_stability(size, funcpack): + errk_scipy, iter_scipy, x_scipy = [], [], [] + errk_dic, iter_dic, x_dic = [], [], [] + errk_keops, iter_keops, x_keops = [], [], [] + + def callback_sci(xk): + env = inspect.currentframe().f_back + iter_scipy.append(env.f_locals['iter_']) + x_scipy.append(env.f_locals['x']) + err = ( ( alpha * xk.reshape(-1, 1) + K(x, x, xk.reshape(-1, 1), gamma) - b) ** 2).sum() + errk_scipy.append(err) + + def callback_kinv_keops(xk): + env = inspect.currentframe().f_back + err = ( ( alpha * xk + K(x, x, xk, gamma) - b) ** 2).sum() + errk_keops.append(err) + iter_keops.append(env.f_locals['k']) + x_keops.append(env.f_locals['a']) + + def callback_dic(xk): + env = inspect.currentframe().f_back + err = ( ( alpha * xk + K(x, x, xk, gamma) - b) ** 2).sum() + errk_dic.append(err) + iter_dic.append(env.f_locals['iter_']) + x_dic.append(env.f_locals['x']) + + callback_list = [callback_sci, callback_kinv_keops, callback_dic] + + for i, funcpack in enumerate(funcpack): + fun, pack = funcpack + + global x, b, gamma, alpha, K + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + fun(x, b, gamma, alpha, callback=callback_list[i]) + del x, b, gamma, alpha, K + return errk_scipy, iter_scipy, x_scipy, errk_dic, iter_dic, x_keops, errk_keops, iter_keops, x_dic + + +######################################### +# Plot the results of the stability +######################################### + +onlynum = [(scipy_cg, "numpy"), (keops_np, "numpy"), (dic_cg_np, "numpy")] +errk_scipy, iter_scipy, x_scipy, errk_dic, iter_dic,\ + x_keops, errk_keops, iter_keops, x_dic = norm_stability( + 1000, onlynum) + +scal_dic, scal_keops, scal_scipy = [], [], [] +for i in range(1,len(iter_dic)): + scal_dic.append((x_dic[i-1].T @ x_dic[i]).flatten()) +for i in range(1, len(iter_keops)): + scal_keops.append((x_keops[i-1].T @ x_keops[i]).flatten()) +for i in range(1, len(iter_scipy)): + scal_scipy.append((x_scipy[i-1].T @ x_scipy[i]).flatten()) + +plt.figure(figsize=(20,10)) +plt.subplot(121) +plt.plot(iter_keops, errk_keops, 'o-', label=labels[1]) +plt.plot(iter_scipy, errk_scipy, '^-', label=labels[0]) +plt.plot(iter_dic, errk_dic, 'x-', label=labels[3]) +plt.yscale('log') +plt.xlabel(r"Iteration k") +plt.ylabel(r"$||(\alpha\ Id + K_{x,x})x_k - b||^2$") +plt.legend() + +plt.subplot(122) +plt.plot(iter_keops[1:], scal_keops, 'o-', label=labels[1]) +plt.plot(iter_scipy[1:], scal_scipy, '^-', label=labels[0]) +plt.plot(iter_dic[1:], scal_dic, 'x-', label=labels[3]) +plt.yscale('log') +plt.xlabel(r"Iteration k") +plt.ylabel(r"$\langle x_{k-1}|x_k\rangle $") +plt.legend() + +plt.tight_layout() +plt.show() + +####################################################### +# Condition number check +# ------------------------------------- +# +# Scipy's algorithm can't be used practically for large kernels in this case. The condition number can be why. +# +# +# The argument ``check_cond`` in Keops lets the user have an idea of the conditioning number of the matrix :math:`A=(K_{x,x} + \alpha Id)`. A warning appears +# if :math:`\mathrm{cond}(A)>500`. The user is also warned if the CG algorithm reached its maximum number of iterations *ie* did not converge. The idea here +# is not to estimate the condition number and let the user have another sanity check at disposal. +# +# To test the condition number :math:`\mathrm{cond}(A)=\frac{\lambda_{\max}}{\lambda_{\min}}`, we first use the +# power iteration to have a good estimation of :math:`\lambda_{\max}`. Then, wee apply the inverse power iteration +# to obtain the iterations :math:`\mu_k` of the estimated :math:`\lambda_{\min}` using the Rayleigh's quotient after having the iterations :math:`u_k` +# of the estimated eigen vector :math:`u_1`. The distance between the vectors :math:`v_k` and :math:`u_1` decreasing over the iterations at a rate of +# :math:`\mathcal{O}\left(\left|\frac{\lambda_{\min}}{\lambda_{submin}}\right|^k\right)`, if we don't want +# :math:`\frac{\lambda_{\max}}{\lambda_{\min}}>500` then :math:`\mu_k` must not be below the threshold :math:`\frac{\lambda_{\max}}{500}` +# If so, the system warns the user that the condition number might be too high. +# +# In practice only a few iterations are necessary to go below this threshold. Thus we fixed a maximum number of iterations for the inverse +# power method to ``50`` so that for large matrices it doesn't take too much time. + + +def test_cond(device, size, pack, alpha): + if device == 'cuda': + torch.cuda.manual_seed_all(1234) + else: + torch.manual_seed(1234) + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones(1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + alpha = torch.ones(1, device=device, dtype=torch.float32) * alpha # regularization + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + ans = dic_cg_np(x, b, gamma, alpha, check_cond=True) + else: + ans = dic_cg_tch(x, b, gamma, alpha, check_cond=True) + return ans + + +print("Condition number warnings tests") +print("Small matrix well conditioned (nothing should appear)") +ans = test_cond(device, 20, 'numpy', alpha=1) +print("Large matrix unwell conditioned (a warning should appear)") +ans2 = test_cond(device, 1000, 'numpy', alpha=1e-6) +print("Large matrix unwell conditioned but with a large regularization (nothing should appear)") +ans3 = test_cond(device, 1000, 'numpy', alpha=100) + + +########################## +# Zoom in on Keops times +############################ +# +# Let's consider the Keops conjugate gradients for large kernels. Scipy's algorithm explodes in time for +# :math:`n\geq 50000` so we only consider the keops implementations here. +# + + +functions = functions[1:] +sizes = [30000, 50000, 100000, 200000] +reps = [5, 5, 5, 2] +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.show() + +########################### +# Random points +########################### +# +# Let's now use random values for :math:`x_i`. + +def to_bench(funcpack, size, rep): + global use_cuda + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + for i in range(rep): + + x = torch.linspace(1/size, 1, size, dtype=torch.float32, + device=device).reshape(size, 1) + b = torch.randn(size, 1, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma) + + return sum(times)/rep, sum(errors)/rep + + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + + +sizes = [50, 100, 500, 1000, 5000, 20000, 40000] +reps = [50 , 50 , 50, 10, 10, 5, 5] + +list_times, list_errors = global_bench(functions, sizes, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel of size $n\times n$") +plt.ylabel(r"Error $||Ax_{k_{end}} -b||^2$") +plt.legend() +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/pykeops/benchmarks/plot_benchmark_cg_dimensions.py b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py new file mode 100644 index 000000000..ae890fd32 --- /dev/null +++ b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py @@ -0,0 +1,400 @@ +""" +Conjugate gradient method in arbitrary dimension +==================================================== + +Different implementations of the conjugate gradient (CG) exist. Here, we compare the CG implemented in scipy which uses Fortran +against it's pythonized version and the older version of the algorithm available in pykeops. + +We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b\in \mathbb R^N` and :math:`x\in\mathbb R^{N\times d}`. +We will use :math:`N=100000` points. + +Let the Gaussian RBF kernel be defined as + +.. math:: + + K_{x,x}=\left[\exp(-\gamma \|x_i - x_j\|^2)\\right]_{i,j=1}^N. + + +The case where :math:`d=1` is already benchmarked in a very ill-conditioned situation, now let's compare when :math:`d` increases. +""" + +############################# +# Setup +# ---------- +# Imports needed + +import importlib +import os +import time +import inspect + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from scipy.sparse import diags +from scipy.sparse.linalg import aslinearoperator, cg + +from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor +from pykeops.torch import KernelSolve +from pykeops.torch.utils import squared_distances +from pykeops.torch import Genred as Genred_tch +from pykeops.numpy import Vi, Vj, Pm +from pykeops.numpy import Genred as Genred_np + +use_cuda = torch.cuda.is_available() + +device = torch.device("cuda") if use_cuda else torch.device("cpu") +print("The device used is {}.".format(device)) + +######################################## +# Gaussian radial basis function kernel +######################################## + +n = 100000 +dv = 1 # number of systems to solve +formula = 'Exp(- g * SqDist(x,y)) * a' # linear w.r.t a + + +############################ +# Functions to benchmark +########################### +# +# All systems are regularized with a ridge parameter ``alpha``. +# +# The originals : +# + + +def keops_tch(x, b, gamma, alpha, aliases, callback=None): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha) + return res + + +def keops_np(x, b, gamma, alpha, aliases, callback=None): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + res = Kinv(x, x, b, gamma, alpha=alpha, callback=callback) + return res + + +#################################### +# Scipy : +# +# + + +def scipy_cg(x, b, gamma, alpha, aliases, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + res = cg(A, b, callback=callback) + return res + + +#################################### +# Pythonized scipy : +# + + +def dic_cg_np(x, b, gamma, alpha, aliases, callback=None, check_cond=False): + Kinv = KernelSolve_np(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, + callback=callback, check_cond=check_cond) + return ans + + +def dic_cg_tch(x, b, gamma, alpha, aliases, check_cond=False): + Kinv = KernelSolve(formula, aliases, "a", axis=1, dtype='float32') + ans = Kinv.cg(x, x, b, gamma, alpha=alpha, check_cond=check_cond) + return ans + + +######################### +# Benchmarking +######################### + +functions = [(scipy_cg, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +sizes_d = [10, 50, 75, 100, 150] # dimension of each point +reps = [5, 5, 5, 5, 5] + + +def compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases): + if str(func)[10:15] == "keops": + code = "a = func(x, b, gamma, alpha, aliases).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha, aliases)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, d, rep): + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha, aliases)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + + aliases = ['x = Vi(' + str(d) + ')', # First arg: i-variable of size d + 'y = Vj(' + str(d) + ')', # Second arg: j-variable of size d + 'a = Vj(' + str(dv) + ')', # Third arg: j-variable of size dv + 'g = Pm(1)'] + + for i in range(rep): + + x = torch.rand(n, d, device=device, dtype=torch.float32) + b = torch.randn(n, dv, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha, aliases) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes_d, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes_d)): + try: + time, err = to_bench(one_to_bench, sizes_d[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes_d[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + + +######################################### +# Plot the results of the benchmarking +######################################### + +list_times, list_errors = global_bench(functions, sizes_d, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes_d, list_times[i], label=labels[i]) +plt.xscale('log') +plt.ylim((1e-1, 1e2)) +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} system.".format(n, 'd', dv)) +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes_d, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.ylim((1e-10, 1e-7)) +plt.xlabel(r"Kernel made from {} points of size {} solving {} system.".format(n, 'd', dv)) +plt.ylabel(r"Error $||Ax_{k_{end}} - b||^2$") +plt.legend() +plt.tight_layout() +plt.show() + + +########################################## +# Changing the number of systems to solve +########################################## +# Let's consider the case where :math:`b\in\mathbb R^{dv}`. Then we need to solve multiple systems at once. + +n = 100000 +d = 4 # an image in RGBA for example + + +############################# +# Prepare the benchmarking : +# We need to modify the code for scipy's solver because scipy only solves one system at a time. + + +def scipy_cg_multi(x, b, gamma, alpha, aliases, callback=None): + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + ans = np.zeros(b.shape).astype('float32') + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + A.dtype = np.dtype('float32') + for i in range(b.shape[1]): + res = cg(A, b[:, i], callback=callback) + ans[:, i] = res[0].flatten() + return ans + +def compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases): + if str(func)[10:15] == "keops" or str(func)[10:15] == "scipy": + code = "a = func(x, b, gamma, alpha, aliases).reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + else: + code = "a = func(x, b, gamma, alpha, aliases)[0].reshape(b.shape);\ + err = ( (alpha * a + K(x, x, a, gamma) - b) ** 2).sum();\ + errors.append(err);" + + if pack == 'numpy': + K = Genred_np(formula, aliases, axis=1, dtype='float32') + else: + K = Genred_tch(formula, aliases, axis=1, dtype='float32') + + exec(code, locals()) + return errors + + +def to_bench(funcpack, dv, rep): + importlib.reload(torch) + if device == 'cuda': + torch.cuda.manual_seed_all(112358) + else: + torch.manual_seed(112358) + code = "func(x, b, gamma, alpha, aliases)" + func, pack = funcpack + + times = [] + errors = [] + + if use_cuda: + torch.cuda.synchronize() + + aliases = ['x = Vi(' + str(d) + ')', # First arg: i-variable of size d + 'y = Vj(' + str(d) + ')', # Second arg: j-variable of size d + 'a = Vj(' + str(dv) + ')', # Third arg: j-variable of size dv + 'g = Pm(1)'] + + for i in range(rep): + + x = torch.rand(n, d, device=device, dtype=torch.float32) + b = torch.randn(n, dv, device=device, dtype=torch.float32) + # kernel bandwidth + gamma = torch.ones( + 1, device=device, dtype=torch.float32) * .5 / .01 ** 2 + # regularization + alpha = torch.ones(1, device=device, dtype=torch.float32) * 2 + + if pack == 'numpy': + x, b = x.cpu().numpy().astype("float32"), b.cpu().numpy().astype("float32") + gamma, alpha = gamma.cpu().numpy().astype( + "float32"), alpha.cpu().numpy().astype("float32") + + if i == 0: + exec(code, locals()) # Warmup run, to compile and load everything + + start = time.perf_counter() + result = func(x, b, gamma, alpha, aliases) + if use_cuda: + torch.cuda.synchronize() + + times.append(time.perf_counter() - start) + errors = compute_error(func, pack, result, errors, x, b, alpha, gamma, aliases) + + return sum(times)/rep, sum(errors)/rep + + +def global_bench(functions, sizes_dv, reps): + list_times = [[] for _ in range(len(functions))] + list_errors = [[] for _ in range(len(functions))] + + for j, one_to_bench in enumerate(functions): + print("~~~~~~~~~~~~~Benchmarking {}~~~~~~~~~~~~~~.".format(one_to_bench)) + for i in range(len(sizes_d)): + try: + time, err = to_bench(one_to_bench, sizes_dv[i], reps[i]) + list_times[j].append(time) + list_errors[j].append(err) + except: + while len(list_times[j]) != len(reps): + list_times[j].append(np.nan) + list_errors[j].append(np.nan) + break + print("Finished size {}.".format(sizes_dv[i])) + + print("Finished", one_to_bench[0], "in a cumulated time of {:3.9f}s.".format( + sum(list_times[j]))) + + return list_times, list_errors + +################################################################ +# Plot the results of the benchmarking for multi-system solver +################################################################## + +sizes_dv = [1, 5, 10, 50, 100, 150] +reps = [5, 5, 5, 5, 5, 5] + +functions = [(scipy_cg_multi, "numpy"), + (keops_np, "numpy"), (keops_tch, "torch"), + (dic_cg_np, "numpy"), (dic_cg_tch, "torch")] + +list_times, list_errors = global_bench(functions, sizes_dv, reps) +labels = ["scipy + keops", "keops_np", "keops_tch", + "dico + keops_np", "dico + keops_tch"] + +plt.style.use('ggplot') +plt.figure(figsize=(20,10)) +plt.subplot(121) +for i in range(len(functions)): + plt.plot(sizes_d, list_times[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} systems.".format(n, d, 'dv')) +plt.ylabel("Computational time (s)") +plt.legend() +plt.subplot(122) +for i in range(len(functions)): + plt.plot(sizes_d, list_errors[i], label=labels[i]) +plt.xscale('log') +plt.yscale('log') +plt.xlabel(r"Kernel made from {} points of size {} solving {} systems.".format(n, d, 'dv')) +plt.ylabel(r"Error $\sum_{i,j}\left((Ax_{k_{end}} - b)_{i,j}\right)^2$") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/pykeops/benchmarks/plot_benchmark_invkernel.py b/pykeops/benchmarks/plot_benchmark_invkernel.py index f7a6edca4..f2ceed0fd 100644 --- a/pykeops/benchmarks/plot_benchmark_invkernel.py +++ b/pykeops/benchmarks/plot_benchmark_invkernel.py @@ -30,11 +30,12 @@ from scipy.sparse import diags from scipy.sparse.linalg import aslinearoperator, cg -from scipy.sparse.linalg.interface import IdentityOperator from pykeops.numpy import KernelSolve as KernelSolve_np, LazyTensor from pykeops.torch import KernelSolve from pykeops.torch.utils import squared_distances +from pykeops.numpy import Vi, Vj, Pm + use_cuda = torch.cuda.is_available() @@ -71,7 +72,7 @@ def generate_samples(N, device, lang): x = torch.rand(N, D, device=device) b = torch.randn(N, Dv, device=device) gamma = torch.ones(1, device=device) * .5 / .01 ** 2 # kernel bandwidth - alpha = torch.ones(1, device=device) * 0.8 # regularization + alpha = torch.ones(1, device=device) * 2 # regularization else: np.random.seed(1234) @@ -116,9 +117,9 @@ def Kinv_keops_numpy(x, b, gamma, alpha): return res def Kinv_scipy(x, b, gamma, alpha): - x_i, y_j = LazyTensor( gamma * x[:, None, :]), LazyTensor( gamma * x[None, :, :]) - K_ij = (- ((x_i - y_j) ** 2).sum(2)).exp() - A = aslinearoperator(diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) + K_ij = (-Pm(gamma) * Vi(x).sqdist(Vj(x))).exp() + A = aslinearoperator( + diags(alpha * np.ones(x.shape[0]))) + aslinearoperator(K_ij) A.dtype = np.dtype('float32') res = cg(A, b) return res diff --git a/pykeops/common/cg.py b/pykeops/common/cg.py new file mode 100644 index 000000000..0fc99ba2e --- /dev/null +++ b/pykeops/common/cg.py @@ -0,0 +1,148 @@ +import torch +from pykeops.common.utils import get_tools +from math import sqrt +import warnings + + +############################################# +# CG_revcom with Python dictionnary +############################################# + +def cg(linop, b, binding, x=None, eps=None, maxiter=None, callback=None, check_cond=False): + if binding not in ("torch", "numpy", "pytorch"): + raise ValueError( + "Language not supported, please use numpy, torch or pytorch.") + + tools = get_tools(binding) + + # we don't need cuda with numpy (at least i think so) + is_cuda = True if (binding == 'torch' or binding == + 'pytorch') and torch.cuda.is_available() else False + device = torch.device("cuda") if is_cuda else torch.device('cpu') + + b, x, replaced = check_dims(b, x, tools, is_cuda) + n, m = b.shape + + if eps == None: + eps = 1e-6 * sqrt((b ** 2).sum()) + + if maxiter == None: + maxiter = 10 * n + + if check_cond: + from pykeops.common.power_iteration import bootleg_inv_power_cond_big as cond_big + cond_too_big = cond_big(linop, n, binding, device) + if cond_too_big: + warnings.warn( + "Warning ----------- Condition number might be too large.") + + # define the functions needed along the iterations + if binding == "numpy": + p, q, r = tools.zeros((n, m), dtype=b.dtype), tools.zeros( + (n, m), dtype=b.dtype), tools.zeros((n, m), dtype=b.dtype) + scal1, scal2 = tools.zeros(1, dtype=b.dtype), tools.zeros( + 1, dtype=b.dtype) # init the scala values + + else: + p, q, r = tools.zeros((n, m), dtype=b.dtype, device=device), tools.zeros( + (n, m), dtype=b.dtype, device=device), tools.zeros((n, m), dtype=b.dtype, device=device) + scal1, scal2 = tools.zeros(1, dtype=b.dtype, device=device), tools.zeros( + 1, dtype=b.dtype, device=device) # init the scala values + + def init_iter(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + r = tools.copy(b) if replaced else (b - linop(x)) + scal1 = (r ** 2).sum() + job_cg = "check" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def check_resid(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # cg -> revc + if scal1 <= eps**2 or scal1 != scal1: + job_rev = "stop" + else: + iter_ += 1 + job_rev = "direction_next" if iter_ > 1 else "direction_first" + return job_rev, x, r, p, q, scal1, scal2, iter_ + + def first_direct(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + p = tools.copy(r) + job_cg = "matvec_p" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def matvec_p(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # cg -> revc + q = linop(p) + job_rev = "update" + return job_rev, x, r, p, q, scal1, scal2, iter_ + + def update(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + alpha = scal1 / (p * q).sum() + x += alpha * p + r -= alpha * q + scal2 = scal1 + job_cg = "check" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + def next_direct(linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_): # revc -> cg + scal1 = (r ** 2).sum() + p = r + (scal1 / scal2) * p + job_cg = "matvec_p" + return job_cg, x, r, p, q, scal1, scal2, iter_ + + jobs_cg = {"matvec_p": matvec_p, + "check": check_resid + } + + jobs_revcom = { + "init": init_iter, + "update": update, + "direction_first": first_direct, + "direction_next": next_direct + } + + iter_ = 0 + job_rev = "init" + job_cg = None + + while iter_ <=maxiter: + if job_cg == "check" and callback is not None: + if iter_ > 1: + callback(x) + job_cg, x, r, p, q, scal1, scal2, iter_ = jobs_revcom[job_rev]( + linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_) + job_rev, x, r, p, q, scal1, scal2, iter_ = jobs_cg[job_cg]( + linop, x, r, p, q, b, scal1, scal2, eps, replaced, iter_) + + if job_rev == "stop": + break + if (iter_ - 1) == maxiter: + warnings.warn("Warning ----------- Conjugate gradient reached maximum iteration !") + + return x, iter_ + + +#################################### +# Sanity checks +#################################### + +def check_dims(b, x, tools, cuda_avlb): # x is always of b's shape + try: + nrow, ncol = b.shape + except ValueError: + b = b.reshape(-1, 1) + nrow, ncol = b.shape + + x_replaced = False + + if x is None: # check x shape and initiate it if needed + x = tools.zeros((nrow, ncol), dtype=b.dtype, device=torch.device('cuda')) if cuda_avlb \ + else tools.zeros((nrow, ncol), dtype=b.dtype) + x_replaced = True + elif (nrow, ncol) != x.shape: # add sth to check if x is on the same device as b if torch is used! + if x.shape == (nrow,): + x = x.reshape((nrow, ncol)) + else: + raise ValueError("Mismatch between shapes of b {} and shape of x {}.".format( + (nrow, nrow), x.shape)) + if x.dtype != b.dtype: + raise ValueError("Type of given x {} is not compatible with type of b {}.".format(x.dtype, b.dtype)) + + return b, x, x_replaced diff --git a/pykeops/common/operations.py b/pykeops/common/operations.py index 95a302921..42b44af86 100644 --- a/pykeops/common/operations.py +++ b/pykeops/common/operations.py @@ -1,4 +1,5 @@ import numpy as np +import warnings from pykeops.common.utils import get_tools @@ -72,20 +73,25 @@ def postprocess(out, binding, reduction_op, nout, opt_arg, dtype): return out -def ConjugateGradientSolver(binding, linop, b, eps=1e-6): +def ConjugateGradientSolver(binding, linop, b, eps=1e-6, callback=None, maxiter=None): # Conjugate gradient algorithm to solve linear system of the form # Ma=b where linop is a linear operation corresponding # to a symmetric and positive definite matrix + if binding not in ("torch", "numpy", "pytorch"): + raise ValueError( + "Language not supported, please use numpy, torch or pytorch.") tools = get_tools(binding) delta = tools.size(b) * eps ** 2 + if maxiter == None: + maxiter = 10 * tools.size(b) a = 0 r = tools.copy(b) nr2 = (r ** 2).sum() if nr2 < delta: return 0 * r p = tools.copy(r) - k = 0 - while True: + k = 1 + while k <= maxiter: Mp = linop(p) alp = nr2 / (p * Mp).sum() a += alp * p @@ -96,6 +102,10 @@ def ConjugateGradientSolver(binding, linop, b, eps=1e-6): p = r + (nr2new / nr2) * p nr2 = nr2new k += 1 + if callback is not None: + callback(a) + if k == maxiter: + warnings.warn("Warning ----------- Conjugate gradient reached maximum iteration !") return a diff --git a/pykeops/common/power_iteration.py b/pykeops/common/power_iteration.py new file mode 100644 index 000000000..dea59be5c --- /dev/null +++ b/pykeops/common/power_iteration.py @@ -0,0 +1,80 @@ +import torch +import numpy as np +import warnings +from math import sqrt + +from pykeops.common.utils import get_tools +from pykeops.common.cg import cg + +##################### +# Power iteration +##################### + + +def random_draw_np(size, device, dtype='float32'): + return np.random.rand(size, 1).astype(dtype) + + +def random_draw_torch(size, device, dtype=torch.float32): + return torch.rand(size, 1, device=device, dtype=dtype) + + +def power_it_ray(linop, size, binding, device, eps=1e-6): + r""" Compute the eigenvalue of maximum magnitude for a linear operator. + + Args: + linop: a linear operator that, when called, computes the matrix-vector product. + size (int): dimension of the linear operator. + binding (string): torch, pytorch or numpy. + device (torch.device): use GPU or CPU, ``torch.device('cuda')`` or ``torch.device("cpu")``* + for example. + + Keyword Args: + eps (float, default=1e-6): precision for the acceptable distance between two iterates of + the eigenvalue. + + Returns: + lambd_ (float): the eigenvalue of maximum magnitude for ``linop``. + A warning is displayed if the algorithm didn't converge. + """ + random = random_draw_np if binding == "numpy" else random_draw_torch + x = random(size, device) + x = x / sqrt((x ** 2).sum()) + maxiter = 10 * size + k = 0 + while k <= maxiter: + y = linop(x) + norm_y = sqrt((y ** 2).sum()) + z = y / norm_y + lambd_ = (z.T @ linop(z)) + if k > 0 and (old_lambd - lambd_) ** 2 <= eps ** 2: + break + old_lambd = lambd_ + x = z + k += 1 + if (k - 1) == maxiter: + warnings.warn( + "Warning ----------- Power iteration method did not converge !") + return lambd_ + + +def bootleg_inv_power_cond_big(linop, size, binding, device, maxcond=500, maxiter=50): + lambda_max = power_it_ray(linop, size, binding, device) + thresh = lambda_max / maxcond + k = 0 + vp = [maxcond] + random = random_draw_np if binding == "numpy" else random_draw_torch + x = random(size, device) + while k <= maxiter: + x = cg(linop, x, binding)[0] + x = x / sqrt((x ** 2).sum()) + vp.append(x.T @ linop(x)) + if vp[k] <= thresh: + cond_too_big = True + break + if k >=1 and (vp[k]-vp[k-1]) ** 2 <= 1e-10: #cv + k = maxiter #exit + k += 1 + if (k - 1) == maxiter: + cond_too_big = False + return cond_too_big diff --git a/pykeops/numpy/operations.py b/pykeops/numpy/operations.py index da68a27f7..c0b3c2d18 100644 --- a/pykeops/numpy/operations.py +++ b/pykeops/numpy/operations.py @@ -7,6 +7,7 @@ from pykeops.common.utils import axis2cat from pykeops.numpy import default_dtype +from pykeops.common.cg import cg class KernelSolve: r""" @@ -135,7 +136,7 @@ def __init__(self, formula, aliases, varinvalias, axis=0, dtype=default_dtype, o varinvpos = tmp.index(varinvalias) self.varinvpos = varinvpos - def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, ranges=None): + def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, ranges=None, callback=None): r""" To apply the routine on arbitrary NumPy arrays. @@ -181,6 +182,9 @@ def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, r as we loop over all indices :math:`i\in[0,M)` and :math:`j\in[0,N)`. + callback (function, default=None): function of x called at the end of + each iteration of the conjugate gradient. + Returns: (M,D) or (N,D) array: @@ -203,4 +207,33 @@ def linop(var): res += alpha * var return res - return ConjugateGradientSolver('numpy', linop, varinv, eps=eps) + return ConjugateGradientSolver('numpy', linop, varinv, eps=eps, callback=callback) + + def cg(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=None, ranges=None, check_cond=False, callback=None): + r""" + Another version of the conjugate gradient. Args and keywords args are the same + as calling ``KernelSolve``. Only one keyword is added. + + Keyword Args: + check_cond (boolean, False by default): Indicates if the condition number + might be greater than 500. *Warning: setting it to True will + result in a more time-consuming method.* + + Returns: + A tuple containing the (M,D) or (N,D) array being the approximated + solution of the problem and the iteration number the algorithm stopped. + + """ + tagCpuGpu, tag1D2D, _ = get_tag_backend(backend, args) + varinv = args[self.varinvpos] + + if ranges is None: ranges = () # ranges should be encoded as a tuple + + def linop(var): + newargs = args[:self.varinvpos] + (var,) + args[self.varinvpos + 1:] + res = self.myconv.genred_numpy(tagCpuGpu, tag1D2D, 0, device_id, ranges, *newargs) + if alpha: + res += alpha * var + return res + + return cg(linop, varinv, 'numpy', eps=eps, callback=callback, check_cond=check_cond) diff --git a/pykeops/test/unit_tests_numpy.py b/pykeops/test/unit_tests_numpy.py index 8db03f8aa..d88a90910 100644 --- a/pykeops/test/unit_tests_numpy.py +++ b/pykeops/test/unit_tests_numpy.py @@ -302,6 +302,46 @@ def test_LazyTensor_sum(self): self.assertTrue(res_keops.shape == res_numpy.shape) self.assertTrue(np.allclose(res_keops, res_numpy, atol=1e-3)) + ############################################################ + def test_cg_dic(self): + ############################################################ + from pykeops.numpy import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + + K = Genred(formula, aliases, axis=1, dtype=self.type_to_test[1]) + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype=self.type_to_test[1]) + + ans = Kinv.cg(self.x, self.x, self.f, + self.sigma, alpha=self.sigma)[0] + err = ((self.sigma * ans + K(self.x, self.x, + ans, self.sigma) - self.f) ** 2).sum() + self.assertTrue(np.allclose(err, np.zeros(err.shape))) + + ############################################################ + def test_cg_call(self): + ############################################################ + from pykeops.numpy import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + + K = Genred(formula, aliases, axis=1, dtype=self.type_to_test[1]) + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype=self.type_to_test[1]) + + ans = Kinv(self.x, self.x, self.f, + self.sigma, alpha=self.sigma) + err = ((self.sigma * ans + K(self.x, self.x, + ans, self.sigma) - self.f) ** 2).sum() + self.assertTrue(np.allclose(err, np.zeros(err.shape))) + if __name__ == '__main__': unittest.main() diff --git a/pykeops/test/unit_tests_pytorch.py b/pykeops/test/unit_tests_pytorch.py index 3fc669f06..58dcd26a9 100644 --- a/pykeops/test/unit_tests_pytorch.py +++ b/pykeops/test/unit_tests_pytorch.py @@ -556,7 +556,41 @@ def invert_permutation_numpy(permutation): grad_torch = torch.autograd.grad(sum_f_torch2, y, e)[0] self.assertTrue(torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4)) - + ############################################################ + def test_cg_dic(self): + ############################################################ + from pykeops.torch import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + K = Genred(formula, aliases, axis=1, dtype="float32") + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype="float32") + ans = Kinv.cg(self.xc, self.xc, self.fc, + self.sigmac, alpha=self.sigmac)[0] + err = ((self.sigmac * ans + K(self.xc, self.xc, + ans, self.sigmac) - self.fc) ** 2).sum() + self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) + + ############################################################# + def test_cg_dic(self): + ############################################################ + from pykeops.torch import KernelSolve, Genred + formula = 'Exp(- g * SqDist(x,y)) * a' + aliases = ['x = Vi(3)', # First arg: i-variable of size D + 'y = Vj(3)', # Second arg: j-variable of size D + 'a = Vj(1)', # Third arg: j-variable of size Dv + 'g = Pm(1)'] + K = Genred(formula, aliases, axis=1, dtype="float32") + Kinv = KernelSolve(formula, aliases, "a", axis=1, + dtype="float32") + ans = Kinv(self.xc, self.xc, self.fc, + self.sigmac, alpha=self.sigmac) + err = ((self.sigmac * ans + K(self.xc, self.xc, + ans, self.sigmac) - self.fc) ** 2).sum() + self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) if __name__ == '__main__': """ diff --git a/pykeops/torch/__init__.py b/pykeops/torch/__init__.py index 506671ce4..4a760a5a6 100644 --- a/pykeops/torch/__init__.py +++ b/pykeops/torch/__init__.py @@ -33,7 +33,6 @@ from .generic.generic_ops import generic_sum, generic_logsumexp, generic_argmin, generic_argkmin from .kernel_product.formula import Formula from pykeops.common.lazy_tensor import LazyTensor, Vi, Vj, Pm - # N.B.: If "from pykeops.numpy import LazyTensor" has already been run, # the line above will *not* import "torchtools" and we'll end up with an error... # So even though it may be a bit ugly, we re-load the lazy_tensor file @@ -43,5 +42,5 @@ importlib.reload(pykeops.common.lazy_tensor) __all__ = sorted( - ["Genred", "generic_sum", "generic_logsumexp", "generic_argmin", "generic_argkmin", "Kernel", "kernel_product", + ["cg", "Genred", "generic_sum", "generic_logsumexp", "generic_argmin", "generic_argkmin", "Kernel", "kernel_product", "KernelSolve", "kernel_formulas", "Formula", "LazyTensor", "Vi", "Vj", "Pm"]) diff --git a/pykeops/torch/operations.py b/pykeops/torch/operations.py index b5931552c..d64c40f41 100644 --- a/pykeops/torch/operations.py +++ b/pykeops/torch/operations.py @@ -8,6 +8,7 @@ from pykeops.torch import default_dtype from pykeops.torch import include_dirs from pykeops.torch.generic.generic_red import GenredAutograd +from pykeops.common.cg import cg class KernelSolveAutograd(torch.autograd.Function): @@ -47,7 +48,6 @@ def forward(ctx, formula, aliases, varinvpos, alpha, backend, dtype, device_id, for i in range(1,len(args)): if args[i].device.index != device_id: raise ValueError("[KeOps] Input arrays must be all located on the same device.") - def linop(var): newargs = args[:varinvpos] + (var,) + args[varinvpos+1:] res = myconv.genred_pytorch(tagCPUGPU, tag1D2D, tagHostDevice, device_id, ranges, *newargs) @@ -61,11 +61,11 @@ def linop(var): # relying on the 'ctx.saved_variables' attribute is necessary if you want to be able to differentiate the output # of the backward once again. It helps pytorch to keep track of 'who is who'. ctx.save_for_backward(*args, result) - return result @staticmethod def backward(ctx, G): + formula = ctx.formula aliases = ctx.aliases varinvpos = ctx.varinvpos @@ -77,7 +77,6 @@ def backward(ctx, G): myconv = ctx.myconv ranges = ctx.ranges accuracy_flags = ctx.accuracy_flags - args = ctx.saved_tensors[:-1] # Unwrap the saved variables nargs = len(args) result = ctx.saved_tensors[-1] @@ -105,7 +104,7 @@ def backward(ctx, G): if var_ind == varinvpos: grads.append(KinvG) else: - # adding new aliases is way too dangerous if we want to compute + #adding new aliases is way too dangerous if we want to compute # second derivatives, etc. So we make explicit references to Var instead. # New here (Joan) : we still add the new variables to the list of "aliases" (without giving new aliases for them) # these will not be used in the C++ code, @@ -136,8 +135,6 @@ def backward(ctx, G): # Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, *args return (None, None, None, None, None, None, None, None, None, None, *grads) - - class KernelSolve(): r""" Creates a new conjugate gradient solver. @@ -325,8 +322,123 @@ def __call__(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=1e-6, r that is inferred from the **formula**. """ - return KernelSolveAutograd.apply(self.formula, self.aliases, self.varinvpos, alpha, backend, self.dtype, device_id, eps, ranges, self.accuracy_flags, *args) + def cg(self, *args, backend='auto', device_id=-1, alpha=1e-10, eps=None, check_cond=False, callback=None, ranges=None): + r""" + Same as calling ``KernelSolve``. The keyword argument `check_cond` being added. + + Keyword Args: + check_cond (boolean, default=False): Indicates if the condition number + **might be** greater than 500. *Warning: setting it to True will + result in a more time-consuming method.* + + Returns: + A tuple of tensors containing the (M,D) or (N,D) tensor being the approximated + solution of the problem and the iteration number the algorithm stopped. + + """ + return dic_KernelSolveAutograd.apply(self.formula, self.aliases, self.varinvpos, alpha, backend, self.dtype, device_id, eps, ranges, self.accuracy_flags, check_cond, callback, *args) +class dic_KernelSolveAutograd(torch.autograd.Function): + @staticmethod + def forward(ctx, formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *args): + + optional_flags = include_dirs + accuracy_flags + + myconv = LoadKeOps(formula, aliases, dtype, 'torch', + optional_flags).import_module() + + # Context variables: save everything to compute the gradient: + ctx.formula = formula + ctx.aliases = aliases + ctx.varinvpos = varinvpos + ctx.alpha = alpha + ctx.backend = backend + ctx.dtype = dtype + ctx.device_id = device_id + ctx.check_cond = check_cond + ctx.eps = eps + ctx.myconv = myconv + ctx.ranges = ranges + ctx.callback = callback + ctx.accuracy_flags = accuracy_flags + if ranges is None: ranges = () # To keep the same type + + varinv = args[varinvpos] + ctx.varinvpos = varinvpos + + tagCPUGPU, tag1D2D, tagHostDevice = get_tag_backend(backend, args) + + if tagCPUGPU==1 & tagHostDevice==1: + device_id = args[0].device.index + for i in range(1,len(args)): + if args[i].device.index != device_id: + raise ValueError("[KeOps] Input arrays must be all located on the same device.") + + def linop(var): + newargs = args[:varinvpos] + (var,) + args[varinvpos+1:] + res = myconv.genred_pytorch(tagCPUGPU, tag1D2D, tagHostDevice, device_id, ranges, *newargs) + if alpha: + res += alpha*var + return res + + result, iter_ = cg(linop, varinv.data, 'torch', eps=eps, check_cond=check_cond, callback=callback) + ctx.save_for_backward(*args, result) + return result, torch.as_tensor(iter_) + + @staticmethod + def backward(ctx, G): + + formula = ctx.formula + aliases = ctx.aliases + varinvpos = ctx.varinvpos + backend = ctx.backend + alpha = ctx.alpha + dtype = ctx.dtype + device_id = ctx.device_id + eps = ctx.eps + check_cond = ctx.check_cond + myconv = ctx.myconv + ranges = ctx.ranges + callback = ctx.callback + accuracy_flags = ctx.accuracy_flags + args = ctx.saved_tensors[:-1] # Unwrap the saved variables + nargs = len(args) + result = ctx.saved_tensors[-1] + + eta = 'Var(' + str(nargs) + ',' + str(myconv.dimout) + ',' + str(myconv.tagIJ) + ')' + + # there is also a new variable for the formula's output + resvar = 'Var(' + str(nargs+1) + ',' + str(myconv.dimout) + ',' + str(myconv.tagIJ) + ')' + + newargs = args[:varinvpos] + (G,) + args[varinvpos+1:] + KinvG = KernelSolveAutograd.apply(formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *newargs) + + grads = [] # list of gradients wrt. args; + + for (var_ind, sig) in enumerate(aliases): + if not ctx.needs_input_grad[var_ind + 10]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags) + grads.append(None) # Don't waste time computing it. + + else: + if var_ind == varinvpos: + grads.append(KinvG) + else: + _, cat, dim, pos = get_type(sig, position_in_list=var_ind) + var = 'Var(' + str(pos) + ',' + str(dim) + ',' + str(cat) + ')' # V + formula_g = 'Grad_WithSavedForward(' + formula + ', ' + var + ', ' + eta + ', ' + resvar + ')' # Grad + aliases_g = aliases + [eta, resvar] + args_g = args[:varinvpos] + (result,) + args[varinvpos+1:] + (-KinvG,) + (result,) + genconv = GenredAutograd().apply + + if cat == 2: + grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) + grad = torch.ones(1, grad.shape[0]).type_as(grad.data) @ grad + grad = grad.view(-1) + else: + grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) + grads.append(grad) + + return (None, None, None, None, None, None, None, None, None, None, None, None, None, *grads) From ac5a1fa6bec46681feba26265a4c7ead89cb04a3 Mon Sep 17 00:00:00 2001 From: Tanguy Lefort Date: Tue, 28 Jul 2020 14:57:58 +0200 Subject: [PATCH 2/4] unitest fix typo --- pykeops/test/unit_tests_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykeops/test/unit_tests_pytorch.py b/pykeops/test/unit_tests_pytorch.py index 58dcd26a9..03b59158e 100644 --- a/pykeops/test/unit_tests_pytorch.py +++ b/pykeops/test/unit_tests_pytorch.py @@ -575,7 +575,7 @@ def test_cg_dic(self): self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) ############################################################# - def test_cg_dic(self): + def test_cg(self): ############################################################ from pykeops.torch import KernelSolve, Genred formula = 'Exp(- g * SqDist(x,y)) * a' From 88cac04bedbb59bcb770712d0a3fd7ff9aa3b317 Mon Sep 17 00:00:00 2001 From: Tanguy Lefort Date: Thu, 30 Jul 2020 19:04:58 +0200 Subject: [PATCH 3/4] autograd compatibility with dic-cg --- pykeops/torch/operations.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pykeops/torch/operations.py b/pykeops/torch/operations.py index d64c40f41..a9acf36b6 100644 --- a/pykeops/torch/operations.py +++ b/pykeops/torch/operations.py @@ -383,13 +383,15 @@ def linop(var): if alpha: res += alpha*var return res + global copy result, iter_ = cg(linop, varinv.data, 'torch', eps=eps, check_cond=check_cond, callback=callback) ctx.save_for_backward(*args, result) + return result, torch.as_tensor(iter_) @staticmethod - def backward(ctx, G): + def backward(ctx, G, G2): formula = ctx.formula aliases = ctx.aliases @@ -399,11 +401,11 @@ def backward(ctx, G): dtype = ctx.dtype device_id = ctx.device_id eps = ctx.eps - check_cond = ctx.check_cond myconv = ctx.myconv ranges = ctx.ranges - callback = ctx.callback accuracy_flags = ctx.accuracy_flags + check_cond = ctx.check_cond + callback = ctx.callback args = ctx.saved_tensors[:-1] # Unwrap the saved variables nargs = len(args) result = ctx.saved_tensors[-1] @@ -412,14 +414,11 @@ def backward(ctx, G): # there is also a new variable for the formula's output resvar = 'Var(' + str(nargs+1) + ',' + str(myconv.dimout) + ',' + str(myconv.tagIJ) + ')' - newargs = args[:varinvpos] + (G,) + args[varinvpos+1:] - KinvG = KernelSolveAutograd.apply(formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *newargs) - + KinvG = dic_KernelSolveAutograd.apply(formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags, check_cond, callback, *newargs) grads = [] # list of gradients wrt. args; - for (var_ind, sig) in enumerate(aliases): - if not ctx.needs_input_grad[var_ind + 10]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags) + if not ctx.needs_input_grad[var_ind + 12]: # because of (formula, aliases, varinvpos, alpha, backend, dtype, device_id, eps, ranges, accuracy_flags) grads.append(None) # Don't waste time computing it. else: @@ -430,9 +429,8 @@ def backward(ctx, G): var = 'Var(' + str(pos) + ',' + str(dim) + ',' + str(cat) + ')' # V formula_g = 'Grad_WithSavedForward(' + formula + ', ' + var + ', ' + eta + ', ' + resvar + ')' # Grad aliases_g = aliases + [eta, resvar] - args_g = args[:varinvpos] + (result,) + args[varinvpos+1:] + (-KinvG,) + (result,) + args_g = args[:varinvpos] + (result,) + args[varinvpos+1:] + (-KinvG[0],) + (result,) genconv = GenredAutograd().apply - if cat == 2: grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) grad = torch.ones(1, grad.shape[0]).type_as(grad.data) @ grad @@ -440,5 +438,4 @@ def backward(ctx, G): else: grad = genconv(formula_g, aliases_g, backend, dtype, device_id, ranges, accuracy_flags, *args_g) grads.append(grad) - - return (None, None, None, None, None, None, None, None, None, None, None, None, None, *grads) + return (None, None, None, None, None, None, None, None, None, None, None, None, *grads) \ No newline at end of file From 008dff6536d1c68768df9e053f939167fbfb37a8 Mon Sep 17 00:00:00 2001 From: Tanguy Lefort Date: Fri, 7 Aug 2020 12:34:54 +0200 Subject: [PATCH 4/4] type tex --- pykeops/benchmarks/plot_benchmark_cg_dimensions.py | 2 +- pykeops/test/unit_tests_pytorch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pykeops/benchmarks/plot_benchmark_cg_dimensions.py b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py index ae890fd32..aeb609346 100644 --- a/pykeops/benchmarks/plot_benchmark_cg_dimensions.py +++ b/pykeops/benchmarks/plot_benchmark_cg_dimensions.py @@ -5,7 +5,7 @@ Different implementations of the conjugate gradient (CG) exist. Here, we compare the CG implemented in scipy which uses Fortran against it's pythonized version and the older version of the algorithm available in pykeops. -We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b\in \mathbb R^N` and :math:`x\in\mathbb R^{N\times d}`. +We want to solve the positive definite linear system :math:`(K_{x,x} + \\alpha Id)a = b` for :math:`a, b\in \mathbb R^N` and :math:`x\in\mathbb R^{N\\times d}`. We will use :math:`N=100000` points. Let the Gaussian RBF kernel be defined as diff --git a/pykeops/test/unit_tests_pytorch.py b/pykeops/test/unit_tests_pytorch.py index 03b59158e..23fb14ed7 100644 --- a/pykeops/test/unit_tests_pytorch.py +++ b/pykeops/test/unit_tests_pytorch.py @@ -573,7 +573,7 @@ def test_cg_dic(self): err = ((self.sigmac * ans + K(self.xc, self.xc, ans, self.sigmac) - self.fc) ** 2).sum() self.assertTrue(np.allclose(err.cpu().data.numpy(), np.zeros(err.shape))) - + ############################################################# def test_cg(self): ############################################################