Async Knossos compiler
awf opened this issue · 0 comments
awf commented
[Writing in progress]
class KnossosVMapCompiler:
compiled = False
compilation_future = None
def __init__(f):
torch_fallback = torch.vmap(f)
generic_example_arg = rand_of_standard_shape(f.arg.Type)
compilation_future = Future(knossos_vmap_do_it)
def __call__(arg):
must_redo = arg_characteristics_requiring_recompilation(arg)
if must_redo != compiled:
compilation_future = compile...
if compilation_future has returned:
compiled = must_redo
py_mod = compilation_future.result
if py_mod:
return py_mod(arg)
else:
return torch_fallback(arg)
def knossos_vmap(function, generate_lm=True):
return KnossosVMapCompiler(function)