
Official code for 'Golden Retriever' project of AIKU

Primary LanguagePython

🦮 Golden Retriever

This is the repository of 'Golden Retriever' for AIKU team Project! Golden Retriever is the service that search most similar images given user text query. I think all of you've suffered searching images on your laptop or phone. Golden Retriever can pick up the images fast and correctly like this, Woof!


First, place your own images at folder. And then, (1) generate the caption of images to train the model that can align with images well. In this stage, you'll have image-caption pairs and (2) finetune the text-to-image retrieval model with pairs. Finally, (3) if you put a description of the images you want to search into the trained model, you will get the images you want!

Basic Setting

# create conda env
conda create -n gr python=3.8

# activate conda env
conda activate gr

Our framework is based on BLIP model. BLIP is the Pretrained Visual-Language Model. Please refer official repository or paper for more details.

git clone https://github.com/salesforce/BLIP.git

And place your own images in the images directory.

mkdir images

BLIP model basically supports the text-to-image and/or image-to-text retrieval only on COCO, flickr30k. Please add some files for retrieving on custom datasets.

# 1. Place the 'gr_config.yaml' to BLIP/configs
mv gr_config.yaml BLIP/configs/gr_config.yaml

# 2. Place the 'gr_dataset.py' to BLIP/data
mv gr_dataset.py BLIP/data/gr_dataset.py

# 3. Place the 'gradio_demo.py' to BLIP/
mv gradio_demo.py BLIP/gradio_demo.py

# 4. Place the 'translation.py' to BLIP/
mv translation.py BLIP/translation.py

Modify the BLIP/data/__init__.py file.

from data.gr_dataset import gr_train, gr_retrieval_eval

def create_dataset(dataset, config, min_scale=0.5):
    elif dataset=='retrieval_gr':          
        train_dataset = gr_train(transform_train, config['image_root'], config['ann_root'])
        val_dataset = gr_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 
        test_dataset = gr_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')          
        return train_dataset, val_dataset, test_dataset   

Modify the BLIP/data/gr_dataset.py file.

class gr_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):        
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        with open(..., 'r') as f: # caption file path

class gr_retrieval_eval(Dataset):
    def __init__(self, transform, image_root, ann_root, split, max_words=30):  
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        with open(..., 'r') as f: # caption file path


We also propose demo code. Please download our model checkpoint, pre-extracted image embeddings and images. And replace the path in gradio_demo.py like below code.


def get_image(text):    
    image_embeds = np.load(...) # pre-extracted embeddings path
    image_embeds = torch.from_numpy(image_embeds).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    distributed = True
    image_path = ... # image path
    config = yaml.load(open(..., 'r'), Loader=yaml.Loader) # gr_config.yaml path

All the preparations are done! Now, let's find the images you want using Golden Retriever!

python gradio_demo.py


If you want to train or inference the golden-retriever with your own images, please follow the next steps.

1. Captioning

CUDA_VISIBLE_DEVICES=0 python caption.py \
--sample True \
--image_path {your image directory} \
--output_path {output directory}

sample means whether you use the nucleus sampling or beam search when captioning. image_path means directory containing images that you want to caption. output_path means output directory.

2. Training the retriever

Train the 🦮 retriever!


python -m torch.distributed.run --nproc_per_node=2 train_retrieval.py \
--config ./configs/retrieval_gr.yaml \
--output_dir output/gr_retrieval

3. Test your 🦮 retriever!

Code will be updated soon!


@sylee0520 @ONground-Korea @subin9 @JeonSeongHu @ KorBrodStat


If you have any questions, please contact me! sy-lee@korea.ac.kr


본 프로젝트는 Naver D2 제공 서버로 운영되고 있습니다.