lucidrains/imagen-pytorch

multiple values for unet_number

Shikamaru5 opened this issue · 1 comments

It's sorta odd but I was/am having this issue where the program gets to the loss part and says that multiple values are trying to be put into unet_number:

  55                                                                                             │
  │   56 for epoch in range(gpc.config.NUM_EPOCHS):                                                  │
  │   57 │                                                                                           │
  │ ❱ 58 │   loss = model.train_step(images, unet_number = 1, text_embeds = text_embeds)             │
  │   59                                                                                             │
  │   60 print(f'loss: {loss}')                                                                      │
  │   61                                                                                             │
  ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
  TypeError: ImagenTrainer.train_step() got multiple values for argument 'unet_number'

Not too sure why that is because as far as I can tell my script isn't really all that different from the examples. Currently I've tried changing it to num_unets = 1, and unets = 1. Both of these don't hit that issue but what they do is just stall the program to where it doesn't go any further without showing the multiple values error. When I hit control c, it shows that the program gets to torch/utils/data/dataloader.py:669 in _next_data, and that's where it stops:

       55                                                                                             │
  │   56 for epoch in range(gpc.config.NUM_EPOCHS):                                                  │
  │   57 │                                                                                           │
  │ ❱ 58 │   loss = model.train_step(images, unets = 1, text_embeds = text_embeds)                   │
  │   59                                                                                             │
  │   60 print(f'loss: {loss}')                                                                      │
  │   61                                                                                             │
  │                                                                                                  │
  │ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/trainer.py:611 in train_step                │
  │                                                                                                  │
  │   608 │   │   if not self.prepared:                                                              │
  │   609 │   │   │   self.prepare()                                                                 │
  │   610 │   │   self.create_train_iter()                                                           │
  │ ❱ 611 │   │   loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **k   │
  │   612 │   │   self.update(unet_number = unet_number)                                             │
  │   613 │   │   return loss                                                                        │
  │   614                                                                                            │
  │                                                                                                  │
  │ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/trainer.py:627 in step_with_dl_iter         │
  │                                                                                                  │
  │   624 │   │   return loss                                                                        │
  │   625 │                                                                                          │
  │   626 │   def step_with_dl_iter(self, dl_iter, **kwargs):                                        │
  │ ❱ 627 │   │   dl_tuple_output = cast_tuple(next(dl_iter))                                        │
  │   628 │   │   model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output   │
  │   629 │   │   loss = self.forward(**{**kwargs, **model_input})                                   │
  │   630 │   │   return loss                                                                        │
  │                                                                                                  │
  │ /mnt/f/genaitor/majel/imagen/imagen_p/imagen_pytorch/data.py:27 in cycle                         │
  │                                                                                                  │
  │    24                                                                                            │
  │    25 def cycle(dl):                                                                             │
  │    26 │   while True:                                                                            │
  │ ❱  27 │   │   for data in dl:                                                                    │
  │    28 │   │   │   yield data                                                                     │
  │    29                                                                                            │
  │    30 def convert_image_to(img_type, image):                                                     │
  │                                                                                                  │
  │ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:628 in __next__           │
  │                                                                                                  │
  │    625 │   │   │   if self._sampler_iter is None:                                                │
  │    626 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
  │    627 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
  │ ❱  628 │   │   │   data = self._next_data()                                                      │
  │    629 │   │   │   self._num_yielded += 1                                                        │
  │    630 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
  │    631 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
  │                                                                                                  │
  │ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:669 in _next_data         │
  │                                                                                                  │
  │    666 │   │   self._dataset_fetcher = _DatasetKind.create_fetcher(                              │
  │    667 │   │   │   self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, s  │
  │    668 │                                                                                         │
  │ ❱  669 │   def _next_data(self):                                                                 │
  │    670 │   │   index = self._next_index()  # may raise StopIteration                             │
  │    671 │   │   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration              │
  │    672 │   │   if self._pin_memory:

If anyone has any thoughts on what the issue could be that would greatly be appreciated.