Puzzled in mask operation
DeepTecher opened this issue · 5 comments
Thank you for your good work. However, I have some doubts about the following code: (Source in ant_torch.py#L144 when I ran Ant model.)
- What is the main logic of this part? I did not get it.
- when inferencing,
context
is all set toTrue
, andspan
is all set to0
on _convert_to_tensors, and it seems mask is all to1
after the following code. So what do those codes?
with torch.no_grad():
device = input.device
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(
seqlen, device=device
).view(-1, 1)
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
mask_1d = (
torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
)
attention_mask = (
mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
)
Hi,
- As we use a unified architecture of Transformer, the attention of
target
part is unidirectional (a.k.a causal), and the attention ofcontext
part is bidirectional. That's the main logic of this code snippet. - When inferencing, every thing is
context
.
You can refer to our blog for more details of the model.
Smart design in MASK
operation when I dived into your code. However, It still exists another problem for _process_text:pad when the batch is greater than 1.
Take example:
我们在假期去了法国的埃菲尔铁塔,
今天天气真好
After the _process_text
,input_id is as follows
tensor([[ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 6, 18039, 638, 10547, 454, 267, 21744, 1405,
2064, 2452, 536, 2182, 2760, 245],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 6, 9802, 14962, 2082, 831]], device='cuda:0',
dtype=torch.int32)
When ant_torch
forward, I think the value of prompt_states
and hidden_states
is not good especially the two embedding parameters are not the same. (problem both on forward
and inference
)
The code is listed below, ori here
input_prompt = input[:, : self.prompt_length].contiguous()
input_ids = input[:, self.prompt_length :].contiguous()
prompt_states = self.prompt_embedding(input_prompt)
hidden_states = self.input_embedding(input_ids)
input_prompt
tensor([[64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 65, 66, 67, 68, 69, 70, 71, 72,
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86]],
device='cuda:0', dtype=torch.int32)
input_id
tensor([[ 6, 18039, 638, 10547, 454, 267, 21744, 1405, 2064, 2452,
536, 2182, 2760, 245],
[ 87, 88, 89, 90, 91, 92, 93, 94, 95, 6,
9802, 14962, 2082, 831]], device='cuda:0', dtype=torch.int32)
Obviously, this data is incorrect especially input_id
and input_prompt
that are converted by short text. In this case, input_id[1] and input_prompt[1]. (I guess prompt should begin from 64
not 0
, input should begin from 6
not one of [64-95]
).
I hope my understanding is correct and look forward to your answers.
You are right, there are some bugs related to prompt tokens in CPM-Ant. We have unified the prompt and input embeddings in CPM-Ant+, see #148, which fixed these kinds of issues.
okay...
It will be great if we can fix it on Ant model, as it has affected the results in different cases(batch=1, batch>1)
One of the ways I can think of is as follows:
Do some data transformations (though it will be a slight time loss) to align by using var: LENGTH
on model.inference
and model.forward
, flip back
Actually, left padding is only applied to inference.
It will be great if you can create a PR to fix the issue in the inference stage.