Potential slowdown of vmapping on the outside in `influence_fn`
agrawalraj opened this issue · 1 comments
agrawalraj commented
In ops.py
,
def influence_fn(...):
...
@functools.wraps(target)
def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
param_eif = linearized(point, *args, **kwargs)
return torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,)
)[1]
return _fn
Here if we vmap over _fn
, then func_target
keeps having to be evaluated. This might be a source of slowdown. One way to fix this would be to instead assume linearized
batches over multiple datapoints and _fn
takes as input multiple points instead of a single datapoint.
agrawalraj commented
Closed by #464