Using PPL on LAMM, discoverd 2 bugs, and result is much worse than direct inference type
AlexWang1900 opened this issue · 2 comments
I set the scenario recipe like this:
`scenario_cfg:
dataset_name: STL10_LAMM
base_data_path: ./data/LAMM/2D_Benchmark
ppl: True
eval_cfg:
instruction_cfg:
query_type: standard_query
inferencer_cfg:
inferencer_type: PPL #Direct
batch_size: 1 #32
metric_cfg:
metric_type: LAMM`
with lamm186k_llama2chat7b_lora32 on STL10 dataset which I added and tested on Direct without ppl , accuracy was 0.56
but when I changed the recipe like above, there are two bugs:
1 in src/ChEF/instruction/init.py line 87
`for i in range(batch_size)
if isinstance(batch_options[0], list):
answers += [answer_template.format(option) for option in batch_options[i]]
new_len = len(batch_options[i])
questions += [prompts[i] for _ in range(new_len)]
options += batch_options[i]`
answer_template is None,and does not have format
it is because: src/ChEF/instruction/query.py line 261
`def build_template(**kwargs):
# LAMM-style inference does not require template
if kwargs['task_name'].endswith('lamm'):
#return None #fix by wang for testing
pass
return ppl_template(**kwargs)`
it was None, when lamm model is detected,I fixed it so it will return a ppl_template
2 src/ChEF/models/test_lamm.py line 203 ,logits is a scalar which has no index
`logits, target_ids = self.model(dict(
vision_type = 'image',
task_type = self.task_type,
vision_paths = images,# fix by wang for Image type image
output_texts = conversations,
icl_sysmsg = icl_sysmsg
))
logits = logits[:,:-1]`
and in src/model/LAMM/openlamm.py line 616-626
`loss = outputs.loss
# calculate the token accuarcy
return outputs.logits[:, 1:,:],targets[:, 1:]
# chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1]
# labels = targets[:, 2:]
# gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(
# torch.long
# ) # [B*S]
# valid_mask = (labels != -100).reshape(-1)
# valid_tokens = gen_acc & valid_mask # [B*S]
# gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
# return loss, gen_acc`
it returned a loss and gen_acc, which is not logits and target_ids, so I fixed it that returning logits and targets now
but when I run the test, the result is only 0.19125, way below 0.56 with Direct inference type.
In the Chef paper, it shows a 80.7% on table 1 and standard query type and PPL inference type from table 6.
In the LAMM paper, it shows 37,9% on table 1 and I guess standard query type and Direct inference type.
So there must be something wrong, either the numbers on papers or the code.
Thanks for your feedback. We found that there exists some conflicts between the code base of LAMM and ChEF. We will fix this. Sorry for the mistake.
We have fixed the bugs and checked the results. Thanks again for your feedback.