MadryLab/trak

Support the LDS metric

Closed this issue · 1 comments

The Section 2 of the trak paper proposes using LDS as the main metric to evaluate data attribution methods, which reduces the reliance on manual inspection, providing a more automated and objective assessment.
Would you like to support the LDS metric in this repository?

The only code needed to compute the LDS metric outside of scripts for training & recording margins is available in the quickstart notebook and in the tests:

trak/tests/utils.py

Lines 177 to 205 in 39bf22a

def eval_correlations(infls, tmp_path, ds="cifar10"):
if ds == "cifar10":
masks_url = "https://www.dropbox.com/s/x76uyen8ffkjfke/mask.npy?dl=1"
margins_url = "https://www.dropbox.com/s/q1dxoxw78ct7c27/val_margins.npy?dl=1"
else:
masks_url = "https://www.dropbox.com/s/2nmcjaftdavyg0m/mask.npy?dl=1"
margins_url = "https://www.dropbox.com/s/tc3r3c3kgna2h27/val_margins.npy?dl=1"
masks_path = Path(tmp_path).joinpath("mask.npy")
wget.download(masks_url, out=str(masks_path), bar=None)
# num masks, num train samples
masks = ch.as_tensor(np.load(masks_path, mmap_mode="r")).float()
margins_path = Path(tmp_path).joinpath("val_margins.npy")
wget.download(margins_url, out=str(margins_path), bar=None)
# num , num val samples
margins = ch.as_tensor(np.load(margins_path, mmap_mode="r"))
val_inds = np.arange(2000)
preds = masks @ infls
rs = []
ps = []
for ind, j in tqdm(enumerate(val_inds)):
r, p = spearmanr(preds[:, ind], margins[:, j])
rs.append(r)
ps.append(p)
rs, ps = np.array(rs), np.array(ps)
print(f"Correlation: {rs.mean()} (avg p value {ps.mean()})")
return rs.mean()

In case there's community interest to provide more explicit support for LDS (and possibly counterfactuals as well), I'm leaving this is as a discussion.