aws-neuron/aws-neuron-sdk

BERT model implemented usiing TransformerEncoder returns all NaNs when running it torch==1.13.1

sgaseretto opened this issue · 3 comments

I have some scripts to train a BERT model on WikiText2. When testing it on both torch 2.1.2 and 1.13.1, the scripts get compiled and executed successfully. The issue is that when running bert_vanilla.py on torch==1.13.1, all the elements from the output returned by the TransformerEncoder are NaNs, and the loss is not computed. Here are the scripts the the output when training:

bert_utils.py

from typing import Tuple, Any, List, Dict, OrderedDict

import math
import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import os

# PositionalEncoding class from transformer_tutorial.py
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# TransformerModel class from transformer_tutorial.py
class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(src.device)
        has_nans = torch.isnan(src_mask).any().item()
        # if has_nans:
        #     print("src_mask contains NaNs")
        # print("ALL GOOD")
        output = self.transformer_encoder(src, src_mask) # returns NaN on Neuron
        # print("OUTPUT BEFORE LINEAR:", output)
        output = self.linear(output)
        # print("RETURNED OUTPUT AFTER LINEAR:", output)
        return output

# Function to generate input and target sequence
def get_batch(source: Tensor, i: int, bptt: int) -> Tuple[Tensor, Tensor]:
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

# Other constants and hyperparameters
bptt = 35

bert_xla.py

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm
from bert_utils import TransformerModel, get_batch, bptt
import math
import time
import os

# XLA imports
import torch_xla.core.xla_model as xm

# Constants
EPOCHS = 6
BATCH_SIZE = 8

# Load the dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_dataset = dataset["train"]

# Tokenize and encode the dataset using HF tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_and_encode(batch):
    return tokenizer(batch['text'], return_tensors='pt', truncation=True, padding='max_length', max_length=512)

tokenized_dataset = train_dataset.map(tokenize_and_encode, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type='torch', columns=['input_ids'])

# Downsample the dataset
sample_percentage = 0.01
num_samples = int(len(tokenized_dataset) * sample_percentage)
tokenized_dataset = tokenized_dataset.shuffle(seed=42).select(range(num_samples))

# Create DataLoader
train_dl = DataLoader(tokenized_dataset, shuffle=True, batch_size=BATCH_SIZE)

# XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)
device = 'xla'
print(f"DEVICE={device}")

# Model configuration
ntokens = tokenizer.vocab_size
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

# Training setup
criterion = torch.nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1.45e-4)

num_training_steps = EPOCHS * len(train_dl)
progress_bar = tqdm(range(num_training_steps))
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Training loop
model.train()
for epoch in range(EPOCHS):
    for batch in train_dl:
        input_ids = batch['input_ids'].squeeze(1).to(device)
        targets = input_ids.clone().detach()
        outputs = model(input_ids)
        outputs_flat = outputs.view(-1, ntokens)
        loss = criterion(outputs_flat, targets.view(-1))
        loss.backward()
        optimizer.step()
        xm.mark_step()  # XLA: collect ops and run them in XLA runtime
        optimizer.zero_grad()
        lr_scheduler.step()
        progress_bar.update(1)
        

    print(f"Epoch {epoch}, Loss: {loss.detach().to('cpu')}")

print("Training complete")

# Save checkpoint for evaluation
os.makedirs("checkpoints", exist_ok=True)
checkpoint = {'state_dict': model.state_dict()}
xm.save(checkpoint, 'checkpoints/checkpoint.pt')

You can see the loss being computed as NaN in the training log here:

DEVICE=xla

  0%|          | 0/276 [00:00<?, ?it/s]2024-06-10 22:55:46.000069:  7023  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-06-10 22:55:46.000070:  7023  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_14596765104133435291+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-06-10 22:55:46.000379:  7049  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-06-10 22:55:46.000473:  7049  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_1300752231203242943+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
  0%|          | 1/276 [00:00<03:07,  1.47it/s]2024-06-10 22:55:48.000863:  7078  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-06-10 22:55:48.000956:  7078  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_15046888644701568097+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
 16%|█▋        | 45/276 [00:10<00:32,  7.21it/s]2024-06-10 22:55:57.000027:  7979  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-06-10 22:55:57.000028:  7979  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_11440727015853899450+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-06-10 22:55:57.000340:  8001  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-06-10 22:55:57.000449:  8001  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_65065912177798273+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
 17%|█▋        | 46/276 [00:11<01:10,  3.28it/s]Epoch 0, Loss: nan
 33%|███▎      | 92/276 [00:20<00:25,  7.19it/s]Epoch 1, Loss: nan
 50%|█████     | 138/276 [00:26<00:19,  7.22it/s]Epoch 2, Loss: nan
 67%|██████▋   | 184/276 [00:32<00:12,  7.18it/s]Epoch 3, Loss: nan
 83%|████████▎ | 230/276 [00:39<00:06,  7.17it/s]Epoch 4, Loss: nan
100%|██████████| 276/276 [00:45<00:00,  7.20it/s]Epoch 5, Loss: nan
Training complete

100%|██████████| 276/276 [00:46<00:00,  5.97it/s]

When running it on torch==2.1.2, I don't encounter this issue. Since I'm porting code implemented in PyTorch Lightning using neuron-distributed, and torch==1.13.1 is the recommended version (see this issue) for using NeuronLTModule, I'm experiencing the same issue. Here is also the code and the log for you to reproduce the issue:

import os
from pathlib import Path
from typing import Tuple, Dict
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer, LightningDataModule
from torch.utils.data import DataLoader, Dataset, Subset
from torch.optim import AdamW
from transformers import get_scheduler
from bert_utils import TransformerModel
import random

# Neuron imports
from neuronx_distributed.lightning import (
    NeuronLTModule,
    NeuronTensorBoardLogger,
    NeuronXLAPrecisionPlugin,
    NeuronXLAStrategy,
    NeuronTQDMProgressBar
)
import neuronx_distributed.parallel_layers.parallel_state as parallel_state
import neuronx_distributed as nxd
from neuronx_distributed import initialize_parallel_optimizer
import torch_xla.core.xla_model as xm

class WikiText2Dataset(Dataset):
    def __init__(self, data_dir: Path, block_size: int = 35) -> None:
        self.path = data_dir / "wikitext-2.txt"
        self.data, self.dictionary = self.tokenize(self.path)
        self.block_size = block_size
        # Ensure the dataset is not empty
        if len(self.data) == 0:
            raise RuntimeError("Tokenization resulted in an empty dataset.")

    @property
    def vocab_size(self) -> int:
        return len(self.dictionary)

    def __len__(self) -> int:
        return len(self.data) // self.block_size - 1

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        start = index * self.block_size
        end = start + self.block_size
        inputs = self.data[start:end]
        target = self.data[(start + 1): (end + 1)]
        return inputs, target

    @staticmethod
    def download(destination: Path) -> None:
        import requests
        os.makedirs(destination.parent, exist_ok=True)
        url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"
        if not os.path.exists(destination):
            with open(destination, "w") as f:
                f.write(requests.get(url).text)

    @staticmethod
    def tokenize(path: Path) -> Tuple[torch.Tensor, Dict[str, int]]:
        dictionary = {}
        idx2word = []

        def add_word(word: str) -> int:
            if word not in dictionary:
                idx2word.append(word)
                dictionary[word] = len(idx2word) - 1
            return dictionary[word]

        data = []
        with open(path, encoding="utf8") as f:
            for line in f:
                words = line.split() + ["<eos>"]
                for word in words:
                    add_word(word)

        with open(path, encoding="utf8") as f:
            for line in f:
                words = line.split() + ["<eos>"]
                ids = [dictionary[word] for word in words]
                data.append(torch.tensor(ids, dtype=torch.long))

        # Ensure the data list is not empty
        if len(data) == 0:
            raise RuntimeError(f"No data found in the tokenization of file: {path}")

        return torch.cat(data), dictionary

class WikiText2DataModule(LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 8, block_size: int = 35, sample_percentage: float = 1.0):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.block_size = block_size
        self.sample_percentage = sample_percentage

    def prepare_data(self):
        # Download the dataset if it doesn't exist
        WikiText2Dataset.download(self.data_dir / "wikitext-2.txt")

    def setup(self, stage=None):
        dataset = WikiText2Dataset(self.data_dir, block_size=self.block_size)
        if self.sample_percentage < 1.0:
            num_samples = int(len(dataset) * self.sample_percentage)
            indices = list(range(len(dataset)))
            random.seed(42)
            random.shuffle(indices)
            sampled_indices = indices[:num_samples]
            self.train_dataset = Subset(dataset, sampled_indices)
        else:
            self.train_dataset = dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

class BERTNeuronLTModule(NeuronLTModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = kwargs['model_fn'](kwargs["model_args"][0])
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters()

        # Manual optimization
        self.automatic_optimization = False

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model(input_ids)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        xm.mark_step()  # Isolate forward+backward graph
        optimizer = self.optimizers()
        scheduler = self.lr_schedulers()

        input_ids, targets = batch
        outputs = self(input_ids)
        outputs_flat = outputs.view(-1, outputs.size(-1))
        loss = self.criterion(outputs_flat, targets.view(-1))
        
        # Manual optimization
        optimizer.zero_grad()
        self.manual_backward(loss)
        xm.mark_step()  # Isolate forward+backward graph
        optimizer.step()
        scheduler.step()

        # self.log("train_loss", loss)
        # self.log_dict({"train_loss":loss}, on_step=False, on_epoch=True, prog_bar=True)
        self.log_dict({"train_loss":loss.detach().to('cpu')}, on_step=False, on_epoch=True, prog_bar=True)
        xm.mark_step()  # Isolate optimization step graph
        return loss

    def configure_optimizers(self):
        optimizer = initialize_parallel_optimizer(
            nxd_config=self.nxd_config, 
            optimizer_class=self.opt_cls,
            parameters=self.model.parameters(),
            **self.opt_kwargs
        )

        scheduler = self.scheduler_cls(
            optimizer=optimizer,
            **self.scheduler_kwargs
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            }
        }

def get_model(model_config: dict):
    return TransformerModel(
        ntoken=model_config["ntoken"],
        d_model=model_config["d_model"],
        nhead=model_config["nhead"],
        d_hid=model_config["d_hid"],
        nlayers=model_config["nlayers"],
        dropout=model_config["dropout"]
    )

def main():
    data_dir = "./data"
    batch_size = 8
    block_size = 35
    sample_percentage = 0.01  # Adjust this value to change the sample size
    model_config = {
        "ntoken": 33278,  # Vocab size of WikiText2
        "d_model": 200,
        "nhead": 2,
        "d_hid": 200,
        "nlayers": 2,
        "dropout": 0.2
    }

    data_module = WikiText2DataModule(
        data_dir=data_dir, 
        batch_size=batch_size, 
        block_size=block_size,
        sample_percentage=sample_percentage
    )
    
    # Ensure the dataset is downloaded
    data_module.prepare_data()
    # Call setup explicitly to initialize the dataset
    data_module.setup('fit')

    nxd_config = nxd.neuronx_distributed_config(
        optimizer_config={"zero_one_enabled": False}
    )
    opt_cls = AdamW
    scheduler_cls = get_scheduler

    # Calculate the number of training steps
    num_training_steps = len(data_module.train_dataloader()) * 6  # Adjust this if necessary

    model = BERTNeuronLTModule(
        model_fn=get_model,
        model_args=(model_config,),
        nxd_config=nxd_config,
        opt_cls=opt_cls,
        scheduler_cls=scheduler_cls,
        opt_kwargs={"lr": 1.45e-4},
        scheduler_kwargs={
            "name":"linear",
            "num_warmup_steps": 0,
            "num_training_steps": num_training_steps
        }
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints',
        filename='checkpoint',
        # monitor="train_loss", 
        # mode="min"
        monitor="global_step", 
        mode="max",
        verbose=True,
        save_top_k=1, 
    )

    strategy = NeuronXLAStrategy(nxd_config=nxd_config)
    callbacks = [NeuronTQDMProgressBar(), checkpoint_callback]

    trainer = Trainer(
        max_epochs=6,
        strategy=strategy,
        callbacks=callbacks,
        enable_checkpointing=True,
        enable_progress_bar=True,
    )
    trainer.fit(model, data_module)

if __name__ == "__main__":
    main()

I also tested on GPUs using both torch==2.1.2 and torch==1.13.1, and the issue only occurs when running this on neuron.

To run the code on Neuron I'm running it on a trn1.x2large using this Hugging Face AMI. I'm building images both for torch 1.13.1 and 2.1.2 with this dockerfile:

ARG NEURON_VER=2.18.2
ARG PY_VERSION=310
ARG PYTORCH_VER=2.1.2

# Use arguments to specify the base image
FROM public.ecr.aws/neuron/pytorch-training-neuronx:${PYTORCH_VER}-neuronx-py${PY_VERSION}-sdk${NEURON_VER}-ubuntu20.04

# Configure tensorboard for trn1
COPY requirements.txt requirements.txt
RUN python -m pip install -r requirements.txt

WORKDIR /opt/app/trainium

With this requirements.txt:

jsonargparse[signatures]>=4.17.0
wandb
pandas
s3fs
sagemaker
pytorch-lightning==2.1.0

I'm launching those containers in an interactive mode, for example for torch 1.13.1 I named the image nxd_torch1:

docker run -it --device=/dev/neuron0 --name neuron_torch_v1 -v <folder path from host>:/opt/app/trainium/src nxd_torch1

And then I'm running the script with:

python bert_xla.py

Thank you for the great example! We were able to reproduce the problem.

The most likely cause of this issue is due to the exact operations that are lowered across different versions of torch. The compiler receives a different graph which may trigger different instructions to be executed. Some of these operations may cause very large (infinity) values to appear in one version of torch while these will not appear in another version. One way to fix this is to set the compiler to optimize for this type of model:

  • --model-type=transformer - This options enables compiler optimizations for transformer models. This can avoid NaN values by performing more specialized methods which are less prone to numerical issues.
  • (Optional) --enable-saturate-infinity - This option ensures that large values do not go to infinity. Infinity values can cause NaNs to be produced due to the underlying hardware operations, so this can avoid problems with large values.

I was able to produce non-NaN values by configuring the --model-type flag prior to compilation:

import os
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer'

@sgaseretto , does the above recommendation solve the issue?

Hi @aws-rhsoln, yes it solved the issue! Sorry for the late response, closing this issue now