multiple values for unet_number
Shikamaru5 opened this issue · 1 comments
It's sorta odd but I was/am having this issue where the program gets to the loss part and says that multiple values are trying to be put into unet_number:
55 │
│ 56 for epoch in range(gpc.config.NUM_EPOCHS): │
│ 57 │ │
│ ❱ 58 │ loss = model.train_step(images, unet_number = 1, text_embeds = text_embeds) │
│ 59 │
│ 60 print(f'loss: {loss}') │
│ 61 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: ImagenTrainer.train_step() got multiple values for argument 'unet_number'
Not too sure why that is because as far as I can tell my script isn't really all that different from the examples. Currently I've tried changing it to num_unets = 1, and unets = 1. Both of these don't hit that issue but what they do is just stall the program to where it doesn't go any further without showing the multiple values error. When I hit control c, it shows that the program gets to torch/utils/data/dataloader.py:669 in _next_data, and that's where it stops:
55 │
│ 56 for epoch in range(gpc.config.NUM_EPOCHS): │
│ 57 │ │
│ ❱ 58 │ loss = model.train_step(images, unets = 1, text_embeds = text_embeds) │
│ 59 │
│ 60 print(f'loss: {loss}') │
│ 61 │
│ │
│ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/trainer.py:611 in train_step │
│ │
│ 608 │ │ if not self.prepared: │
│ 609 │ │ │ self.prepare() │
│ 610 │ │ self.create_train_iter() │
│ ❱ 611 │ │ loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **k │
│ 612 │ │ self.update(unet_number = unet_number) │
│ 613 │ │ return loss │
│ 614 │
│ │
│ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/trainer.py:627 in step_with_dl_iter │
│ │
│ 624 │ │ return loss │
│ 625 │ │
│ 626 │ def step_with_dl_iter(self, dl_iter, **kwargs): │
│ ❱ 627 │ │ dl_tuple_output = cast_tuple(next(dl_iter)) │
│ 628 │ │ model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output │
│ 629 │ │ loss = self.forward(**{**kwargs, **model_input}) │
│ 630 │ │ return loss │
│ │
│ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/data.py:27 in cycle │
│ │
│ 24 │
│ 25 def cycle(dl): │
│ 26 │ while True: │
│ ❱ 27 │ │ for data in dl: │
│ 28 │ │ │ yield data │
│ 29 │
│ 30 def convert_image_to(img_type, image): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:628 in __next__ │
│ │
│ 625 │ │ │ if self._sampler_iter is None: │
│ 626 │ │ │ │ # TODO(https://github.com/pytorch/pytorch/issues/76750) │
│ 627 │ │ │ │ self._reset() # type: ignore[call-arg] │
│ ❱ 628 │ │ │ data = self._next_data() │
│ 629 │ │ │ self._num_yielded += 1 │
│ 630 │ │ │ if self._dataset_kind == _DatasetKind.Iterable and \ │
│ 631 │ │ │ │ │ self._IterableDataset_len_called is not None and \ │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:669 in _next_data │
│ │
│ 666 │ │ self._dataset_fetcher = _DatasetKind.create_fetcher( │
│ 667 │ │ │ self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, s │
│ 668 │ │
│ ❱ 669 │ def _next_data(self): │
│ 670 │ │ index = self._next_index() # may raise StopIteration │
│ 671 │ │ data = self._dataset_fetcher.fetch(index) # may raise StopIteration │
│ 672 │ │ if self._pin_memory:
If anyone has any thoughts on what the issue could be that would greatly be appreciated.
train_step
doesn't accept images https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/trainer.py#L607