`TransformerDataModule.setup()` run more than once unnecessarily
RR-28023 opened this issue ยท 0 comments
๐ Bug
TransformerDataModule.setup()
is run more than once unnecessarily. For example, when running the code included below, it runs setup()
when calling dm.num_classes
and then when calling trainer.fit(model, dm)
.
setup()
then calls self.load_dataset()
, self.split_dataset(dataset)
and self.process_data(dataset, stage=stage)
. Calling self.load_dataset()
several times is not a big deal because it will load it from the cache, but the other two methods are expensive and I think it does not make sense to run them again (since they just override whatever self.ds
was there before.
To Reproduce
Take the below example from the docs and just check the console output or run it in debug mode with a breakpoint. It can be seen that TransformerDataModule.setup()
and the subsequent methods load_dataset()
, split_dataset()
and are run more than once.
import pytorch_lightning as pl
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.text_classification import (
TextClassificationDataModule,
TextClassificationTransformer,
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased")
dm = TextClassificationDataModule(
batch_size=1,
dataset_name="glue",
dataset_config_name="sst2",
max_length=512,
tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", num_labels=dm.num_classes)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Expected behavior
Given that TransformerDataModule.setup()
currently does the following:
def setup(self, stage: Optional[str] = None):
dataset = self.load_dataset()
dataset = self.split_dataset(dataset)
dataset = self.process_data(dataset, stage=stage)
self.ds = dataset
Perhaps a way to avoid running it again would be creating the class attribute self.setup_stages_run = []
when the class is initialized and then defining the setup
method as:
def setup(self, stage: Optional[str] = None):
# Load and split dataset only if setup has not been run before
if len(self.setup_stages_run) == 0:
dataset = self.load_dataset()
dataset = self.split_dataset(dataset)
else:
dataset = self.ds
# Process dataset only if setup has not been run before for this stage
if stage not in self.setup_stages_run:
self.ds = self.process_data(dataset, stage=stage)
self.setup_stages_run.append(stage)
Can create a PR if you think this makes sense.
Thanks!