/vsearch

An Extensible Framework for Retrieval-Augmented LLM Applications: Learning Relevance Beyond Simple Similarity.

Primary LanguagePythonMIT LicenseMIT

Vsearch

License Python 3.9

Vsearch: Representing data on LM Vocabulary space for Search.

This repository includes:

πŸ—Ί Overview

  1. Preparation

    • Setup Environment
    • Download Data
  2. Quick Start

    • Embedding and Compute Relevance
    • Building an Index for Large-scale Retrieval
    • Building a Bag-of-Token Index for Faster Retrieval Setup
    • Inspecting Retrieval insights from Representation
    • Semi-parametric Search
    • Cross-modal Retrieval
  3. Training

  4. Inference

    • Build index
    • Search
    • Scoring

πŸ’» Preparation

Setup Environment via Poetry

# install poetry first
# curl -sSL https://install.python-poetry.org | python3 -
poetry install
poetry shell

Download Data

Download data using identifiers in the YAML configuration files at conf/data_stores/*.yaml.

# Download a single dataset file
python download.py nq_train
# Download multiple dataset files:
python download.py nq_train trivia_train
# Download all dataset files:
python download.py all

πŸš€ Quick Start

Embedding and Compute Relevance

import torch
from src.ir import Retriever

# Define a query and a list of passages
query = "Who first proposed the theory of relativity?"
passages = [
    "Albert Einstein (14 March 1879 – 18 April 1955) was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time. He is best known for developing the theory of relativity.",
    "Sir Isaac Newton FRS (25 December 1642 – 20 March 1727) was an English polymath active as a mathematician, physicist, astronomer, alchemist, theologian, and author who was described in his time as a natural philosopher.",
    "Nikola Tesla (10 July 1856 – 7 January 1943) was a Serbian-American inventor, electrical engineer, mechanical engineer, and futurist. He is known for his contributions to the design of the modern alternating current (AC) electricity supply system."
]

# Initialize the retriever
ir = Retriever.from_pretrained("vsearch/svdr-msmarco")
ir = ir.to("cuda")

# Embed the query and passages
q_emb = ir.encoder_q.embed(query)  # Shape: [1, V]
p_emb = ir.encoder_p.embed(passages)  # Shape: [4, V]

# Query-passage Relevance
scores = q_emb @ p_emb.t()
print(scores)

# Output: 
# tensor([[97.2964, 39.7844, 37.6955]], device='cuda:0')

Building an Index for Large-scale Retrieval

For large-scale retrieval tasks, it's efficient to build the index once and reuse it for subsequent retrieval tasks.

# Build the sparse index for the passages
ir.build_index(passages, index_type="sparse")
print(ir.index)

# Output:
# Index Type      : SparseIndex
# Vector Shape    : torch.Size([3, 29523])
# Vector Dtype    : torch.float32
# Vector Layout   : torch.sparse_csr
# Number of Texts : 3
# Vector Device   : cuda:0

# Save the index to disk
index_file = "/path/to/index.npz"
ir.save_index(path)

# Load the index from disk
index_file = "/path/to/index.npz"
data_file = "/path/to/texts.jsonl"
ir.load_index(index_file=index_file, data_file=data_file)

You can retrieve results for queries directly from a pre-built index.

# Search top-k results for queries
queries = [query]
results = ir.retrieve(queries, k=3)
print(results)

# Output:
# SearchResults(
#   ids=tensor([[0, 1, 2]], device='cuda:0'),
#   scores=tensor([[97.2458, 39.7507, 37.6407]], device='cuda:0')
# )

query_id = 0
top1_psg_id = results.ids[query_id][0]
top1_psg = ir.index.get_sample(top1_psg_id)
print(top1_psg)
# Output:

# Albert Einstein (14 March 1879 – 18 April 1955) was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time. He is best known for developing the theory of relativity.

Building a Bag-of-Token Index for Faster Retrieval Setup

Our framework supports a non-parametric index built directly from the tokenizer, known as the Bag-of-Token (BoT) index. This approach significantly reduces indexing time and disk storage size by over 90%. You can build and use a BoT index as follows:

# Build the bag-of-token index for the passages
ir.build_index(passages, index_type="bag_of_token")
print(ir.index)

# Output:
# Index Type      : BoTIndex
# Vector Shape    : torch.Size([3, 29523])
# Vector Dtype    : torch.float16
# Vector Layout   : torch.sparse_csr
# Number of Texts : 3
# Vector Device   : cuda:0

# Search top-k results from bag-of-token index, and embed and rerank them on-the-fly
queries = [query]
results = ir.retrieve(queries, k=3, rerank=True)
print(results)

# Output:
# SearchResults(
#   ids=tensor([0, 2, 1], device='cuda:3'), 
#   scores=tensor([97.2964, 39.7844, 37.6955], device='cuda:0')
# )

Inspecting IR insights from Representation

# Inspect token-level importance/weights of the query embeddings
token_weights = ir.encoder_q.dst(query, topk=768, visual=False) 
print(token_weights)

# Output: 
# {
#     'relativity': 7.262620449066162, 
#     'proposed': 3.588329792022705, 
#     'first': 2.918099880218506, 
#     ...
# }

# Inspect token-level contributions to the relevance score (i.e., retrieval results)
token_contributions = ir.explain(q=query, p=passages[0], topk=768, visual=False)
print(token_contributions)

# Output: 
# {
#     'relativity': 54.66442432370013, 
#     'whom': 13.934619790257784, 
#     'theory': 4.645142051911478, 
#     ...
# }
Semi-parametric Retrieval

Alpha search

# non-parametric query -> parametric passage
q_bin = svdr.encoder_q.embed(query, bow=True)
p_emb = svdr.encoder_p.embed(passages)
scores = q_bin @ p_emb.t()

Beta search

# parametric query -> non-parametric passage (binary token index)
q_emb = svdr.encoder_q.embed(query)
p_bin = svdr.encoder_p.embed(passages, bow=True)
scores = q_emb @ p_bin.t()
Cross-modal Retrieval
# Note: we use `encoder_q` for text and `encoder_p` for image
vdr_cross_modal = Retriever.from_pretrained("vsearch/vdr-cross-modal") 

image_file = './examples/images/mars.png'
texts = [
  "Four thousand Martian days after setting its wheels in Gale Crater on Aug. 5, 2012, NASA’s Curiosity rover remains busy conducting exciting science. The rover recently drilled its 39th sample then dropped the pulverized rock into its belly for detailed analysis.",
  "ChatGPT is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language."
]
image_emb = vdr_cross_modal.encoder_p.embed(image_file) # Shape: [1, V]
text_emb = vdr_cross_modal.encoder_q.embed(texts)  # Shape: [2, V]

# Image-text Relevance
scores = image_emb @ text_emb.t()
print(scores)

# Output: 
# tensor([[0.3209, 0.0984]])

πŸ‘Ύ Training

We are testing on python==3.9 and torch==2.2.1. Configuration is handled through hydra==1.3.2.

EXPERIMENT_NAME=test
python -m torch.distributed.launch --nnodes=1 --nproc_per_node=4 train_vdr.py \
hydra.run.dir=./experiments/${EXPERIMENT_NAME}/train \
train=vdr_nq \
data_stores=wiki21m \
train_datasets=[nq_train]
  • --hydra.run.dir: Directory where training logs and outputs will be saved
  • --train: Identifier for the training config, in conf/train/*.yaml.
  • --data_stores: Identifier for the datastore, in conf/data_stores/*.yaml.
  • --train_datasets: List of identifiers for the training datasets to be used, in data_stores

During training, we display InfoCard to monitor the training progress.

Tip

What is InfoCard?

InfoCard is a organized log generated during the training that helps us visually track the progress.

An InfoCard looks like this:

InfoCard Layout

  1. Global Variables (V(q), V(p), etc.):

    • Shape: Displays the dimensions of the variable matrix.
    • Gate: Indicates the sparsity by showing the ratio of non-zero activations.
    • Mean, Max, Min: Statistical measures of the data distribution within the variable.
  2. EXAMPLE Section:

    • Contains one sample from the training batch, including query text (Q_TEXT), positive passages (P_TEXT1), negative passage (P_TEXT2), and the correct answer (ANSWER).
  3. Token Triple Sections (V(q), V(p), V(p_neg), V(q) * V(p)), which provided token-level impact:

    • Token (t): The specific vocabulary token.
    • Query Rank (qrank): Rank of the token in the query representation.
    • Passage Rank (prank): Rank of the token in the passage representation.

πŸ‰ Citation

If you find this repository useful, please consider giving ⭐ and citing our paper:

@inproceedings{zhou2023retrieval,
  title={Retrieval-based Disentangled Representation Learning with Natural Language Supervision},
  author={Zhou, Jiawei and Li, Xiaoguang and Shang, Lifeng and Jiang, Xin and Liu, Qun and Chen, Lei},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2023}
}
@article{zhou2024semi,
  title={Semi-Parametric Retrieval via Binary Token Index},
  author={Zhou, Jiawei and Dong, Li and Wei, Furu and Chen, Lei},
  journal={arXiv preprint arXiv:2405.01924},
  year={2024}
}