LeMei/UniMSE

Reproduction in 2023

pretbc opened this issue · 0 comments

pretbc commented

Hello

I tried to run this code repo and I end up in some issues

How to star

  1. Datasets
    I downloaded files for MOSI based on links provided in README.
    In MOSI.zip i have found already created U-labels new_MOSI-label-v3.csv

I changed paths in config to fit my paths/to/data
(csv, mosi.pkl, etc....)

  1. T5
    For each ../t5-base I renamed to t5-base due to fact that this can be taken from hugging face repo ( automatically)

For PyTorch model.bin I opened corresponding T5 hugging face repo and download .bin file
Changed path/to/bin

  • U can as well clone hugging face repo

So after this I ran
python main.py --dataset=mosi --multi=False

And...

in file data_loader.py line 423

encoding = tokenizer( [task_prefix + sequence for sequence in inputs_seq], return_tensors="pt", padding=True )

this code throw issue due to fact that task_prefix + sequence trying to do str + list[str] -> to fix this I did
task_prefix + ' '.join(sequence)

next....

in modeling_t5_prefix.py line 1848

else: input_ids = self._prepare_decoder_input_ids_for_generation( input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id )

an error occur because what Im assume we put tensor as batch and func return expect tensor.ones((batch_size, 1)........) and if input_ids is here as batch_size as type tensor -> tried to fix this like -> input_ids.shape[1]

but than next lines start to fail

line 1900

logits_processor = self._get_logits_processor() expect Optional argument logits_processor which I assigned as None

and I finally end with fail

File "modules/modeling_t5_prefix.py", line 524, in forward
    scores += position_bias
RuntimeError: The size of tensor a (32) must match the size of tensor b (55) at non-singleton dimension 3

I stopped here because I don't want fix this any more -> I assume code need refactor by author

Cannot install transformers==4.14.5 because such version did not exist and Im using transformers==4.16.0