hidet-org/hidet

Is `hidet_launch` called by any other runtimes to inference?

VincentXWD opened this issue · 1 comments

Hello, I would like to ask a question about generated code: graph_module/source.cc.
I saw this code has a method hidet_launch. I wonder if it is used by some runtime library like pytorch or onnx-runtime?
For example, a hidet_launch method likes:

DLL void hidet_launch(void * __restrict__ inputs, void * __restrict__ inputs_1, void * __restrict__ inputs_2, void * __restrict__ outputs, void * __restrict__ outputs_1, void* * __restrict__ p_kernels) {
  memory_planner_init(0);
  memory_planner_init(1);
  int64_t fusedmultiplyscalar = memory_planner_allocate(1, int64_t(12ll));
  void (*hidet_k0_FusedMultiplyScalar)(void*, void*) = ((void (*)( void*, void*))(p_kernels[0]));
  hidet_k0_FusedMultiplyScalar(inputs_1, (cuda_workspace + fusedmultiplyscalar));
  int64_t fusedadd = memory_planner_allocate(1, int64_t(9216ll));
  void (*hidet_k1_FusedAdd)(void*, void*, void*, void*, void*, void*) = ((void (*)( void*, void*, void*, void*, void*, void*))(p_kernels[1]));
  hidet_k1_FusedAdd(weights[24], weights[21], inputs, weights[139], inputs_2, (cuda_workspace + fusedadd));
  int64_t fusednormalize = memory_planner_allocate(1, int64_t(9216ll));
  void (*hidet_k2_FusedNormalize)(void*, void*, void*) = ((void (*)( void*, void*, void*))(p_kernels[2]));
  hidet_k2_FusedNormalize((cuda_workspace + fusedadd), weights[37], (cuda_workspace + fusednormalize));
  memory_planner_free(1, fusedadd);
  int64_t fusedbatchmatmul = memory_planner_allocate(1, int64_t(221184ll));
  void (*hidet_k3_FusedBatchMatmul)(void*, void*, void*, void*) = ((void (*)( void*, void*, void*, void*))(p_kernels[3]));
  hidet_k3_FusedBatchMatmul(weights[36], (cuda_workspace + fusednormalize), weights[9], (cuda_workspace + fusedbatchmatmul));
  int64_t reducesum = memory_planner_allocate(1, int64_t(27648ll));
  void (*hidet_k4_ReduceSum)(void*, void*) = ((void (*)( void*, void*))(p_kernels[4]));
  hidet_k4_ReduceSum((cuda_workspace + fusedbatchmatmul), (cuda_workspace + reducesum));
  memory_planner_free(1, fusedbatchmatmul);
  int64_t add = memory_planner_allocate(1, int64_t(27648ll));
  void (*hidet_k5_Add)(void*, void*, void*) = ((void (*)( void*, void*, void*))(p_kernels[5]));
  hidet_k5_Add((cuda_workspace + reducesum), weights[91], (cuda_workspace + add));
  memory_planner_free(1, reducesum);
  int64_t fusedbatchmatmul_1 = memory_planner_allocate(1, int64_t(3456ll));
  void (*hidet_k6_FusedBatchMatmul)(void*, void*) = ((void (*)( void*, void*))(p_kernels[6]));
  hidet_k6_FusedBatchMatmul((cuda_workspace + add), (cuda_workspace + fusedbatchmatmul_1));
  int64_t reducesum_1 = memory_planner_allocate(1, int64_t(432ll));
  void (*hidet_k7_ReduceSum)(void*, void*) = ((void (*)( void*, void*))(p_kernels[7]));
  hidet_k7_ReduceSum((cuda_workspace + fusedbatchmatmul_1), (cuda_workspace + reducesum_1));
  memory_planner_free(1, fusedbatchmatmul_1);
  int64_t fusedsoftmax = memory_planner_allocate(1, int64_t(432ll));
  void (*hidet_k8_FusedSoftmax)(void*, void*, void*) = ((void (*)( void*, void*, void*))(p_kernels[8]));
  hidet_k8_FusedSoftmax((cuda_workspace + fusedmultiplyscalar), (cuda_workspace + reducesum_1), (cuda_workspace + fusedsoftmax));
  memory_planner_free(1, reducesum_1);
  int64_t fusedbatchmatmul_2 = memory_planner_allocate(1, int64_t(9216ll));
  void (*hidet_k9_FusedBatchMatmul)(void*, void*, void*) = ((void (*)( void*, void*, void*))(p_kernels[9]));
  hidet_k9_FusedBatchMatmul((cuda_workspace + fusedsoftmax), (cuda_workspace + add), (cuda_workspace + fusedbatchmatmul_2));
  memory_planner_free(1, fusedsoftmax);
  memory_planner_free(1, add);
  int64_t fusedbatchmatmul_3 = memory_planner_allocate(1, int64_t(73728ll));
  void (*hidet_k10_FusedBatchMatmul)(void*, void*, void*) = ((void (*)( void*, void*, void*))(p_kernels[10]));
  hidet_k10_FusedBatchMatmul((cuda_workspace + fusedbatchmatmul_2), weights[11], (cuda_workspace + fusedbatchmatmul_3));
  memory_planner_free(1, fusedbatchmatmul_2);
  int64_t reducesum_2 = memory_planner_allocate(1, int64_t(9216ll));
  void (*hidet_k11_ReduceSum)(void*, void*) = ((void (*)( void*, void*))(p_kernels[11]));
  hidet_k11_ReduceSum((cuda_workspace + fusedbatchmatmul_3), (cuda_workspace + reducesum_2));
  memory_planner_free(1, fusedbatchmatmul_3);
  int64_t fusedadd_1 = memory_planner_allocate(1, int64_t(9216ll));
  void (*hidet_k12_FusedAdd)(void*, void*, void*, void*) = ((void (*)( void*, void*, void*, void*))(p_kernels[12]));
...
}

If it is, is that means the process is running by sequence?

Hi @VincentXWD,

This function will be loaded in to the CompiledGraph in hidet's python package (see https://github.com/hidet-org/hidet/blob/main/python/hidet/runtime/compiled_graph.py#L85) and be used when we run a compiled hidet model.