Conformer Transducer inference
Opened this issue · 2 comments
Transducer inference할 때 audio encoder의 output을 time step마다 한개씩 넣는 이유가 있을까요?
Real time을 대비해서 그렇게 inference를 하는 것 같은데 non-real time일때는 어떻게 inference가 되는건지
제가 이해하고 있는게 맞는건지 잘 모르겠습니다.
어디쪽 코드를 말씀하시는 걸까요?
@torch.no_grad()
def decode(self, encoder_output: Tensor, max_length: int) -> Tensor:
"""
Decode encoder_outputs
.
Args:
encoder_output (torch.FloatTensor): A output sequence of encoder. FloatTensor
of size
(seq_length, dimension)
max_length (int): max decoding time step
Returns:
* predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
"""
pred_tokens, hidden_state = list(), None
decoder_input = encoder_output.new_tensor([[self.decoder.sos_id]], dtype=torch.long)
for t in range(max_length):
decoder_output, hidden_state = self.decoder(decoder_input, hidden_states=hidden_state)
step_output = self.joint(encoder_output[t].view(-1), decoder_output.view(-1))
step_output = step_output.softmax(dim=0)
pred_token = step_output.argmax(dim=0)
pred_token = int(pred_token.item())
pred_tokens.append(pred_token)
decoder_input = step_output.new_tensor([[pred_token]], dtype=torch.long)
여기에서 for문에서 self.joint에 입력으로 받는 encoder output이 time step 한개씩 들어가는게 real time으로 진행되는 것 같아서 이렇게 inference하는 방법밖에 없나 싶어서 질문해봅니다.