Tensor multiplications not giving right results
dm701 opened this issue · 3 comments
Hello,
unless I'm mistaken, the following does not return the correct results:
(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
returns:
(tensor (tensor 5 10) (tensor 18 24))
whereas I would expect it to return:
(tensor (tensor 5 12) (tensor 15 24))
However the following multiplication:
(* (tensor (tensor 3 4 5) (tensor 7 8 9)) (tensor 2 4 3))
returns:
(tensor (tensor 6 16 15) (tensor 14 32 27))
which seems to be correct.
I was unsure if I was doing it wrong, so I checked with chatgpt and it has confirmed that something is incorrect somewhere...
I would most be grateful to receive a reply on this please.
Thank you
@dm701 Sorry for a late reply.
Based on the definition of ext2
and desc
in the book:
(define ext2
(λ (f n m)
(λ (t u)
(cond
((of-ranks? n t m u) (f t u))
(else
(desc (ext2 f n m) n t m u))))))
(define desc
(λ (g n t m u)
(cond
((of-rank? n t) (desc-u g t u))
((of-rank? m u) (desc-t g t u))
((= (tlen t) (tlen u)) (tmap g t u))
((rank> t u) (desc-t g t u))
(else (desc-u g t u)))))
The example you provided shows the expected behavior:
(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
=> (tensor (tensor 5 10) (tensor 18 24))
Because (quoting the book):
Our next clause deals with when
t and u have the same number of elements,
which we determine using(tlen t)
and(tlen u)
.
This means we descend into both simultaneously.
Let's see the same-as chart:
1. | (* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
2. | (tmap * (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
3. | (tensor (* (tensor 1 2) 5) (* (tensor 3 4) 6))
4. | (tensor (tensor 5 10) (tensor 18 24))
If we want your expected behavior:
(* (tensor (tensor 1 2) (tensor 3 4)) (tensor 5 6))
=> (tensor (tensor 5 12) (tensor 15 24))
We need to change one line of the definition of desc
,
from
((= (tlen t) (tlen u)) (tmap g t u))
to
((= (rank t) (rank u)) (tmap g t u))
which is also a reasonable definition of desc
and ext2
,
since all other clauses in the cond
of desc
is based on rank
of a tensor instead of tlen
(which is part of the shape
of a tensor),
maybe this is why you are surprised.
@themetaschemer Maybe using rank
is a better way to extend binary functions here?
Because given the definition of plane
, we have:
((plane [1 3]) (list [1 2] 3)) => 10
((plane [2 4]) (list [1 2] 3)) => 13
((plane [[1 3]]) (list [1 2] 3)) => [10]
((plane [[2 4]]) (list [1 2] 3)) => [13]
We naturally want the following result:
((plane [[1 3] [2 4]]) (list [1 2] 3)) => [10 13]
But if we use tlen
instead of rank
, we get:
((plane [[1 3] [2 4]]) (list [1 2] 3)) => [7 15]
Runnable code snippet (in the Racket REPL):
(require malt)
((plane (tensor 1 3)) (list (tensor 1 2) 3)) ;=> 10.0
((plane (tensor 2 4)) (list (tensor 1 2) 3)) ;=> 13.0
((plane (tensor (tensor 1 3))) (list (tensor 1 2) 3)) ;=> (tensor 10.0)
((plane (tensor (tensor 2 4))) (list (tensor 1 2) 3)) ;=> (tensor 13.0)
((plane (tensor (tensor 1 3) (tensor 2 4))) (list (tensor 1 2) 3)) ;=> (tensor 7.0 15.0)