MedCPT is a first-of-its-kind Contrastive Pre-trained Transformer model trained with an unprecedented scale of PubMed search logs for zero-shot biomedical information retrieval. MedCPT consists of:
- A frist-stage dense retriever (MedCPT retriever)
- Contains a query encoder (QEnc) and an article encoder (DEnc), both initialized by PubMedBERT.
- Trained by 255M query-article pairs from PubMed search logs and in-batch negatives.
- A second-stage re-ranker (MedCPT re-ranker)
- A transformer cross-encoder (CrossEnc) initialized by PubMedBERT.
- Trained by 18M semantic query-article pairs and localized negatives from the pre-trained MedCPT retriever.
This directory contains:
- Code for training the MedCPT retriever.
- Code for training the MedCPT re-ranker.
- Code for evaluating the pre-trained model.
MedCPT model weights are publicly available on Hugging Face:
import torch
from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
queries = [
"diabetes treatment",
"How to treat diabetes?",
"A 45-year-old man presents with increased thirst and frequent urination over the past 3 months.",
]
with torch.no_grad():
# tokenize the queries
encoded = tokenizer(
queries,
truncation=True,
padding=True,
return_tensors='pt',
max_length=64,
)
# encode the queries (use the [CLS] last hidden states as the representations)
embeds = model(**encoded).last_hidden_state[:, 0, :]
print(embeds)
print(embeds.size())
The output will be:
tensor([[ 0.0413, 0.0084, -0.0491, ..., -0.4963, -0.3830, -0.3593],
[ 0.0801, 0.1193, -0.0905, ..., -0.5380, -0.5059, -0.2944],
[-0.3412, 0.1521, -0.0946, ..., 0.0952, 0.1660, -0.0902]])
torch.Size([3, 768])
These embeddings are also in the same space as those generated by the MedCPT article encoder.
import torch
from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
# each article contains a list of two texts (usually a title and an abstract)
articles = [
[
"Diagnosis and Management of Central Diabetes Insipidus in Adults",
"Central diabetes insipidus (CDI) is a clinical syndrome which results from loss or impaired function of vasopressinergic neurons in the hypothalamus/posterior pituitary, resulting in impaired synthesis and/or secretion of arginine vasopressin (AVP). [...]",
],
[
"Adipsic diabetes insipidus",
"Adipsic diabetes insipidus (ADI) is a rare but devastating disorder of water balance with significant associated morbidity and mortality. Most patients develop the disease as a result of hypothalamic destruction from a variety of underlying etiologies. [...]",
],
[
"Nephrogenic diabetes insipidus: a comprehensive overview",
"Nephrogenic diabetes insipidus (NDI) is characterized by the inability to concentrate urine that results in polyuria and polydipsia, despite having normal or elevated plasma concentrations of arginine vasopressin (AVP). [...]",
],
]
with torch.no_grad():
# tokenize the queries
encoded = tokenizer(
articles,
truncation=True,
padding=True,
return_tensors='pt',
max_length=512,
)
# encode the queries (use the [CLS] last hidden states as the representations)
embeds = model(**encoded).last_hidden_state[:, 0, :]
print(embeds)
print(embeds.size())
The output will be:
tensor([[-0.0189, 0.0115, 0.0988, ..., -0.0655, 0.3155, -0.0357],
[-0.3402, -0.3064, -0.0749, ..., -0.0799, 0.3332, 0.1263],
[-0.2764, -0.0506, -0.0608, ..., 0.0389, 0.2532, 0.1580]])
torch.Size([3, 768])
These embeddings are also in the same space as those generated by the MedCPT query encoder.
Due to privacy concerns, we are not able to release the PubMed user logs. As a surrogate, we provide the question-article pair data from BioASQ in this repo as example training datasets. You can convert your data to the example data formats and train the MedCPT model.
This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of Medicine.
This tool shows the results of research conducted in the Computational Biology Branch, NCBI/NLM. The information produced on this website is not intended for direct diagnostic use or medical decision-making without review and oversight by a clinical professional. Individuals should not change their health behavior solely on the basis of information produced on this website. NIH does not independently verify the validity or utility of the information produced by this tool. If you have questions about the information produced on this website, please see a health care professional. More information about NCBI's disclaimer policy is available.
If you find this repo helpful, please cite MedCPT by:
@misc{jin2023MedCPT,
title={MedCPT: Contrastive Pre-trained Transformers with Large-scale PubMed Search Logs for Zero-shot Biomedical Information Retrieval},
author={Qiao Jin and Won Kim and Qingyu Chen and Donald C. Comeau and Lana Yeganova and John Wilbur and Zhiyong Lu},
year={2023},
eprint={2307.00589},
archivePrefix={arXiv},
primaryClass={cs.IR}
}