microsoft/knossos-ksc

TS2K polymorphic over tensor rank

awf opened this issue · 2 comments

awf commented

AB#19174

User writes

vrelu3 = knossos.vmap(relu3)

Which calls this stub

class KnossosVMap: #May need to be already a child of Torch.Autograd.Function
    def __init__(self, f):
        self.f = f

def knossos_vmap(function, generate_lm=True):
    return KnossosVMap(function)

and later, when we compile, this happens:

def ts2mod(function, example_inputs, generate_lm=True):
    if isinstance(function, KnossosVMap):
        function = function.f
        # special stuff... knossos_vmap_do_it
    else:
        fn = torch.jit.script(function)
        ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
    
        derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
        return ksc_defs_to_module([ksc_def], ksc_def, derivatives_to_generate)


def knossos_vmap_do_it(function, args, generate_lm=True):
    """
    Given a python function foo

    """
    # Extract torchscript graph
    fn = torch.jit.script(function)
    # convert to KSC Expr
    ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
    # ksc_def = (def relu3 None ((_x$o1 : ARG_TYPE)) ... )

    # Now insert definitions for all the maps
    if True:
        assert len(args) == 1
        arg_type = ksc_def.args[0].Type
        vardecl, var = make_new_var(arg_type)
        map_lambda = Lam(vardecl, Call(ksc_def.name, [var]))

        # Now build
        """
(def vrelu3_Tensor_1_Float (vs : Tensor 1 Float)
    (map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_2_Float (vs : Tensor 2 Float)
    (map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_3_Float (vs : Tensor 3 Float)
    (map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_4_Float (vs : Tensor 4 Float)
    (map (lam (v : Float) (relu3 v)) vs))
"""
        # return
        #     Call(
        #         "map", [map_lambda] + [var_or_constant(i) for i in tail(node.inputs())]
        #     ),
        # )

        def wrapper(t: Tensor) -> Tensor:
            rank = len(size(t))
            if rank == 1:
                return vrelu3_Tensor_1_Float(t)
            if rank == 2:
                return vrelu3_Tensor_2_Float(t)
            if rank == 3:
                return vrelu3_Tensor_3_Float(t)
            if rank == 4:
                return vrelu3_Tensor_4_Float(t)
            return KnossosVMapCompiler(rank)

    derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
    return ksc_defs_to_module([ksc_def], ksc_def, derivatives_to_generate)
awf commented

And we may want to incorporate ideas from #849, e.g.

template <class T>
struct ks_tensor<T> {
   ks::Tensor<1, T> t1;
   ks::Tensor<2, T> t2;
   ks::Tensor<3, T> t3;
   ks::Tensor<4, T> t4;
   Py_Object py_handle;
};

or messy and tricksy

template <class T>
struct ks_tensor<T> {
   ks::Tensor<7, T> t;
   Py_Object py_handle;
   template<int n>
   ks::Tensor<n, T>& get_t() { return get_t<n+1>()[0]; }
   template<>
   ks::Tensor<7, T>& get_t() { return t; }   
};