prefix tuning with t5-3b
zluw1117 opened this issue · 11 comments
I am trying to run prefix tuning with t5-3b, but I got some strange error
File "/home/ubuntu/anaconda3/envs/py3.7pytorch1.8new/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ubuntu/code/UnifiedSKG/models/prompt/modeling_t5.py", line 486, in forward
key_states = torch.cat([prefix["prev_key"], key_states], dim=2)
RuntimeError: Sizes of tensors must match except in dimension 3. Got 128 and 32 (The offending index is 0)
This error does not take place for t5-base or t5-large, only got this for t5-3b. Any tips?
Also I am having OOM issue with t5-3b model, it crashed even in case of mini-batch size = 1 and running on a 40GB GPU. Does anyone have the same issue? Thanks.
Hi,
Thanks for pointing out! I just double-checked the code in preifx-tuning and found we over-cleaned a small part which will cause mis-match in dimension.
We will fix it asap.
Thanks!
The key is the different design logic in HuggingFace BART and T5-series config files, which requires us to treat them separately. (btw. You can look through the change we make, and you will know the previous code and the current code are in equivalence instead of t5-3b due to the dimension. It won't affect anything.
Hi,
We updated the code and made up the missing code, could you try again and tell us if everything work smoothly?
Hope these information above helpful!
Thanks!
Thank you for your input!
Actually I got a new error when running the code for t5-3b (T5-large works well)
File "/home/ubuntu/code/UnifiedSKG/models/unified/prefixtuning.py", line 265, in forward
bsz=bsz, description=description_representation, knowledge=knowledge_representation,
File "/home/ubuntu/code/UnifiedSKG/models/unified/prefixtuning.py", line 125, in get_prompt
bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
RuntimeError: shape '[1, 10, 48, 32, 128]' is invalid for input of size 491520
Thanks.
Okay, then could you give us the command and config file you used?
Yeah. I just modified location in UnifiedSKG/configure/Salesforce/T5_large_prefix_spider_with_cell_value.cfg to t5-3b, and then run
deepspeed train.py --deepspeed deepspeed/ds_config_zero2.json --seed 2 --cfg /home/ubuntu/code/UnifiedSKG/configure/Salesforce/T5_large_prefix_spider_with_cell_value.cfg --run_name nl2sql-prefix_tuning --logging_strategy steps --logging_first_step true --logging_steps 4 --evaluation_strategy steps --eval_steps 64 --metric_for_best_model avr --greater_is_better true --save_strategy steps --save_steps 64 --save_total_limit 3 --load_best_model_at_end --gradient_accumulation_steps 250 --num_train_epochs 650 --adafactor false --learning_rate 1e-4 --do_train --do_eval --predict_with_generate --output_dir output/spider_prompt --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --generation_num_beams 1 --generation_max_length 512 --input_max_length 1024 --ddp_find_unused_parameters true
Copy that, we are trying to reproduce the problem, hold tight pls.
For UnifiedSKG/configure/Salesforce/T5_large_prefix_spider_with_cell_value.cfg:
[model]
name = unified.prefixtuning
use_description = False
concatenate_description = False
map_description = False
# Should be one of (separate, concatenate)
knowledge_usage = concatenate
freeze_plm = True
freeze_prefix = False
[dataset]
data_store_path = ./data
description_max_length = 64
upsample_temp = 1
[seq2seq]
constructor = seq2seq_construction.meta_tuning
patience = 50
[arg_paths]
spider = META_TUNING/spider_with_cell.cfg
[evaluate]
tool = metrics.meta_tuning.evaluator
[prefix_tuning]
# 10 previously.
prefix_sequence_length = 10
mid_dim = 512
prefix_dropout = 0.0
[special_tokens]
less = ' <'
less_or_equal = ' <='
[bert]
location = t5-3b
Hey, could you try this prefixtuning code your side? Currently we have no deepspeed and machine handy.
# -*- coding: utf-8 -*-
import torch
from torch import nn
from transformers import AutoTokenizer
from .base import PushToHubFriendlyModel
from ..prompt.modeling_auto import AutoModelForSeq2SeqLM
class Model(PushToHubFriendlyModel):
def __init__(self, args):
super().__init__()
self.args = args
"""The prefix-tuning code"""
self.preseqlen = args.prefix_tuning.prefix_sequence_length
self.mid_dim = args.prefix_tuning.mid_dim
print("prefix-tuning sequence length is {}.".format(self.preseqlen))
# Load tokenizer and model.
self.tokenizer = AutoTokenizer.from_pretrained(args.bert.location, use_fast=False)
self.pretrain_model = AutoModelForSeq2SeqLM.from_pretrained(
args.bert.location
)
self.config = self.pretrain_model.config
from ..prompt.modeling_bart import BartForConditionalGeneration
from ..prompt.modeling_t5 import T5ForConditionalGeneration
if isinstance(self.pretrain_model, BartForConditionalGeneration):
self.match_n_layer = self.config.decoder_layers
self.match_n_head = self.config.decoder_attention_heads
self.n_embd = self.config.d_model
assert self.n_embd % self.match_n_head == 0
self.match_n_embd = self.n_embd // self.match_n_head # huggingface BART's dim of kv need to be calculated
elif isinstance(self.pretrain_model, (T5ForConditionalGeneration)):
self.match_n_layer = self.config.num_decoder_layers
self.match_n_head = self.config.num_heads
self.n_embd = self.config.d_model
self.match_n_embd = self.config.d_kv
else:
raise ValueError("Other models are not supported yet!")
if args.special_tokens:
self.tokenizer.add_tokens([v for k, v in args.special_tokens])
self.pretrain_model.resize_token_embeddings(len(self.tokenizer))
# Prefix related.
self.register_buffer('input_tokens', torch.arange(self.preseqlen).long())
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
self.control_trans = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
if self.args.model.knowledge_usage == 'separate':
self.knowledge_trans = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
self.wte_enc = nn.Embedding(self.preseqlen, self.n_embd)
self.control_trans_enc = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
if self.args.model.knowledge_usage == 'separate':
self.knowledge_trans_enc = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
self.wte_dec = nn.Embedding(self.preseqlen, self.n_embd)
self.control_trans_dec = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
# Knowledge prompt.
if self.args.model.knowledge_usage == 'separate':
self.knowledge_trans_dec = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.match_n_head * self.match_n_embd),
)
self.dropout = nn.Dropout(args.prefix_tuning.prefix_dropout)
if self.args.model.freeze_plm:
for param in self.pretrain_model.parameters():
param.requires_grad = False
if self.args.model.freeze_prefix:
for param in self.wte.parameters():
param.requires_grad = False
for param in self.control_trans.parameters():
param.requires_grad = False
for param in self.wte_dec.parameters():
param.requires_grad = False
for param in self.control_trans_dec.parameters():
param.requires_grad = False
for param in self.wte_enc.parameters():
param.requires_grad = False
for param in self.control_trans_enc.parameters():
param.requires_grad = False
def get_prompt(self, bsz=None, sample_size=1, description=None, knowledge=None):
old_bsz = bsz
bsz = bsz * sample_size
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1)
temp_control = self.wte(input_tokens)
if description is not None:
temp_control = temp_control + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
past_key_values = self.control_trans(temp_control) # bsz, seqlen, layer*emb
if knowledge is not None:
past_key_values = torch.cat([past_key_values, self.knowledge_trans(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)
bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(
bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
)
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
# Cross prefix
temp_control_dec = self.wte_dec(input_tokens)
if description is not None:
temp_control_dec = temp_control_dec + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
past_key_values_dec = self.control_trans_dec(
temp_control_dec
) # bsz, seqlen, layer*emb
if knowledge is not None:
past_key_values_dec = torch.cat([past_key_values_dec, self.knowledge_trans_dec(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)
bsz, seqlen, _ = past_key_values_dec.shape
past_key_values_dec = past_key_values_dec.view(
bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
)
past_key_values_dec = self.dropout(past_key_values_dec)
past_key_values_dec = past_key_values_dec.permute([2, 0, 3, 1, 4]).split(2)
# Encoder prefix
input_tokens_enc = (
self.input_tokens.unsqueeze(0).expand(old_bsz, -1)
)
temp_control_enc = self.wte_enc(input_tokens_enc)
if description is not None:
temp_control_enc = temp_control_enc + description.unsqueeze(1)
past_key_values_enc = self.control_trans_enc(
temp_control_enc
) # bsz, seqlen, layer*emb
if knowledge is not None:
past_key_values_enc = torch.cat([past_key_values_enc, self.knowledge_trans_enc(knowledge)], dim=1)
bsz_enc, seqlen, _ = past_key_values_enc.shape
past_key_values_enc = past_key_values_enc.view(
bsz_enc,
seqlen,
self.match_n_layer * 2,
self.match_n_head,
self.match_n_embd,
)
past_key_values_enc = self.dropout(past_key_values_enc)
past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)
result = []
for i, key_val in enumerate(past_key_values):
temp = dict()
temp["decoder_prompt"] = {
"prev_key": key_val[0].contiguous(),
"prev_value": key_val[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen)
.to(key_val.device)
.bool()
# bsz, preseqlen
}
key_val_dec = past_key_values_dec[i]
temp["cross_attention_prompt"] = {
"prev_key": key_val_dec[0].contiguous(),
"prev_value": key_val_dec[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen)
.to(key_val_dec.device)
.bool(),
}
key_val_enc = past_key_values_enc[i]
temp["encoder_prompt"] = {
"prev_key": key_val_enc[0].contiguous(),
"prev_value": key_val_enc[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz_enc, seqlen)
.to(key_val_enc.device)
.bool(),
}
result.append(temp)
return result
def get_description_representation(self, kwargs):
if self.args.model.use_description and self.args.model.map_description:
description_input_ids = kwargs.pop("description_input_ids")
description_attention_mask = kwargs.pop("description_attention_mask")
if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
description_outputs = self.pretrain_model.encoder(
input_ids=description_input_ids,
attention_mask=description_attention_mask,
)
description = description_outputs.last_hidden_state[:, 0] # TODO: the first token from the encoder.
elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
description_outputs = self.pretrain_model.model.encoder(
input_ids=description_input_ids,
attention_mask=description_attention_mask,
)
description = description_outputs.last_hidden_state[:, 0] # TODO: the first token from the encoder.
else:
raise ValueError()
else:
description = None
return description
def get_knowledge_representation(self, kwargs):
if self.args.model.knowledge_usage == 'separate':
knowledge_input_ids = kwargs.pop("knowledge_input_ids", None)
knowledge_attention_mask = kwargs.pop("knowledge_attention_mask", None)
if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
knowledge_outputs = self.pretrain_model.encoder(
input_ids=knowledge_input_ids,
attention_mask=knowledge_attention_mask,
)
knowledge = knowledge_outputs.last_hidden_state
elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
knowledge_outputs = self.pretrain_model.model.encoder(
input_ids=knowledge_input_ids,
attention_mask=knowledge_attention_mask,
)
knowledge = knowledge_outputs.last_hidden_state
else:
raise ValueError()
elif self.args.model.knowledge_usage == 'concatenate':
knowledge = None
else:
raise ValueError()
return knowledge
def forward(self,
input_ids,
attention_mask,
labels,
**kwargs,
):
bsz = input_ids.shape[0]
# Encode description.
description_representation = self.get_description_representation(kwargs)
# Encode knowledge.
knowledge_representation = self.get_knowledge_representation(kwargs)
past_prompt = self.get_prompt(
bsz=bsz, description=description_representation, knowledge=knowledge_representation,
)
loss = self.pretrain_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
past_prompt=past_prompt,
).loss
return {'loss': loss}
def generate(self,
input_ids,
attention_mask,
**kwargs):
bsz = input_ids.shape[0]
# Encode description.
description_representation = self.get_description_representation(kwargs)
# Encode knowledge.
knowledge_representation = self.get_knowledge_representation(kwargs)
past_prompt = self.get_prompt(
bsz=bsz, sample_size=kwargs['num_beams'], description=description_representation, knowledge=knowledge_representation,
)
generated_ids = self.pretrain_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
past_prompt=past_prompt,
use_cache=True,
**kwargs,
)
return generated_ids
Looks good! Now I only have OOM issue (even though I ran with mini-batch = 1 on 40G A100)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 4; 39.59 GiB total capacity; 35.44 GiB already allocated; 21.19 MiB free; 36.39 GiB reserved in total by PyTorch)
So the code should somehow works. 👍
Ok, thanks!
I think using multiple GPUs will fix that issue then!
Thanks again for pointing out the bug!