iree-org/iree-turbine

Adding 0d tensors to parameter archive is unreliable

Closed this issue · 0 comments

In general, getting a raw buffer representing a tensor is done like this (to add to a parameter archive):

flat_array = tensor.detach().flatten().contiguous().cpu().view(torch.uint8)
host_array = flat_array.numpy()

There is some kind of internal path in torch that is different, though, for scalar 0d tensors. It has been observed that taking a view like this seems to not always point to a persistent version of that backing data. Such tensors act a little bit differently indicating some special case optimization or something for them (i.e. if viewing them as uint8, you can directly access the contents by casting to bytes without first going to numpy()).

Through some trial and error, I have found that removing the view and directly capturing a copy as bytes seems to be reliable. While not satisfied without having run down the root cause, this seems robust.

Example:

        if len(tensor.shape) == 0:
            flat_array = tensor.detach().cpu().numpy()
            host_array = bytes(flat_array)
        else:
            flat_array = tensor.detach().flatten().contiguous().cpu().view(torch.uint8)
            host_array = flat_array.numpy()