SIGSEGV fatal error when re-assigning a tensor that has been previously split
Opened this issue · 3 comments
Hello!
I stumbled upon a fatal error while using torch.split
+ reassignment of a tensor - not sure how to even start debugging this, but I am documenting it here in case someone knows how to investigate this further.
Here is a way to replicate the error.
@ val data = torch.arange(0L, 1_000_000L)
data: Tensor[Int64] = tensor dtype=int64, shape=[1000000], device=CPU
[0, 1, 2, ..., 999997, 999998, 999999]
@ val Seq(a,b) = torch.split(data, 600_000)
@ a
res2: Tensor[Int64] = tensor dtype=int64, shape=[600000], device=CPU
[0, 1, 2, ..., 599997, 599998, 599999]
@ b
res3: Tensor[Int64] = tensor dtype=int64, shape=[400000], device=CPU
[600000, 600001, 600002, ..., 999997, 999998, 999999]
@ val x = a
#
# A fatal error has been detected by the Java Runtime Environment:
#
# SIGSEGV (0xb) at pc=0x00000001b984535a, pid=19939, tid=9731
#
# JRE version: OpenJDK Runtime Environment Zulu19.30+11-CA (19.0.1+10) (build 19.0.1+10)
# Java VM: OpenJDK 64-Bit Server VM Zulu19.30+11-CA (19.0.1+10, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-amd64)
# Problematic frame:
# C [libjnitorch.dylib+0x4ff35a] Java_org_bytedeco_pytorch_TensorBase_sizes+0x5a
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# An error report file with more information is saved as:
# /experiments/hs_err_pid19939.log
#
# If you would like to submit a bug report, please visit:
# http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#
Trying with other variations.. any operation done after any of the portions of tensor.split
causes this panic, even a + 1
Quick update:
Adding .clone()
seems to fix the issue, I wonder what other operations might require the same fix. Will do a PR soon.
I also wonder what are the implications on calling .clone()
in terms of memory usage or any other computing factor.
--- a/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
+++ b/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
@@ -1014,7 +1014,7 @@ private[torch] trait IndexingSlicingJoiningOps {
case i: Int => torchNative.split(input.native, i.toLong, dim.toLong)
case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
}
- (0L until result.size()).map(i => Tensor(result.get(i)))
+ (0L until result.size()).map(i => Tensor(result.get(i)).clone())
}
/** Returns a tensor with all specified dimensions of `input` of size 1 removed.
It might have to do something with the fact that split returns a view.
https://pytorch.org/docs/stable/generated/torch.split.html:
Splits the tensor into chunks. Each chunk is a view of the original tensor.
It's just a guess for now, but it would explain why clone()
makes a difference.
I ran into this while implementing tensor printing, which needs to convert tensor values to buffers, and crashed on non-contiguous values, as the memory layout of views can sometimes be non-contiguous.
storch/core/src/main/scala/torch/Tensor.scala
Line 558 in 19abf3e
In this case the view should be contiguous, so it's not exactly the same issue, but it could still be related to being a view.
Interestingly, your example works on my machine (I tried in ammonite too):
object Split extends App {
val data = torch.arange(0L, 1_000_000L)
val Seq(a, b) = torch.split(data, 600_000)
println(a)
println(b)
val x = a
println(x)
}
Perhaps we also need to understand why the Python impl calls split_with_sizes
in certain cases. We might need to do something similar.
def split(self, split_size, dim=0):
r"""See :func:`torch.split`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.split, (self,), self, split_size, dim=dim
)
if isinstance(split_size, Tensor):
try:
split_size = int(split_size)
except ValueError:
pass
if isinstance(split_size, (int, torch.SymInt)):
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
else:
return torch._VF.split_with_sizes(self, split_size, dim)