Creating an app with multi-regularizers (wavelet and total variation)
joeyplum opened this issue · 2 comments
Hi SigPy group,
First, thank you for creating this powerful software - I've found it easy to use and very helpful for understanding CS in MRI.
I've been working on a problem for a few weeks now: I want to build an app that is capable of solving a multi-regularizer problem. Specifically, I want to solve the following optimization problem:
where lambdaW and lambdaTV are penalties for the wavelet and total variation components, respectively. This optimization problem can also be found in Miki Lustig's 2007 paper.
I came here to see if you have solved this issue in SigPy before, or if you have any advice? So far, I have tried reformulating the problem to use temporary variables for x in each of the regularizer components. However, I haven't been able to get my head around getting to the final solution (i.e. how to recombine the temporary variables). Perhaps it is my inexperience with Python, but I can't quite figure out where to go next.
Thanks for your help, and I look forward to hearing back from you!
Joey
Copied below is the most recent version of an app I tried to create using the current apps as a template: (please be critical)
class TVWaveletRecon(sp.app.LinearLeastSquares):
r"""L1 Wavelet and total variation regularized reconstruction.
Wavelet is good at preserving edges and low contrast information while TV
is efficient at suppressing noise and streaking artifacts.
Considers the problem
.. math::
\min_x \frac{1}{2} \| A x - y \|_2^2 + \lambdaW \| W x \|_1 + \lambdaTV \| G x \|_1
where A is the sampling operator,
W is the wavelet operator,
x is the image, and y is the k-space measurements.
Args:
y (array): k-space measurements.
mps (array): sensitivity maps.
lamdaW (float): regularization parameter for the wavelet component.
lamdaTV (float): regularization parameter for the finite difference component.
weights (float or array): weights for data consistency.
coord (None or array): coordinates.
wave_name (str): wavelet name.
device (Device): device to perform reconstruction.
coil_batch_size (int): batch size to process coils.
Only affects memory usage.
comm (Communicator): communicator for distributed computing.
**kwargs: Other optional arguments.
References:
Lustig, M., Donoho, D., & Pauly, J. M. (2007).
Sparse MRI: The application of compressed sensing for rapid MR imaging.
Magnetic Resonance in Medicine, 58(6), 1082-1195.
Zangen, Z., Khan, W., Babyn, P., Cooper, D., Pratt, I., Carter, Y. (2013)
Improved Compressed Sensing-Based Algorithm for Sparse-View CT Image Reconstruction.
Computational and Mathematical Methods in Medicine.
10.1155/2013/185750
"""
def __init__(self, y, mps, lamdaW, lamdaTV,
weights=None, coord=None,
wave_name='db4', device=sp.cpu_device,
coil_batch_size=None, comm=None, show_pbar=True,
transp_nufft=False, **kwargs):
weights = _estimate_weights(y, weights, coord)
if weights is not None:
y = sp.to_device(y * weights**0.5, device=device)
else:
y = sp.to_device(y, device=device)
A = linop.Sense(mps, coord=coord, weights=weights,
comm=comm, coil_batch_size=coil_batch_size,
transp_nufft=transp_nufft)
img_shape = mps.shape[1:]
# Wavelet
W = sp.linop.Wavelet(img_shape, wave_name=wave_name)
# Finite difference
G = sp.linop.FiniteDifference(A.ishape)
proxg1 = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, lamdaW), W)
proxg2 = sp.prox.L1Reg(G.oshape, lamdaTV)
def g(input):
device = sp.get_device(input)
xp = device.xp
with device:
return lamdaW * xp.sum(xp.abs(W(input))).item() + lamdaTV * xp.sum(xp.abs(input)).item()
if comm is not None:
show_pbar = show_pbar and comm.rank == 0
# Call super().__init(...) to call the __init(...) of the parent class,
# sp.app.LinearLeastSquares
super().__init__(A, y, proxg=proxg1, g=g, show_pbar=show_pbar, **kwargs)
def h(input):
device = sp.get_device(input)
xp = device.xp
with device:
return lamdaW * xp.sum(xp.abs(W(input))).item() + lamdaTV * xp.sum(xp.abs(input)).item()
if comm is not None:
show_pbar = show_pbar and comm.rank == 0
super().__init__(A, y, proxg=proxg2, g=h, G=G, show_pbar=show_pbar, **kwargs)
Hi Joey,
I do not think I will have the time to write the method. Once you understand the math, it's not too difficult to implement this using prox
functions. Please see Equation 5.3 of the following: https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf
If I have time, I will re-open this request and try to address it. I am happy to help clarify any questions as well.