
FR: Change tuple to NamedTuple in `ClassificationVQADataset.py`

kai-tub opened this issue · 0 comments

I would favour the use of a NamedTuple for the qa_data rather than an index-based tuple ClassificationVQADataset.py.

def prepare_split(self, split: str) -> list:
This method should return a list of tuples, where each tuple contains
the following elements:
- The key of the image at index 0
- The question at index 1
- The answer at index 2
- additional information at index 3 and higher
:param split: The name of the split to prepare
:return: A list of tuples, each tuple containing the elements described
raise NotImplementedError("This method should be implemented by the subclass")

Although the doc-string does say how the ordering should be implemented, the information is quickly lost when implementing a subclass or while debugging the application. The following code contains "magic numbers" for example:

def __getitem__(self, idx):
qa_pair = self.qa_data[idx]
# get image
img = self.load_image(qa_pair[0]).to(torch.float32)
if self.transform is not None:
img = self.transform(img)
assert img.shape == self.img_size, f"Image shape is {img.shape}, expected {self.img_size}"
# tokenize question
question_ids = huggingface_tokenize_and_pad(
assert (
len(question_ids) == self.seq_length
), f"Question length is {len(question_ids)}, expected {self.seq_length}"
# convert answer to tensor
# note: this assumes that the qa_pair is a tuple of length 3, where the
# answer is at index 2
# the answer is a list of strings, where each string is an answer, therefore we
# need wrap the single answer in a list
tmp_answer = qa_pair[2]
if not isinstance(tmp_answer, list):
# answers are sometimes provided as a single string, sometimes as a list
# of strings
# if it is a single string, we need to wrap it in a list
tmp_answer = [tmp_answer]
answer = self._answers_to_tensor(tmp_answer)
if self.return_extras:
return img, question_ids, answer, *qa_pair[3:]
return img, question_ids, answer

I would just a NamedTuple with the keys:
image_key, question, answer, and extra. Where extra may either be an iterable or an object. Though an iterable would require less changes:

if self.return_extras:
return img, question_ids, answer, *qa_pair[3:]
return img, question_ids, answer