Skip to content

Commit de71003

Browse files
committed
feat: added solvers parameter to L2
1 parent ca6fb10 commit de71003

File tree

2 files changed

+162
-55
lines changed

2 files changed

+162
-55
lines changed

pyproximal/proximal/L2.py

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import numpy as np
44
from pylops import Identity, MatrixMult
5-
from pylops.optimization.basic import lsqr
5+
from pylops.optimization.basic import cg, lsqr
6+
from pylops.optimization.leastsquares import regularized_inversion
67
from pylops.utils.backend import get_array_module, get_module_name
78
from pylops.utils.typing import NDArray, ShapeLike
89
from scipy.linalg import cho_factor, cho_solve
@@ -47,6 +48,15 @@ class L2(ProxOperator):
4748
warm : :obj:`bool`, optional
4849
Warm start (``True``) or not (``False``). Uses estimate from previous
4950
call of ``prox`` method.
51+
solver : :obj:`str`, optional
52+
.. versionadded:: 0.11.0
53+
54+
Name of solver to use with non-explicit operators:
55+
- ``legacy``: enforces the legacy behaviour where :py:func:`scipy.sparse.linalg.lsqr` is used with numpy data and :py:func:`pylops.optimization.solver.lsqr` is used with cupy data (both are applied to the normal equations);
56+
- ``cg`` to use :py:func:`pylops.optimization.solver.cg` on the
57+
normal equations;
58+
- ``cgls`` to use :py:func:`pylops.optimization.solver.cgls` on the
59+
regularized system of equations;
5060
densesolver : :obj:`str`, optional
5161
Use ``numpy``, ``scipy``, or ``factorize`` when dealing with explicit
5262
operators. The former two rely on dense solvers from either library,
@@ -55,10 +65,8 @@ class L2(ProxOperator):
5565
have changed. Choose ``densesolver=None`` when using PyLops versions
5666
earlier than v1.18.1 or v2.0.0
5767
**kwargs_solver : :obj:`dict`, optional
58-
Dictionary containing extra arguments for
59-
:py:func:`scipy.sparse.linalg.lsqr` solver when using
60-
numpy data (or :py:func:`pylops.optimization.solver.lsqr` and
61-
when using cupy data)
68+
Dictionary containing extra arguments for the solver selected
69+
via the ``solver`` parameter.
6270
6371
Notes
6472
-----
@@ -78,8 +86,26 @@ class L2(ProxOperator):
7886
iterative solver is employed. In this case it is possible to provide a warm
7987
start via the ``x0`` input parameter.
8088
81-
When only ``b`` is provided, ``Op`` is assumed to be an Identity operator
82-
and the proximal operator reduces to:
89+
Note that alternatively the proximal operator can be computed solving the following
90+
augumented system of equations (instead of its normal equations as shown above):
91+
92+
.. math::
93+
94+
\begin{bmatrix}
95+
\sqrt{\tau \sigma} \mathbf{Op} \\
96+
\mathbf{I}
97+
\end{bmatrix}
98+
prox_{\tau f_\alpha}(\mathbf{x}) =
99+
\begin{bmatrix}
100+
\sqrt{\tau \sigma} \mathbf{b} \\
101+
\mathbf{x} - \tau \alpha \mathbf{q}
102+
\end{bmatrix}
103+
104+
The choice of the parameter ``solver`` determines which of the two
105+
approaches is used.
106+
107+
Alternatively, when only ``b`` is provided, ``Op`` is assumed to be an
108+
Identity operator and the proximal operator reduces to:
83109
84110
.. math::
85111
@@ -111,6 +137,7 @@ def __init__(
111137
niter: Union[int, Callable[[int], int]] = 10,
112138
x0: Optional[NDArray] = None,
113139
warm: bool = True,
140+
solver: Optional[str] = "legacy",
114141
densesolver: Optional[str] = None,
115142
kwargs_solver: Optional[Dict[str, Any]] = None,
116143
) -> None:
@@ -123,17 +150,33 @@ def __init__(
123150
self.niter = niter
124151
self.x0 = x0
125152
self.warm = warm
153+
self.solver = solver
126154
self.densesolver = densesolver
127155
self.count = 0
128156
self.kwargs_solver = {} if kwargs_solver is None else kwargs_solver
129157

158+
# define whether the normal equations or the regularized system
159+
# of equations are solved
160+
if self.solver in ("legacy", "cg"):
161+
self.normaleqs = True
162+
elif self.solver == "cgls":
163+
self.normaleqs = False
164+
else:
165+
raise ValueError(
166+
f"Provided solver={self.solver}. "
167+
"Available options are: 'legacy', 'cg', 'cgls'."
168+
)
130169
# when using factorize, store the first tau*sigma=0 so that the
131170
# first time it will be recomputed (as tau cannot be 0)
132171
if self.densesolver == "factorize":
133172
self.tausigma = 0.0
134173

135174
# create data term
136-
if self.Op is not None and self.b is not None:
175+
if (
176+
self.Op is not None
177+
and self.b is not None
178+
and (self.Op.explicit or self.normaleqs)
179+
):
137180
self.OpTb = self.sigma * self.Op.H @ self.b
138181
# create A.T A upfront for explicit operators
139182
if self.Op.explicit:
@@ -170,9 +213,10 @@ def prox(self, x: NDArray, tau: float) -> NDArray:
170213

171214
# solve proximal optimization
172215
if self.Op is not None and self.b is not None:
173-
y = x + tau * self.OpTb
174-
if self.q is not None:
175-
y -= tau * self.alpha * self.q
216+
if self.normaleqs or self.Op.explicit:
217+
y = x + tau * self.OpTb
218+
if self.q is not None:
219+
y -= tau * self.alpha * self.q
176220
if self.Op.explicit:
177221
if self.densesolver != "factorize":
178222
Op1 = MatrixMult(
@@ -192,18 +236,41 @@ def prox(self, x: NDArray, tau: float) -> NDArray:
192236
ATA = np.eye(self.Op.shape[1]) + self.tausigma * self.ATA
193237
self.cl = cho_factor(ATA)
194238
x = cho_solve(self.cl, y)
195-
else:
239+
elif self.normaleqs:
196240
Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + float(
197241
tau * self.sigma
198242
) * (self.Op.H * self.Op)
199-
if get_module_name(get_array_module(x)) == "numpy":
200-
x = sp_lsqr(
201-
Op1, y, iter_lim=niter, x0=self.x0, **self.kwargs_solver
202-
)[0]
203-
else:
204-
x = lsqr(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[
243+
if self.solver == "legacy":
244+
if get_module_name(get_array_module(x)) == "numpy":
245+
x = sp_lsqr(
246+
Op1, y, iter_lim=niter, x0=self.x0, **self.kwargs_solver
247+
)[0]
248+
else:
249+
x = lsqr(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[
250+
0
251+
].ravel()
252+
elif self.solver == "cg":
253+
x = cg(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[
205254
0
206255
].ravel()
256+
else:
257+
y = x
258+
if self.q is not None:
259+
y -= tau * self.alpha * self.q
260+
x = regularized_inversion(
261+
np.sqrt(tau * self.sigma) * self.Op,
262+
np.sqrt(tau * self.sigma) * self.b,
263+
[
264+
Identity(self.Op.shape[1], dtype=self.Op.dtype),
265+
],
266+
x0=self.x0,
267+
dataregs=[
268+
y,
269+
],
270+
niter=niter,
271+
engine="pylops",
272+
**self.kwargs_solver,
273+
)[0].ravel()
207274
if self.warm:
208275
self.x0 = x
209276
elif self.b is not None:

pytests/test_norms.py

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,60 +42,58 @@ def test_Euclidean(par):
4242

4343

4444
@pytest.mark.parametrize("par", [(par1), (par2)])
45-
def test_L2(par):
46-
"""L2 norm of Op*x-b and proximal/dual proximal"""
47-
l2 = L2(
48-
Op=Identity(par["nx"], dtype=par["dtype"]),
49-
b=np.zeros(par["nx"], dtype=par["dtype"]),
50-
sigma=par["sigma"],
51-
)
45+
def test_L2_op(par):
46+
"""L2 norm of Op*x and proximal (since Op is a Diagonal
47+
operator the denominator becomes 1 + sigma*tau*d[i]^2
48+
for every i)"""
49+
b = np.zeros(par["nx"], dtype=par["dtype"])
50+
d = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
51+
l2 = L2(Op=Diagonal(d, dtype=par["dtype"]), b=b, sigma=par["sigma"], niter=500)
52+
5253
# norm
5354
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
54-
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(x) ** 2
55+
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(d * x) ** 2
5556

56-
# prox
57+
# prox: since Op is a Diagonal operator the denominator becomes
58+
# 1 + sigma*tau*d[i] for every i
5759
tau = 2.0
58-
assert_array_almost_equal(l2.prox(x, tau), x / (1.0 + par["sigma"] * tau))
60+
den = 1.0 + par["sigma"] * tau * d**2
61+
assert_array_almost_equal(l2.prox(x, tau), x / den, decimal=4)
5962

6063

6164
@pytest.mark.parametrize("par", [(par1), (par2)])
62-
def test_L2_diff(par):
63-
"""L2 norm of difference (x-b) and proximal/dual proximal"""
65+
def test_L2_op_solver(par):
66+
"""L2 norm of Op*x-b and proximal, the first compared to close-form
67+
solution and the second with different choices of solver."""
68+
Op = MatrixMult(
69+
np.random.normal(0, 1, (par["nx"], par["nx"])).astype(dtype=par["dtype"]),
70+
dtype=par["dtype"],
71+
)
6472
b = np.ones(par["nx"], dtype=par["dtype"])
65-
l2 = L2(b=b, sigma=par["sigma"])
73+
l2_leg = L2(Op=Op, b=b, sigma=par["sigma"], solver="legacy", niter=par["nx"])
74+
l2_cg = L2(Op=Op, b=b, sigma=par["sigma"], solver="cg", niter=par["nx"])
75+
l2_cgls = L2(Op=Op, b=b, sigma=par["sigma"], solver="cgls", niter=par["nx"])
6676

6777
# norm
6878
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
69-
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(x - b) ** 2
79+
assert l2_leg(x) == (par["sigma"] / 2.0) * np.linalg.norm(Op * x - b) ** 2
7080

7181
# prox
7282
tau = 2.0
73-
assert_array_almost_equal(
74-
l2.prox(x, tau), (x + par["sigma"] * tau * b) / (1.0 + par["sigma"] * tau)
75-
)
83+
prox_leg = l2_leg.prox(x, tau)
84+
prox_cg = l2_cg.prox(x, tau)
85+
prox_cgls = l2_cgls.prox(x, tau)
7686

77-
78-
@pytest.mark.parametrize("par", [(par1), (par2)])
79-
def test_L2_op(par):
80-
"""L2 norm of Op*x and proximal/dual proximal"""
81-
b = np.zeros(par["nx"], dtype=par["dtype"])
82-
d = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
83-
l2 = L2(Op=Diagonal(d, dtype=par["dtype"]), b=b, sigma=par["sigma"], niter=500)
84-
85-
# norm
86-
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
87-
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(d * x) ** 2
88-
89-
# prox: since Op is a Diagonal operator the denominator becomes
90-
# 1 + sigma*tau*d[i] for every i
91-
tau = 2.0
92-
den = 1.0 + par["sigma"] * tau * d**2
93-
assert_array_almost_equal(l2.prox(x, tau), x / den, decimal=4)
87+
assert_array_almost_equal(prox_leg, prox_cg, decimal=4)
88+
assert_array_almost_equal(prox_leg, prox_cgls, decimal=4)
9489

9590

9691
@pytest.mark.parametrize("par", [(par1), (par2)])
9792
def test_L2_dense(par):
98-
"""L2 norm of Op*x with dense Op and proximal/dual proximal"""
93+
"""L2 norm of Op*x with dense Op and proximal
94+
compared to closed-form solution (since Op is a Diagonal
95+
operator the denominator becomes 1 + sigma*tau*d[i]^2 for
96+
every i)"""
9997
for densesolver in ("numpy", "scipy", "factorize"):
10098
b = np.zeros(par["nx"], dtype=par["dtype"])
10199
d = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
@@ -110,13 +108,55 @@ def test_L2_dense(par):
110108
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
111109
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(d * x) ** 2
112110

113-
# prox: since Op is a Diagonal operator the denominator becomes
114-
# 1 + sigma*tau*d[i] for every i
111+
# prox
115112
tau = 2.0
116113
den = 1.0 + par["sigma"] * tau * d**2
117114
assert_array_almost_equal(l2.prox(x, tau), x / den, decimal=4)
118115

119116

117+
@pytest.mark.parametrize("par", [(par1), (par2)])
118+
def test_L2_diff(par):
119+
"""L2 norm of difference (x-b) and proximal
120+
compared to closed-form solution"""
121+
b = np.ones(par["nx"], dtype=par["dtype"])
122+
l2 = L2(b=b, sigma=par["sigma"])
123+
124+
# norm
125+
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
126+
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(x - b) ** 2
127+
128+
# prox
129+
tau = 2.0
130+
assert_array_almost_equal(
131+
l2.prox(x, tau), (x + par["sigma"] * tau * b) / (1.0 + par["sigma"] * tau)
132+
)
133+
134+
135+
@pytest.mark.parametrize("par", [(par1), (par2)])
136+
def test_L2_x(par):
137+
"""L2 norm of x and proximal (implemented both directly and
138+
with identity operator and zero b and compared to closed-form
139+
solution)"""
140+
l2 = L2(
141+
Op=Identity(par["nx"], dtype=par["dtype"]),
142+
b=np.zeros(par["nx"], dtype=par["dtype"]),
143+
sigma=par["sigma"],
144+
)
145+
l2direct = L2(
146+
sigma=par["sigma"],
147+
)
148+
149+
# norm
150+
x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
151+
assert l2(x) == (par["sigma"] / 2.0) * np.linalg.norm(x) ** 2
152+
assert l2direct(x) == (par["sigma"] / 2.0) * np.linalg.norm(x) ** 2
153+
154+
# prox
155+
tau = 2.0
156+
assert_array_almost_equal(l2.prox(x, tau), x / (1.0 + par["sigma"] * tau))
157+
assert_array_almost_equal(l2direct.prox(x, tau), x / (1.0 + par["sigma"] * tau))
158+
159+
120160
@pytest.mark.parametrize("par", [(par1), (par2)])
121161
def test_L1(par):
122162
"""L1 norm and proximal/dual proximal"""

0 commit comments

Comments
 (0)