themetaschemer/malt

Bug in tref for flat-tensors implementation

DarshalShetty opened this issue · 1 comments

Here is an example expression:

(tref (tref (tref (tensor
  (tensor
        (tensor (tensor -1) (tensor -2) (tensor -3))
        (tensor (tensor -4) (tensor 1.0) (tensor -5))
        (tensor (tensor -6) (tensor -7) (tensor -8)))) 0) 1) 1)

Expected result:

(tensor 1.0)

Actual result:

(tensor -2)

The fix seems to be simple. The current definition of tref doesn't consider the offset of the input tensor while computing the offset of the output tensor.

Current definition:

(define tref
  (λ (t i)
    (cond
     ((= 1 (flat-rank t))
      (vref (flat-store t) (+ (flat-offset t) i)))
     (else
      (flat (cdr (flat-shape t))
            (flat-store t)
            (* i (car (flat-strides t))))))))

Suggested fix:

(define tref
  (λ (t i)
    (cond
     ((= 1 (flat-rank t))
      (vref (flat-store t) (+ (flat-offset t) i)))
     (else
      (flat (cdr (flat-shape t))
            (flat-store t)
            (+ (flat-offset t) (* i (car (flat-strides t)))))))))

Will submit a PR shortly