TS2K polymorphic over tensor rank
awf opened this issue · 2 comments
awf commented
User writes
vrelu3 = knossos.vmap(relu3)
Which calls this stub
class KnossosVMap: #May need to be already a child of Torch.Autograd.Function
def __init__(self, f):
self.f = f
def knossos_vmap(function, generate_lm=True):
return KnossosVMap(function)
and later, when we compile, this happens:
def ts2mod(function, example_inputs, generate_lm=True):
if isinstance(function, KnossosVMap):
function = function.f
# special stuff... knossos_vmap_do_it
else:
fn = torch.jit.script(function)
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
return ksc_defs_to_module([ksc_def], ksc_def, derivatives_to_generate)
def knossos_vmap_do_it(function, args, generate_lm=True):
"""
Given a python function foo
"""
# Extract torchscript graph
fn = torch.jit.script(function)
# convert to KSC Expr
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
# ksc_def = (def relu3 None ((_x$o1 : ARG_TYPE)) ... )
# Now insert definitions for all the maps
if True:
assert len(args) == 1
arg_type = ksc_def.args[0].Type
vardecl, var = make_new_var(arg_type)
map_lambda = Lam(vardecl, Call(ksc_def.name, [var]))
# Now build
"""
(def vrelu3_Tensor_1_Float (vs : Tensor 1 Float)
(map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_2_Float (vs : Tensor 2 Float)
(map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_3_Float (vs : Tensor 3 Float)
(map (lam (v : Float) (relu3 v)) vs))
(def vrelu3_Tensor_4_Float (vs : Tensor 4 Float)
(map (lam (v : Float) (relu3 v)) vs))
"""
# return
# Call(
# "map", [map_lambda] + [var_or_constant(i) for i in tail(node.inputs())]
# ),
# )
def wrapper(t: Tensor) -> Tensor:
rank = len(size(t))
if rank == 1:
return vrelu3_Tensor_1_Float(t)
if rank == 2:
return vrelu3_Tensor_2_Float(t)
if rank == 3:
return vrelu3_Tensor_3_Float(t)
if rank == 4:
return vrelu3_Tensor_4_Float(t)
return KnossosVMapCompiler(rank)
derivatives_to_generate = ["fwd", "rev"] if generate_lm else ["sufrev"]
return ksc_defs_to_module([ksc_def], ksc_def, derivatives_to_generate)
awf commented
And we may want to incorporate ideas from #849, e.g.
template <class T>
struct ks_tensor<T> {
ks::Tensor<1, T> t1;
ks::Tensor<2, T> t2;
ks::Tensor<3, T> t3;
ks::Tensor<4, T> t4;
Py_Object py_handle;
};
or messy and tricksy
template <class T>
struct ks_tensor<T> {
ks::Tensor<7, T> t;
Py_Object py_handle;
template<int n>
ks::Tensor<n, T>& get_t() { return get_t<n+1>()[0]; }
template<>
ks::Tensor<7, T>& get_t() { return t; }
};