mbanani/lgssl

Function get_dataset and subsample_dataset

pouqual opened this issue · 1 comments

Hello, I found that the get_dataset and subsample_dataset functions are missing in the sample_language_nn.py file when I executed your code. Could you please provide them?

Hi @pouqual

Sorry about that. We refactored the code quite a bit before release to make it easier to read and use, and we must have forgotten this.

Below are the two functions from an older version of the code. I am currently busy with another deadline, so I haven't had the chance to test them to make sure that they aren't affected by the refactoring. I can look at them again in a week or so if you're unable to use them, so let me know if they work for you.

import json
import time

from torch.nn.functional import normalize
from tqdm import tqdm

from .sample_language_nn import embed_captions
from lgssl.utils.faiss import faiss_knn

def get_dataset(json_name):
    dict_path = Path(__file__).parent / "data_dicts" / json_name
    data_dict = json.load(dict_path.open())

    tar_keys = list(data_dict.keys())
    image_ids = [list(data_dict[key].keys()) for key in data_dict]
    dataset = []
    for i, tar_key in tqdm(enumerate(tar_keys)):
        for img_id in image_ids[i]:
            dataset.append([tar_key, img_id, data_dict[tar_key][img_id][1]])

    return dataset


def subsample_dataset(encoder, encoder, k, verbose=True):
    # generate captions
    captions = [datum[2] for datum in dataset]

    if verbose:
        print("---- generate embeddings ----")
    num_gpus = torch.cuda.device_count()
    curr_time = time.time()
    embeddings = embed_captions(captions, encoder, 2048, num_gpus)
    if verbose:
        print(f"Gathered embeddings in {time.time() - curr_time:.2f}sec")
        print("---- do nearest neighbors ----")

    # filter things with 0 norm
    embed_valid = embeddings.norm(p=2, dim=1) > 0
    dataset = [dataset[i] for i in range(len(dataset)) if embed_valid[i]]
    embeddings = embeddings[embed_valid]
    n_orig = embed_valid.shape[0]
    n_filt = embed_valid.float().sum()
    print(f"{n_orig - n_filt}/{n_orig} captions could not be embedded.")

    # sample
    curr_time = time.time()
    embeddings = normalize(embeddings, p=2, dim=1)
    distnn, idnn = faiss_knn(
        embeddings, embeddings, k=k, num_gpus=num_gpus, exclude_self=True, pbar=verbose
    )
    sim_nn = 1 - 0.5 * distnn
    rat_nn = sim_nn / sim_nn[:, 0:1]

    data = [datum[:2] for datum in dataset]
    data_nn = [[data[_ij] for _ij in idnn_k] for idnn_k in idnn]
    data_all = [
        (data[i], data_nn[i][_k], sim_nn[i, _k].item(), rat_nn[i, _k].item())
        for i in range(len(dataset))
        for _k in range(k)
    ]

    if verbose:
        print(f"Computed Nearest Neighbors in {time.time() - curr_time:.2f}sec")

    return data_all