/CookBERT

Domain Specific BERT Model for the Cooking Domain

Primary LanguageJupyter Notebook

CookBERT: A Domain Adapted BERT Model for the Cooking Domain

See the official CookBERT paper.

What is CookBERT?

CookBERT is a domain-specific BERT model that was created by domain adaptive pretraining on the instructions of the RecipeNLG corpus and by enhancing BERT's default vocabulary by a total of 1229 cooking specific words. As a result, CookBERT is geared more towards the cooking domain than the default model:

Input Model Top 5 predictions for [MASK] token
"Do I have to [MASK] the apple?" BERTbase eat, take, have, touch, get
CookBERT peel, slice, use, dice, chop
“[MASK] the water.” BERTbase in, drink, under, into, on
CookBERT boil, heat, add, scald, chill
“Cut the [MASK] into small pieces.” BERTbase wood, paper, leaves, meat, bark
CookBERT chicken, cheese, fruit, cabbage, sausage

The domain-specifity of CookBERT has proven to be superior in text classification and named entity recognition when dealing with data related to the cooking domain.

Training specs

To obtain CookBERT, BERTbase (uncased version) was used as the starting point which was then further pretrained for three additional epochs on the MLM task on the RecipeNLG instructions, with 5% serving as validation data. Training was performed with a learning rate of 2e-5, an effective batch size of 32, and a maximum sequence length of 256. The training took appoximately five complete days on a single NVIDIA Tesla P100 GPU provided by Google Colab Pro.

Performance

CookBERT was finetuned and evaluated on three different tasks, including information need classification, food entity tagging and question answering. In addition, BERTbase (uncased version) and FoodBERT were applied for the same tasks in order to be able to compare and rank CookBERT's performance.

Text classification

Results of the classification of user information needs that arise during cooking; Based on the Cookversational dataset.

Model Condition Precision Recall F-Measure 95%-CI
BERTbase no context 47.94% 48.68% 46.15% [41.15%;51.16%]
1 prev turn 46.29% 49.84% 45.38% [40.06%;50.70%]
CookBERT no context 48.58% 55.65% 50.72% [45.54%;55.90%]
1 prev turn 52.26% 59.30% 54.05% [48.93%;59.16%]
FoodBERT no context 42.41% 49.81% 44.32% [38.92%;49.73%]
1 prev turn 36.89% 44.49% 38.09% [32.64%;43.55%]
Best performances printed in bold

Named entity recognition

Results of the food entity tagging task using the curated version of the FoodBase corpus, as well as the labels provided by Stojanov et al. (2021) for five different tagging schemes.

Model Tagging-Task Precision Recall F-Measure 95%-CI
BERTbase Food-classification 90.68% 96.06% 93.29% [92,87%;93.71%]
FoodOn 65.24% 73.10% 68.94% [67.04%;70.83%]
Hansard-parent 80.35% 88.68% 84.31% [83.54%;85.08%]
Hansard-closest 70.79% 79.98% 75.10% [73.87%;76.34%]
SNOMED CT 63.04% 70.65% 66.62% [64.49%;68.75%]
CookBERT Food-classification 92.25% 96.52% 94.47% [94.17%;94.76%]
FoodOn 69.75% 77.51% 73.42% [71.91%;74.93%]
Hansard-parent 82.72% 89.18% 85.83% [84.69%;86.97%]
Hansard-closest 72.21% 80.41% 76.08% [74.60%;77.56%]
SNOMED CT 68.58% 75.51% 71.87% [69.99%;73.75%]
FoodBERT Food-classification 85.28% 94.24% 89.53% [88.90%;90.17%]
FoodOn 58.73% 61.03% 59.85% [56.56%;63.13%]
Hansard-parent 68.41% 80.62% 74.01% [72.13%;75.90%]
Hansard-closest 59.55% 67.52% 63.28% [60.43%;66.13%]
SNOMED CT 53.63% 51.84% 52.67% [49.17%;56.17%]
Best performances printed in bold

Question answering

Results of the question answering task in the sense of answer span extraction; Based on the cooking subset of the DoQA dataset.

Model Exact match F-measure 95%-CI
BERTbase 14.06% 32.39% [31.25%;33.54%]
CookBERT 12.51% 30.64% [29.50%;31.78%]
FoodBERT 10.81% 27.51% [26.51%;28.50%]

Best performances printed in bold

Using CookBERT

The CookBERT pytorch model checkpoint can be downloaded from this Google Drive folder. Huggingface Transformer Library enables the model to be set up easily:

from transformers import (
    BertTokenizerFast,
    BertForMaskedLM,
    pipeline
)

CookBERT_tokenizer = BertTokenizerFast.from_pretrained("CookBERT-checkpoint")
CookBERT = BertForMaskedLM.from_pretrained("CookBERT-checkpoint")
CookBERT_pipeline = pipeline("fill-mask", model=CookBERT, tokenizer=CookBERT_tokenizer)

masked_text = "Cut the [MASK] into small pieces."
print("Predictions: ", CookBERT_pipeline(masked_text, top_k=5))

Note that the Google Drive folder only contains the CookBERT checkpoint that was trained on the MLM task. In order to apply CookBERT for different tasks (NER, QA, ...), the finetuning scripts from Huggingface can be used.