Questions about the format and shape of the data
funasshi opened this issue · 17 comments
Hello. I am currently using this package. I'm afraid this may be a basic question, but I'd like to ask a question.
1 Is the input a spectrogram or raw audio data?
2 When I run model(x,x_len,target,target_len), I get a four-dimensional output (batch, join_len,target_len,class_num) due to the calculation of the loss function.
I wanted to see the recognition result, so I used model.recognize(x,x_len), but the shape of the output was (batch,join_len). But the shape of the output was (batch,join_len). I would like to see it with (batch,target_len). What is the process of recognizing?
- Spectrogram
- Ignore everything behind .
Thank you very much.
Sorry, I have one more question related to point 2. When I look at the results of model.recognize, the last category label is output unusually many times.
ex, in a model with 40 categories
[40,40,40,40,40,40,40,3,40,40,40,40,40,40,5,8,40,40,....]
Does this mean that the last category is being used as the category that should be ignored?
Spectrogram
- Ignore everything behind .
Hi , I am working on a ASR related project by using conformer.
The four dim output is has confused me for calculating the loss to train the ASR model
Would you please provide an example for the calculation of the loss ?
Kind Regards
Thank you very much. Sorry, I have one more question related to point 2. When I look at the results of model.recognize, the last category label is output unusually many times.
ex, in a model with 40 categories [40,40,40,40,40,40,40,3,40,40,40,40,40,40,5,8,40,40,....]
Does this mean that the last category is being used as the category that should be ignored?
It should be recongnised as the blank symbol
Spectrogram
- Ignore everything behind .
Hi , I am working on a ASR related project by using conformer.
The four dim output is has confused me for calculating the loss to train the ASR model
Would you please provide an example for the calculation of the loss ?
Kind Regards
do you have some idea? I am confused about that too.
Spectrogram
- Ignore everything behind .
Hi , I am working on a ASR related project by using conformer.
The four dim output is has confused me for calculating the loss to train the ASR model
Would you please provide an example for the calculation of the loss ?
Kind Regards
@sooftware can you please answer to @zwan074 ? many of us are confused as to how to use a loss function to train the conformer as the outputs are log probabilities of model prediction in 4 dimensions
Sorry for the late response. I recommend checking this project
I have another question about the function of the conformer.
I am using a vocab of 6030 classes and my input data: batch, dim, seq_len= 32, 201, 1162 (where 1162 max len as they are padded) and targets 32,20 (where 20 max len as they are padded)
I am forwarding propagating and then when using the recognize function, it returns a tensor 32,289. I am trying to understand what is that 289 as I was expecting a tensor 32,20 so I would then convert it to text. @sooftware
Show me the code.
@sooftware When I execute the following code, recognize_sp variable has the shape: [32, 289]
import torch
import time
import torch.nn as nn
from conformer import Conformer
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')
print(device)
#conformer model init
model = nn.DataParallel(Conformer(num_classes=6030, input_dim=201, encoder_dim=32, num_encoder_layers=3, decoder_dim=32)).to(device)
for i, (audio,audio_len, translations, translation_len) in enumerate(train_loader):
#sorting inputs and targets to have targets in descending order based on len
sorted_list,sorted_indices=torch.sort(translation_len,descending=True)
sorted_audio=torch.zeros((32,201,1162),dtype=torch.float)
sorted_audio_len=torch.zeros(32,dtype=torch.int)
sorted_translations=torch.zeros((32,20),dtype=torch.int)
sorted_translation_len=sorted_list
for index, contentof in enumerate(translation_len):
sorted_audio[index]=audio[sorted_indices[index]]
sorted_audio_len[index]=audio_len[sorted_indices[index]]
sorted_translations[index]=translations[sorted_indices[index]]
#transpose inputs from 32, 201, 1162 (batch, dim, seq_len) to 32, 1162, 201 (batch, seq_len, dim)
inputs=sorted_audio.to(device)
inputs=torch.transpose(inputs, 1, 2)
input_lengths=sorted_audio_len
targets=sorted_translations.to(device)
target_lengths=sorted_translation_len
# shapes:
# inputs: [32, 1162, 201]
# input_len: [32]
# targets: [32, 20]
# target_len: [32]
preds = model(inputs, input_lengths, targets, target_lengths)
recognize_sp=model.module.recognize(inputs, input_lengths)
print(recognize_sp.shape)
break
@sooftware can you please answer to @zwan074 ? many of us are confused as to how to use a loss function to train the conformer as the outputs are log probabilities of model prediction in 4 dimensions
As per the https://github.com/openspeech-team/openspeech project
When training the conformer model, it uses conformer block to compute the output for a ctc loss. The LSTM decoder layer is unused ..
code is as below:
`
def training_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
inputs, targets, input_lengths, target_lengths = batch
encoder_outputs, encoder_logits, output_lengths = self.encoder(inputs, input_lengths)
logits = self.fc(encoder_outputs).log_softmax(dim=-1)
return self.collect_outputs(
stage='train',
logits=logits,
output_lengths=output_lengths,
targets=targets,
target_lengths=target_lengths,
)`
@jcgeo9 289 is almost a quarter of 1162. This phenomenon occurs due to Conv2dSubampling during the convolution block of the Conformer.
@sooftware hmm ok but what do i do with that? i mean how do i convert it to what i actually want? isnt it suppose to return [32, 20] tensor containing integers that correspond to words from my vocabulary that will then be converted with itos in order to check the loss?
I updated the code and README because many people seemed to have a hard time calculating losses.
Below is an example of calculating CTC Loss.
import torch
import torch.nn as nn
from conformer import Conformer
batch_size, sequence_length, dim = 3, 12345, 80
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')
criterion = nn.CTCLoss()
inputs = torch.rand(batch_size, sequence_length, dim).to(device)
input_lengths = torch.IntTensor([12345, 12300, 12000])
targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2],
[1, 3, 3, 3, 3, 3, 4, 5, 2, 0],
[1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device)
target_lengths = torch.LongTensor([9, 8, 7])
model = Conformer(num_classes=10,
input_dim=dim,
encoder_dim=32,
num_encoder_layers=3)
# Forward propagate
outputs, output_lengths = model(inputs, input_lengths)
# Calculate CTC Loss
loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)
I have a question. The input_lengths has not send to calculate the mask for mulithead-attention. Is it work?