Tractables/pyjuice

Triton Illegal Memory Access

Closed this issue · 3 comments

I am running the code from example 01_train_pc.py to learn a HCLT on MNIST dataset. The only thing I have added is a marginal query after the parameter learning. All the code functions properly until the marginal query, which throws me an error.

query:
data = torch.rand((28,28)).long().to(device)
lls = juice.queries.marginal(pc, data)

error:
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

The sampling query works, but marginal and conditional do not work.

Hi mjojic,

The PC in that script assumes categorical input variables with range 0-255. So you need to change the dtype of data to torch.long. Also, its shape needs to be [batch size, # variables].

You can refer to this tutorial for more information https://tractables.github.io/pyjuice/getting-started/tutorials/04_query_pc.html#sphx-glr-getting-started-tutorials-04-query-pc-py.

Best,
Anji

Yes, the memory access ended up being a shape issue.

With the correct shape, I ran into an error about the TILE_SIZE < 4 (coming from sum_layer.py block sparse kernel). By setting block_size=1 in the HCLT definition, then I was able to do a marginal call without the tile size error.

You may increase the batch size to e.g. 16 to avoid the TILE_SIZE error for now. I will fix it in the near future. Thanks!