aws-neuron/aws-neuron-sdk

Internal tensorizer error when trying to compile and train a simple CNN

sgaseretto opened this issue · 3 comments

I replaced the MLP from this example with a CNN and I'm getting a Internal tensorizer error when trying to run it. Here are the scripts:

model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Declare 3-layer MLP for MNIST dataset
class MLP(nn.Module):
  def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]):
      super(MLP, self).__init__()
      self.fc1 = nn.Linear(input_size, layers[0])
      self.fc2 = nn.Linear(layers[0], layers[1])
      self.fc3 = nn.Linear(layers[1], output_size)

  def forward(self, x):
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return F.log_softmax(x, dim=1)
  
 # PyTorch models inherit from torch.nn.Module
class CnnClassifier(nn.Module):
    def __init__(self):
        super(CnnClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

train_xmp.py:

import os
import time
import torch
from model import MLP, CnnClassifier

from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

# XLA imports
import torch_xla.core.xla_model as xm
# XLA imports for parallel loader and multi-processing
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data.distributed import DistributedSampler

# Global constants
EPOCHS = 4
WARMUP_STEPS = 2
BATCH_SIZE = 32
# MODEL_ARCHITECTURE = 'mlp' 
MODEL_ARCHITECTURE = 'cnn' 

# Load MNIST train dataset
train_dataset = mnist.MNIST(root='./MNIST_DATA_train',
                            train=True, download=True, transform=ToTensor())

def main(index):
    # torch.set_default_tensor_type('torch.FloatTensor')
    # XLA MP: get world size
    world_size = xm.xrt_world_size()
    # multi-processing: ensure each worker has same initial weights
    torch.manual_seed(0)
    # Move model to device and declare optimizer and loss function
    device = 'xla'
    model = MLP().to(device) if MODEL_ARCHITECTURE == 'mlp' else CnnClassifier().to(device)
    # For multiprocessing, scale up learning rate
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size)
    loss_fn = torch.nn.NLLLoss() if MODEL_ARCHITECTURE == 'mlp' else torch.nn.CrossEntropyLoss()

    # Prepare data loader
    train_sampler = None
    if world_size > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           num_replicas=world_size,
                                           rank=xm.get_ordinal(),
                                           shuffle=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              sampler=train_sampler,
                              shuffle=False if train_sampler else True)
    # XLA MP: use MpDeviceLoader from torch_xla.distributed
    train_device_loader = pl.MpDeviceLoader(train_loader, device)

    # Run the training loop
    print(f'----------Training {MODEL_ARCHITECTURE}---------------')
    model.train()
    for epoch in range(EPOCHS):
        start = time.time()
        for idx, (train_x, train_label) in enumerate(train_device_loader):
            optimizer.zero_grad()
            if MODEL_ARCHITECTURE == 'mlp':
                train_x = train_x.view(train_x.size(0), -1)
            train_x = train_x.to(device)
            output = model(train_x)
            loss = loss_fn(output, train_label)
            loss.backward()
            xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
            if idx < WARMUP_STEPS: # skip warmup iterations
                start = time.time()

    # Compute statistics for the last epoch
    interval = idx - WARMUP_STEPS # skip warmup iterations
    throughput = interval / (time.time() - start)
    print("Train throughput (iter/sec): {}".format(throughput))
    print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))

    # Save checkpoint for evaluation (xm.save ensures only one process save)
    os.makedirs("checkpoints", exist_ok=True)
    checkpoint = {'state_dict': model.state_dict()}
    xm.save(checkpoint,'checkpoints/checkpoint.pt')

    print('----------End Training ---------------')

if __name__ == '__main__':
    xmp.spawn(main)

When executing it with python train_xmp.py I get this error log:

WARNING:root:MASTER_ADDR not setting, defaulting to localhost
----------Training cnn---------------
2024-04-30 21:07:46.000985:  136004  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-04-30 21:07:46.000986:  136004  ERROR ||NEURON_CC_WRAPPER||: Got a cached failed neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.72.0+78a426937/MODULE_13170552691779535425+d41d8cd9/model.neff. Will skip compilation, please set --retry_failed_compilation for recompilation: 
 Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/0a25a9a8-cf7f-4aa9-a289-7feed2d422d2/model.MODULE_13170552691779535425+d41d8cd9.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/0a25a9a8-cf7f-4aa9-a289-7feed2d422d2/model.MODULE_13170552691779535425+d41d8cd9.neff', '--verbose=35']: 2024-04-30T20:50:00Z [TEN404] Internal tensorizer error - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new
.
2024-04-30 21:07:47.013728: F ./torch_xla/csrc/runtime/debug_macros.h:20] Non-OK-status: status.status() status: INTERNAL: RunNeuronCCImpl: error condition error != 0: <class 'subprocess.CalledProcessError'>: Command '' died with <Signals.SIGHUP: 1>.
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230125::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
	torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
	torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
	torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
	torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::string const>, bool, bool, bool)
	torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)
	
	
	
	PyCFunction_Call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	PyObject_Call
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	PyObject_Call
	
	
	
	clone
*** End stack trace ***

Traceback (most recent call last):
  File "train_xmp.py", line 86, in <module>
    xmp.spawn(main)
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 202, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 159, in run_multiprocess
    replica_results = list(
  File "/home/ubuntu/aws_neuronx_venv_pytorch/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 160, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

This are my packages version:

aws-neuronx-runtime-discovery==2.9
libneuronxla==2.0.965
neuronx-cc==2.13.72.0+78a426937
torch==2.1.2
torch-neuronx==2.1.2.2.1.0
torch-xla==2.1.2
torchvision==0.16.2

I don't think I'm using something that is not supported since this is basically the CNN used in Pytorch's FashionMNIST example.

Thank you for reporting the issue. We are able to reproduce the issue and have started looking into it.

Also the same issue when modifying the train_xmp.py to train the same CNN

Run into the same issue - seems like it only compiles for batch_size = 25 (and multiple of 25) which somehow corresponds to 5x5 conv region