[NotImplementedError] operator.lshift
HanGuo97 opened this issue · 7 comments
Describe the Problem
BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
operator.lshift
Thanks for the quick fix!
Does that mean I could simply build from source, and this would work via torch.compile
interface?
I am not sure about the particulars of your model, but this script worked for me (when built from source):
import hidet
import torch
def test(a):
return a << 3
t = torch.compile(test, backend='hidet')
t(torch.randn(3, 5, device='cuda').to(torch.int64))
Hi @Aalanli,
Thank you so much for your help! The patch you provided worked perfectly for me.
I wanted to follow up with a more detailed example. This particular issue caused several errors for me. Some of these were fixable following the patch you provided, such as pow
, bitwise_and
, and torch.Tensor.max
.
However, there is one error that I’m having difficulty fixing without additional knowledge of the codebase. Here’s the error message:
ValueError: Unknown data type: uint8x4, candidates...
I’ve attached a simplified code that shows this error, and possibly a few others. I would appreciate any help you can offer in resolving this issue. Thank you in advance for your time!
Code
import math
import torch
import jaxtyping
from typing import Tuple
DEFAULT_CONTAINER_NUM_BITS = 8
FloatTensorType = jaxtyping.Float[torch.Tensor, "..."]
UInt8TensorType = jaxtyping.UInt8[torch.Tensor, "..."]
BinaryTensorType = jaxtyping.Bool[torch.Tensor, "..."]
PackedBinaryTensorType = jaxtyping.UInt8[torch.Tensor, "..."]
def from_binary(tensor: BinaryTensorType, num_bits: int) -> UInt8TensorType:
if tensor.dtype != torch.bool:
raise TypeError
if tensor.shape[-1] != num_bits:
raise ValueError
if num_bits > 8:
raise NotImplementedError
mask = torch.tensor([2], dtype=torch.float32, device=tensor.device) ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
tensor = tensor.to(dtype=torch.uint8)
output = torch.sum(mask * tensor, dim=-1)
output = output.to(dtype=torch.uint8)
return output
def unpack_uint8_into_bool(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
) -> BinaryTensorType:
if packed_tensor.ndim != 1:
raise ValueError
if packed_tensor.dtype != torch.uint8:
raise TypeError
# Some constants
packed_dtype = torch.uint8
packed_num_bits = torch.iinfo(packed_dtype).bits
# [1, packed_num_bits]
bits = torch.tensor(
1,
dtype=packed_dtype,
device=packed_tensor.device)
bits = bits << torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=packed_tensor.device)
bits = torch.unsqueeze(
bits,
dim=0)
unpacked_tensor = torch.unsqueeze(
packed_tensor,
dim=-1)
unpacked_tensor = unpacked_tensor & bits
unpacked_tensor = unpacked_tensor > 0
unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
unpacked_tensor = unpacked_tensor.view(-1)
if padding_length > 0:
unpacked_tensor = unpacked_tensor[:-padding_length]
return unpacked_tensor
@torch.compile(fullgraph=True, backend="hidet")
def unpack_integer_tensors(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
num_bits: int,
shape: Tuple[int, ...],
) -> UInt8TensorType:
packed_size = (
(math.prod(shape) * num_bits + padding_length) /
DEFAULT_CONTAINER_NUM_BITS)
if packed_tensor.shape != (packed_size,):
raise ValueError
# [tensor.numel() x num_bits / 8]
packed_tensor = packed_tensor.contiguous()
# [tensor.numel() x num_bits]
binary_tensor = unpack_uint8_into_bool(
packed_tensor=packed_tensor,
padding_length=padding_length)
# [*tensor.shape, num_bits]
binary_tensor = binary_tensor.view(
*shape, num_bits)
return from_binary(
tensor=binary_tensor,
num_bits=num_bits)
num_bits = 8
shape = torch.Size([1024, 256, 1])
unpack_integer_tensors(
torch.randint(
2 ** 8,
size=(shape.numel(),),
dtype=torch.uint8,
device="cuda"),
padding_length=0,
num_bits=num_bits,
shape=shape,
)
Amazing, thanks a ton! Will give this a try soon.