diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 957f5413..88db2323 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -19,7 +19,6 @@ This library must be installed *before* PyLops is installed. ``cupy`` and ``jax`` backends. This can be also used if a previous version of ``cupy`` or ``jax`` is installed in your system, otherwise you will get an error when importing PyLops. - Apart from a few exceptions, all operators and solvers in PyLops can seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy/jax`` arrays on GPU. For CuPy, users simply need to consistently create operators and @@ -32,6 +31,7 @@ be also wrapped into a :class:`pylops.JaxOperator`. See below for a comphrensive list of supported operators and additional functionalities for both the ``cupy`` and ``jax`` backends. + Install dependencies -------------------- GPU-enabled development environments can created using ``conda`` diff --git a/pylops/optimization/cls_leastsquares.py b/pylops/optimization/cls_leastsquares.py index 415d3609..d2abc758 100644 --- a/pylops/optimization/cls_leastsquares.py +++ b/pylops/optimization/cls_leastsquares.py @@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Optional, Sequence, Tuple import numpy as np -from scipy.sparse.linalg import cg as sp_cg -from scipy.sparse.linalg import lsqr as sp_lsqr from pylops.basicoperators import Diagonal, VStack from pylops.optimization.basesolver import Solver, _units from pylops.optimization.basic import cg, cgls -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, get_cg, get_lsqr from pylops.utils.typing import NDArray, Tmemunit, Tsolverengine if TYPE_CHECKING: @@ -294,12 +292,14 @@ def run( ``<0``: illegal input or breakdown """ - if engine == "scipy" and self.ncp == np: + if engine == "scipy": if "tol" in kwargs_solver: kwargs_solver["atol"] = kwargs_solver["tol"] kwargs_solver.pop("tol") - xinv, istop = sp_cg(self.Op_normal, self.y_normal, x0=x, **kwargs_solver) - elif engine == "pylops" or self.ncp != np: + xinv, istop = get_cg(x)( + self.Op_normal, self.y_normal, x0=x, **kwargs_solver + ) + elif engine == "pylops": if show: kwargs_solver["show"] = True xinv = cg( @@ -703,13 +703,13 @@ def run( Equal to ``r1norm`` if :math:`\epsilon=0` """ - if engine == "scipy" and self.ncp == np: + if engine == "scipy": if show: kwargs_solver["show"] = 1 - xinv, istop, itn, r1norm, r2norm = sp_lsqr( + xinv, istop, itn, r1norm, r2norm = get_lsqr(x)( self.RegOp, self.datatot, x0=x, **kwargs_solver )[0:5] - elif engine == "pylops" or self.ncp != np: + elif engine == "pylops": if show: kwargs_solver["show"] = True xinv, istop, itn, r1norm, r2norm = cgls( @@ -974,16 +974,16 @@ def run( Equal to ``r1norm`` if :math:`\epsilon=0` """ - if engine == "scipy" and self.ncp == np: + if engine == "scipy": if show: kwargs_solver["show"] = 1 - pinv, istop, itn, r1norm, r2norm = sp_lsqr( + pinv, istop, itn, r1norm, r2norm = get_lsqr(x)( self.POp, self.y, x0=x, **kwargs_solver, )[0:5] - elif engine == "pylops" or self.ncp != np: + elif engine == "pylops": if show: kwargs_solver["show"] = True pinv, istop, itn, r1norm, r2norm = cgls( diff --git a/pylops/optimization/cls_sparsity.py b/pylops/optimization/cls_sparsity.py index 55fbad1c..e4051772 100644 --- a/pylops/optimization/cls_sparsity.py +++ b/pylops/optimization/cls_sparsity.py @@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np -from scipy.sparse.linalg import lsqr from pylops import LinearOperator from pylops.basicoperators import Diagonal, Identity, VStack @@ -25,6 +24,7 @@ from pylops.utils import deps from pylops.utils.backend import ( get_array_module, + get_lsqr, get_module_name, get_real_dtype, inplace_set, @@ -570,15 +570,15 @@ def _step_model( kwargs_solver["preallocate"] = True if self.iiter == 0: # first iteration (unweighted least-squares) - if engine == "scipy" and self.ncp == np: + if engine == "scipy": x = self.Op.rmatvec( - lsqr( + get_lsqr(x)( self.Op @ self.Op.H + (self.epsI**2) * self.Iop, self.y, **kwargs_solver, )[0] ) - elif engine == "pylops" or self.ncp != np: + elif engine == "pylops": x = self.Op.rmatvec( cgls( self.Op @ self.Op.H + (self.epsI**2) * self.Iop, @@ -599,17 +599,17 @@ def _step_model( self.ncp.divide(self.rw, self.rw.max(), out=self.rw) R = Diagonal(self.rw, dtype=self.rw.dtype) - if engine == "scipy" and self.ncp == np: + if engine == "scipy": x = R.matvec( self.Op.rmatvec( - lsqr( + get_lsqr(x)( self.Op @ R @ self.Op.H + self.epsI**2 * self.Iop, self.y, **kwargs_solver, )[0] ) ) - elif engine == "pylops" or self.ncp != np: + elif engine == "pylops": x = R.matvec( self.Op.rmatvec( cgls( @@ -643,11 +643,8 @@ def step( Display iteration log **kwargs_solver Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and - :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data and ``engine='scipy'`` (or - :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data or + :py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cgls` when using ``engine='pylops'``) Returns @@ -702,11 +699,8 @@ def run( three element of the list. **kwargs_solver Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and - :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data and ``engine='scipy'`` (or - :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data or + :py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cgls` when using ``engine='pylops'``) Returns @@ -822,10 +816,9 @@ def solve( three element of the list. **kwargs_solver Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and - :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data(or :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data) + :py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cgls` when using + ``engine='pylops'``) Returns ------- @@ -1192,9 +1185,11 @@ def step( else: # OMP update Opcol = self.Op.apply_columns(cols) - if engine == "scipy" and self.ncp == np: - x = lsqr(Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver)[0] - elif engine == "pylops" or self.ncp != np: + if engine == "scipy": + x = get_lsqr(x)( + Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver + )[0] + elif engine == "pylops": x = cgls( Opcol, self.y, @@ -3038,9 +3033,10 @@ def step( **kwargs_solver Arbitrary keyword arguments for chosen solver used to solve the first subproblem in the first step of the - Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` and - :py:func:`pylops.optimization.solver.cgls` are used as default - for numpy and cupy `data`, respectively). + Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is + used if ``engine='scipy'`` and + :py:func:`pylops.optimization.solver.cgls` is used if engine is + ``engine='pylops'``). Returns ------- @@ -3149,9 +3145,12 @@ def run( show_inner : :obj:`bool`, optional Display inner iteration logs of lsqr **kwargs_lsqr - Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first - subproblem in the first step of the Split Bregman algorithm. + Arbitrary keyword arguments for chosen solver + used to solve the first subproblem in the first step of the + Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is + used if ``engine='scipy'`` and + :py:func:`pylops.optimization.solver.cgls` is used if engine is + ``engine='pylops'``). Returns ------- @@ -3285,9 +3284,12 @@ def solve( show_inner : :obj:`bool`, optional Display inner iteration logs of lsqr **kwargs_lsqr - Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first - subproblem in the first step of the Split Bregman algorithm. + Arbitrary keyword arguments for chosen solver + used to solve the first subproblem in the first step of the + Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is + used if ``engine='scipy'`` and + :py:func:`pylops.optimization.solver.cgls` is used if engine is + ``engine='pylops'``). Returns ------- diff --git a/pylops/optimization/sparsity.py b/pylops/optimization/sparsity.py index 36ed636f..95b59241 100644 --- a/pylops/optimization/sparsity.py +++ b/pylops/optimization/sparsity.py @@ -106,10 +106,9 @@ def irls( pre-allocated since JAX does not support in-place operations. **kwargs_solver Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.cg` solver for data IRLS and - :py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using - numpy data(or :py:func:`pylops.optimization.solver.cg` and - :py:func:`pylops.optimization.solver.cgls` when using cupy data) + :py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or + :py:func:`pylops.optimization.solver.cgls` when using + ``engine='pylops'``) Returns ------- @@ -790,15 +789,18 @@ def splitbregman( Function with signature (``callback(x)``) to call after each iteration where ``x`` is the current model vector preallocate : :obj:`bool`, optional - .. versionadded:: 2.6.0 + .. versionadded:: 2.6.0 - Pre-allocate all variables used by the solver. Note that if ``y`` - is a JAX array, this option is ignored and variables are not - pre-allocated since JAX does not support in-place operations. + Pre-allocate all variables used by the solver. Note that if ``y`` + is a JAX array, this option is ignored and variables are not + pre-allocated since JAX does not support in-place operations. **kwargs_lsqr - Arbitrary keyword arguments for - :py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first - subproblem in the first step of the Split Bregman algorithm. + Arbitrary keyword arguments for chosen solver + used to solve the first subproblem in the first step of the + Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is + used if ``engine='scipy'`` and + :py:func:`pylops.optimization.solver.cgls` is used if engine is + ``engine='pylops'``). Returns ------- diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index 2cc125df..5e3f65cb 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -14,6 +14,8 @@ "get_csc_matrix", "get_sparse_eye", "get_lstsq", + "get_cg", + "get_lsqr", "get_sp_fft", "get_complex_dtype", "get_real_dtype", @@ -36,6 +38,7 @@ from scipy.linalg import block_diag, lstsq, toeplitz from scipy.signal import convolve, correlate, fftconvolve, oaconvolve from scipy.sparse import csc_matrix, eye +from scipy.sparse.linalg import cg, lsqr from pylops.utils import deps from pylops.utils.typing import ArrayLike, DTypeLike, NDArray, Tfftengine_ncj @@ -52,6 +55,8 @@ from cupyx.scipy.signal import oaconvolve as cp_oaconvolve from cupyx.scipy.sparse import csc_matrix as cp_csc_matrix from cupyx.scipy.sparse import eye as cp_eye + from cupyx.scipy.sparse.linalg import cg as cp_cg + from cupyx.scipy.sparse.linalg import lsqr as cp_lsqr if deps.jax_enabled: import jax @@ -429,6 +434,52 @@ def get_lstsq(x: ArrayLike) -> Callable: return cp.linalg.lstsq +def get_cg(x: ArrayLike) -> Callable: + """Returns correct cg module based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` + Array + + Returns + ------- + f : :obj:`callable` + Function to be used to process array + + """ + if not deps.cupy_enabled: + return cg + + if cp.get_array_module(x) == np: + return cg + else: + return cp_cg + + +def get_lsqr(x: ArrayLike) -> Callable: + """Returns correct lsqr module based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` + Array + + Returns + ------- + f : :obj:`callable` + Function to be used to process array + + """ + if not deps.cupy_enabled: + return lsqr + + if cp.get_array_module(x) == np: + return lsqr + else: + return cp_lsqr + + def get_sp_fft(x: ArrayLike) -> Callable: """Returns correct scipy.fft module based on input diff --git a/pytests/test_leastsquares.py b/pytests/test_leastsquares.py index ccdcbe36..a4ed3ebb 100644 --- a/pytests/test_leastsquares.py +++ b/pytests/test_leastsquares.py @@ -101,58 +101,64 @@ def test_NormalEquationsInversion(par): ) y = Gop * x - for preallocate in [False, True]: - # normal equations with regularization - xinv = normal_equations_inversion( - Gop, - y, - [Reg], - epsI=1e-5, - epsRs=[1e-8], - x0=x0, - engine="pylops", - **dict(niter=200, tol=1e-10, preallocate=preallocate) - )[0] - assert_array_almost_equal(x, xinv, decimal=3) - # normal equations with weight - xinv = normal_equations_inversion( - Gop, - y, - None, - Weight=Weigth, - epsI=1e-5, - x0=x0, - engine="pylops", - **dict(niter=200, tol=1e-10, preallocate=preallocate) - )[0] - assert_array_almost_equal(x, xinv, decimal=3) - # normal equations with weight and small regularization - xinv = normal_equations_inversion( - Gop, - y, - [Reg], - Weight=Weigth, - epsI=1e-5, - epsRs=[1e-8], - x0=x0, - engine="pylops", - **dict(niter=200, tol=1e-10, preallocate=preallocate) - )[0] - assert_array_almost_equal(x, xinv, decimal=3) - # normal equations with weight and small normal regularization - xinv = normal_equations_inversion( - Gop, - y, - [], - NRegs=[NReg], - Weight=Weigth, - epsI=1e-5, - epsNRs=[1e-8], - x0=x0, - engine="pylops", - **dict(niter=200, tol=1e-10, preallocate=preallocate) - )[0] - assert_array_almost_equal(x, xinv, decimal=3) + for engine in ["scipy", "pylops"]: + for preallocate in [False, True]: + # define solver parameters + if engine == "scipy": + dict_solver = dict(maxiter=200, tol=1e-10) + else: + dict_solver = dict(niter=200, tol=1e-10, preallocate=preallocate) + # normal equations with regularization + xinv = normal_equations_inversion( + Gop, + y, + [Reg], + epsI=1e-5, + epsRs=[1e-8], + x0=x0, + engine=engine, + **dict_solver + )[0] + assert_array_almost_equal(x, xinv, decimal=3) + # normal equations with weight + xinv = normal_equations_inversion( + Gop, + y, + None, + Weight=Weigth, + epsI=1e-5, + x0=x0, + engine=engine, + **dict_solver + )[0] + assert_array_almost_equal(x, xinv, decimal=3) + # normal equations with weight and small regularization + xinv = normal_equations_inversion( + Gop, + y, + [Reg], + Weight=Weigth, + epsI=1e-5, + epsRs=[1e-8], + x0=x0, + engine=engine, + **dict_solver + )[0] + assert_array_almost_equal(x, xinv, decimal=3) + # normal equations with weight and small normal regularization + xinv = normal_equations_inversion( + Gop, + y, + [], + NRegs=[NReg], + Weight=Weigth, + epsI=1e-5, + epsNRs=[1e-8], + x0=x0, + engine=engine, + **dict_solver + )[0] + assert_array_almost_equal(x, xinv, decimal=3) @pytest.mark.parametrize(