Demo of ColBERT full retrieval in Astra DB and RAGStack
Inspired from @jbellis colbert-astra project
Uses AstraDB's Data API and langchain
style for data loading and semantic search.
Initial Setup
-
Download pretrained BERT checkpoint and untar it to this folder.
-
Add any PDFs you want to load to
files
folder -
pip install -r requirements.txt
-
python3 run load.py
-
.env
OPENAI_API_KEY=
ASTRA_DB_API_ENDPOINT=
ASTRA_DB_APPLICATION_TOKEN=
Query Interface
streamlit run app.py
For ColBERT style retrieval, data is stored in 2 Astra DB collections
interactions
- for storing the content of individual chunks
{
"content": "text data",
"part": "chunk_id",
"metadata": {
"source": "filename",
"page": 1
}
}
interactions_bert
- for storing contextualized BERT embeddings for individual tokens in a chunk
{
"_id": "..",
"part": "15",
"token": "15",
"$vector": [...],
"metadata": {
"source": "filename",
"page": 1
}
}
Query time:
- for a input query
q
, get the query vectorsqv
- for every query vector v in qv, find ANN from
interactions_bert
- compute
score
- sum of maximum similarity for allv
inqv
with respect to retrieved$vector
for each part
def maxsim(qv, document_embeddings):
return max(qv @ dv for dv in document_embeddings)
def score(query_embeddings, document_embeddings):
return sum(maxsim(qv, document_embeddings) for qv in query_embeddings)
- sort the
score
, now retrieve thecontent
- actual chunk data frominteractions
collection base onpart
ordered byscore
- most relevant chunks for the given query.
Langchain style makes it easier for developers to get started without worrying about the internal implementation details of ColBERT
from colbert_vectorstore import Astra_ColBERT_VectorStore
import os
colbert_vstore = Astra_ColBERT_VectorStore(
collection_name="interactions",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
)
texts = ["hello world", "cat is sitting on a wall", "dog is running"]
colbert_vstore.add_texts(texts)
results = colbert_vstore.similarity_search('is there a cat?')
Apache License 2.0