Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ This library must be installed *before* PyLops is installed.
``cupy`` and ``jax`` backends. This can be also used if a previous version of ``cupy``
or ``jax`` is installed in your system, otherwise you will get an error when importing PyLops.


Apart from a few exceptions, all operators and solvers in PyLops can
seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy/jax`` arrays
on GPU. For CuPy, users simply need to consistently create operators and
Expand All @@ -32,6 +31,7 @@ be also wrapped into a :class:`pylops.JaxOperator`.
See below for a comphrensive list of supported operators and additional functionalities for both the
``cupy`` and ``jax`` backends.


Install dependencies
--------------------
GPU-enabled development environments can created using ``conda``
Expand Down
24 changes: 12 additions & 12 deletions pylops/optimization/cls_leastsquares.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from typing import TYPE_CHECKING, Optional, Sequence, Tuple

import numpy as np
from scipy.sparse.linalg import cg as sp_cg
from scipy.sparse.linalg import lsqr as sp_lsqr

from pylops.basicoperators import Diagonal, VStack
from pylops.optimization.basesolver import Solver, _units
from pylops.optimization.basic import cg, cgls
from pylops.utils.backend import get_array_module
from pylops.utils.backend import get_array_module, get_cg, get_lsqr
from pylops.utils.typing import NDArray, Tmemunit, Tsolverengine

if TYPE_CHECKING:
Expand Down Expand Up @@ -294,12 +292,14 @@ def run(
``<0``: illegal input or breakdown

"""
if engine == "scipy" and self.ncp == np:
if engine == "scipy":
if "tol" in kwargs_solver:
kwargs_solver["atol"] = kwargs_solver["tol"]
kwargs_solver.pop("tol")
xinv, istop = sp_cg(self.Op_normal, self.y_normal, x0=x, **kwargs_solver)
elif engine == "pylops" or self.ncp != np:
xinv, istop = get_cg(x)(
self.Op_normal, self.y_normal, x0=x, **kwargs_solver
)
elif engine == "pylops":
if show:
kwargs_solver["show"] = True
xinv = cg(
Expand Down Expand Up @@ -703,13 +703,13 @@ def run(
Equal to ``r1norm`` if :math:`\epsilon=0`

"""
if engine == "scipy" and self.ncp == np:
if engine == "scipy":
if show:
kwargs_solver["show"] = 1
xinv, istop, itn, r1norm, r2norm = sp_lsqr(
xinv, istop, itn, r1norm, r2norm = get_lsqr(x)(
self.RegOp, self.datatot, x0=x, **kwargs_solver
)[0:5]
elif engine == "pylops" or self.ncp != np:
elif engine == "pylops":
if show:
kwargs_solver["show"] = True
xinv, istop, itn, r1norm, r2norm = cgls(
Expand Down Expand Up @@ -974,16 +974,16 @@ def run(
Equal to ``r1norm`` if :math:`\epsilon=0`

"""
if engine == "scipy" and self.ncp == np:
if engine == "scipy":
if show:
kwargs_solver["show"] = 1
pinv, istop, itn, r1norm, r2norm = sp_lsqr(
pinv, istop, itn, r1norm, r2norm = get_lsqr(x)(
self.POp,
self.y,
x0=x,
**kwargs_solver,
)[0:5]
elif engine == "pylops" or self.ncp != np:
elif engine == "pylops":
if show:
kwargs_solver["show"] = True
pinv, istop, itn, r1norm, r2norm = cgls(
Expand Down
68 changes: 35 additions & 33 deletions pylops/optimization/cls_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
from scipy.sparse.linalg import lsqr

from pylops import LinearOperator
from pylops.basicoperators import Diagonal, Identity, VStack
Expand All @@ -25,6 +24,7 @@
from pylops.utils import deps
from pylops.utils.backend import (
get_array_module,
get_lsqr,
get_module_name,
get_real_dtype,
inplace_set,
Expand Down Expand Up @@ -570,15 +570,15 @@ def _step_model(
kwargs_solver["preallocate"] = True
if self.iiter == 0:
# first iteration (unweighted least-squares)
if engine == "scipy" and self.ncp == np:
if engine == "scipy":
x = self.Op.rmatvec(
lsqr(
get_lsqr(x)(
self.Op @ self.Op.H + (self.epsI**2) * self.Iop,
self.y,
**kwargs_solver,
)[0]
)
elif engine == "pylops" or self.ncp != np:
elif engine == "pylops":
x = self.Op.rmatvec(
cgls(
self.Op @ self.Op.H + (self.epsI**2) * self.Iop,
Expand All @@ -599,17 +599,17 @@ def _step_model(
self.ncp.divide(self.rw, self.rw.max(), out=self.rw)

R = Diagonal(self.rw, dtype=self.rw.dtype)
if engine == "scipy" and self.ncp == np:
if engine == "scipy":
x = R.matvec(
self.Op.rmatvec(
lsqr(
get_lsqr(x)(
self.Op @ R @ self.Op.H + self.epsI**2 * self.Iop,
self.y,
**kwargs_solver,
)[0]
)
)
elif engine == "pylops" or self.ncp != np:
elif engine == "pylops":
x = R.matvec(
self.Op.rmatvec(
cgls(
Expand Down Expand Up @@ -643,11 +643,8 @@ def step(
Display iteration log
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
numpy data and ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cg` and
:py:func:`pylops.optimization.solver.cgls` when using cupy data or
:py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cgls` when using
``engine='pylops'``)

Returns
Expand Down Expand Up @@ -702,11 +699,8 @@ def run(
three element of the list.
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
numpy data and ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cg` and
:py:func:`pylops.optimization.solver.cgls` when using cupy data or
:py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cgls` when using
``engine='pylops'``)

Returns
Expand Down Expand Up @@ -822,10 +816,9 @@ def solve(
three element of the list.
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
numpy data(or :py:func:`pylops.optimization.solver.cg` and
:py:func:`pylops.optimization.solver.cgls` when using cupy data)
:py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cgls` when using
``engine='pylops'``)

Returns
-------
Expand Down Expand Up @@ -1192,9 +1185,11 @@ def step(
else:
# OMP update
Opcol = self.Op.apply_columns(cols)
if engine == "scipy" and self.ncp == np:
x = lsqr(Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver)[0]
elif engine == "pylops" or self.ncp != np:
if engine == "scipy":
x = get_lsqr(x)(
Opcol, self.y, iter_lim=self.niter_inner, **kwargs_solver
)[0]
elif engine == "pylops":
x = cgls(
Opcol,
self.y,
Expand Down Expand Up @@ -3038,9 +3033,10 @@ def step(
**kwargs_solver
Arbitrary keyword arguments for chosen solver
used to solve the first subproblem in the first step of the
Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` and
:py:func:`pylops.optimization.solver.cgls` are used as default
for numpy and cupy `data`, respectively).
Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is
used if ``engine='scipy'`` and
:py:func:`pylops.optimization.solver.cgls` is used if engine is
``engine='pylops'``).

Returns
-------
Expand Down Expand Up @@ -3149,9 +3145,12 @@ def run(
show_inner : :obj:`bool`, optional
Display inner iteration logs of lsqr
**kwargs_lsqr
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
subproblem in the first step of the Split Bregman algorithm.
Arbitrary keyword arguments for chosen solver
used to solve the first subproblem in the first step of the
Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is
used if ``engine='scipy'`` and
:py:func:`pylops.optimization.solver.cgls` is used if engine is
``engine='pylops'``).

Returns
-------
Expand Down Expand Up @@ -3285,9 +3284,12 @@ def solve(
show_inner : :obj:`bool`, optional
Display inner iteration logs of lsqr
**kwargs_lsqr
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
subproblem in the first step of the Split Bregman algorithm.
Arbitrary keyword arguments for chosen solver
used to solve the first subproblem in the first step of the
Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is
used if ``engine='scipy'`` and
:py:func:`pylops.optimization.solver.cgls` is used if engine is
``engine='pylops'``).

Returns
-------
Expand Down
24 changes: 13 additions & 11 deletions pylops/optimization/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,9 @@ def irls(
pre-allocated since JAX does not support in-place operations.
**kwargs_solver
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.cg` solver for data IRLS and
:py:func:`scipy.sparse.linalg.lsqr` solver for model IRLS when using
numpy data(or :py:func:`pylops.optimization.solver.cg` and
:py:func:`pylops.optimization.solver.cgls` when using cupy data)
:py:func:`scipy.sparse.linalg.lsqr` when ``engine='scipy'`` (or
:py:func:`pylops.optimization.solver.cgls` when using
``engine='pylops'``)

Returns
-------
Expand Down Expand Up @@ -790,15 +789,18 @@ def splitbregman(
Function with signature (``callback(x)``) to call after each iteration
where ``x`` is the current model vector
preallocate : :obj:`bool`, optional
.. versionadded:: 2.6.0
.. versionadded:: 2.6.0

Pre-allocate all variables used by the solver. Note that if ``y``
is a JAX array, this option is ignored and variables are not
pre-allocated since JAX does not support in-place operations.
Pre-allocate all variables used by the solver. Note that if ``y``
is a JAX array, this option is ignored and variables are not
pre-allocated since JAX does not support in-place operations.
**kwargs_lsqr
Arbitrary keyword arguments for
:py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
subproblem in the first step of the Split Bregman algorithm.
Arbitrary keyword arguments for chosen solver
used to solve the first subproblem in the first step of the
Split Bregman algorithm (:py:func:`scipy.sparse.linalg.lsqr` is
used if ``engine='scipy'`` and
:py:func:`pylops.optimization.solver.cgls` is used if engine is
``engine='pylops'``).

Returns
-------
Expand Down
51 changes: 51 additions & 0 deletions pylops/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"get_csc_matrix",
"get_sparse_eye",
"get_lstsq",
"get_cg",
"get_lsqr",
"get_sp_fft",
"get_complex_dtype",
"get_real_dtype",
Expand All @@ -36,6 +38,7 @@
from scipy.linalg import block_diag, lstsq, toeplitz
from scipy.signal import convolve, correlate, fftconvolve, oaconvolve
from scipy.sparse import csc_matrix, eye
from scipy.sparse.linalg import cg, lsqr

from pylops.utils import deps
from pylops.utils.typing import ArrayLike, DTypeLike, NDArray, Tfftengine_ncj
Expand All @@ -52,6 +55,8 @@
from cupyx.scipy.signal import oaconvolve as cp_oaconvolve
from cupyx.scipy.sparse import csc_matrix as cp_csc_matrix
from cupyx.scipy.sparse import eye as cp_eye
from cupyx.scipy.sparse.linalg import cg as cp_cg
from cupyx.scipy.sparse.linalg import lsqr as cp_lsqr

if deps.jax_enabled:
import jax
Expand Down Expand Up @@ -429,6 +434,52 @@ def get_lstsq(x: ArrayLike) -> Callable:
return cp.linalg.lstsq


def get_cg(x: ArrayLike) -> Callable:
"""Returns correct cg module based on input

Parameters
----------
x : :obj:`numpy.ndarray`
Array

Returns
-------
f : :obj:`callable`
Function to be used to process array

"""
if not deps.cupy_enabled:
return cg

if cp.get_array_module(x) == np:
return cg
else:
return cp_cg


def get_lsqr(x: ArrayLike) -> Callable:
"""Returns correct lsqr module based on input

Parameters
----------
x : :obj:`numpy.ndarray`
Array

Returns
-------
f : :obj:`callable`
Function to be used to process array

"""
if not deps.cupy_enabled:
return lsqr

if cp.get_array_module(x) == np:
return lsqr
else:
return cp_lsqr


def get_sp_fft(x: ArrayLike) -> Callable:
"""Returns correct scipy.fft module based on input

Expand Down
Loading
Loading