srush/annotated-s4

Assertion failure for differing kernel K and input u

briancheung opened this issue · 2 comments

When using the following run command:
python -m s4.train --dataset mnist-classification --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64
I'm finding that this assertion is faliing:

assert K.shape[0] == u.shape[0]

for the mnist classification task. I'm not quite sure if the failure is intentional or there's an off-by-one bug somewhere between the initialization of the model and the training. The following lines in the code will pad out any discrepancy between the Kernel length and u length, so the code runs fine if you simply remove the assertion. But that might not be the intended behavior.

        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
srush commented

Good catch. For MNIST classification you can remove the assert and it should be fine.

The real issue is that we are dropping the last pixel [:-1] for generation, and left it in for classification. We should not drop the last classification pixel.

srush commented

Fixed now.