aws-neuron/aws-neuron-sdk

Is F.interpolate support on neuronx-cc due to Unsupported CustomCall target: ResizeBilinear

takipipo opened this issue ยท 4 comments

Seems like F.interpolation is the problem. Is there any workaround for this ?

I follow the instruction on this tutorial and activated the
source /opt/aws_neuronx_venv_pytorch_1_13/bin/activate env

Env

neuronx-cc --version
NeuronX Compiler version 2.13.72.0+78a426937

Python version 3.10.12
HWM version 2.13.72.0+78a426937
NumPy version 1.25.2

Running on AMI ami-06d5ee23a754fdb52
Running in region apse1-az2
pip list
Package                       Version
----------------------------- -------------------
absl-py                       2.1.0
aiohttp                       3.9.5
aiosignal                     1.3.1
amqp                          5.2.0
annotated-types               0.6.0
anyio                         4.3.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
astroid                       3.1.0
asttokens                     2.4.1
async-lru                     2.0.4
async-timeout                 4.0.3
attrs                         23.2.0
Automat                       22.10.0
aws-neuronx-runtime-discovery 2.9
awscli                        1.32.92
Babel                         2.14.0
beautifulsoup4                4.12.3
billiard                      4.2.0
bleach                        6.1.0
boto3                         1.34.92
botocore                      1.34.92
build                         1.2.1
cachetools                    5.3.3
celery                        5.4.0
certifi                       2024.2.2
cffi                          1.16.0
charset-normalizer            3.3.2
click                         8.1.7
click-didyoumean              0.3.1
click-plugins                 1.1.1
click-repl                    0.3.0
cloud-tpu-client              0.10
cloudpickle                   3.0.0
cmake                         3.29.2
colorama                      0.4.4
comm                          0.2.2
constantly                    23.10.4
contourpy                     1.2.1
cryptography                  42.0.5
cssselect                     1.2.0
cycler                        0.12.1
dask                          2024.4.2
debugpy                       1.8.1
decorator                     5.1.1
defusedxml                    0.7.1
dill                          0.3.8
distlib                       0.3.8
docutils                      0.16
dparse                        0.6.3
ec2-metadata                  2.10.0
entrypoints                   0.4
environment-kernels           1.2.0
exceptiongroup                1.2.1
executing                     2.0.1
fastapi                       0.110.2
fastjsonschema                2.19.1
filelock                      3.13.4
fonttools                     4.51.0
fqdn                          1.5.1
frozenlist                    1.4.1
fsspec                        2024.3.1
google-api-core               1.34.1
google-api-python-client      1.8.0
google-auth                   2.29.0
google-auth-httplib2          0.2.0
googleapis-common-protos      1.63.0
h11                           0.14.0
httpcore                      1.0.5
httpie                        3.2.2
httplib2                      0.22.0
httpx                         0.27.0
hyperlink                     21.0.0
idna                          3.7
imageio                       2.34.1
importlib_metadata            7.1.0
incremental                   22.10.0
iniconfig                     2.0.0
ipykernel                     6.29.4
ipython                       8.23.0
ipywidgets                    8.1.2
islpy                         2023.1
isoduration                   20.11.0
isort                         5.13.2
itemadapter                   0.8.0
itemloaders                   1.2.0
jedi                          0.19.1
Jinja2                        3.1.3
jmespath                      1.0.1
joblib                        1.4.0
json5                         0.9.25
jsonpointer                   2.4
jsonschema                    4.21.1
jsonschema-specifications     2023.12.1
jupyter                       1.0.0
jupyter_client                8.6.1
jupyter-console               6.6.3
jupyter_core                  5.7.2
jupyter-events                0.10.0
jupyter-lsp                   2.2.5
jupyter_server                2.14.0
jupyter_server_terminals      0.5.3
jupyterlab                    4.1.6
jupyterlab_pygments           0.3.0
jupyterlab_server             2.27.1
jupyterlab_widgets            3.0.10
kiwisolver                    1.4.5
kombu                         5.3.7
libneuronxla                  0.5.971
llvmlite                      0.42.0
locket                        1.0.0
lockfile                      0.12.2
lxml                          5.2.1
markdown-it-py                3.0.0
MarkupSafe                    2.1.5
matplotlib                    3.8.4
matplotlib-inline             0.1.7
mccabe                        0.7.0
mdurl                         0.1.2
mistune                       3.0.2
multidict                     6.0.5
nbclient                      0.10.0
nbconvert                     7.16.3
nbformat                      5.10.4
nest-asyncio                  1.6.0
networkx                      2.6.3
neuronx-cc                    2.13.72.0+78a426937
neuronx-distributed           0.7.0
notebook                      7.1.3
notebook_shim                 0.2.4
numba                         0.59.1
numpy                         1.25.2
nvidia-cublas-cu11            11.10.3.66
nvidia-cuda-nvrtc-cu11        11.7.99
nvidia-cuda-runtime-cu11      11.7.99
nvidia-cudnn-cu11             8.5.0.96
oauth2client                  4.1.3
opencv-python                 4.9.0.80
overrides                     7.7.0
packaging                     21.3
pandas                        2.2.2
pandocfilters                 1.5.1
papermill                     2.5.0
parsel                        1.9.1
parso                         0.8.4
partd                         1.4.1
pexpect                       4.9.0
pgzip                         0.3.5
pillow                        10.3.0
pip                           24.0
pip-tools                     7.4.1
pipenv                        2023.12.1
platformdirs                  4.2.1
plotly                        5.21.0
pluggy                        1.5.0
prometheus_client             0.20.0
prompt-toolkit                3.0.43
Protego                       0.3.1
protobuf                      3.19.6
psutil                        5.9.8
ptyprocess                    0.7.0
pure-eval                     0.2.2
pyasn1                        0.6.0
pyasn1_modules                0.4.0
pycparser                     2.22
pydantic                      2.7.1
pydantic_core                 2.18.2
PyDispatcher                  2.0.7
Pygments                      2.17.2
pylint                        3.1.0
pyOpenSSL                     24.1.0
pyparsing                     3.1.2
pyproject_hooks               1.0.0
PySocks                       1.7.1
pytest                        8.1.1
python-daemon                 3.0.1
python-dateutil               2.9.0.post0
python-json-logger            2.0.7
pytz                          2024.1
PyYAML                        6.0.1
pyzmq                         26.0.2
qtconsole                     5.5.1
QtPy                          2.4.1
queuelib                      1.6.2
referencing                   0.35.0
requests                      2.31.0
requests-file                 2.0.0
requests-toolbelt             1.0.0
requests-unixsocket           0.3.0
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          13.7.1
rpds-py                       0.18.0
rsa                           4.7.2
ruamel.yaml                   0.18.6
ruamel.yaml.clib              0.2.8
s3transfer                    0.10.1
safety                        2.3.5
scikit-learn                  1.4.2
scipy                         1.11.2
Scrapy                        2.11.1
seaborn                       0.13.2
Send2Trash                    1.8.3
service-identity              24.1.0
setuptools                    69.5.1
shap                          0.45.0
six                           1.16.0
slicer                        0.0.7
sniffio                       1.3.1
soupsieve                     2.5
stack-data                    0.6.3
starlette                     0.37.2
tenacity                      8.2.3
terminado                     0.18.1
threadpoolctl                 3.4.0
tinycss2                      1.3.0
tldextract                    5.1.2
tomli                         2.0.1
tomlkit                       0.12.4
toolz                         0.12.1
torch                         1.13.1
torch-neuronx                 1.13.1.1.14.0
torch-xla                     1.13.1+torchneurone
torchvision                   0.14.1
tornado                       6.4
tqdm                          4.66.2
traitlets                     5.14.3
Twisted                       24.3.0
types-python-dateutil         2.9.0.20240316
typing_extensions             4.11.0
tzdata                        2024.1
uri-template                  1.3.0
uritemplate                   3.0.1
urllib3                       2.2.1
vine                          5.1.0
virtualenv                    20.26.0
w3lib                         2.1.2
wcwidth                       0.2.13
webcolors                     1.13
webencodings                  0.5.1
websocket-client              1.8.0
wget                          3.2
wheel                         0.43.0
widgetsnbextension            4.0.10
yarl                          1.9.4
zipp                          3.18.1
zope.interface                6.3

Model architecture

craft.py

"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from scripts.model_architectures.detector.vgg16 import VGG16BN


class DoubleConv(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, input):
        output = self.conv(input)
        return output


class CRAFT(nn.Module):
    def __init__(self, pretrained=True, freeze=False, amp=False):
        super(CRAFT, self).__init__()

        self.amp = amp

        """ Base network """
        self.basenet = VGG16BN(pretrained, freeze)

        """ U network """
        self.upconv1 = DoubleConv(1024, 512, 256)
        self.upconv2 = DoubleConv(512, 256, 128)
        self.upconv3 = DoubleConv(256, 128, 64)
        self.upconv4 = DoubleConv(128, 64, 32)

        num_class = 2
        self.conv_cls = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, num_class, kernel_size=1),
        )

        self._init_weights(self.upconv1.modules())
        self._init_weights(self.upconv2.modules())
        self._init_weights(self.upconv3.modules())
        self._init_weights(self.upconv4.modules())
        self._init_weights(self.conv_cls.modules())

    def forward(self, input):
        """Base network"""
          sources = self.basenet(input)

          """ U network """
          y = torch.cat([sources[0], sources[1]], dim=1)
          y = self.upconv1(y)

          y = F.interpolate(
              y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
          )
          y = torch.cat([y, sources[2]], dim=1)
          y = self.upconv2(y)

          y = F.interpolate(
              y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
          )
          y = torch.cat([y, sources[3]], dim=1)
          y = self.upconv3(y)

          y = F.interpolate(
              y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
          )
          y = torch.cat([y, sources[4]], dim=1)
          feature = self.upconv4(y)

          y = self.conv_cls(feature)

          return y.permute(0, 2, 3, 1), feature

    def _init_weights(self, modules):
        for m in modules:
            if isinstance(m, nn.Conv2d):
                init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

vgg.py

import torch
import torchvision
from packaging import version
from torch import nn
from torchvision import models


class VGG16BN(torch.nn.Module):
    def __init__(self, pretrained=True, freeze=True):
        super(VGG16BN, self).__init__()
        if version.parse(torchvision.__version__) >= version.parse("0.13"):
            vgg_pretrained_features = models.vgg16_bn(
                weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
            ).features
        else:  # torchvision.__version__ < 0.13
            models.vgg.model_urls["vgg16_bn"] = models.vgg.model_urls[
                "vgg16_bn"
            ].replace("https://", "http://")
            vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features

        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for index in range(12):  # conv2_2
            self.slice1.add_module(str(index), vgg_pretrained_features[index])
        for index in range(12, 19):  # conv3_3
            self.slice2.add_module(str(index), vgg_pretrained_features[index])
        for index in range(19, 29):  # conv4_3
            self.slice3.add_module(str(index), vgg_pretrained_features[index])
        for index in range(29, 39):  # conv5_3
            self.slice4.add_module(str(index), vgg_pretrained_features[index])

        # fc6, fc7 without atrous conv
        self.slice5 = torch.nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
            nn.Conv2d(1024, 1024, kernel_size=1),
        )

        if not pretrained:
            self._init_weights(self.slice1.modules())
            self._init_weights(self.slice2.modules())
            self._init_weights(self.slice3.modules())
            self._init_weights(self.slice4.modules())

        self._init_weights(self.slice5.modules())  # no pretrained model for fc6 and fc7

        if freeze:
            for param in self.slice1.parameters():  # only first conv
                param.requires_grad = False

    def _init_weights(self, modules):
        for m in modules:
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def forward(self, input):
        h = self.slice1(input)
        h_relu2_2 = h
        h = self.slice2(h)
        h_relu3_2 = h
        h = self.slice3(h)
        h_relu4_3 = h
        h = self.slice4(h)
        h_relu5_3 = h
        h = self.slice5(h)
        h_fc7 = h
        return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2

Compiling scripts

import torch
import torch_neuronx
from torchvision import models
from torchvision.transforms import functional
from scripts.model_architectures.detector.craft import CRAFT

model = CRAFT()
detector_dummy_input = torch.rand((1, 3, 512, 512))
model.eval()

model_neuron = torch_neuronx.trace(model, detector_dummy_input)

Compiler output

model_neuron = torch_neuronx.trace(model, detector_dummy_input)
Process Process-1:
Traceback (most recent call last):
  File "neuronxcc/driver/CommandDriver.py", line 343, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand
  File "neuronxcc/driver/commands/CompileCommand.py", line 1240, in neuronxcc.driver.commands.CompileCommand.CompileCommand.run
  File "neuronxcc/driver/commands/CompileCommand.py", line 1199, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/commands/CompileCommand.py", line 1216, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/commands/CompileCommand.py", line 1219, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
  File "neuronxcc/driver/Job.py", line 346, in neuronxcc.driver.Job.SingleInputJob.run
  File "neuronxcc/driver/Job.py", line 372, in neuronxcc.driver.Job.SingleInputJob.runOnState
  File "neuronxcc/driver/Pipeline.py", line 30, in neuronxcc.driver.Pipeline.Pipeline.runSingleInput
  File "neuronxcc/driver/Job.py", line 346, in neuronxcc.driver.Job.SingleInputJob.run
  File "neuronxcc/driver/Job.py", line 372, in neuronxcc.driver.Job.SingleInputJob.runOnState
  File "neuronxcc/driver/jobs/Frontend.py", line 425, in neuronxcc.driver.jobs.Frontend.Frontend.runSingleInput
  File "neuronxcc/driver/jobs/Frontend.py", line 207, in neuronxcc.driver.jobs.Frontend.Frontend.runXLAFrontend
  File "neuronxcc/driver/jobs/Frontend.py", line 183, in neuronxcc.driver.jobs.Frontend.Frontend.runHlo2Tensorizer
neuronxcc.driver.Exceptions.CompilerInvalidInputException: ERROR: Failed command  /opt/aws_neuronx_venv_pytorch_1_13/lib/python3.10/site-packages/neuronxcc/starfish/bin/hlo2penguin --input /tmp/tmp1jc7ca4j/model --out-dir ./ --output penguin.py --layers-per-module=1 --coalesce-all-gathers=false --coalesce-reduce-scatters=false --coalesce-all-reduces=false --emit-tensor-level-dropout-ops --emit-tensor-level-rng-ops
------------
Reported stdout:
INFO: Found compute bound graph
Replaced 0 dropout sequences with OffloadedDropout
INFO: HloMacCount has found 90963217152
INFO: Traffic has found 95114636
INFO: AIF 1912.71
HLO Ops used in computation: add broadcast concatenate constant convolution custom-call maximum multiply parameter reduce-window transpose tuple
loc("_custom-call.530{hlo_id=530}"): error: 'mhlo.custom_call' op Error: Unsupported CustomCall target: ResizeBilinear
munmap_chunk(): invalid pointer

------------
Reported stderr:
None
------------
Import of the HLO graph into the Neuron Compiler has failed.
This may be caused by unsupported operators or an internal compiler error.
More details can be found in the error message(s) above.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "neuronxcc/driver/CommandDriver.py", line 350, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand_in_process
  File "neuronxcc/driver/CommandDriver.py", line 345, in neuronxcc.driver.CommandDriver.CommandDriver.run_subcommand
  File "neuronxcc/driver/CommandDriver.py", line 111, in neuronxcc.driver.CommandDriver.handleError
  File "neuronxcc/driver/GlobalState.py", line 102, in neuronxcc.driver.GlobalState.FinalizeGlobalState
  File "neuronxcc/driver/GlobalState.py", line 82, in neuronxcc.driver.GlobalState._GlobalStateImpl.shutdown
  File "/usr/lib/python3.10/shutil.py", line 715, in rmtree
    onerror(os.lstat, path, sys.exc_info())
  File "/usr/lib/python3.10/shutil.py", line 713, in rmtree
    orig_st = os.lstat(path)
FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/kirin-ocr/neuronxcc-yic9fnqx'
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 model_neuron = torch_neuronx.trace(model, detector_dummy_input)

File /opt/aws_neuronx_venv_pytorch_1_13/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py:556, in trace(func, example_inputs, input_output_aliases, compiler_workdir, compiler_args, partitioner_config, inline_weights_to_neff, *_, **kwargs)
    551     return torch_neuronx.partition(
    552         func, example_inputs, **(partitioner_config.__dict__)
    553     )
    555 with context:
--> 556     neff_filename, metaneff, flattener, packer, weights = _trace(
    557         func,
    558         example_inputs,
    559         states,
    560         input_output_aliases,
    561         compiler_workdir,
    562         compiler_args,
    563         inline_weights_to_neff,
    564     )
    565     return create_neuron_model(
    566         neff_filename,
    567         metaneff,
   (...)
    572         weights,
    573     )

File /opt/aws_neuronx_venv_pytorch_1_13/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py:623, in _trace(func, example_inputs, states, input_output_aliases, compiler_workdir, compiler_args, inline_weights_to_neff)
    607 (
    608     hlo,
    609     constant_parameter_tensors,
   (...)
    619     inline_weights_to_neff=inline_weights_to_neff
    620 )
    622 # Call neuronx-cc to generate neff
--> 623 neff_filename = generate_neff(
    624     hlo,
    625     constant_parameter_tensors,
    626     compiler_workdir=compiler_workdir,
    627     compiler_args=compiler_args,
    628     inline_weights_to_neff=inline_weights_to_neff,
    629 )
    631 return neff_filename, metaneff, flattener, packer, weights

File /opt/aws_neuronx_venv_pytorch_1_13/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py:446, in generate_neff(hlo, constant_parameter_tensors, compiler_workdir, compiler_args, inline_weights_to_neff)
    441 compiler_target = setup_compiler_dirs(
    442     hlo, compiler_workdir, constant_parameter_tensors, inline_weights_to_neff
    443 )
    445 # Compile HLO to NEFF
--> 446 neff_filename = hlo_compile(
    447     compiler_target,
    448     compiler_workdir,
    449     compiler_args,
    450 )
    452 return neff_filename

File /opt/aws_neuronx_venv_pytorch_1_13/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py:384, in hlo_compile(filename, compiler_workdir, compiler_args)
    377     elif status == -11:
    378         logger.warning(
    379             "The neuronx-cc (neuron compiler) crashed (SEGFAULT). "
    380             "This is likely due to a bug in the compiler.  "
    381             "Please lodge an issue at 'https://github.com/aws/aws-neuron-sdk/issues'"
    382         )
--> 384     raise RuntimeError(f"neuronx-cc failed with {status}")
    386 return neff_filename

RuntimeError: neuronx-cc failed with 1

Thanks @takipipo, we're taking a look.

Hitting a similar error when trying to compile the Deeplabv3 model

https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/

neuronxcc.driver.Exceptions.CompilerInvalidInputException: ERROR: Failed command  [/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.8/site-packages/neuronxcc/starfish/bin/hlo2penguin](https://vscode-remote+ssh-002dremote-002b34-002e216-002e66-002e211.vscode-resource.vscode-cdn.net/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.8/site-packages/neuronxcc/starfish/bin/hlo2penguin) --input [/tmp/tmpbydm5xwu/model](https://vscode-remote+ssh-002dremote-002b34-002e216-002e66-002e211.vscode-resource.vscode-cdn.net/tmp/tmpbydm5xwu/model) --out-dir [./](https://vscode-remote+ssh-002dremote-002b34-002e216-002e66-002e211.vscode-resource.vscode-cdn.net/home/ubuntu/Development/deeplabv3/) --output penguin.py --layers-per-module=1 --coalesce-all-gathers=false --coalesce-reduce-scatters=false --coalesce-all-reduces=false --emit-tensor-level-dropout-ops --emit-tensor-level-rng-ops
------------ 
Reported stdout: 
INFO: Found compute bound graph
Replaced 0 dropout sequences with OffloadedDropout
INFO: HloMacCount has found 1199915289280
INFO: Traffic has found 481182340
INFO: AIF 4987.36
HLO Ops used in computation: add broadcast concatenate constant convolution custom-call maximum multiply parameter reduce reduce-window reshape subtract transpose tuple 
Invoking RemoveOptimizationBarriers pass
loc("_custom-call.4242{hlo_id=4242}"): error: 'mhlo.custom_call' op Error: Unsupported CustomCall target: ResizeBilinear

Thanks for reporting and sorry for the delay. We are looking into the issue