hidet-org/hidet

[Bug] Outputs of torch.flatten abnormally mismatch on GPU when adding an intermediate result as output

Azyka opened this issue · 1 comments

Describe the bug
When adding the intermediate result of original output as an extra output in this model:

class Model0():
    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (flatten)

New:

class Model1():
    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (abs_1, flatten)

The output of torch.flatten is expected to be the same for the same input. However, it mismatched between the 2 models.
This mismatch is seen only on cuda.

To Reproduce
Repro script:

import numpy as np
import pickle
from numpy import testing
import torch

DEVICE='cuda'

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (flatten)

model_0 = Model0()
output_names_0 = ['v0_0']

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (abs_1, flatten)

model_1 = Model1()
output_names_1 = ['v5_0', 'v0_0']

data = np.array([6, 3, 4, 5, 4, 7, 5, 5, 3, 3, 4, 4, 6, 3, 7, 5, 4, 3, 5, 6, 3, 7,
       7, 5, 6, 6, 5, 4, 5, 6, 5, 3, 3, 5, 4, 5, 3, 7, 6, 6, 6, 4, 5, 3,
       7, 4, 4, 6, 5, 3, 7], dtype=np.int8)
input_data_0 = [data]

optmodel_0 = torch.compile(model_0, fullgraph=True, backend='hidet', mode=None)
model_out_0 = optmodel_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

input_data_1 = [data]

optmodel_1 = torch.compile(model_1, fullgraph=True, backend='hidet', mode=None)
model_out_1 = optmodel_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))
output_name_dict = {'v0_0': 'v0_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("hidet does not trigger assertion")
except AssertionError as e:
    print("hidet triggers assertion")
    print(e)
print('=========================')

model_out_0 = model_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

model_out_1 = model_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("torch_eager does not trigger assertion")
except AssertionError as e:
    print("torch_eager triggers assertion")
    print(e)
print('=========================')

Output:

=========================
hidet triggers assertion

Not equal to tolerance rtol=1, atol=0
at v0_0, v0_0
Mismatched elements: 51 / 51 (100%)
Max absolute difference: 7
Max relative difference: inf
 x: array([6, 3, 4, 5, 4, 7, 5, 5, 3, 3, 4, 4, 6, 3, 7, 5, 4, 3, 5, 6, 3, 7,
       7, 5, 6, 6, 5, 4, 5, 6, 5, 3, 3, 5, 4, 5, 3, 7, 6, 6, 6, 4, 5, 3,
       7, 4, 4, 6, 5, 3, 7], dtype=int8)
 y: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0], dtype=int8)
=========================
=========================
torch_eager does not trigger assertion
=========================

Expected behavior
The output of torch.flatten is expected to be the same for the same input.

Enviroment

  • OS: Ubuntu 22.04.3 LTS (x86_64)
  • GPU: RTX 1660
  • NVIDIA GPU Driver: 525.147.05
  • Hidet Version: 0.3.0
  • PyTorch Version: 2.1.0+cu118

Fixed in #384 , Thanks for you efforts on it! @Aalanli and @yaoyaoding