themetaschemer/malt

Tensor multiplications not giving right results

dm701 opened this issue · 3 comments

Hello,

unless I'm mistaken, the following does not return the correct results:

(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))

returns:

(tensor (tensor 5 10) (tensor 18 24))

whereas I would expect it to return:

(tensor (tensor 5 12) (tensor 15 24))


However the following multiplication:

(* (tensor (tensor 3 4 5) (tensor 7 8 9)) (tensor 2 4 3))

returns:

(tensor (tensor 6 16 15) (tensor 14 32 27))

which seems to be correct.


I was unsure if I was doing it wrong, so I checked with chatgpt and it has confirmed that something is incorrect somewhere...

I would most be grateful to receive a reply on this please.

Thank you

@dm701 Sorry for a late reply.

Based on the definition of ext2 and desc in the book:

(define ext2
  (λ (f n m)
    (λ (t u)
      (cond
        ((of-ranks? n t m u) (f t u))
        (else
         (desc (ext2 f n m) n t m u))))))

(define desc
  (λ (g n t m u)
    (cond
      ((of-rank? n t) (desc-u g t u))
      ((of-rank? m u) (desc-t g t u))
      ((= (tlen t) (tlen u)) (tmap g t u))
      ((rank> t u) (desc-t g t u))
      (else (desc-u g t u)))))

The example you provided shows the expected behavior:

(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
=> (tensor (tensor 5 10) (tensor 18 24))

Because (quoting the book):

Our next clause deals with when
t and u have the same number of elements,
which we determine using (tlen t) and (tlen u).
This means we descend into both simultaneously.

Let's see the same-as chart:

1. | (* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
2. | (tmap * (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
3. | (tensor (* (tensor 1 2) 5) (* (tensor 3 4) 6))
4. | (tensor (tensor 5 10) (tensor 18 24))

If we want your expected behavior:

(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
=> (tensor (tensor 5 12) (tensor 15 24))

We need to change one line of the definition of desc,
from

((= (tlen t) (tlen u)) (tmap g t u))

to

((= (rank t) (rank u)) (tmap g t u))

which is also a reasonable definition of desc and ext2,
since all other clauses in the cond of desc is based on rank of a tensor instead of tlen (which is part of the shape of a tensor),
maybe this is why you are surprised.

@themetaschemer Maybe using rank is a better way to extend binary functions here?

Because given the definition of plane, we have:

((plane [1 3]) (list [1 2] 3)) => 10
((plane [2 4]) (list [1 2] 3)) => 13
((plane [[1 3]]) (list [1 2] 3)) => [10]
((plane [[2 4]]) (list [1 2] 3)) => [13]

We naturally want the following result:

((plane [[1 3] [2 4]]) (list [1 2] 3)) => [10 13]

But if we use tlen instead of rank, we get:

((plane [[1 3] [2 4]]) (list [1 2] 3)) => [7 15]

Runnable code snippet (in the Racket REPL):

(require malt)

((plane (tensor 1 3)) (list (tensor 1 2) 3)) ;=> 10.0
((plane (tensor 2 4)) (list (tensor 1 2) 3)) ;=> 13.0
((plane (tensor (tensor 1 3))) (list (tensor 1 2) 3)) ;=> (tensor 10.0)
((plane (tensor (tensor 2 4))) (list (tensor 1 2) 3)) ;=> (tensor 13.0)
((plane (tensor (tensor 1 3) (tensor 2 4))) (list (tensor 1 2) 3)) ;=> (tensor 7.0 15.0)