Shape mismatch issue in Fine-tuing of Idefics 2 Tutorial
chang-changiti opened this issue · 1 comments
Referenced file:
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Idefics2/Fine_tune_Idefics2_for_JSON_extraction_use_cases_(PyTorch_Lightning).ipynb
Hi! In the above-mentioned tutorial, I face the following issue when I try to replicate the tutorial. I think there is some issue during evaluation step on validation dataset, but I can't seem to figure out the root cause.
RuntimeError: shape mismatch: value tensor of shape [128, 4096] cannot be broadcast to indexing result of shape [0, 4096]
Full stack trace:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
982 self._signal_connector.register_signal_handlers()
984 # ----------------------------
985 # RUN THE TRAINER
986 # ----------------------------
--> 987 results = self._run_stage()
989 # ----------------------------
990 # POST-Training CLEAN UP
991 # ----------------------------
992 log.debug(f"{self.class.name}: trainer tearing down")
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1033, in Trainer._run_stage(self)
1031 self._run_sanity_check()
1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1033 self.fit_loop.run()
1034 return None
1035 raise RuntimeError(f"Unexpected state {self.state}")
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher)
139 try:
140 self.advance(data_fetcher)
--> 141 self.on_advance_end(data_fetcher)
142 self._restarting = False
143 except StopIteration:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
291 if not self._should_accumulate():
292 # clear gradients to not leave any unused memory during validation
293 call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> 295 self.val_loop.run()
296 self.trainer.training = True
297 self.trainer._logger_connector._first_loop_iter = first_loop_iter
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:182, in _no_grad_context.._decorator(self, *args, **kwargs)
180 context_manager = torch.no_grad
181 with context_manager():
--> 182 return loop_run(self, *args, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self)
133 self.batch_progress.is_last_batch = data_fetcher.done
134 # run step hooks
--> 135 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
136 except StopIteration:
137 # this needs to wrap the*_step
call too (not justnext
) fordataloader_iter
support
138 break
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
390 hook_name = "test_step" if trainer.testing else "validation_step"
391 step_args = (
392 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
393 if not using_dataloader_iter
394 else (dataloader_iter,)
395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
398 self.batch_progress.increment_processed()
400 if using_dataloader_iter:
401 # update the hook kwargs now that the step method might have consumed the iterator
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
306 return None
308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.class.name}.{hook_name}"):
--> 309 output = fn(*args, **kwargs)
311 # restore current_fx when nested context
312 pl_module._current_fx_name = prev_fx_name
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
410 if self.model != self.lightning_module:
411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)
File , line 36, in Idefics2ModelPLModule.validation_step(self, batch, batch_idx, dataset_idx)
33 input_ids, attention_mask, pixel_values, pixel_attention_mask, answers = batch
35 # autoregressively generate token IDs
---> 36 generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask,
37 pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask,
38 max_new_tokens=768)
39 # turn them back into text, chopping of the prompt
40 # important: we don't skip special tokens here, because we want to see them in the output
41 predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)
File /databricks/python/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/generation/utils.py:1896, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
1888 input_ids, model_kwargs = self._expand_inputs_for_generation(
1889 input_ids=input_ids,
1890 expand_size=generation_config.num_return_sequences,
1891 is_encoder_decoder=self.config.is_encoder_decoder,
1892 **model_kwargs,
1893 )
1895 # 13. run sample (it degenerates to greedy search whengeneration_config.do_sample=False
)
-> 1896 result = self._sample(
1897 input_ids,
1898 logits_processor=prepared_logits_processor,
1899 logits_warper=prepared_logits_warper,
1900 stopping_criteria=prepared_stopping_criteria,
1901 generation_config=generation_config,
1902 synced_gpus=synced_gpus,
1903 streamer=streamer,
1904 **model_kwargs,
1905 )
1907 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
1908 # 11. prepare logits warper
1909 prepared_logits_warper = (
1910 self._get_logits_warper(generation_config, device=input_ids.device)
1911 if generation_config.do_sample
1912 else None
1913 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/generation/utils.py:2633, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2630 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2632 # forward pass to get next token
-> 2633 outputs = self(
2634 **model_inputs,
2635 return_dict=True,
2636 output_attentions=output_attentions,
2637 output_hidden_states=output_hidden_states,
2638 )
2640 if synced_gpus and this_peer_finished:
2641 continue # don't waste resources running the code we don't need
File /databricks/python/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /databricks/python/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(module, *args, **kwargs)
163 output = module._old_forward(*args, **kwargs)
164 else:
--> 165 output = module._old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1829, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1826 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1828 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1829 outputs = self.model(
1830 input_ids=input_ids,
1831 attention_mask=attention_mask,
1832 position_ids=position_ids,
1833 past_key_values=past_key_values,
1834 inputs_embeds=inputs_embeds,
1835 pixel_values=pixel_values,
1836 pixel_attention_mask=pixel_attention_mask,
1837 image_hidden_states=image_hidden_states,
1838 use_cache=use_cache,
1839 output_attentions=output_attentions,
1840 output_hidden_states=output_hidden_states,
1841 return_dict=return_dict,
1842 )
1844 hidden_states = outputs[0]
1845 logits = self.lm_head(hidden_states)
File /databricks/python/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /databricks/python/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(module, *args, **kwargs)
163 output = module._old_forward(*args, **kwargs)
164 else:
--> 165 output = module._old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1656, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
1651 image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
1653 if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
1654 # When we generate, we don't want to replace the potential image_token_id that we generated by images
1655 # that simply don't exist
-> 1656 inputs_embeds = self.inputs_merger(
1657 input_ids=input_ids,
1658 inputs_embeds=inputs_embeds,
1659 image_hidden_states=image_hidden_states,
1660 )
1662 outputs = self.text_model(
1663 inputs_embeds=inputs_embeds,
1664 attention_mask=attention_mask,
(...)
1669 return_dict=return_dict,
1670 )
1672 if return_legacy_cache:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1542, in Idefics2Model.inputs_merger(self, input_ids, inputs_embeds, image_hidden_states)
1540 new_inputs_embeds = inputs_embeds.clone()
1541 reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
-> 1542 new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
1543 return new_inputs_embeds
found the fix to this. we'll need to install the latest version of transformer package (as of time of commenting, v4.41.2)