hikettei/Caten

Opt: remove unused allocations with JIT=1 autodiff

Opened this issue · 0 comments

with JIT=1:

CATEN/TEST-SUITE>
(let ((a (make-tensor `(3 3) :requires-grad t :initial-element 2.0))
	  (b (make-tensor `(3 3) :requires-grad t :initial-element 3.0)))
      (let ((m (caten (!mul a b)))) (print m)
	(forward m)
	(backward m nil)
        (print (grad a)) (grad b)))
[graph-schedule] Schedule Graph:

FastGraph[seen=NIL, outputs=(val_5 val_10 val_12)] {
    { Allocate } : [ val_2 <- (3 3) where lowered-p=nil ]
    { Allocate } : [ val_0 <- (3 3) where lowered-p=nil ]
    {  KERNEL  } : [ val_3, val_1, val_4 <- val_0, val_2 where lowered-p=nil :name=FUSED_MUL_LOAD_LOAD277235]
    {   VMOP   } : [ val_5 <- val_4 where lowered-p=nil :name=FUSED_BACKWARD277233]
    { Allocate } : [ GRAD277067 <- (3 3) where lowered-p=nil ]
    { Allocate } : [ PREV-GRAD <- (3 3) where lowered-p=nil ]
    { Allocate } : [ GRAD277070 <- (3 3) where lowered-p=nil ]
    {  KERNEL  } : [ val_12, val_10 <- val_3, val_6, GRAD277070, PREV-GRAD, val_1, GRAD277067 where lowered-p=nil :name=FUSED_MOVE_ADD_MUL_LOAD_MOVE_ADD_MUL277254]
}

[19:15:50, 11/10/2024 (GMT+9)] : JIT Compilation Start (AVM=MAIN277073)

* (1/2) FUSED_MOVE_ADD_MUL_LOAD_MOVE_ADD_MUL277254
=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<9);_gid0+=1) {
    val_6 = 1.0;
    val_7 = (val_3[_gid0]*val_6);
    val_12[_gid0] = val_7;
    val_8 = (val_1[_gid0]*val_6);
    val_10[_gid0] = val_8;
  } // _gid0
}
Compilation Time : 0.004925(sec)
* (2/2) FUSED_MUL_LOAD_LOAD277235 

=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<9);_gid0+=1) {
    val_3[_gid0] = 2.0;
    val_1[_gid0] = 3.0;
    val_4[_gid0] = (val_3[_gid0]*val_1[_gid0]);
  } // _gid0
}
Compilation Time : 0.001634(sec)
[19:15:50, 11/10/2024 (GMT+9)] : Running the memory planner...
[19:15:50, 11/10/2024 (GMT+9)] :  | number of allocations: 7 -> 5
[19:15:50, 11/10/2024 (GMT+9)] :  | total allocation size: 2.52e-7 GB -> 1.8e-7 GB
[19:15:50, 11/10/2024 (GMT+9)] :  | Compressing rate(GB):  28.571%
[19:15:50, 11/10/2024 (GMT+9)] : Rendering ...
[19:15:50, 11/10/2024 (GMT+9)] : Compiling ...
[Final Code]:

#include <math.h>
#include <stdint.h>

#define boolean _Bool
#define _infinity INFINITY
#define _negative_infinity -INFINITY
#define _nan NAN
#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))
void fused_mul_load_load277235(float* val_4, float* val_1, float* val_3);
void fused_mul_load_load277235(float* val_4, float* val_1, float* val_3) {
  for (int _gid0=0; (_gid0<9); _gid0+=1) {
    val_3[_gid0] = 2.0;
    val_1[_gid0] = 3.0;
    val_4[_gid0] = (val_3[_gid0]*val_1[_gid0]);
  }
}
void fused_move_add_mul_load_move_add_mul277254(float* val_1, float* val_3, const float* restrict val_4);
void fused_move_add_mul_load_move_add_mul277254(float* val_1, float* val_3, const float* restrict val_4) {
  for (int _gid0=0; (_gid0<9); _gid0+=1) {
    float val_6 = 1.0;
    float val_7 = (val_3[_gid0]*val_6);
    val_3[_gid0] = val_7;
    float val_8 = (val_1[_gid0]*val_6);
    val_1[_gid0] = val_8;
  }
}


#S(AVM
   :GRAPH 
Graph[seen=NIL, outputs=NIL] {
    <ALLOCATE : val_4 <- (shape=(3, 3), stride=(3, 1)) where :nrank=2 :dtype=FLOAT32 :_read_views=NIL :_output_type=NIL>
    <ALLOCATE : val_1 <- (shape=(3, 3), stride=(3, 1)) where :nrank=2 :dtype=FLOAT32 :_read_views=NIL :_output_type=NIL>
    <ALLOCATE : val_3 <- (shape=(3, 3), stride=(3, 1)) where :nrank=2 :dtype=FLOAT32 :_read_views=NIL :_output_type=NIL>
    <Node[JIT] JIT_KERNEL(NID280728) : val_4, val_1, val_3 <- (val_4, val_1, val_3) where :output-buffer-n=3 :kernel-info=<CLANG[FUSED_MUL_LOAD_LOAD277235]> :dtypes=(FLOAT32
                                                                                   FLOAT32
                                                                                   FLOAT32)>
    <Node[SPECIAL/VM] PAUSE/BACKWARD(NID277214) : val_5 <- (val_4)>
    <ALLOCATE : GRAD277070 <- (shape=(3, 3), stride=(3, 1)) where :nrank=2 :dtype=FLOAT32 :_type_relay=NIL :_read_views=NIL :_output_type=NIL>
    <ALLOCATE : GRAD277067 <- (shape=(3, 3), stride=(3, 1)) where :nrank=2 :dtype=FLOAT32 :_type_relay=NIL :_read_views=NIL :_output_type=NIL>
    <Node[JIT] JIT_KERNEL(NID280729) : val_10, val_12 <- (val_1, val_3, val_4) where :output-buffer-n=2 :kernel-info=<CLANG[FUSED_MOVE_ADD_MUL_LOAD_MOVE_ADD_MUL277254]> :dtypes=(FLOAT32
                                                                                                    FLOAT32
                                                                                                    FLOAT32)>
}

   :NAME :MAIN277073
   :FW-OUTPUTS (|val_5|)
   :BW-OUTPUTS (|val_10| |val_12|)
   :ID2TENSOR #<HASH-TABLE :TEST EQL :COUNT 3 {703F09B743}>
   :TAPE-LENGTH 8
   :PC 0
   :VARIABLES #<HASH-TABLE :TEST EQL :COUNT 0 {703F9C23D3}>
   :DUMPED NIL) 
{Tensor[float32] :shape (3 3) :id GRAD277067
   ((3.0 3.0 3.0)
    (3.0 3.0 3.0)
    (3.0 3.0 3.0))
  :op #<ALLOCATE {703F05E883}>
  :requires-grad NIL
  :variables NIL
  :tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>} 
{Tensor[float32] :shape (3 3) :id GRAD277070
   ((2.0 2.0 2.0)
    (2.0 2.0 2.0)
    (2.0 2.0 2.0))
  :op #<ALLOCATE {703F05ED83}>
  :requires-grad NIL
  :variables NIL
  :tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
CATEN/TEST-SUITE> 
  • GRAD277070/GRAD277067 is not used. thus it should be purged from the graph.