sbrunk/storch

SIGSEGV fatal error when re-assigning a tensor that has been previously split

Opened this issue · 3 comments

Hello!

I stumbled upon a fatal error while using torch.split + reassignment of a tensor - not sure how to even start debugging this, but I am documenting it here in case someone knows how to investigate this further.

Here is a way to replicate the error.

@ val data = torch.arange(0L, 1_000_000L)
data: Tensor[Int64] = tensor dtype=int64, shape=[1000000], device=CPU
[0, 1, 2, ..., 999997, 999998, 999999]

@ val Seq(a,b) = torch.split(data, 600_000)

@ a
res2: Tensor[Int64] = tensor dtype=int64, shape=[600000], device=CPU
[0, 1, 2, ..., 599997, 599998, 599999]

@ b
res3: Tensor[Int64] = tensor dtype=int64, shape=[400000], device=CPU
[600000, 600001, 600002, ..., 999997, 999998, 999999]

@ val x = a
#
# A fatal error has been detected by the Java Runtime Environment:
#
#  SIGSEGV (0xb) at pc=0x00000001b984535a, pid=19939, tid=9731
#
# JRE version: OpenJDK Runtime Environment Zulu19.30+11-CA (19.0.1+10) (build 19.0.1+10)
# Java VM: OpenJDK 64-Bit Server VM Zulu19.30+11-CA (19.0.1+10, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-amd64)
# Problematic frame:
# C  [libjnitorch.dylib+0x4ff35a]  Java_org_bytedeco_pytorch_TensorBase_sizes+0x5a
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# An error report file with more information is saved as:
# /experiments/hs_err_pid19939.log
#
# If you would like to submit a bug report, please visit:
#   http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

Trying with other variations.. any operation done after any of the portions of tensor.split causes this panic, even a + 1

Quick update:

Adding .clone() seems to fix the issue, I wonder what other operations might require the same fix. Will do a PR soon.

I also wonder what are the implications on calling .clone() in terms of memory usage or any other computing factor.

--- a/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
+++ b/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
@@ -1014,7 +1014,7 @@ private[torch] trait IndexingSlicingJoiningOps {
         case i: Int      => torchNative.split(input.native, i.toLong, dim.toLong)
         case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
       }
-    (0L until result.size()).map(i => Tensor(result.get(i)))
+    (0L until result.size()).map(i => Tensor(result.get(i)).clone())
   }

   /** Returns a tensor with all specified dimensions of `input` of size 1 removed.
sbrunk commented

It might have to do something with the fact that split returns a view.

https://pytorch.org/docs/stable/generated/torch.split.html:

Splits the tensor into chunks. Each chunk is a view of the original tensor.

It's just a guess for now, but it would explain why clone() makes a difference.

I ran into this while implementing tensor printing, which needs to convert tensor values to buffers, and crashed on non-contiguous values, as the memory layout of views can sometimes be non-contiguous.

val buf = tensor.native.contiguous.createBuffer[B]

In this case the view should be contiguous, so it's not exactly the same issue, but it could still be related to being a view.

Interestingly, your example works on my machine (I tried in ammonite too):

object Split extends App {
  val data = torch.arange(0L, 1_000_000L)
  val Seq(a, b) = torch.split(data, 600_000)
  println(a)
  println(b)
  val x = a
  println(x)
}
sbrunk commented

Perhaps we also need to understand why the Python impl calls split_with_sizes in certain cases. We might need to do something similar.

https://github.com/pytorch/pytorch/blob/9adfaf880784ec0cf5f085fc3f282cf53650050f/torch/_tensor.py#L770

    def split(self, split_size, dim=0):
        r"""See :func:`torch.split`"""
        if has_torch_function_unary(self):
            return handle_torch_function(
                Tensor.split, (self,), self, split_size, dim=dim
            )
        if isinstance(split_size, Tensor):
            try:
                split_size = int(split_size)
            except ValueError:
                pass


        if isinstance(split_size, (int, torch.SymInt)):
            return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
        else:
            return torch._VF.split_with_sizes(self, split_size, dim)