[Bug] operator.gt got an object with type <class 'int'>
ruofan-wu opened this issue · 8 comments
Hi @yaoyaoding, I encounter a bug when I run T5Model with hidet:
DEBUG:hidet.graph.frontend.torch.interpreter:interpreting node 34: %gt : [#users=1] = call_function[target=operator.gt](args = (%sub_1, 0), kwargs = {})
Traceback (most recent call last):
File "hidet-path/python/hidet/graph/frontend/torch/interpreter.py", line 260, in forward
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
File "hidet-path/python/hidet/graph/frontend/torch/register_functions.py", line 694, in gt
return ops.greater(a, b)
File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 85, in greater
return GreaterOp(x, y).get_output(0)
File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 34, in __init__
super().__init__(x, y, lambda a, b: a > b, name='gt')
File "hidet-path/python/hidet/graph/ops/definitions/arithmetic.py", line 125, in __init__
task=BinaryElementwiseTask(name, input_like(x, 'x'), input_like(y, 'y'), op=op),
File "hidet-path/python/hidet/graph/ops/definitions/utils/tensor_utils.py", line 26, in input_like
raise TypeError('Expect a hidet.Tensor, but got an object with type {}'.format(type(tensor)))
TypeError: Expect a hidet.Tensor, but got an object with type <class 'int'>
Could you please help me fix it?
Hi @GisellWu,
Any minimal reproducible example to reproduce the error?
import torch
from transformers import T5Tokenizer, T5Model
import hidet
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name).to(device="cuda:0")
model = torch.compile(model, backend='hidet')
input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device="cuda:0")
outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)
Furthermore, I added some functions to register_functions.py in order to run through T5Model:
@register_function(torch.abs)
def abs(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.abs(..., out=...)")
return ops.abs(x)
@register_function(torch.log)
def log(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.log(..., out=...)")
return ops.log(x)
@register_function(torch.full_like)
def full_like(input, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format):
if layout not in [None, torch.strided]:
raise NotImplementedError("hidet: does not support torch.full(..., layout=..., ...)")
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
hidet_device: Device = device_from_torch(torch_device=device)
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
return ops.full(input.size(), fill_value, dtype=hidet_dtype, device=hidet_device)
@register_function(torch.zeros_like)
def zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format):
import hidet
if layout is not None:
raise NotImplementedError("layout is not None")
size = input.size()
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = [int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()
_ = requires_grad
return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))
Hi @GisellWu,
I added the missing operators and fixed some bugs in #322 for T5 model. Could you give a try again?
Thanks for your help! I successfully ran it. Close the issue :)
Hi @yaoyaoding ,
Sorry to bother you again, I’m trying T5Model with float16. There are some new unsupported functions. Could you please help me fix it?
The error is:
NotImplementedError: The following modules/functions are not supported by hidet yet:
torch.clamp
torch.isinf
And the example code is:
import torch
from transformers import T5Tokenizer, T5Model
import hidet
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name, torch_dtype=torch.float16).to(device="cuda:0")
model = torch.compile(model, backend='hidet')
input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device="cuda:0")
outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)
Hi @GisellWu,
I added the missing operators in #343, could you give it a try? Thanks!
Hi @yaoyaoding,
That's ok. Appreciate your help again!