This is a minimal JAX/Flax port of lpips
, as implemented in:
Only the essential features have been implemented. Our motivation is to support VQGAN training for DALL•E Mini.
It currently supports the vgg16
backend, leveraging the implementation in flaxmodels
.
Pre-trained weights for the network and the linear layers are downloaded from the 🤗 Hugging Face hub.
- Install JAX for CUDA or TPU following the instructions at https://github.com/google/jax#installation.
- Install this package:
pip install lpips-j
Inputs must be in the range [-1, 1]
, and not already normalized with ImageNet stats. (They are internally converted to [0, 1]
and then normalized by the underlying flax model.
x = PILToTensor()(Image.open("img8.jpg")).unsqueeze(0)
y = PILToTensor()(Image.open("img8_edited.jpg")).unsqueeze(0)
x = 2 * (x / 255.) - 1
y = 2 * (y / 255.) - 1
x = jnp.array(x).transpose(0, 2, 3, 1)
y = jnp.array(y).transpose(0, 2, 3, 1)
lpips = LPIPS()
params = lpips.init(key, x, x)
loss = lpips.apply(params, x, y)