lanl/scico

Tuning penalty parameter in ProximalADMM for better deconvolution results

Closed this issue · 2 comments

I'm exploring ProximalADMM as a last-ditch effort to get my space-variant deconvolution problem to work with scico. I've seem to have run a deadlock for now using ADMM, PDHG and PGM as discussed in other issues. I tried ProximalADMM with my forward operator of sum of weighted convolutions but the solver rapidly diverges to infinity. So I'm trying to understand how to set the penalty parameter rho for better convergence. I tried a simpler space-invariant deconvolution problem using a simple gaussian kernel defined by CircularConvolve. However, despite checking a range of values of rho with ProximalADMM, I'm not getting the results to match that of ADMM.

Below is my code using ProximalADMM with returned deblur results, followed by ADMM code and results for reference. Is ProximalADMM not expected to generate results as good as ADMM? Is there a heuristic that can help in evaluating a good estimate of the penalty parameter rho?

import numpy as np
import matplotlib.pyplot as plt
import jax
from scico import linop, loss, functional
from scico.optimize import ProximalADMM
from scico.optimize.admm import ADMM, CircularConvolveSolver
from scico.util import device_info

# SYNTHETIC DATA
plt.figure(figsize=(20, 10))
# Create Synthetic Horizontal Stripes Pattern Image
im_s = np.zeros((2748, 3840)).astype(float)
stripe_width, stripe_gap, stripe_start, stripe_end = 50, 50, 500, 500
for y in range(0, im_s.shape[0]-(stripe_start+stripe_end), stripe_width + stripe_gap):
    im_s[stripe_start+y : stripe_start+y + 50, :] = .8

xx, yy = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
im_ctr = (np.array(im_s.shape)/2).astype(int)
r = np.sqrt((xx - im_ctr[0])**2 + (yy - im_ctr[1])**2)
mask = np.zeros_like(im_s).astype(float)
mask[r < 1000] = 1
im_s *= mask
plt.subplot(131); plt.imshow(im_s); plt.title('Ground Truth');

# Create Gaussian Kernel
from scipy.stats import multivariate_normal
x, y = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
pos = np.dstack((x, y))
rv = multivariate_normal.pdf(pos, im_ctr, [[200, 0], [300, 500]])
psf2 = rv/np.sum(rv)
psf2_cropped = psf2[im_ctr[0]-100:im_ctr[0]+101, im_ctr[1]-100:im_ctr[1]+101]
plt.subplot(132); plt.imshow(psf2_cropped); plt.title('PSF zoomed in');

# Convolve and Create Blurred Image
C = linop.CircularConvolve(h=jax.device_put(psf2_cropped), 
                           input_shape=im_s.shape, 
                           h_center=[psf2_cropped.shape[0] // 2, psf2_cropped.shape[1] // 2])
Cx = C(jax.device_put(im_s))

plt.subplot(133); plt.imshow(Cx); plt.title('Blurred Image')
image
# PROXIMAL ADMM
f = functional.ZeroFunctional()
g0 = loss.SquaredL2Loss(y=Cx)
lbd = 5e-1  # L1 norm regularization parameter
g1 = lbd * functional.L21Norm()
g = functional.SeparableFunctional((g0, g1))

D = linop.FiniteDifference(input_shape=im_s.shape, circular=True)#append=0)
A = linop.VerticalStack((C, D))

rho = 5e-4 #1.0e-2  # ADMM penalty parameter
maxiter = 20  # number of ADMM iterations
mu, nu = ProximalADMM.estimate_parameters(D)

solver = ProximalADMM(
    f=f,
    g=g,
    A=A,
    B=None,
    rho=rho,
    mu=mu,
    nu=nu,
    x0=Cx,
    maxiter=maxiter,
    itstat_options={"display": True, "period": 10},
) #x0=C.adj(y)

print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)

plt.figure(figsize=(20, 10))
plt.subplot(121); plt.imshow(Cx); plt.title('Blurred image')
plt.subplot(122); plt.imshow(x, vmin=0, vmax=1); plt.title(f'Recovered Image; ProximalADMM \nrho: {rho}, lambda: {lbd}, iter: {maxiter}'); #plt.colorbar()
image
# ADMM
f = loss.SquaredL2Loss(y=Cx, A=C) #BlurOperator(psf_jx, input_shape=im_jx.shape)) #A=C
lbd = 5e-1#50  # L1 norm regularization parameter
g = lbd * functional.L21Norm()
D = linop.FiniteDifference(input_shape=im_s.shape, circular=True)

rho = 5e0#5e2  # ADMM penalty parameter
maxiter = 20  # number of ADMM iterations

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[D],
    rho_list=[rho],
    x0=Cx,
    maxiter=maxiter,
    subproblem_solver=CircularConvolveSolver(),
    itstat_options={"display": True, "period": 10},
)

print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)

plt.figure(figsize=(12, 6))
plt.subplot(121); plt.imshow(Cx); plt.title('Blurred image')
plt.subplot(122); plt.imshow(x); plt.title(f'Recovered Image; ADMM \nrho: {rho}, lambda: {lbd}, iter: {maxiter}'); #plt.colorbar()
image

You should be able to get a similar quality reconstruction with Proximal ADMM, but note that it may require more iterations than ADMM. Choosing a good $\rho$ can be difficult. Your best option here is perhaps to scale the problem down substantially and then experiment to find a good value. (Note, though, that the best value for the scaled-down problem may not be the same as for the original problem, but it should at least give you some idea if you're way off.)

Closing as answered. Feel free to reopen if this is not resolved.