awslabs/slapo

[Feature] Type Inference of primitives and module selection API change

zarzen opened this issue · 7 comments

Currently, we have the following schedule operations logic for gpt2.

# code snippet from slapo/model_schedule/gpt2.py
...
    attn_op = []
    for idx in range(model_config.num_hidden_layers):
        sub_sch = sch[attn_path.replace("N", str(idx))]
        with init_empty_weights(enable=delay_init):
            new_mod = Attention(**init_config)
            attn_op.append(new_mod.module.attn_op_name)
        sub_sch.replace(new_mod)
        cnt += 1
...

Issues of this code snippet from my view:

  1. the primitive function replace cannot provide type inference features: difficult to know the options for the primitive and not possible to get the doc string for the primitive.
  2. selection for sub graph is not intuitive due to the concept of the sub schedule. Treating the schedule as a dictionary/hash table is not that intuitive to me. For a single model, it is natural to me that we have a single schedule for this model. The schedule can only affect part of the model, and consider them as a list of tuples, e.g., (module_part_id, "replace_with", new_module_obj). This can also facility debugging the schedule, e.g., removing entries of in the schedule to disable schedules.

Recommend to changes the APIs to the following

    ...
    for idx in range(model_config.num_hidden_layers):
        sub_module = slapo.select(model, "transformer.h."+str(idx))
        with init_empty_weights(enable=delay_init):
            new_mod = Attention(**init_config)
        cur_schedule = slapo.replace(cur_schedule, sub_module, new_mod)
        cnt += 1
    ...

And the select method can be further improved to consider fuzzy match

    ...
    sub_modules = slapo.select("transformer.h.*")
    with init_empty_weights(enable=delay_init):
        new_mod = Attention(**init_config)
    cur_schedule = slapo.replace(cur_schedule, sub_modules, new_mod)
    cnt = len(sub_modules)
    ...

My two cents.

For the first point, it seems not related to the API. You still need to register each primitive anyways. The current registration may not be intuitive to new developers, but I believe a good tutorial could largely solve this issue. After all developers don't need to know how the registration works, but only need to where to find the primitive implementation and their doc strings.

For the second point, I'm not sure I got it. To me, sch["transformer.h."+str(idx)] is more intuitive:

  1. This is similar as the schedule language of Halide and TVM.
  2. This is similar as the way to access submodules in a PyTorch model.

On the other hand, slapo.select("transformer.h"+str(idx)) confuses me because I cannot tell which model/schedule I'm working on by looking at this statement. I could think of two ways to maintain the current working schedule:

  1. Global variable. This is unsafe and could be a mess.
  2. Context manager. So you may need to use something like with slapo.schedule(model): to wrap an entire schedule logic. This might be working, but it seems not worth to spend engineering efforts on this refactoring.

For fuzzy matching, there's nothing preventing us from implementing sub_schs = sch["transformer.h.*"].

Sorry not intend to have a global variable. But to have all primitive in a functional mode.
And for the context manager, I didn't intend to change it.

Basically, what I was recommending is to make the operation as explicit as possible. By the way, I think what slapo is try to do is more like a CSS to HTML, rather than a scheduling langauge as tvm for optimizing the performance using tiling loops etc.

I'm still trying to learn how to use Slapo. For me, the primitives is a bit counter-intuitive and hard to maintain and debug (e.g., I can't use the step-in/out in the VScode to check/modify this API). I would expect a more state-less interface for applying the schedule primitives just like @zarzen has suggested.

new_sch = slapo.replace(old_sch, pattern, new_fn)

I think the current API is fine for users who are familiar with the coding style of Halide/TVM. We do not need to change the interface. VSCode is able to capture the data type of __getitem__, so writing sch[op].replace(...) can still prompt the correct hints for programmers. In order to achieve that, what we need to do is to change the implementation of those primitives. I would suggest explicitly expose the primitives to the programmers, but use separate files to implement different primitives, just like what PyTorch does -- having a top-level interface, and then dispatching the nn.module op to nn.functional implemenation.

This way couples the concept of schedule and module/graph. I don't think it's a good way to manage the schedules.
This kind of hierarchical organization of schedule also increases the difficulty of debugging in some sense.
And limiting the user-based from Halide/TVM community is not good from my view. We should expect more MLEs can use Slapo to adopt different optimized kernels, distributed strategies.

And limiting the user-based from Halide/TVM community is not good from my view. We should expect more MLEs can use Slapo to adopt different optimized kernels, distributed strategies.

I think the point is we need fine-grained control for different optimizations. That is why we need to design the schedule in a hierarchical way and apply the schedule to a specific submodule. For MLEs, I don't expect them to use these low-level primitives. The best way for them is to directly generate an optimized model, so we need to add more automation to it. Then they only need to call some APIs like .autoshard() or .fuse_all() to accomplish complicated optimizations. Our decoupled primitives are actually specifying the design space, which provides a good interface for compilers to further optimize.