Classify entities into clusters via a zero-shot approach using embedding vectors, using a given list of category names.
- use an embedding to make vectors of entity names
- use the same embedding to make vectors of category names
- for each embedding, find the category that has a nearest vector
- then can classify the entities, for presentation in logical groups
Compare words (labels) by examining how close are their encoded vectors:
- the dot product of 2 normalised vectors = cosine Angle
- cosine distance = 1 - v.w
- smaller means closer
- Python 3.11
- pyenv - if on Windows use pyenv-win
Switch to Python 3.11.6:
pyenv install 3.11.6
pyenv local 3.11.6
Setup a virtual environment:
./create_env.sh
Install SBERT and cornsnake via this pip command:
pip install -U sentence-transformers==2.2.2 cornsnake==0.0.26
python main.py <path to category list file> <path to entity names file> [threshold (number between 0 and 1)]
To test:
./test.sh
OUTPUT:
CATEGORY: (unknown)
entity ['Aardvark', 'Alpaca', 'Anaconda']
CATEGORY: animal
entity ['Albatross', 'Alligator', 'Ant', 'Zebu']
CATEGORY: country
entity ['Albania', 'Andorra', 'Angola', 'Austria', 'Bangladesh', 'Belgium']
The results are not perfect, but not bad considering this is a simple 'out of the box' solution.
Hierarchy of labels:
- first, classify against a top-level list of labels
- then, for each label, classify against that labels list of sub-labels
Increase accuracy:
- take several embeddings per class and use their average for that class
- try different embeddings, can get better results
- try different distance measures from your library
- consider tuning the embedding (for example, for the domain vocabulary of a particular industry or problem space)
Conference notes from ML Con Berlin 2023
SBERT: How to Use Sentence Embeddings to Solve Real-World Problems