[Question] Are dynamic axes supported?
dazzle-me opened this issue · 2 comments
Consider the following model graph, - simple MHSA layer
It is, however, depended on the length of the passed sequence.
This code produces the keras model that is capable of running fixed-shape input with fixed length L.
However, when I try to run it with new sequence length, the model forward pass fails with the error
import torch
import torch.nn as nn
import torch.nn.functional as F
from nobuco.convert.converter import pytorch_to_keras
import numpy as np
class MHSA(nn.Module):
def __init__(self,
embed_dim,
out_dim,
qk_dim,
v_dim,
num_head,
):
super().__init__()
self.embed_dim = embed_dim
self.num_head = num_head
self.qk_dim = qk_dim
self.v_dim = v_dim
self.q = nn.Linear(embed_dim, qk_dim*num_head)
self.k = nn.Linear(embed_dim, qk_dim*num_head)
self.v = nn.Linear(embed_dim, v_dim*num_head)
self.out = nn.Linear(v_dim*num_head, out_dim)
self.scale = 1/(qk_dim**0.5)
def forward(self, x):
L,dim = x.shape
num_head = self.num_head
qk_dim = self.qk_dim
v_dim = self.v_dim
q = self.q(x)
k = self.k(x)
v = self.v(x)
q = q.reshape(L, num_head, qk_dim).permute(1,0,2).contiguous()
k = k.reshape(L, num_head, qk_dim).permute(1,2,0).contiguous()
v = v.reshape(L, num_head, v_dim ).permute(1,0,2).contiguous()
dot = q *self.scale @ k # H L L
attn = F.softmax(dot, -1) # L L
v = torch.matmul(attn, v) # L H dim
v = v.permute(1,0,2).reshape(L, v_dim*num_head).contiguous()
out = self.out(v)
return out
if __name__ == "__main__":
emb_dim = 128
out_dim = 200
num_heads = 4
qk_dim = emb_dim // num_heads
L = 33
model = MHSA(embed_dim=emb_dim, out_dim=out_dim, qk_dim=qk_dim, v_dim=qk_dim, num_head=num_heads)
keras_model = pytorch_to_keras(
model, args=[torch.rand(L, emb_dim)],
)
inp = np.random.rand(33, 128)
print(keras_model(inp)) ## runs fine
inp = np.random.rand(15, 128)
keras_model(inp) ## fails
Error:
InvalidArgumentError: Exception encountered when calling layer 'tf.reshape' (type TFOpLambda).
{{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 1920 values, but the requested shape has 4224 [Op:Reshape]
Call arguments received by layer 'tf.reshape' (type TFOpLambda):
• tensor=tf.Tensor(shape=(15, 128), dtype=float32)
• shape=('33', '4', '32')
• name=None
It looks like that model recorded the static shape of the input and don't support varied-length input, I'm new to keras and I want to ask if there any possible solution?
Yes, there are more than one solutions. In your particular situation, it suffices to just replace L with -1. Also, specify input_shapes
with dynamic dimensions set to None (since v0.2.0):
inp = torch.rand(L, emb_dim)
keras_model = nobuco.pytorch_to_keras(
model, args=[inp],
input_shapes={inp: (None, emb_dim)}
)
But suppose the solution above doesn't fit you. Then it gets much more interesting. In pytorch, tensor shape is a tuple of regular integers, not scalar tensors, and it's quite difficult to track them. There's a workaround, though. You can do this:
L, dim = nobuco.shape(x)
This function returns tensors, much like tf.shape
does.
@traceable
def shape(x):
return tuple(torch.tensor(d, dtype=torch.int32) for d in x.shape)
Take a look at examples/dynamic_shape to see how it works.
Works like a charm, thank you!