AlexanderLutsenko/nobuco

[Question] Are dynamic axes supported?

dazzle-me opened this issue · 2 comments

Consider the following model graph, - simple MHSA layer
It is, however, depended on the length of the passed sequence.

This code produces the keras model that is capable of running fixed-shape input with fixed length L.
However, when I try to run it with new sequence length, the model forward pass fails with the error

import torch
import torch.nn as nn
import torch.nn.functional as F
from nobuco.convert.converter import pytorch_to_keras
import numpy as np

class MHSA(nn.Module):
    def __init__(self,
            embed_dim,
            out_dim,
            qk_dim,
            v_dim,
            num_head,
        ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_head  = num_head
        self.qk_dim = qk_dim
        self.v_dim  = v_dim

        self.q = nn.Linear(embed_dim, qk_dim*num_head)
        self.k = nn.Linear(embed_dim, qk_dim*num_head)
        self.v = nn.Linear(embed_dim, v_dim*num_head)
        
        self.out = nn.Linear(v_dim*num_head, out_dim)
        self.scale = 1/(qk_dim**0.5)

    def forward(self, x):
        L,dim = x.shape
        num_head = self.num_head
        qk_dim = self.qk_dim
        v_dim = self.v_dim
        
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        q = q.reshape(L, num_head, qk_dim).permute(1,0,2).contiguous()
        k = k.reshape(L, num_head, qk_dim).permute(1,2,0).contiguous()
        v = v.reshape(L, num_head, v_dim ).permute(1,0,2).contiguous()

        dot = q *self.scale @ k  # H L L
        attn = F.softmax(dot, -1)    # L L

        v = torch.matmul(attn, v)  # L H dim
        v = v.permute(1,0,2).reshape(L, v_dim*num_head).contiguous()
        out = self.out(v)

        return out
if __name__ == "__main__":
    emb_dim = 128
    out_dim = 200
    num_heads = 4
    qk_dim = emb_dim // num_heads
    L = 33
    model = MHSA(embed_dim=emb_dim, out_dim=out_dim, qk_dim=qk_dim, v_dim=qk_dim, num_head=num_heads)
    keras_model = pytorch_to_keras(
        model, args=[torch.rand(L, emb_dim)],
    )
    inp = np.random.rand(33, 128)
    print(keras_model(inp)) ## runs fine

    inp = np.random.rand(15, 128)
    keras_model(inp) ## fails

Error:

InvalidArgumentError: Exception encountered when calling layer 'tf.reshape' (type TFOpLambda).

{{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 1920 values, but the requested shape has 4224 [Op:Reshape]

Call arguments received by layer 'tf.reshape' (type TFOpLambda):
  • tensor=tf.Tensor(shape=(15, 128), dtype=float32)
  • shape=('33', '4', '32')
  • name=None

It looks like that model recorded the static shape of the input and don't support varied-length input, I'm new to keras and I want to ask if there any possible solution?

Yes, there are more than one solutions. In your particular situation, it suffices to just replace L with -1. Also, specify input_shapes with dynamic dimensions set to None (since v0.2.0):

inp = torch.rand(L, emb_dim)
keras_model = nobuco.pytorch_to_keras(
    model, args=[inp],
    input_shapes={inp: (None, emb_dim)}
)

But suppose the solution above doesn't fit you. Then it gets much more interesting. In pytorch, tensor shape is a tuple of regular integers, not scalar tensors, and it's quite difficult to track them. There's a workaround, though. You can do this:

L, dim = nobuco.shape(x)

This function returns tensors, much like tf.shape does.

@traceable
def shape(x):
    return tuple(torch.tensor(d, dtype=torch.int32) for d in x.shape)

Take a look at examples/dynamic_shape to see how it works.

Works like a charm, thank you!