Bug in tref for flat-tensors implementation
DarshalShetty opened this issue · 1 comments
DarshalShetty commented
Here is an example expression:
(tref (tref (tref (tensor
(tensor
(tensor (tensor -1) (tensor -2) (tensor -3))
(tensor (tensor -4) (tensor 1.0) (tensor -5))
(tensor (tensor -6) (tensor -7) (tensor -8)))) 0) 1) 1)
Expected result:
(tensor 1.0)
Actual result:
(tensor -2)
The fix seems to be simple. The current definition of tref
doesn't consider the offset of the input tensor while computing the offset of the output tensor.
Current definition:
(define tref
(λ (t i)
(cond
((= 1 (flat-rank t))
(vref (flat-store t) (+ (flat-offset t) i)))
(else
(flat (cdr (flat-shape t))
(flat-store t)
(* i (car (flat-strides t))))))))
Suggested fix:
(define tref
(λ (t i)
(cond
((= 1 (flat-rank t))
(vref (flat-store t) (+ (flat-offset t) i)))
(else
(flat (cdr (flat-shape t))
(flat-store t)
(+ (flat-offset t) (* i (car (flat-strides t)))))))))
DarshalShetty commented
Will submit a PR shortly