
[P0] enable `use_fast` option in the alignable to hyper boost training speed in case intervention locations (for position+subspace) are fixed in a batch

frankaging opened this issue · 2 comments

Currently, the library aims for flexibility in the inputs as well as a small training batch size in case the intervention is trainable. For instance, we assume each example in the batch can have different intervention locations as well as different intervention subspaces allowing more flexible configurations.

This is not desired when we have a large batch size, and intervention location does not change within a batch. Suppose we want to localize (a+b) with a simple NN that solves (a+b)*c, and we want to localize (a+b) with DAS and a fixed dimensionality of 16, the intervention location stays the same. However, current code will actually do the intervention in the example-level, not in the batch level. See,

for batch_i, locations in enumerate(unit_locations):
        batch_i, locations, start_index:end_index
    ] = replacing_tensor_input[batch_i]

this can be,

    :, locations, start_index:end_index
] = replacing_tensor_input[:]

subspace intervention,

    if subspaces is not None:
        for example_i in range(len(subspaces)):
            # render subspace as column indices
            sel_subspace_indices = []
            for subspace in subspaces[example_i]:
                        i for i in range(
            if mode == "interchange":
                base[example_i, ..., sel_subspace_indices] = \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "add":
                base[example_i, ..., sel_subspace_indices] += \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "subtract":
                base[example_i, ..., sel_subspace_indices] -= \
                    source[example_i, ..., sel_subspace_indices]

can be,

if subspaces is not None:
    if subspace_partition is None:
        sel_subspace_indices = subspaces[0]
        sel_subspace_indices = []
        for subspace in subspaces[0]:
                    i for i in range(
    if mode == "interchange":
        base[..., sel_subspace_indices] = \
            source[..., sel_subspace_indices]
    elif mode == "add":
        base[..., sel_subspace_indices] += \
            source[..., sel_subspace_indices]
    elif mode == "subtract":
        base[..., sel_subspace_indices] -= \
            source[..., sel_subspace_indices]
    base[..., :interchange_dim] = source[..., :interchange_dim]

We should enable a flag as use_fact in the alignable config, and do a validation check that fails fast during the forward call.

This PR tracks the use_fast effort for position-based intervention as well as subspace-based intervention. It does not cover head-based or head+position-based yet. Will cover the latter one in a separate PR.

Testing Done:

  • writing additional integration tests (4)
  • log:
In case multiple location tags are passed only the first one will be considered
testing stream: value_output with a single position
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
Ran 18 tests in 30.117s


Instead of a validation check, we will throw a warning to speed up stuff, since validation over inputs will take time which is against the motivation of being fast.