keras-team/tf-keras

Adam (and other) optimizers miscalculate the momentum update for complex variables

Closed this issue · 7 comments

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): CentOS Linux 7.9.2009
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): 2.12.1
  • Python version: 3.11.4

Describe the problem.

When the Adam optimizer is used to minimize a variable of dtype complex64 or complex128, the momenta calculated are incorrect, causing slower or incorrect updates.

For example, suppose we are trying to find the roots of the complex polynomial $f(z)=(z-(1+3j))^2$ with respect to an absolute square loss function. $f$ has a root at $z=1+3j$, so we expect the minimization to converge to this.

There are currently two approaches. In the first approach, we take as input variables the real scalars $a$ and $b$, and inside the function to minimize we combine these to $z = a+bj$. In the second approach we directly define a variable $z$ of complex datatype as an input to the function to minimise. In the first approach, the momenta calculated for a and b are different, so we expect the optimization to behave slightly differently, but they both should converge in a similar manner with roughly the same rate.

This issue has been raised before #38541, also in pytorch, where it was eventually decided to adjust the computations, see (pytorch/pytorch#65711). I understand that a fix might not desired due to implications for real-valued variables, however in this I would expect usage of the complex case to atleast raise a warning that complex variables currently have unexpected behavior.

Describe the current behavior.

As can be seen in the attached code, this is not the case: the first approach converges much faster, while the second approach tends to oscillate around a the minimum before settling. Furthermore, the first approach does not actually follow the gradient in the complex loss landscape due to the independent momenta for $a$ and $b$.

Describe the expected behavior.

In the expected behavior, both results should converge in the same way and follow the gradient of the loss landscape. If we plot the updates of the variables in the complex plane together with the loss landscape, we get the following figure (generated in the colab document):

image

The expected behavior is shown by the fix: the optimised variable does not move around the minimum but obeys the symmetry of the loss function.

Contributing.

  • Do you want to contribute a PR? (yes/no): Yes
  • Briefly describe your candidate solution(if contributing):

The update function in the tf.keras.optimizers.Adam class currently computes the second moment as tf.square(gradient) * (1 - self.beta_2). For complex values, this should be gradient * tf.math.conj(gradient) * (1 - self.beta_2). This should still work for real valued variables, but I don't know if there are any performance related issues. An alternative would be to output a warning when using the optimiser on complex variables.

Standalone code to reproduce the issue.

import tensorflow as tf

f = lambda z: (z-(1+3j))**2

optimizer1 = tf.keras.optimizers.Adam(learning_rate=0.1)
a = tf.Variable(0.5, dtype=tf.float32)
b = tf.Variable(1.0, dtype=tf.float32)

def loss1():
    z2 = tf.complex(a, b)
    return tf.abs(f(z2))**2

optimizer2 = tf.keras.optimizers.Adam(learning_rate=0.1)
z = tf.Variable(0.5 + 1.0j, dtype=tf.complex64)

def loss2():
    return tf.abs(f(z))**2

for i in range(50):
  optimizer1.minimize(loss1, [a,b])
  optimizer2.minimize(loss2, [z])

print(f"[{a.value()}, {b.value()}], loss: {loss1()}")
# Outputs: [0.015215082094073296, 0.01610667072236538], loss: 0.004847094416618347
print(f"[{tf.math.real(z).numpy()}, {tf.math.imag(z).numpy()}], loss {loss2()}" )
# Outputs: [1.1107741594314575, 0.8542921543121338], loss 9.064789772033691

Colab link: https://colab.research.google.com/drive/1DKsktAM7MOUQhFxbr2kElDtDJE1AunK-?usp=sharing

Source code / logs.

N/A

@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.12, v2.13 and tf-nightly. Kindly find the gist of it here.

@Hihaatje the framework currently does not support complex variables. It is recommended for users to write their own optimizers for this case.

@Hihaatje , If you are still interested to contribute, ideal way to handle this case by adding the warning or error message where Tensorflow dtype casting is performed in the code. Thanks!

Hey @sachinprasadhs, I indeed think it is a good idea to throw this warning, however there is no typecasting performed anywhere here (I believe that when an implicit typecast is done there is already a warning produced). So maybe we can just add a warning whenever the update function is called with a complex datatype?

@Hihaatje , Sounds good, feel free to file a PR. Thanks

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.