lanl/scico

Re-open issue "Proximal operator for a custom prior #437"

matinaz opened this issue · 1 comments

I tried to re-open the issue #437 but it was not possible, so I open w new one, since the proposed solution didn't work. In fact, we tried to use jaxopt.LBFGS, but it seems that it doesn't work with operators. A second thought was to use scico.optimize to evaluate the whole h=f+λg function, instead of evaluating only the prior, but it seems to be a large-scale problem, so it was runnong until being killed. Is there any suggestion to solve the problem? In order to help you, I provide some details on the function and the way we implement it:
f = loss.SquaredL2Loss(y=Y, A=X), where Y is a jax array and X is a linear operator implemented with linop
g is our prior implemented as follows:

import cv2
import numpy as np
import jax
import jax.numpy as jnp
import sys
import scico
import scico.numpy as snp

block_size=9
target_value = 5.0
n = 7
lam = 1000

def mse_similarity(block1, block2):
#Calculate the Mean Squared Error (MSE) similarity between two blocks.
#block1=np.asarray(block1)
#block2=np.asarray(block2)
diff = block1.astype("float") - block2.astype("float")
squared_diff = diff ** 2
mse = jnp.mean(squared_diff)
return mse

def calculate_block_similarity(image, x1, y1, x2, y2, block_size):
#Calculate the similarity norm between two blocks in an image.
# Extract the two blocks from the image
block1 = image[y1:y1 + block_size, x1:x1 + block_size]
block2 = image[y2:y2 + block_size, x2:x2 + block_size]
#block1=np.asarray(block1)
#block2=np.asarray(block2)

# Calculate the similarity using the MSE metric
similarity = mse_similarity(block1, block2)
return similarity

def calculate_sumsimilarity(image, xi, yi):
# Calculate the sum of similarity norms in an image
#image=np.asarray(image)
block_size = 9
Ax = 0
x1 = xi
y1 = yi
for i in range(1,6):
x2=x1+iblock_size
y2=y1+i
block_size
similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
Ax = Ax + (1 + lam * similarity_norm)(-(n + 1)/2)/20
x2 = x1
y2 = y1 + i*block_size
similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
Ax = Ax + (1 + lam * similarity_norm)
(-(n + 1)/2)/20
x2 = x1 + iblock_size
y2 = y1
similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
Ax = Ax + (1 + lam * similarity_norm)**(-(n + 1)/2)/20
x2 = x1 - i
block_size
y2 = y1 - i*block_size
similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
Ax = Ax + (1 + lam * similarity_norm) ** (-(n + 1) / 2) / 20

return Ax

def solve_for_x(image):
# Find the pixel value 'x' that satisfies h(x) = log(B)

#image=np.asarray(image)
height, width = image.shape
B = 0.0
for y in range(60, height-60, block_size):
    for x in range(60, width-60, block_size):
       sumsimilarity = calculate_sumsimilarity(image, x, y)
        B += jnp.log(sumsimilarity)
        #print(B)
res=B/(height*width)
print(B)
print(res)
return abs(res)

Originally posted by @matinaz in #437 (comment)

Original issue re-opened, closing this duplicate.