BasisResearch/chirho

Potential slowdown of vmapping on the outside in `influence_fn`

agrawalraj opened this issue · 1 comments

In ops.py,

def influence_fn(...):
    ...
    @functools.wraps(target)
    def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
        param_eif = linearized(point, *args, **kwargs)
        return torch.func.jvp(
            lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,)
        )[1]

    return _fn

Here if we vmap over _fn, then func_target keeps having to be evaluated. This might be a source of slowdown. One way to fix this would be to instead assume linearized batches over multiple datapoints and _fn takes as input multiple points instead of a single datapoint.

Closed by #464