Lightning-Universe/lightning-transformers

What format does the model expect for the data(set)?

jmwoloso opened this issue · 1 comments

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

This is probably due to my unfamiliarity with datasets and pytorch (I've used the TF implementations of Transformers, to-date), but I'm getting the error below when trying to load custom data from csvs.

Traceback (most recent call last):
  File "train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Code

What have you tried?

This is my current implementation of subclassing the TextClassificationDataModule

import functools

from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerBase


from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule


class PrismTextClassificationDataModule(TextClassificationDataModule):
    def __init__(self, cfg, tokenizer, **kwargs):
        # tokenizer_cfg = cfg.copy()
        # tokenizer_cfg.pop("train_file")
        # tokenizer_cfg.pop("validation_file")
        super().__init__(tokenizer, cfg)
        # self.__dict__.update({"num_classes": 2})
        # self.train_file = kwargs["train_file"]
        # self.validation_file = kwargs["validation_file"]
        self.train_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.csv"
        self.validation_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/val.csv"
        self.test_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/test.csv"
        self.data_files = {
            "train": self.train_file,
            "validation": self.validation_file,
            # "test": self.test_file
        }

    def process_data(self, dataset, stage) -> Dataset:
        ds = load_dataset("csv", data_files=self.data_files)
        columns = [
            "InputIds", 
            # "AttentionMask", 
            "Label"]
        ds = ds.rename_column("InputIds", "input_ids")
        # ds = ds.rename_column("AttentionMask", "attention_mask")
        ds = ds.rename_column("Label", "labels")
        columns = [
            "input_ids", 
            # "attention_mask", 
            "labels"]
        ds.set_format("pytorch", columns=columns)
        print(ds.__dict__)
        # train = pd.read_csv(self.train_file)
        # train = train.rename(columns={"InputIds": "input_ids", 
        #                               "AttentionMask": "attention_mask",
        #                               "Label": "labels"})
        # val = pd.read_csv(self.validation_file)
        # val = val.rename(columns=columns={"InputIds": "input_ids", 
        #                               "AttentionMask": "attention_mask",
        #                               "Label": "labels"})
        # ds = {}
        return ds

What's your environment?

  • OS: Linux Mint 19.3
  • Conda (environment.yml):
name: pml
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=1_llvm
  - abseil-cpp=20210324.2=h9c3ff4c_0
  - absl-py=0.13.0=py38h06a4308_0
  - aiohttp=3.8.1=py38h7f8727e_0
  - aiosignal=1.2.0=pyhd3eb1b0_0
  - arrow-cpp=3.0.0=py38h6b21186_4
  - async-timeout=4.0.1=pyhd3eb1b0_0
  - attrs=21.2.0=pyhd3eb1b0_0
  - aws-c-common=0.4.57=he6710b0_1
  - aws-c-event-stream=0.1.6=h2531618_5
  - aws-checksums=0.1.9=he6710b0_0
  - aws-sdk-cpp=1.8.185=hce553d0_0
  - backcall=0.2.0=pyhd3eb1b0_0
  - blas=1.0=mkl
  - blinker=1.4=py38h06a4308_0
  - boost-cpp=1.69.0=h11c811c_1000
  - boto3=1.18.21=pyhd3eb1b0_0
  - botocore=1.21.41=pyhd3eb1b0_1
  - brotli=1.0.9=h7f98852_6
  - brotli-bin=1.0.9=h7f98852_6
  - brotlipy=0.7.0=py38h27cfd23_1003
  - bzip2=1.0.8=h7b6447c_0
  - c-ares=1.17.1=h27cfd23_0
  - ca-certificates=2021.10.8=ha878542_0
  - certifi=2021.10.8=py38h578d9bd_1
  - cffi=1.14.6=py38h400218f_0
  - cryptography=3.4.8=py38hd23ed53_0
  - cudatoolkit=11.1.1=h6406543_9
  - dataclasses=0.8=pyh6d0b6a4_7
  - datasets=1.16.1=pyhd8ed1ab_0
  - debugpy=1.5.1=py38h295c915_0
  - decorator=5.1.0=pyhd3eb1b0_0
  - dill=0.3.4=pyhd8ed1ab_0
  - double-conversion=3.1.6=h9c3ff4c_0
  - ffmpeg=4.2.2=h20bf706_0
  - filelock=3.4.0=pyhd8ed1ab_0
  - freetype=2.11.0=h70c0345_0
  - frozenlist=1.2.0=py38h7f8727e_0
  - fsspec=2021.10.1=pyhd3eb1b0_0
  - future=0.18.2=py38_1
  - gflags=2.2.2=he1b5a44_1004
  - giflib=5.2.1=h7b6447c_0
  - glog=0.5.0=h48cff8f_0
  - gmp=6.2.1=h2531618_2
  - gnutls=3.6.15=he1e5248_0
  - grpc-cpp=1.39.0=hae934f6_5
  - grpcio=1.42.0=py38hce63b2e_0
  - huggingface_hub=0.2.1=pyhd8ed1ab_0
  - icu=58.2=hf484d3e_1000
  - idna=3.3=pyhd3eb1b0_0
  - importlib-metadata=4.8.2=py38h06a4308_0
  - importlib_metadata=4.8.2=hd8ed1ab_0
  - intel-openmp=2021.4.0=h06a4308_3561
  - ipython=7.29.0=py38hb070fc8_0
  - ipython_genutils=0.2.0=pyhd3eb1b0_1
  - jmespath=0.10.0=pyhd3eb1b0_0
  - joblib=1.1.0=pyhd3eb1b0_0
  - jpeg=9d=h7f8727e_0
  - jupyter_client=7.0.6=pyhd3eb1b0_0
  - jupyter_core=4.9.1=py38h06a4308_0
  - krb5=1.19.2=hcc1bbae_3
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.35.1=h7274673_9
  - libboost=1.73.0=h3ff78a5_11
  - libbrotlicommon=1.0.9=h7f98852_6
  - libbrotlidec=1.0.9=h7f98852_6
  - libbrotlienc=1.0.9=h7f98852_6
  - libcurl=7.78.0=h0b77cf5_0
  - libedit=3.1.20210910=h7f8727e_0
  - libev=4.33=h516909a_1
  - libevent=2.1.10=h9b69904_4
  - libffi=3.3=he6710b0_2
  - libgcc-ng=11.2.0=h1d223b6_11
  - libidn2=2.3.2=h7f8727e_0
  - libnghttp2=1.43.0=h812cca2_0
  - libopus=1.3.1=h7b6447c_0
  - libpng=1.6.37=hbc83047_0
  - libprotobuf=3.17.2=h4ff587b_1
  - libsodium=1.0.18=h7b6447c_0
  - libssh2=1.10.0=ha56f1ee_2
  - libstdcxx-ng=11.2.0=he4da1e4_11
  - libtasn1=4.16.0=h27cfd23_0
  - libthrift=0.14.2=hcc01f38_0
  - libtiff=4.2.0=h85742a9_0
  - libunistring=0.9.10=h27cfd23_0
  - libuv=1.40.0=h7b6447c_0
  - libvpx=1.7.0=h439df22_0
  - libwebp=1.2.0=h89dd481_0
  - libwebp-base=1.2.0=h27cfd23_0
  - llvm-openmp=12.0.1=h4bd325d_1
  - lz4-c=1.9.3=h295c915_1
  - markdown=3.3.4=py38h06a4308_0
  - mkl=2021.4.0=h06a4308_640
  - mkl-service=2.4.0=py38h7f8727e_0
  - mkl_fft=1.3.1=py38hd3c417c_0
  - mkl_random=1.2.2=py38h51133e4_0
  - multidict=5.1.0=py38h27cfd23_2
  - multiprocess=0.70.12.2=py38h497a2fe_1
  - ncurses=6.3=h7f8727e_2
  - nest-asyncio=1.5.1=pyhd3eb1b0_0
  - nettle=3.7.3=hbbd107a_1
  - numpy-base=1.21.2=py38h79a1101_0
  - oauthlib=3.1.1=pyhd3eb1b0_0
  - olefile=0.46=pyhd3eb1b0_0
  - openh264=2.1.0=hd408876_0
  - openssl=1.1.1l=h7f98852_0
  - orc=1.6.9=ha97a36c_3
  - packaging=21.3=pyhd3eb1b0_0
  - parso=0.8.2=pyhd3eb1b0_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pickleshare=0.7.5=pyhd3eb1b0_1003
  - pip=21.2.4=py38h06a4308_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pyasn1=0.4.8=pyhd3eb1b0_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pydeprecate=0.3.1=pyhd8ed1ab_0
  - pygments=2.10.0=pyhd3eb1b0_0
  - pyparsing=3.0.4=pyhd3eb1b0_0
  - pysocks=1.7.1=py38h06a4308_0
  - python=3.8.12=h12debd9_0
  - python-dateutil=2.8.2=pyhd3eb1b0_0
  - python-xxhash=2.0.2=py38h497a2fe_1
  - python_abi=3.8=2_cp38
  - pytorch=1.10.0=py3.8_cuda11.1_cudnn8.0.5_0
  - pytorch-lightning=1.5.5=pyhd8ed1ab_0
  - pytorch-mutex=1.0=cuda
  - pytz=2021.3=pyhd8ed1ab_0
  - pyyaml=6.0=py38h7f8727e_1
  - pyzmq=22.3.0=py38h295c915_2
  - re2=2021.11.01=h9c3ff4c_0
  - readline=8.1=h27cfd23_0
  - regex=2021.8.3=py38h7f8727e_0
  - requests=2.26.0=pyhd3eb1b0_0
  - requests-oauthlib=1.3.0=py_0
  - rsa=4.7.2=pyhd3eb1b0_1
  - s3transfer=0.5.0=pyhd3eb1b0_0
  - sacremoses=0.0.43=pyhd3eb1b0_0
  - setuptools=58.0.4=py38h06a4308_0
  - six=1.16.0=pyhd3eb1b0_0
  - snappy=1.1.8=he1b5a44_3
  - sqlite=3.36.0=hc218d9a_0
  - tk=8.6.11=h1ccaba5_0
  - tokenizers=0.10.3=py38hb63a372_1
  - torchaudio=0.10.0=py38_cu111
  - torchmetrics=0.6.1=pyhd8ed1ab_0
  - torchvision=0.11.1=py38_cu111
  - tornado=6.1=py38h27cfd23_0
  - tqdm=4.62.3=pyhd3eb1b0_1
  - traitlets=5.1.1=pyhd3eb1b0_0
  - transformers=4.11.3=pyhd8ed1ab_0
  - typing-extensions=3.10.0.2=hd3eb1b0_0
  - typing_extensions=3.10.0.2=pyh06a4308_0
  - uriparser=0.9.5=h9c3ff4c_0
  - utf8proc=2.6.1=h27cfd23_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - werkzeug=2.0.2=pyhd3eb1b0_0
  - wheel=0.37.0=pyhd3eb1b0_1
  - x264=1!157.20191217=h7b6447c_0
  - xxhash=0.8.0=h7f98852_3
  - xz=5.2.5=h7b6447c_0
  - yaml=0.2.5=h7b6447c_0
  - yarl=1.6.3=py38h27cfd23_0
  - zeromq=4.3.4=h2531618_0
  - zipp=3.6.0=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.4.9=haebb681_0
  - pip:
    - adal==1.2.7
    - antlr4-python3-runtime==4.8
    - applicationinsights==0.11.10
    - astunparse==1.6.3
    - azure-common==1.1.27
    - azure-core==1.20.1
    - azure-graphrbac==0.61.1
    - azure-identity==1.7.0
    - azure-mgmt-authorization==0.61.0
    - azure-mgmt-containerregistry==8.2.0
    - azure-mgmt-core==1.3.0
    - azure-mgmt-keyvault==9.3.0
    - azure-mgmt-resource==13.0.0
    - azure-mgmt-storage==11.2.0
    - azureml-core==1.36.0.post2
    - azureml-dataprep==2.24.4
    - azureml-dataprep-native==38.0.0
    - azureml-dataprep-rslex==2.0.3
    - azureml-dataset-runtime==1.36.0
    - azureml-defaults==1.36.0
    - azureml-inference-server-http==0.4.2
    - azureml-mlflow==1.36.0
    - azureml-telemetry==1.36.0
    - backports-tempfile==1.0
    - backports-weakref==1.0.post1
    - cachetools==4.2.4
    - charset-normalizer==2.0.7
    - click==8.0.3
    - cloudpickle==2.0.0
    - configparser==3.7.4
    - contextlib2==21.6.0
    - cycler==0.11.0
    - databricks-cli==0.16.2
    - deepspeed==0.5.8
    - distro==1.6.0
    - docker==5.0.3
    - dotnetcore2==2.1.21
    - entrypoints==0.3
    - flask==1.0.3
    - flatbuffers==2.0
    - fonttools==4.28.1
    - fusepy==3.0.1
    - gast==0.4.0
    - gitdb==4.0.9
    - gitpython==3.1.24
    - google-auth==2.3.3
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - gunicorn==20.1.0
    - h5py==3.6.0
    - hjson==3.0.2
    - horovod==0.23.0
    - hydra-core==1.1.0
    - importlib-resources==5.4.0
    - inference-schema==1.3.0
    - ipykernel==6.5.1
    - isodate==0.6.0
    - itsdangerous==2.0.1
    - jedi==0.18.1
    - jeepney==0.7.1
    - jinja2==3.0.3
    - json-logging-py==0.2
    - jsonpickle==2.0.0
    - keras==2.7.0
    - keras-preprocessing==1.1.2
    - kiwisolver==1.3.2
    - libclang==12.0.0
    - lightgbm==3.3.1
    - markupsafe==2.0.1
    - matplotlib==3.5.0
    - matplotlib-inline==0.1.3
    - mlflow-skinny==1.21.0
    - msal==1.16.0
    - msal-extensions==0.3.0
    - msrest==0.6.21
    - msrestazure==0.6.4
    - ndg-httpsclient==0.5.1
    - ninja==1.10.2.3
    - numpy==1.21.4
    - omegaconf==2.1.1
    - onnxruntime-gpu==1.9.0
    - opt-einsum==3.3.0
    - pandas==1.3.4
    - pathspec==0.9.0
    - pillow==8.4.0
    - plotly==5.4.0
    - portalocker==1.7.1
    - prompt-toolkit==3.0.22
    - protobuf==3.19.1
    - psutil==5.8.0
    - pyarrow==3.0.0
    - pyasn1-modules==0.2.8
    - pyjwt==2.3.0
    - pyopenssl==20.0.1
    - scikit-learn==1.0.1
    - scipy==1.7.2
    - secretstorage==3.3.1
    - setuptools-scm==6.3.2
    - smmap==5.0.0
    - tabulate==0.8.9
    - tenacity==8.0.1
    - tensorboard==2.7.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.0
    - tensorflow-estimator==2.7.0
    - tensorflow-gpu==2.7.0
    - tensorflow-io-gcs-filesystem==0.22.0
    - termcolor==1.1.0
    - threadpoolctl==3.0.0
    - tomli==1.2.2
    - torch-tb-profiler==0.3.1
    - triton==1.1.1
    - urllib3==1.26.7
    - websocket-client==1.2.1
    - wrapt==1.13.3
prefix: /anaconda/envs/pml

requirements.txt:

adal==1.2.7
antlr4-python3-runtime==4.8
applicationinsights==0.11.10
astunparse==1.6.3
azure-common==1.1.27
azure-core==1.20.1
azure-graphrbac==0.61.1
azure-identity==1.7.0
azure-mgmt-authorization==0.61.0
azure-mgmt-containerregistry==8.2.0
azure-mgmt-core==1.3.0
azure-mgmt-keyvault==9.3.0
azure-mgmt-resource==13.0.0
azure-mgmt-storage==11.2.0
azureml-core==1.36.0.post2
azureml-dataprep==2.24.4
azureml-dataprep-native==38.0.0
azureml-dataprep-rslex==2.0.3
azureml-dataset-runtime==1.36.0
azureml-defaults==1.36.0
azureml-inference-server-http==0.4.2
azureml-mlflow==1.36.0
azureml-telemetry==1.36.0
backports-tempfile==1.0
backports-weakref==1.0.post1
cachetools==4.2.4
charset-normalizer==2.0.7
click==8.0.3
cloudpickle==2.0.0
configparser==3.7.4
contextlib2==21.6.0
cycler==0.11.0
databricks-cli==0.16.2
deepspeed==0.5.8
distro==1.6.0
docker==5.0.3
dotnetcore2==2.1.21
entrypoints==0.3
flask==1.0.3
flatbuffers==2.0
fonttools==4.28.1
fusepy==3.0.1
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
google-auth==2.3.3
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
gunicorn==20.1.0
h5py==3.6.0
hjson==3.0.2
horovod==0.23.0
hydra-core==1.1.0
importlib-resources==5.4.0
inference-schema==1.3.0
ipykernel==6.5.1
isodate==0.6.0
itsdangerous==2.0.1
jedi==0.18.1
jeepney==0.7.1
jinja2==3.0.3
json-logging-py==0.2
jsonpickle==2.0.0
keras==2.7.0
keras-preprocessing==1.1.2
kiwisolver==1.3.2
libclang==12.0.0
lightgbm==3.3.1
markupsafe==2.0.1
matplotlib==3.5.0
matplotlib-inline==0.1.3
mlflow-skinny==1.21.0
msal==1.16.0
msal-extensions==0.3.0
msrest==0.6.21
msrestazure==0.6.4
ndg-httpsclient==0.5.1
ninja==1.10.2.3
numpy==1.21.4
omegaconf==2.1.1
onnxruntime-gpu==1.9.0
opt-einsum==3.3.0
pandas==1.3.4
pathspec==0.9.0
pillow==8.4.0
plotly==5.4.0
portalocker==1.7.1
prompt-toolkit==3.0.22
protobuf==3.19.1
psutil==5.8.0
pyarrow==3.0.0
pyasn1-modules==0.2.8
pyjwt==2.3.0
pyopenssl==20.0.1
scikit-learn==1.0.1
scipy==1.7.2
secretstorage==3.3.1
setuptools-scm==6.3.2
smmap==5.0.0
tabulate==0.8.9
tenacity==8.0.1
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow-estimator==2.7.0
tensorflow-gpu==2.7.0
tensorflow-io-gcs-filesystem==0.22.0
termcolor==1.1.0
threadpoolctl==3.0.0
tomli==1.2.2
torch-tb-profiler==0.3.1
triton==1.1.1
urllib3==1.26.7
websocket-client==1.2.1
wrapt==1.13.3
deepspeed==0.5.8

after some playing around, I figured this out. I needed to implement all of the methods and properties the normal data modules have. my text was processed ahead of time, so I just needed to return the dict with the input_ids. my class implementation is below:

import functools
import json
from typing import Dict

from datasets import Dataset, load_dataset, set_caching_enabled, ClassLabel
import pandas as pd
import numpy as np
import torch.multiprocessing

from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule

# https://github.com/pytorch/pytorch/issues/11201
# too many open files error
torch.multiprocessing.set_sharing_strategy("file_system")


class MyTextClassificationDataModule(TextClassificationDataModule):    
    def process_data(self, dataset, stage):
        
        dataset = Dataset.from_parquet(
            {"train": self.cfg.train_file,
             "validation": self.cfg.validation_file},
            columns=["input_ids", "label"]
        )

        dataset = self.preprocess(dataset)

        dataset.set_format("pytorch", columns=["input_ids", "labels"])
        self.labels = dataset["train"].features["labels"]

        return dataset

    @property
    def num_classes(self) -> int:
        return self.labels.num_classes

    @property
    def model_data_kwargs(self) -> Dict[str, int]:
        return {"num_labels": self.num_classes}

    @staticmethod
    def convert_to_features(example_batch, input_feature_fields, **fn_kwargs):
        return {"input_ids": example_batch["input_ids"]}

    @staticmethod
    def preprocess(ds, **fn_kwargs):
        ds = ds.map(
            # todo: change this to self.convert_to_features for users to override
            MyTextClassificationDataModule.convert_to_features,
            batched=True,
            with_indices=True,
            fn_kwargs=fn_kwargs,
        )
        ds.rename_column_("label", "labels")
        return ds