What format does the model expect for the data(set)?
jmwoloso opened this issue · 1 comments
jmwoloso commented
❓ Questions and Help
Before asking:
- search the issues.
- search the docs.
What is your question?
This is probably due to my unfamiliarity with datasets and pytorch (I've used the TF implementations of Transformers, to-date), but I'm getting the error below when trying to load custom data from csvs.
Traceback (most recent call last):
File "train.py", line 10, in hydra_entry
main(cfg)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
run(
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
trainer.fit(model, datamodule=data_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
self.training_type_plugin.start_training(self)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
return self._run_train()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
self._run_sanity_check(self.lightning_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop.run()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
return self.model(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
loss = self.module(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
return super().forward(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
return self.common_step("val", batch)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
outputs = self.model(**batch)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
main(cfg)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
run(
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
trainer.fit(model, datamodule=data_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
self.training_type_plugin.start_training(self)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
return self._run_train()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
self._run_sanity_check(self.lightning_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop.run()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
return self.model(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
loss = self.module(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
return super().forward(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
return self.common_step("val", batch)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
outputs = self.model(**batch)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
main(cfg)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
run(
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
trainer.fit(model, datamodule=data_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
self.training_type_plugin.start_training(self)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
return self._run_train()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
self._run_sanity_check(self.lightning_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop.run()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
return self.model(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
loss = self.module(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
return super().forward(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
return self.common_step("val", batch)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
outputs = self.model(**batch)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
main(cfg)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
run(
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
trainer.fit(model, datamodule=data_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
self.training_type_plugin.start_training(self)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
return self._run_train()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
self._run_sanity_check(self.lightning_module)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop.run()
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
return self.model(*args, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
loss = self.module(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
return super().forward(*inputs, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
return self.common_step("val", batch)
File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
outputs = self.model(**batch)
File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list
Code
What have you tried?
This is my current implementation of subclassing the TextClassificationDataModule
import functools
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerBase
from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule
class PrismTextClassificationDataModule(TextClassificationDataModule):
def __init__(self, cfg, tokenizer, **kwargs):
# tokenizer_cfg = cfg.copy()
# tokenizer_cfg.pop("train_file")
# tokenizer_cfg.pop("validation_file")
super().__init__(tokenizer, cfg)
# self.__dict__.update({"num_classes": 2})
# self.train_file = kwargs["train_file"]
# self.validation_file = kwargs["validation_file"]
self.train_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.csv"
self.validation_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/val.csv"
self.test_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/test.csv"
self.data_files = {
"train": self.train_file,
"validation": self.validation_file,
# "test": self.test_file
}
def process_data(self, dataset, stage) -> Dataset:
ds = load_dataset("csv", data_files=self.data_files)
columns = [
"InputIds",
# "AttentionMask",
"Label"]
ds = ds.rename_column("InputIds", "input_ids")
# ds = ds.rename_column("AttentionMask", "attention_mask")
ds = ds.rename_column("Label", "labels")
columns = [
"input_ids",
# "attention_mask",
"labels"]
ds.set_format("pytorch", columns=columns)
print(ds.__dict__)
# train = pd.read_csv(self.train_file)
# train = train.rename(columns={"InputIds": "input_ids",
# "AttentionMask": "attention_mask",
# "Label": "labels"})
# val = pd.read_csv(self.validation_file)
# val = val.rename(columns=columns={"InputIds": "input_ids",
# "AttentionMask": "attention_mask",
# "Label": "labels"})
# ds = {}
return ds
What's your environment?
- OS: Linux Mint 19.3
- Conda (environment.yml):
name: pml
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=1_llvm
- abseil-cpp=20210324.2=h9c3ff4c_0
- absl-py=0.13.0=py38h06a4308_0
- aiohttp=3.8.1=py38h7f8727e_0
- aiosignal=1.2.0=pyhd3eb1b0_0
- arrow-cpp=3.0.0=py38h6b21186_4
- async-timeout=4.0.1=pyhd3eb1b0_0
- attrs=21.2.0=pyhd3eb1b0_0
- aws-c-common=0.4.57=he6710b0_1
- aws-c-event-stream=0.1.6=h2531618_5
- aws-checksums=0.1.9=he6710b0_0
- aws-sdk-cpp=1.8.185=hce553d0_0
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- blinker=1.4=py38h06a4308_0
- boost-cpp=1.69.0=h11c811c_1000
- boto3=1.18.21=pyhd3eb1b0_0
- botocore=1.21.41=pyhd3eb1b0_1
- brotli=1.0.9=h7f98852_6
- brotli-bin=1.0.9=h7f98852_6
- brotlipy=0.7.0=py38h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- c-ares=1.17.1=h27cfd23_0
- ca-certificates=2021.10.8=ha878542_0
- certifi=2021.10.8=py38h578d9bd_1
- cffi=1.14.6=py38h400218f_0
- cryptography=3.4.8=py38hd23ed53_0
- cudatoolkit=11.1.1=h6406543_9
- dataclasses=0.8=pyh6d0b6a4_7
- datasets=1.16.1=pyhd8ed1ab_0
- debugpy=1.5.1=py38h295c915_0
- decorator=5.1.0=pyhd3eb1b0_0
- dill=0.3.4=pyhd8ed1ab_0
- double-conversion=3.1.6=h9c3ff4c_0
- ffmpeg=4.2.2=h20bf706_0
- filelock=3.4.0=pyhd8ed1ab_0
- freetype=2.11.0=h70c0345_0
- frozenlist=1.2.0=py38h7f8727e_0
- fsspec=2021.10.1=pyhd3eb1b0_0
- future=0.18.2=py38_1
- gflags=2.2.2=he1b5a44_1004
- giflib=5.2.1=h7b6447c_0
- glog=0.5.0=h48cff8f_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- grpc-cpp=1.39.0=hae934f6_5
- grpcio=1.42.0=py38hce63b2e_0
- huggingface_hub=0.2.1=pyhd8ed1ab_0
- icu=58.2=hf484d3e_1000
- idna=3.3=pyhd3eb1b0_0
- importlib-metadata=4.8.2=py38h06a4308_0
- importlib_metadata=4.8.2=hd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipython=7.29.0=py38hb070fc8_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- jmespath=0.10.0=pyhd3eb1b0_0
- joblib=1.1.0=pyhd3eb1b0_0
- jpeg=9d=h7f8727e_0
- jupyter_client=7.0.6=pyhd3eb1b0_0
- jupyter_core=4.9.1=py38h06a4308_0
- krb5=1.19.2=hcc1bbae_3
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libboost=1.73.0=h3ff78a5_11
- libbrotlicommon=1.0.9=h7f98852_6
- libbrotlidec=1.0.9=h7f98852_6
- libbrotlienc=1.0.9=h7f98852_6
- libcurl=7.78.0=h0b77cf5_0
- libedit=3.1.20210910=h7f8727e_0
- libev=4.33=h516909a_1
- libevent=2.1.10=h9b69904_4
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1d223b6_11
- libidn2=2.3.2=h7f8727e_0
- libnghttp2=1.43.0=h812cca2_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.37=hbc83047_0
- libprotobuf=3.17.2=h4ff587b_1
- libsodium=1.0.18=h7b6447c_0
- libssh2=1.10.0=ha56f1ee_2
- libstdcxx-ng=11.2.0=he4da1e4_11
- libtasn1=4.16.0=h27cfd23_0
- libthrift=0.14.2=hcc01f38_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.0=h89dd481_0
- libwebp-base=1.2.0=h27cfd23_0
- llvm-openmp=12.0.1=h4bd325d_1
- lz4-c=1.9.3=h295c915_1
- markdown=3.3.4=py38h06a4308_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- multidict=5.1.0=py38h27cfd23_2
- multiprocess=0.70.12.2=py38h497a2fe_1
- ncurses=6.3=h7f8727e_2
- nest-asyncio=1.5.1=pyhd3eb1b0_0
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.21.2=py38h79a1101_0
- oauthlib=3.1.1=pyhd3eb1b0_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1l=h7f98852_0
- orc=1.6.9=ha97a36c_3
- packaging=21.3=pyhd3eb1b0_0
- parso=0.8.2=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pip=21.2.4=py38h06a4308_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pyasn1=0.4.8=pyhd3eb1b0_0
- pycparser=2.21=pyhd3eb1b0_0
- pydeprecate=0.3.1=pyhd8ed1ab_0
- pygments=2.10.0=pyhd3eb1b0_0
- pyparsing=3.0.4=pyhd3eb1b0_0
- pysocks=1.7.1=py38h06a4308_0
- python=3.8.12=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-xxhash=2.0.2=py38h497a2fe_1
- python_abi=3.8=2_cp38
- pytorch=1.10.0=py3.8_cuda11.1_cudnn8.0.5_0
- pytorch-lightning=1.5.5=pyhd8ed1ab_0
- pytorch-mutex=1.0=cuda
- pytz=2021.3=pyhd8ed1ab_0
- pyyaml=6.0=py38h7f8727e_1
- pyzmq=22.3.0=py38h295c915_2
- re2=2021.11.01=h9c3ff4c_0
- readline=8.1=h27cfd23_0
- regex=2021.8.3=py38h7f8727e_0
- requests=2.26.0=pyhd3eb1b0_0
- requests-oauthlib=1.3.0=py_0
- rsa=4.7.2=pyhd3eb1b0_1
- s3transfer=0.5.0=pyhd3eb1b0_0
- sacremoses=0.0.43=pyhd3eb1b0_0
- setuptools=58.0.4=py38h06a4308_0
- six=1.16.0=pyhd3eb1b0_0
- snappy=1.1.8=he1b5a44_3
- sqlite=3.36.0=hc218d9a_0
- tk=8.6.11=h1ccaba5_0
- tokenizers=0.10.3=py38hb63a372_1
- torchaudio=0.10.0=py38_cu111
- torchmetrics=0.6.1=pyhd8ed1ab_0
- torchvision=0.11.1=py38_cu111
- tornado=6.1=py38h27cfd23_0
- tqdm=4.62.3=pyhd3eb1b0_1
- traitlets=5.1.1=pyhd3eb1b0_0
- transformers=4.11.3=pyhd8ed1ab_0
- typing-extensions=3.10.0.2=hd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- uriparser=0.9.5=h9c3ff4c_0
- utf8proc=2.6.1=h27cfd23_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- werkzeug=2.0.2=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- x264=1!157.20191217=h7b6447c_0
- xxhash=0.8.0=h7f98852_3
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- yarl=1.6.3=py38h27cfd23_0
- zeromq=4.3.4=h2531618_0
- zipp=3.6.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- adal==1.2.7
- antlr4-python3-runtime==4.8
- applicationinsights==0.11.10
- astunparse==1.6.3
- azure-common==1.1.27
- azure-core==1.20.1
- azure-graphrbac==0.61.1
- azure-identity==1.7.0
- azure-mgmt-authorization==0.61.0
- azure-mgmt-containerregistry==8.2.0
- azure-mgmt-core==1.3.0
- azure-mgmt-keyvault==9.3.0
- azure-mgmt-resource==13.0.0
- azure-mgmt-storage==11.2.0
- azureml-core==1.36.0.post2
- azureml-dataprep==2.24.4
- azureml-dataprep-native==38.0.0
- azureml-dataprep-rslex==2.0.3
- azureml-dataset-runtime==1.36.0
- azureml-defaults==1.36.0
- azureml-inference-server-http==0.4.2
- azureml-mlflow==1.36.0
- azureml-telemetry==1.36.0
- backports-tempfile==1.0
- backports-weakref==1.0.post1
- cachetools==4.2.4
- charset-normalizer==2.0.7
- click==8.0.3
- cloudpickle==2.0.0
- configparser==3.7.4
- contextlib2==21.6.0
- cycler==0.11.0
- databricks-cli==0.16.2
- deepspeed==0.5.8
- distro==1.6.0
- docker==5.0.3
- dotnetcore2==2.1.21
- entrypoints==0.3
- flask==1.0.3
- flatbuffers==2.0
- fonttools==4.28.1
- fusepy==3.0.1
- gast==0.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- gunicorn==20.1.0
- h5py==3.6.0
- hjson==3.0.2
- horovod==0.23.0
- hydra-core==1.1.0
- importlib-resources==5.4.0
- inference-schema==1.3.0
- ipykernel==6.5.1
- isodate==0.6.0
- itsdangerous==2.0.1
- jedi==0.18.1
- jeepney==0.7.1
- jinja2==3.0.3
- json-logging-py==0.2
- jsonpickle==2.0.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- kiwisolver==1.3.2
- libclang==12.0.0
- lightgbm==3.3.1
- markupsafe==2.0.1
- matplotlib==3.5.0
- matplotlib-inline==0.1.3
- mlflow-skinny==1.21.0
- msal==1.16.0
- msal-extensions==0.3.0
- msrest==0.6.21
- msrestazure==0.6.4
- ndg-httpsclient==0.5.1
- ninja==1.10.2.3
- numpy==1.21.4
- omegaconf==2.1.1
- onnxruntime-gpu==1.9.0
- opt-einsum==3.3.0
- pandas==1.3.4
- pathspec==0.9.0
- pillow==8.4.0
- plotly==5.4.0
- portalocker==1.7.1
- prompt-toolkit==3.0.22
- protobuf==3.19.1
- psutil==5.8.0
- pyarrow==3.0.0
- pyasn1-modules==0.2.8
- pyjwt==2.3.0
- pyopenssl==20.0.1
- scikit-learn==1.0.1
- scipy==1.7.2
- secretstorage==3.3.1
- setuptools-scm==6.3.2
- smmap==5.0.0
- tabulate==0.8.9
- tenacity==8.0.1
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow-estimator==2.7.0
- tensorflow-gpu==2.7.0
- tensorflow-io-gcs-filesystem==0.22.0
- termcolor==1.1.0
- threadpoolctl==3.0.0
- tomli==1.2.2
- torch-tb-profiler==0.3.1
- triton==1.1.1
- urllib3==1.26.7
- websocket-client==1.2.1
- wrapt==1.13.3
prefix: /anaconda/envs/pml
requirements.txt:
adal==1.2.7
antlr4-python3-runtime==4.8
applicationinsights==0.11.10
astunparse==1.6.3
azure-common==1.1.27
azure-core==1.20.1
azure-graphrbac==0.61.1
azure-identity==1.7.0
azure-mgmt-authorization==0.61.0
azure-mgmt-containerregistry==8.2.0
azure-mgmt-core==1.3.0
azure-mgmt-keyvault==9.3.0
azure-mgmt-resource==13.0.0
azure-mgmt-storage==11.2.0
azureml-core==1.36.0.post2
azureml-dataprep==2.24.4
azureml-dataprep-native==38.0.0
azureml-dataprep-rslex==2.0.3
azureml-dataset-runtime==1.36.0
azureml-defaults==1.36.0
azureml-inference-server-http==0.4.2
azureml-mlflow==1.36.0
azureml-telemetry==1.36.0
backports-tempfile==1.0
backports-weakref==1.0.post1
cachetools==4.2.4
charset-normalizer==2.0.7
click==8.0.3
cloudpickle==2.0.0
configparser==3.7.4
contextlib2==21.6.0
cycler==0.11.0
databricks-cli==0.16.2
deepspeed==0.5.8
distro==1.6.0
docker==5.0.3
dotnetcore2==2.1.21
entrypoints==0.3
flask==1.0.3
flatbuffers==2.0
fonttools==4.28.1
fusepy==3.0.1
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
google-auth==2.3.3
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
gunicorn==20.1.0
h5py==3.6.0
hjson==3.0.2
horovod==0.23.0
hydra-core==1.1.0
importlib-resources==5.4.0
inference-schema==1.3.0
ipykernel==6.5.1
isodate==0.6.0
itsdangerous==2.0.1
jedi==0.18.1
jeepney==0.7.1
jinja2==3.0.3
json-logging-py==0.2
jsonpickle==2.0.0
keras==2.7.0
keras-preprocessing==1.1.2
kiwisolver==1.3.2
libclang==12.0.0
lightgbm==3.3.1
markupsafe==2.0.1
matplotlib==3.5.0
matplotlib-inline==0.1.3
mlflow-skinny==1.21.0
msal==1.16.0
msal-extensions==0.3.0
msrest==0.6.21
msrestazure==0.6.4
ndg-httpsclient==0.5.1
ninja==1.10.2.3
numpy==1.21.4
omegaconf==2.1.1
onnxruntime-gpu==1.9.0
opt-einsum==3.3.0
pandas==1.3.4
pathspec==0.9.0
pillow==8.4.0
plotly==5.4.0
portalocker==1.7.1
prompt-toolkit==3.0.22
protobuf==3.19.1
psutil==5.8.0
pyarrow==3.0.0
pyasn1-modules==0.2.8
pyjwt==2.3.0
pyopenssl==20.0.1
scikit-learn==1.0.1
scipy==1.7.2
secretstorage==3.3.1
setuptools-scm==6.3.2
smmap==5.0.0
tabulate==0.8.9
tenacity==8.0.1
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow-estimator==2.7.0
tensorflow-gpu==2.7.0
tensorflow-io-gcs-filesystem==0.22.0
termcolor==1.1.0
threadpoolctl==3.0.0
tomli==1.2.2
torch-tb-profiler==0.3.1
triton==1.1.1
urllib3==1.26.7
websocket-client==1.2.1
wrapt==1.13.3
deepspeed==0.5.8
jmwoloso commented
after some playing around, I figured this out. I needed to implement all of the methods and properties the normal data modules have. my text was processed ahead of time, so I just needed to return the dict with the input_ids
. my class implementation is below:
import functools
import json
from typing import Dict
from datasets import Dataset, load_dataset, set_caching_enabled, ClassLabel
import pandas as pd
import numpy as np
import torch.multiprocessing
from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule
# https://github.com/pytorch/pytorch/issues/11201
# too many open files error
torch.multiprocessing.set_sharing_strategy("file_system")
class MyTextClassificationDataModule(TextClassificationDataModule):
def process_data(self, dataset, stage):
dataset = Dataset.from_parquet(
{"train": self.cfg.train_file,
"validation": self.cfg.validation_file},
columns=["input_ids", "label"]
)
dataset = self.preprocess(dataset)
dataset.set_format("pytorch", columns=["input_ids", "labels"])
self.labels = dataset["train"].features["labels"]
return dataset
@property
def num_classes(self) -> int:
return self.labels.num_classes
@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}
@staticmethod
def convert_to_features(example_batch, input_feature_fields, **fn_kwargs):
return {"input_ids": example_batch["input_ids"]}
@staticmethod
def preprocess(ds, **fn_kwargs):
ds = ds.map(
# todo: change this to self.convert_to_features for users to override
MyTextClassificationDataModule.convert_to_features,
batched=True,
with_indices=True,
fn_kwargs=fn_kwargs,
)
ds.rename_column_("label", "labels")
return ds