Function get_dataset and subsample_dataset
pouqual opened this issue · 1 comments
pouqual commented
Hello, I found that the get_dataset and subsample_dataset functions are missing in the sample_language_nn.py file when I executed your code. Could you please provide them?
mbanani commented
Hi @pouqual
Sorry about that. We refactored the code quite a bit before release to make it easier to read and use, and we must have forgotten this.
Below are the two functions from an older version of the code. I am currently busy with another deadline, so I haven't had the chance to test them to make sure that they aren't affected by the refactoring. I can look at them again in a week or so if you're unable to use them, so let me know if they work for you.
import json
import time
from torch.nn.functional import normalize
from tqdm import tqdm
from .sample_language_nn import embed_captions
from lgssl.utils.faiss import faiss_knn
def get_dataset(json_name):
dict_path = Path(__file__).parent / "data_dicts" / json_name
data_dict = json.load(dict_path.open())
tar_keys = list(data_dict.keys())
image_ids = [list(data_dict[key].keys()) for key in data_dict]
dataset = []
for i, tar_key in tqdm(enumerate(tar_keys)):
for img_id in image_ids[i]:
dataset.append([tar_key, img_id, data_dict[tar_key][img_id][1]])
return dataset
def subsample_dataset(encoder, encoder, k, verbose=True):
# generate captions
captions = [datum[2] for datum in dataset]
if verbose:
print("---- generate embeddings ----")
num_gpus = torch.cuda.device_count()
curr_time = time.time()
embeddings = embed_captions(captions, encoder, 2048, num_gpus)
if verbose:
print(f"Gathered embeddings in {time.time() - curr_time:.2f}sec")
print("---- do nearest neighbors ----")
# filter things with 0 norm
embed_valid = embeddings.norm(p=2, dim=1) > 0
dataset = [dataset[i] for i in range(len(dataset)) if embed_valid[i]]
embeddings = embeddings[embed_valid]
n_orig = embed_valid.shape[0]
n_filt = embed_valid.float().sum()
print(f"{n_orig - n_filt}/{n_orig} captions could not be embedded.")
# sample
curr_time = time.time()
embeddings = normalize(embeddings, p=2, dim=1)
distnn, idnn = faiss_knn(
embeddings, embeddings, k=k, num_gpus=num_gpus, exclude_self=True, pbar=verbose
)
sim_nn = 1 - 0.5 * distnn
rat_nn = sim_nn / sim_nn[:, 0:1]
data = [datum[:2] for datum in dataset]
data_nn = [[data[_ij] for _ij in idnn_k] for idnn_k in idnn]
data_all = [
(data[i], data_nn[i][_k], sim_nn[i, _k].item(), rat_nn[i, _k].item())
for i in range(len(dataset))
for _k in range(k)
]
if verbose:
print(f"Computed Nearest Neighbors in {time.time() - curr_time:.2f}sec")
return data_all