aws-neuron/aws-neuron-sdk

Input tensor is not an XLA tensor: CPUFloatType while using crf.decode function

PrateekAg1511 opened this issue · 4 comments

Hi,

I am trying to trace crf.decode function in ml.inf2.8xlarge instance.

Here is the code I am trying to run.

def tags(output , mask):
a = model.crf.decode(output , mask)
a = torch.FloatTensor(a)
return (a)

inputs = output , mask

model_crf = torch_neuronx.trace(tags, inputs)

Error:


RuntimeError Traceback (most recent call last)
Cell In[38], line 1
----> 1 model_crf = torch_neuronx.trace(tags, inputs)

File ~/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/torch_neuronx/xla_impl/trace.py:556, in trace(func, example_inputs, input_output_aliases, compiler_workdir, compiler_args, partitioner_config, inline_weights_to_neff, *_, **kwargs)
551 return torch_neuronx.partition(
552 func, example_inputs, **(partitioner_config.dict)
553 )
555 with context:
--> 556 neff_filename, metaneff, flattener, packer, weights = _trace(
557 func,
558 example_inputs,
559 states,
560 input_output_aliases,
561 compiler_workdir,
562 compiler_args,
563 inline_weights_to_neff,
564 )
565 return create_neuron_model(
566 neff_filename,
567 metaneff,
(...)
572 weights,
573 )

File ~/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/torch_neuronx/xla_impl/trace.py:614, in _trace(func, example_inputs, states, input_output_aliases, compiler_workdir, compiler_args, inline_weights_to_neff)
597 def _trace(
598 func: Union[Callable, torch.nn.Module],
599 example_inputs: Any,
(...)
605 ) -> Union[str, str, structure.Flattener, structure.Packer]:
606 # Convert the function to a HloProto message
607 (
608 hlo,
609 constant_parameter_tensors,
610 flattener,
611 packer,
612 metaneff,
613 weights,
--> 614 ) = generate_hlo(
615 func,
616 example_inputs,
617 states=states,
618 input_output_aliases=input_output_aliases,
619 inline_weights_to_neff=inline_weights_to_neff
620 )
622 # Call neuronx-cc to generate neff
623 neff_filename = generate_neff(
624 hlo,
625 constant_parameter_tensors,
(...)
628 inline_weights_to_neff=inline_weights_to_neff,
629 )

File ~/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/torch_neuronx/xla_impl/trace.py:404, in generate_hlo(func, example_inputs, states, input_output_aliases, inline_weights_to_neff)
389 def generate_hlo(
390 func: Union[Callable, torch.nn.Module],
391 example_inputs: Any,
(...)
394 inline_weights_to_neff: bool = True
395 ):
396 with torch_neuronx.contexts.mock_neuron_cores(), revert_device_placement(func):
397 (
398 hlo,
399 input_parameter_names,
400 constant_parameter_tensors,
401 flattener,
402 packer,
403 updated_input_output_aliases,
--> 404 ) = xla_trace(
405 func,
406 example_inputs,
407 states,
408 input_output_aliases,
409 )
411 # make sure that hlo dtype and torch dtype match
412 coerce_parameter_dtypes(hlo, constant_parameter_tensors)

File ~/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:133, in xla_trace(func, example_inputs, states, input_output_aliases)
131 # Lower the HLO graph
132 context = torch_xla._XLAC.lowering.LoweringContext()
--> 133 context.build(tensors)
135 # Determine which HloModule parameters should be inlined (ie. constants,
136 # parameters, buffers). This should NOT include the example inputs.
137 parameters = context.parameter_id_tensor_mapping()

RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:73 : Check failed: xtensor
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::bridge::GetXlaTensor(at::Tensor const&)

PyCFunction_Call
_PyObject_MakeTpCall

_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCodeEx
PyEval_EvalCode


_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall

PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName

_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send
_PyEval_EvalFrameDefault
_PyGen_Send


_PyObject_MakeTpCall


PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName

_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCodeEx
PyEval_EvalCode


_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyVectorcall_Call

Py_RunMain
Py_BytesMain
__libc_start_main
_start

*** End stack trace ***
Input tensor is not an XLA tensor: CPUFloatType

I am have been to find a solution for it but nothing worked so far. Please help!

We had a similar problem when trying to run torch_neuronx.trace() and it was because the output of our model had the device set for cpu rather than an XLA device. Solution we found was just to set FloatTensor(a, device=torch_xla.core.xla_model.xla_device()) when creating the tensor to set the device type. Alternatively you can run .to(torch_xla.core.xla_model.xla_device()) on the output:

import torch_xla.core.xla_model as xm

def tags(output , mask):
    return model.crf.decode(output, mask).to(xm.xla_device())

...

Presumably you could set the device type on the input tensor instead but we did have issues with this on our model.

Hopefully this gets you a little bit further towards solving the problem.

@tombettany Thanks!

I tried this but then got the following warning:

/usr/local/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:143: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([1, 60, 184]), dtype=torch.float32)
  warnings.warn(
/usr/local/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:143: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 60]), dtype=torch.uint8)
  warnings.warn(

Now the traced model is giving the same output for every input that it gets.

Would you be able to give more details about the model you are trying to trace? If there is a minimal open source reproduction of the error you are encountering, we can try to help you solve the problem.

The warning you are running into indicates that the output of the model does not appear to depend on the inputs. This can happen when the output is calculated entirely based on tensors which are newly constructed within the forward function. This likely happens due to the implementation of the model.crf.decode method.

@jluntamazon Here is the minimal open source reproduction of the error:

import torch
from torchcrf import CRF
num_tags = 184
model = CRF(num_tags)

emissions = torch.rand([1,60,184])
mask = torch.ones([1,60], dtype=torch.uint8)

def decode_fn(emissions , mask):
a = model.decode(emissions , mask)
a = torch.Tensor(a)
a = a.to(xm.xla_device())
return (a)

inputs_crf = emissions , mask

trace_crf = torch_neuronx.trace(decode_fn , inputs_crf)

After running trace, I get the waring message for both inputs.

"/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=0, shape=torch.Size([1, 60, 184]), dtype=torch.float32)
warnings.warn(
//aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:144: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 60]), dtype=torch.uint8)
warnings.warn( "