I would favour the use of a NamedTuple for the qa_data
rather than an index-based tuple ClassificationVQADataset.py
.
|
def prepare_split(self, split: str) -> list: |
|
""" |
|
This method should return a list of tuples, where each tuple contains |
|
the following elements: |
|
|
|
- The key of the image at index 0 |
|
- The question at index 1 |
|
- The answer at index 2 |
|
- additional information at index 3 and higher |
|
|
|
:param split: The name of the split to prepare |
|
|
|
:return: A list of tuples, each tuple containing the elements described |
|
""" |
|
raise NotImplementedError("This method should be implemented by the subclass") |
Although the doc-string does say how the ordering should be implemented, the information is quickly lost when implementing a subclass or while debugging the application. The following code contains "magic numbers" for example:
|
def __getitem__(self, idx): |
|
qa_pair = self.qa_data[idx] |
|
|
|
# get image |
|
img = self.load_image(qa_pair[0]).to(torch.float32) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
assert img.shape == self.img_size, f"Image shape is {img.shape}, expected {self.img_size}" |
|
|
|
# tokenize question |
|
question_ids = huggingface_tokenize_and_pad( |
|
tokenizer=self.tokenizer, |
|
string=qa_pair[1], |
|
seq_length=self.seq_length, |
|
) |
|
assert ( |
|
len(question_ids) == self.seq_length |
|
), f"Question length is {len(question_ids)}, expected {self.seq_length}" |
|
|
|
# convert answer to tensor |
|
# note: this assumes that the qa_pair is a tuple of length 3, where the |
|
# answer is at index 2 |
|
# the answer is a list of strings, where each string is an answer, therefore we |
|
# need wrap the single answer in a list |
|
tmp_answer = qa_pair[2] |
|
if not isinstance(tmp_answer, list): |
|
# answers are sometimes provided as a single string, sometimes as a list |
|
# of strings |
|
# if it is a single string, we need to wrap it in a list |
|
tmp_answer = [tmp_answer] |
|
answer = self._answers_to_tensor(tmp_answer) |
|
|
|
if self.return_extras: |
|
return img, question_ids, answer, *qa_pair[3:] |
|
else: |
|
return img, question_ids, answer |
I would just a NamedTuple with the keys:
image_key
, question
, answer
, and extra
. Where extra
may either be an iterable or an object. Though an iterable would require less changes:
|
if self.return_extras: |
|
return img, question_ids, answer, *qa_pair[3:] |
|
else: |
|
return img, question_ids, answer |