/DIETClassifier-pytorch

DIET Classifier mini implementation on pytorch.

Primary LanguagePython

DIETClassifier - Pytorch

Build Status

DIETClassifier stand for Dual Intent Entity from Transformers which can be used to do intent classification and entities recognition at the same time.

  • Using Huggingface Transformers's BERT architect
  • Wrapped by python, with various implemented functions (reads dataset from .yml, builds and trains model, gives dictionary ouput)

Requirements

  • [transformers] - Library for using transformers models in nlp task
  • [pytorch] - Framework for deep learning task in python
  • [fastapi] - Backend building framework

You can also install all requirement packages by:

git clone https://github.com/WeiNyn/DIETClassifier-pytorch.git
cd DIETClassifier-pytorch/
pip install -r requirements.txt

Demo

You can use demo server to create a server that receive text message and predict intent, entities:

  • Download pretrained model from this link
  • extract "latest_model" to "DIETClassifier-pytorch/"
  • run
uvicorn demo.server:app

Configuration

All project configurations stored in [config.yml] file

model:
    model: latest_model
    tokenizer: latest_model
    dataset_folder: dataset
    exclude_file: null
    entities:
        - working_type
        - shift_type
    intents:
        - WorkTimesBreaches
        - WorkingTimeBreachDiscipline
        - HolidaysOff
        - AnnualLeaveApplicationProcess
        - SetWorkingType
        - TemporarySetWorkingType
        - WorkingHours
        - WorkingDay
        - BreakTime
        - Pregnant
        - AttendanceRecord
        - SelectShiftType
        - LaborContract
        - Recruitment
        - SickLeave
        - UnpaidLeave
        - PaidLeaveForFamilyEvent
        - UnusedAnnualLeave
        - RegulatedAnnualLeave
        - rating
    device: cuda
training:
    train_range: 0.95
    num_train_epochs: 100
    per_device_train_batch_size: 4
    per_device_eval_batch_size: 4
    warmup_steps: 500
    weight_decay: 0.01
    logging_dir: logs/
    early_stopping_patience: 10
    early_stopping_threshold: 0.0001
    output_dir: results/
util:
    intent_threshold: 0.7
    entities_threshold: 0.5
    ambiguous_threshold: 0.2
Attribute Explain
model name of transformers pretrained model or path to local model
tokenizer name of transformers pretrained tokenizer or path to local tokenizer
dataset_folder folder that container dataset files, using rasa nlu format
exclude_file files in folder that will not be used to train
entities list of entities
intents list of intents
synonym synonym list for synonym entities
device device to use ("cpu", "cuda", "cuda:0", etc)
train_range range to split dataset into train and valid set
num_train_epochs number of training epochs
per_device_train/eval_batch_size batch size when train/eval
logging_dir directory to save log file (tensorboard supported)
early_stopping_patience/threshold hyper parameters for early stopping training
output_dir directory to save model while training

Usage

You can use DIETClassifierWrapper for loading, training, predicting in python code:

from src.models.wrapper import DIETClassifierWrapper

config_file = "src/config.yml"
wrapper = DIETClassifierWrapper(config=config_file)

#predict
wrapper.predict(["How to check attendance?"])

#train
#after training, wrapper will load best model automatically
wrapper.train_model(save_folder="test_model")

You can also use DIETClassifier in src.models.classifier as huggingface transformers model

from src.models.classifier import DIETClassifier, DIETClassifierConfig

config = DIETClassifierConfig(model="BERT-base-uncased", 
                              intents=[str(i) for i in range(10)], 
                              entities=[str(i) for i in range(5)])

model = DIETClassifier(config=config)

Notice

  • This DIETClassifier using BERT base as the base architect, if you want to change to RoBerta, ALBert, etc. You need to modify the DIETClassifier Class.
  • You can also use any BERT base pretrained from Huggingface transformers for creating and fine tune yourself
  • Please read the source code to understand how the dataset be created in case that you want to make dataset in another file format.
  • If you get the error: AttributeError: """'NoneType' object has no attribute 'detach'""", please check the issue #5