Implementing a custom Sentence Transformer neural network for multi-task learning using PyTorch and Python.
Implement Encoder-only architecture for embedding sentences, then two independent classifier heads on top of the embeddings for multi-task predictions.
Goal 1:
Encode input sentences into fixed length embeddings.
Goal 2:
Expand architecture to support multi-task learning.
Task A:
Text Classification- 6 Classes: 'sports', 'health', 'tech', 'finance', 'education', 'other'
Task B:
Sentiment Analysis- 3 Classes: 'negative', 'neutral', 'positive'
- Use Conda to create virtual environment and install requirements.
- Use Bash to run conda_env_setup.sh
- SentenceTransformer ('all-MiniLM-L6-v2') backbone to compute contextualized sentence embeddings. It has a good balance of performance and efficiency.
- Added two instances of a simple multi-layer perceptron for each classification head
- Each classifier MLP designed for simplicity, while being able to model complex relationships between embeddings and classes.
- fully connected layer from embeddings into smaller hidden representation (half the embedding size 384 -> 192)
- GELU activation function
- Chose nonlinear GELU function to match what was used in MiniLM backbone
- Final fully connected layer to the classes (192 -> 6) or (192 -> 3)
Embedding a few sentences, and their corresponding cosine similarities (relevance):
> python embedding_example.py
...
Embeddings size: torch.Size([3, 384])
cosine_sim sentence_1 sentence_2
1 0.376084 The dog ran across the grass. The cat jumped into the weeds.
0 0.069512 The dog ran across the grass. Why is the sky blue?
2 0.021868 Why is the sky blue? The cat jumped into the weeds.
Example text classification and sentiment predictions for a few sentences.
Note:
With randomized model weights on init, each run produces different predictions.
> python prediction_example.py
Multi-task predictions using randomized classifier weights...
sentence text_class sentiment
0 The dog ran across the grass. tech positive
1 Why is the sky blue? finance neutral
2 The cat jumped into the weeds. health negative
Assume task B has limited data but is still similar to task A, which has ample data.
- This scenario can be addressed with sequential task-specific fine tuning.
- Initially fine tuning the model for task A, then freezing those weights and fine tuning for task B.
- It is possible to explore text data augmentation (synthetic data generation) to supplement the training data for task B.